/*******************************************************************************
*
*  McStas, neutron ray-tracing package
*  Copyright(C) 2000 Risoe National Laboratory.
*
* %I
* Written by: Kim Lefmann
* Date: 04.02.04
* Origin: Risoe
* Modified by: MB,    15.01.24   (removed extra K2V factor in weight, reduced scattered intensity significantly)
*
* A sample for phonon scattering based on cross section expressions from Squires, Ch.3.
* Possibility for adding an (unphysical) bandgap.
*
* %D
* Single-cylinder shape.
* Absorption included.
* No multiple scattering.
* No incoherent scattering emitted.
* No attenuation from coherent scattering. No Bragg scattering.
* fcc crystal n.n. interactions only
* One phonon branch only -> phonon polarization not accounted for.
* Bravais lattice only. (i.e. just one atom per unit cell)
*
* Algorithm:
* 0. Always perform the scattering if possible (otherwise ABSORB)
* 1. Choose direction within a focusing solid angle
* 2. Calculate the zeros of (E_i-E_f-hbar omega(kappa)) as a function of k_f
* 3. Choose one value of k_f (always at least one is possible!)
* 4. Perform the correct weight transformation
*
* %P
* INPUT PARAMETERS:
* radius: [m]         Outer radius of sample in (x,z) plane
* yheight: [m]        Height of sample in y direction
* sigma_abs: [barns]  Absorption cross section at 2200 m/s per atom
* sigma_inc: [barns]  Incoherent scattering cross section per atom
* a: [AA]             fcc Lattice constant
* b: [fm]             Scattering length
* M: [a.u.]           Atomic mass
* c: [meV/AA^(-1)]    Velocity of sound
* DW: [1]             Debye-Waller factor
* T: [K]              Temperature
* focus_r: [m]        Radius of sphere containing target.
* focus_xw: [m]       horiz. dimension of a rectangular area
* focus_yh: [m]       vert.  dimension of a rectangular area
* focus_aw: [deg]     horiz. angular dimension of a rectangular area
* focus_ah: [deg]     vert.  angular dimension of a rectangular area
* target_x: [m]       position of target to focus at . Transverse coordinate
* target_y: [m]       position of target to focus at. Vertical coordinate
* target_z: [m]       position of target to focus at. Straight ahead.
* target_index: [1]   relative index of component to focus at, e.g. next is +1 
* gap: [meV]          Bandgap energy (unphysical)
* e_steps_low: [1]    Amount of possible intersections beneath the elastic line
* e_steps_high: [1]   Amount of possible intersections above the elastic line
*
* CALCULATED PARAMETERS:
* V_rho: [AA^-3]      Atomic density
* V_my_s: [m^-1]      Attenuation factor due to incoherent scattering
* V_my_a_v: [m^-1]    Attenuation factor due to absorbtion
*
* %L
* The test/example instrument <a href="../examples/Test_Phonon.instr">Test_Phonon.instr</a>.
*
* %E
******************************************************************************/

DEFINE COMPONENT Phonon_simple

SETTING PARAMETERS (radius,yheight,sigma_abs,sigma_inc,a,b,M,c,DW,T,
              target_x=0, target_y=0, target_z=0, int target_index=0,
              focus_r=0,focus_xw=0,focus_yh=0,focus_aw=0,focus_ah=0, 
              gap=0, int e_steps_low=50, int e_steps_high=50)

/* Neutron parameters: (x,y,z,vx,vy,vz,t,sx,sy,sz,p) */
SHARE
%{
  #ifndef PHONON_SIMPLE
  #define PHONON_SIMPLE $Revision$
  #define T2E (1/11.605)   /* Kelvin to meV */

  struct phonon_params {
    double a_;   // d spacing of the cubic lattice
    double c_;   // Speed of sound in the material
    double gap_; // optional spin gap
    double ah;   // Half of a
    int e_steps_high_;
    int e_steps_low_;
  };
  struct neutron_params {
    // Statically allocate vectors that are always 3
    double vf;   // Final velocity size
    double vi;   // Initial velocity size
    double vv_x; // vv is the unit vector of the final velocity vector
    double vv_y;
    double vv_z;
    double vi_x; // vi is the initial velocity vector
    double vi_y;
    double vi_z;
  };

  #pragma acc routine
  double
  nbose (double omega, double T) /* Other name ?? */
  {
    double nb;

    nb = (omega > 0) ? 1 + 1 / (exp (omega / (T * T2E)) - 1) : 1 / (exp (-omega / (T * T2E)) - 1);
    return nb;
  }
  #undef T2E
  /* Routine types from Numerical Recipies book */
  #define UNUSED (-1.11e30)
  #define MAXRIDD 60

  void
  fatalerror_cpu (char* s) {
    fprintf (stderr, "%s \n", s);
    exit (1);
  }

  #pragma acc routine
  void
  fatalerror (char* s) {
    #ifndef OPENACC
    fatalerror_cpu (s);
    #endif
  }

  #pragma acc routine
  double
  omega_q (struct neutron_params* neutron, struct phonon_params* phonon) {
    /* dispersion in units of meV  */
    double vi, vf, vv_x, vv_y, vv_z, vi_x, vi_y, vi_z;
    double q, qx, qy, qz, Jq, res_phonon, res_neutron;
    double ah, a, c;
    double gap;

    vf = neutron->vf;
    vi = neutron->vi;
    vv_x = neutron->vv_x;
    vv_y = neutron->vv_y;
    vv_z = neutron->vv_z;
    vi_x = neutron->vi_x;
    vi_y = neutron->vi_y;
    vi_z = neutron->vi_z;
    a = phonon->a_;
    c = phonon->c_;
    gap = phonon->gap_;
    ah = phonon->ah;

    qx = V2K * (vi_x - vf * vv_x);
    qy = V2K * (vi_y - vf * vv_y);
    qz = V2K * (vi_z - vf * vv_z);
    q = sqrt (qx * qx + qy * qy + qz * qz);
    Jq = 2 * (cos (ah * (qx + qy)) + cos (ah * (qx - qy)) + cos (ah * (qx + qz)) + cos (ah * (qx - qz)) + cos (ah * (qy + qz)) + cos (ah * (qy - qz)));
    if (gap > 0) {
      res_phonon = sqrt (gap * gap + (12 - Jq) * (c * c) / (a * a));
    } else {
      res_phonon = c / a * sqrt (12 - Jq);
    }
    res_neutron = fabs (VS2E * (vi * vi - vf * vf));

    return (res_phonon - res_neutron);
  }

  double
  zridd (double (*func) (struct neutron_params*, struct phonon_params*), double x1, double x2, struct neutron_params* neutron, struct phonon_params* phonon,
         double xacc) {
    int j;
    double ans, fh, fl, fm, fnew, s, xh, xl, xm, xnew;

    neutron->vf = x1;
    fl = func (neutron, phonon);
    neutron->vf = x2;
    fh = func (neutron, phonon);
    if (fl * fh >= 0) {
      if (fl == 0)
        return x1;
      if (fh == 0)
        return x2;
      return UNUSED;
    } else {
      xl = x1;
      xh = x2;
      ans = UNUSED;
      for (j = 1; j < MAXRIDD; j++) {
        xm = 0.5 * (xl + xh);
        neutron->vf = xm;
        fm = func (neutron, phonon);
        s = sqrt (fm * fm - fl * fh);
        if (s == 0.0)
          return ans;
        xnew = xm + (xm - xl) * ((fl >= fh ? 1.0 : -1.0) * fm / s);
        if (fabs (xnew - ans) <= xacc)
          return ans;
        ans = xnew;
        neutron->vf = ans;
        fnew = func (neutron, phonon);
        if (fnew == 0.0)
          return ans;
        if (fabs (fm) * SIGN (fnew) != fm) {
          xl = xm;
          fl = fm;
          xh = ans;
          fh = fnew;
        } else if (fabs (fl) * SIGN (fnew) != fl) {
          xh = ans;
          fh = fnew;
        } else if (fabs (fh) * SIGN (fnew) != fh) {
          xl = ans;
          fl = fnew;
        } else
          fatalerror ("never get here in zridd");
        if (fabs (xh - xl) <= xacc)
          return ans;
      }
      fatalerror ("zridd exceeded maximum iterations");
    }
    return 0.0; /* Never get here */
  }

  #pragma acc routine
  double
  zridd_gpu (double x1, double x2, struct neutron_params* neutron, struct phonon_params* phonon, double xacc) {
    int j;
    double ans, fh, fl, fm, fnew, s, xh, xl, xm, xnew;

    neutron->vf = x1;
    fl = omega_q (neutron, phonon);
    neutron->vf = x2;
    fh = omega_q (neutron, phonon);
    if (fl * fh >= 0) {
      if (fl == 0)
        return x1;
      if (fh == 0)
        return x2;
      return UNUSED;
    } else {
      xl = x1;
      xh = x2;
      ans = UNUSED;
      for (j = 1; j < MAXRIDD; j++) {
        xm = 0.5 * (xl + xh);
        neutron->vf = xm;
        fm = omega_q (neutron, phonon);
        s = sqrt (fm * fm - fl * fh);
        if (s == 0.0)
          return ans;
        xnew = xm + (xm - xl) * ((fl >= fh ? 1.0 : -1.0) * fm / s);
        if (fabs (xnew - ans) <= xacc)
          return ans;
        ans = xnew;
        neutron->vf = ans;
        fnew = omega_q (neutron, phonon);
        if (fnew == 0.0)
          return ans;
        if (fabs (fm) * SIGN (fnew) != fm) {
          xl = xm;
          fl = fm;
          xh = ans;
          fh = fnew;
        } else if (fabs (fl) * SIGN (fnew) != fl) {
          xh = ans;
          fh = fnew;
        } else if (fabs (fh) * SIGN (fnew) != fh) {
          xl = ans;
          fl = fnew;
        } else
          fatalerror ("never get here in zridd");
        if (fabs (xh - xl) <= xacc)
          return ans;
      }
      fatalerror ("zridd exceeded maximum iterations");
    }
    return 0.0; /* Never get here */
  }

  #define ROOTACC 1e-8

  void
  findroots (double brack_low, double brack_mid, double brack_high, double* list, int* index, double (*f) (struct neutron_params*, struct phonon_params*),
             struct neutron_params* neutron, struct phonon_params* phonon) {
    double root;
    // Energy gain and energy loss spaces are not equally big. We check uniformly
    // So we use two different ranges
    double range_low = brack_mid - brack_low;
    double range_high = brack_high - brack_mid;
    // First in energy loss for the neutron
    for (int i = 0; i < phonon->e_steps_low_; i++) {
      root = zridd (f, brack_low + range_low * i / phonon->e_steps_low_, brack_low + range_low * (i + 1) / phonon->e_steps_low_, neutron, phonon, ROOTACC);
      if (root != UNUSED) {
        list[(*index)++] = root;
      }
    }

    // Then in energy gain for the neutron
    for (int i = 0; i < phonon->e_steps_high_; i++) {
      root = zridd (f, brack_mid + range_high * i / phonon->e_steps_high_, brack_mid + range_high * (i + 1) / phonon->e_steps_high_, neutron, phonon, ROOTACC);
      if (root != UNUSED) {
        list[(*index)++] = root;
      }
    }
  }

  #pragma acc routine
  void
  findroots_gpu (double brack_low, double brack_mid, double brack_high, double* list, int* index, struct neutron_params* neutron, struct phonon_params* phonon) {
    double root;
    // Energy gain and energy loss spaces are not equally big. We check uniformly
    // So we use two different ranges
    double range_low = brack_mid - brack_low;
    double range_high = brack_high - brack_mid;
    // First in energy loss for the neutron
    for (int i = 0; i < phonon->e_steps_low_; i++) {
      root = zridd_gpu (brack_low + range_low * i / phonon->e_steps_low_, brack_low + range_low * (i + 1) / phonon->e_steps_low_, neutron, phonon, ROOTACC);
      if (root != UNUSED) {
        list[(*index)++] = root;
      }
    }

    // Then in energy gain for the neutron
    for (int i = 0; i < phonon->e_steps_high_; i++) {
      root = zridd_gpu (brack_mid + range_high * i / phonon->e_steps_high_, brack_mid + range_high * (i + 1) / phonon->e_steps_high_, neutron, phonon, ROOTACC);
      if (root != UNUSED) {
        list[(*index)++] = root;
      }
    }
  }

  #undef UNUSED
  #undef MAXRIDD
  #endif
%}

DECLARE
%{
  double V_rho;
  double V_my_s;
  double V_my_a_v;
  double DV;
  struct phonon_params phonon;
%}
INITIALIZE
%{

  V_rho = 4 / (a * a * a);
  V_my_s = (V_rho * 100 * sigma_inc);
  V_my_a_v = (V_rho * 100 * sigma_abs * 2200);
  DV = 0.001; /* Velocity change used for numerical derivative */
  if (focus_aw)
    focus_aw *= DEG2RAD;
  if (focus_ah)
    focus_ah *= DEG2RAD;

  // Set constant parameters for parms object
  phonon.a_ = a;
  phonon.c_ = c;
  phonon.gap_ = gap;
  phonon.ah = a / 2.0;
  phonon.e_steps_high_ = e_steps_high;
  phonon.e_steps_low_ = e_steps_low;

  /* now compute target coords if a component index is supplied */
  if (!target_index && !target_x && !target_y && !target_z)
    target_index = 1;
  if (target_index) {
    Coords ToTarget;
    ToTarget = coords_sub (POS_A_COMP_INDEX (INDEX_CURRENT_COMP + target_index), POS_A_CURRENT_COMP);
    ToTarget = rot_apply (ROT_A_CURRENT_COMP, ToTarget);
    coords_get (ToTarget, &target_x, &target_y, &target_z);
  }
  if (!(target_x || target_y || target_z)) {
    printf ("Phonon_simple: %s: The target is not defined. Using direct beam (Z-axis).\n", NAME_CURRENT_COMP);
    target_z = 1;
  }
%}
TRACE
%{
  double* vf_list;
  #ifdef OPENACC
  vf_list = (double*)malloc ((e_steps_low + e_steps_high) * sizeof (double)); // List of allowed final velocities. Has length of scan_steps
  #else
  vf_list = (double*)calloc (e_steps_low + e_steps_high, sizeof (double)); // List of allowed final velocities. Has length of scan_steps
  #endif
  if (!vf_list) {
    printf ("Memory allocation failed, fatal error!\n");
    exit (-1);
  }
  #ifdef OPENACC
  for (int ii = 0; ii < e_steps_low + e_steps_high; ii++) {
    vf_list[ii] = 0;
  }
  #endif

  struct neutron_params neutron;
  double t0, t1;                          /* Entry/exit time for cylinder */
  double v_i, v_f;                        /* Neutron velocities: initial, final */
  double vx_i, vy_i, vz_i;                /* Neutron initial velocity vector */
  double dt0, dt;                         /* Flight times through sample */
  double l_full;                          /* Flight path length for non-scattered neutron */
  double l_i, l_o;                        /* Flight path lenght in/out for scattered neutron */
  double my_a_i;                          /* Initial attenuation factor */
  double my_a_f;                          /* Final attenuation factor */
  double solid_angle;                     /* Solid angle of target as seen from scattering point */
  double aim_x = 0, aim_y = 0, aim_z = 1; /* Position of target relative to scattering point */
  double kappa_x, kappa_y, kappa_z;       /* Scattering vector */
  double kappa2;                          /* Square of the scattering vector */
  double bose_factor;                     /* Calculated value of the Bose factor */
  double omega;                           /* energy transfer */
  int nf, index;                          /* Number of allowed final velocities */
  double J_factor;                        /* Jacobian from delta fnc.s in cross section */
  double f1, f2;                          /* probed values of omega_q minus omega */
  double p1, p2, p3, p4, p5;              /* temporary multipliers */

  if (cylinder_intersect (&t0, &t1, x, y, z, vx, vy, vz, radius, yheight)) {
    if (t0 < 0)
      ABSORB; /* Neutron came from the sample or begins inside */

    /* Neutron enters at t=t0. */
    dt0 = t1 - t0; /* Time in sample */
    v_i = sqrt (vx * vx + vy * vy + vz * vz);
    l_full = v_i * dt0;   /* Length of path through sample if not scattered */
    dt = rand01 () * dt0; /* Time of scattering (relative to t0) */
    l_i = v_i * dt;       /* Penetration in sample at scattering */
    vx_i = vx;
    vy_i = vy;
    vz_i = vz;
    PROP_DT (dt + t0); /* Point of scattering */

    aim_x = target_x - x; /* Vector pointing at target (e.g. analyzer) */
    aim_y = target_y - y;
    aim_z = target_z - z;

    if (focus_aw && focus_ah) {
      randvec_target_rect_angular (&vx, &vy, &vz, &solid_angle, aim_x, aim_y, aim_z, focus_aw, focus_ah, ROT_A_CURRENT_COMP);
    } else if (focus_xw && focus_yh) {
      randvec_target_rect (&vx, &vy, &vz, &solid_angle, aim_x, aim_y, aim_z, focus_xw, focus_yh, ROT_A_CURRENT_COMP);
    } else {
      randvec_target_sphere (&vx, &vy, &vz, &solid_angle, aim_x, aim_y, aim_z, focus_r);
    }
    NORM (vx, vy, vz);
    nf = 0;
    neutron.vf = -1;
    neutron.vi = v_i;
    neutron.vv_x = vx;
    neutron.vv_y = vy;
    neutron.vv_z = vz;
    neutron.vi_x = vx_i;
    neutron.vi_y = vy_i;
    neutron.vi_z = vz_i;

    #ifndef OPENACC
    findroots (0, v_i, v_i + 2 * c * V2K / VS2E, vf_list, &nf, omega_q, &neutron, &phonon);
    #else
    findroots_gpu (0, v_i, v_i + 2 * c * V2K / VS2E, vf_list, &nf, &neutron, &phonon);
    #endif
    index = (int)floor (rand01 () * nf);

    v_f = vf_list[index];
    neutron.vf = v_f - DV;
    f1 = omega_q (&neutron, &phonon);
    neutron.vf = v_f + DV;
    f2 = omega_q (&neutron, &phonon);
    J_factor = fabs (f2 - f1) / (2 * DV);
    omega = VS2E * (v_i * v_i - v_f * v_f);
    vx *= v_f;
    vy *= v_f;
    vz *= v_f;
    kappa_x = V2K * (vx_i - vx);
    kappa_y = V2K * (vy_i - vy);
    kappa_z = V2K * (vz_i - vz);
    kappa2 = kappa_z * kappa_z + kappa_y * kappa_y + kappa_x * kappa_x;

    if (!cylinder_intersect (&t0, &t1, x, y, z, vx, vy, vz, radius, yheight)) {
      /* ??? did not hit cylinder */
      printf ("FATAL ERROR: Did not hit cylinder from inside.\n");
      exit (1);
    }
    dt = t1;
    l_o = v_f * dt;

    my_a_i = V_my_a_v / v_i;
    my_a_f = V_my_a_v / v_f;
    bose_factor = nbose (omega, T);
    p1 = exp (-(V_my_s * (l_i + l_o) + my_a_i * l_i + my_a_f * l_o));                 /* Absorption factor */
    p2 = nf * solid_angle * l_full * V_rho / (4 * PI);                                /* Focusing factors; assume random choice of n_f possibilities */
    p3 = (v_f / v_i) * DW * (kappa2 * K2V * K2V * VS2E) / fabs (omega) * bose_factor; /* Cross section factor 1 */
    p4 = 2 * VS2E * v_f / J_factor;                                                   /* Jacobian of delta functions in cross section */
    p5 = b * b / M;                                                                   /* Cross section factor 2 */
    p *= p1 * p2 * p3 * p4 * p5;
  } /* else transmit: Neutron did not hit the sample */
  free (vf_list);
%}

FINALLY 
%{


%}

MCDISPLAY
%{

  circle ("xz", 0, yheight / 2.0, 0, radius);
  circle ("xz", 0, -yheight / 2.0, 0, radius);
  line (-radius, -yheight / 2.0, 0, -radius, +yheight / 2.0, 0);
  line (+radius, -yheight / 2.0, 0, +radius, +yheight / 2.0, 0);
  line (0, -yheight / 2.0, -radius, 0, +yheight / 2.0, -radius);
  line (0, -yheight / 2.0, +radius, 0, +yheight / 2.0, +radius);
%}

END
