/****************************************************************************
*
* McStas, neutron ray-tracing package
*         Copyright 1997-2003, All rights reserved
*         Risoe National Laboratory, Roskilde, Denmark
*         Institut Laue Langevin, Grenoble, France
*
* Component: Pol_guide_vmirror
*
* %I
* Written by: Peter Willendrup
* Date: June 2022
* Origin: DTU
*
* Polarising guide with nvs x 2 supermirrors sitting in v-shapes inside, 
* upgraded from original single V-cavity.
*
* %D
* Models a rectangular guide with entrance centered on the Z axis and
* with nvs x two supermirrors sitting in v-shapes inside.
* The entrance lies in the X-Y plane.  Draws a true depiction
* of the guide with mirrors, and trajectories.
* The polarisation is handled similar to in Monochromator_pol.
* The reflec functions are handled similar to Pol_mirror.
* The up direction is hardcoded to be along the y-axis (0, 1, 0)
* 
* The reflectivity parameters can be
* 1. double pointer initializations with 5 parameters (e.g. {R0, Qc, alpha, m, W} AND inputType=0), or
* 2. double pointer initializations with 6 parameters (e.g. {R0, Qc, alpha, m, W, beta} AND inputType=0), or
* 3. table names (e.g."supermirror_m2.rfl" AND inputType=1), or
* 4. individually assigned values (e.g. rR0=1.0, rQc=0.0217, ... etc AND inputType=2). 
* NB! This might cause warnings by the compiler that can be ignored.
* Note that the reflectivity functions that use with values and those that use tables are different. 
* 
* GRAVITY: YES
* 
* %BUGS
* No absorption by mirror.
* 
* %P
* INPUT PARAMETERS:
* •
* xwidth: [m]      Width at the guide entry
* yheight: [m]     Height at the guide entry
* length: [m]      length of guide
* 
* rFunc: [1]       Cavity wall Reflection function
* rPar: [1]        Cavity wall parameters array for rFunc, use with inputType = 0
* rParFile: []     Cavity wall parameters filename for rFunc, use with inputType = 1
* rR0: [1]         Cavity wall reflectivity below critical scattering vector, use with inputType = 2
* rQc: [AA-1]      Cavity wall reflectivity critical scattering vector, use natural Ni (m=1 definition) value of 0.0217, use with inputType = 2
* ralpha: [AA]     Cavity wall slope of reflectivity, use with inputType = 2   
* rmSM: [1]        Cavity wall m-value of material. Zero means completely absorbing. Use with inputType = 2 
* rW: [AA-1]       Cavity wall width of reflectivity cut-off, use with inputType = 2
* rbeta: [AA2]     Cavity wall curvature of reflectivity, use with inputType = 2
* 
* rUpFunc: [1]     Mirror Reflection function for spin up
* rUpPar: [1]      Mirror Parameters array for rUpFunc, use with inputType = 0
* rUpParFile: []   Mirror Parameters filename for rUpFunc, use with inputType = 1
* rUpR0: [1]       Mirror spin up reflectivity below critical scattering vector, use with inputType = 2
* rUpQc: [AA-1]    Mirror spin up reflectivity critical scattering vector, use natural Ni (m=1 definition) value of 0.0217, use with inputType = 2
* rUpalpha:[AA]    Mirror spin up slope of reflectivity, use with inputType = 2   
* rUpmSM: [1]      Mirror spin up m-value of material. Zero means completely absorbing. Use with inputType = 2 
* rUpW: [AA-1]     Mirror spin up width of reflectivity cut-off, use with inputType = 2
* rUpbeta: [AA2]   Mirror spin up curvature of reflectivity, use with inputType = 2
* 
* rDownFunc: [1]   Mirror Reflection function for spin down
* rDownPar: [1]    Mirror Parameters for rDownFunc, use with inputType = 0
* rDownParFile: [] Mirror Parameters filename for rDownFunc, use with inputType = 1
* rDownR0: [1]     Mirror spin down reflectivity below critical scattering vector, use with inputType = 2
* rDownQc: [AA-1]  Mirror spin down reflectivity critical scattering vector, use natural Ni (m=1 definition) value of 0.0217, use with inputType = 2
* rDownalpha: [AA] Mirror spin down slope of reflectivity, use with inputType = 2   
* rDownmSM: [1]    Mirror spin down m-value of material. Zero means completely absorbing. Use with inputType = 2 
* rDownW: [AA-1]   Mirror spin down width of reflectivity cut-off, use with inputType = 2
* rDownbeta: [AA2] Mirror spin down curvature of reflectivity, use with inputType = 2
* 
* inputType: [1]   Reflectivity parameters are 0: Values in array, 1: Table names, 2: Individual named values
* debug: [1]       if debug > 0 print out some internal runtime parameters
* nvs: [1]         Number of V-cavities across width of guide
* allow_inside_start: [1] Allow neutron to start inside cavity (no propagation to entry plane)
*
* CALCULATED PARAMETERS:
* 
* localG: [m/s/s]  Gravity vector in guide reference system 
* normalTop: [1]   One of several normal vectors used for defining the geometry 
* pointTop: [1]    One of several points used for defining the geometry 
* rParPtr:      One of several pointers to reflection parameters used with the ref. functions. []
* SCATTERED: []    is 1 for reflected, and 2 for transmitted neutrons
*
*
* %L
*
* %E
*******************************************************************************/

DEFINE COMPONENT Pol_guide_vmirror

SETTING PARAMETERS (int nvs=1,xwidth=0.1, yheight=0.1, length=0.5,
	rR0=1.0, rQc=0.0217, ralpha=6.5, rmSM=2, rW=0.00157, rbeta=80,
	rUpR0=1.0, rUpQc=0.0217, rUpalpha=2.47, rUpmSM=4, rUpW=0.0014, rUpbeta=0,
	rDownR0=1.0, rDownQc=0.0217, rDownalpha=1, rDownmSM=0.65, rDownW=0.003, rDownbeta=0, 
        int debug=0, allow_inside_start=0,
	vector rPar={1.0, 0.0217, 6.5, 2, 0.00157, 80},
	vector rUpPar={1.0, 0.0217, 2.47, 4, 0.0014},
	vector rDownPar={1.0, 0.0217, 1, 0.65, 0.003}, 
	string rParFile = "",
	string rUpParFile = "",
	string rDownParFile = "",
	int inputType=0)

/* Neutron parameters: (x,y,z,vx,vy,vz,t,sx,sy,sz,p) */

SHARE
%{
%include "pol-lib"
%include "ref-lib"
%}

DECLARE
%{
  Coords localG;
  Coords normalTop;
  Coords normalBot;
  Coords normalLeft;
  Coords normalRight;
  Coords normalInOut;
  Coords pointTop;
  Coords pointBot;
  Coords pointLeft;
  Coords pointRight;
  Coords pointIn;
  Coords pointOut;

  Coords* mirrorNormals;
  Coords* mirrorPoints;

  double xwhalf;
  double xwfull;
  double norm;

  int n_index;
  int i_index;
  double rParToFunc[6];
  double rUpParToFunc[6];
  double rDownParToFunc[6];

  t_Table rParPtr;
  t_Table rUpParPtr;
  t_Table rDownParPtr;
%}

INITIALIZE
%{
  if (inputType == 0) {
    // copy the SM reflectivity input parameter array r..Par to array r..ParToFunc
    // the input array can be 5 parameters, i.e. without beta, then beta=0
    // or it can be 6 parameters, including the value for beta.
    // r..ParToFunc is sent to the reflectivity functions

    for (i_index = 0; i_index < 6; i_index++) {
      rParToFunc[i_index] = 0;
      rUpParToFunc[i_index] = 0;
      rDownParToFunc[i_index] = 0;
    }

    n_index = sizeof (rPar) / sizeof (rPar[0]);
    for (i_index = 0; i_index < n_index; i_index++) {
      rParToFunc[i_index] = rPar[i_index];
    }

    n_index = sizeof (rUpPar) / sizeof (rUpPar[0]);
    for (i_index = 0; i_index < n_index; i_index++) {
      rUpParToFunc[i_index] = rUpPar[i_index];
    }

    n_index = sizeof (rDownPar) / sizeof (rDownPar[0]);
    for (i_index = 0; i_index < n_index; i_index++) {
      rDownParToFunc[i_index] = rDownPar[i_index];
    }
  } else if (inputType == 1) {
    if (Table_Read (&rParPtr, rParFile, 1) <= 0) {
      fprintf (stderr, "Pol_guide_vmirror: %s: can not read file %s\n", NAME_CURRENT_COMP, rPar);
      exit (1);
    }
    if (Table_Read (&rUpParPtr, rUpParFile, 1) <= 0) {
      fprintf (stderr, "Pol_guide_vmirror: %s: can not read file %s\n", NAME_CURRENT_COMP, rUpPar);
      exit (1);
    }
    if (Table_Read (&rDownParPtr, rDownParFile, 1) <= 0) {
      fprintf (stderr, "Pol_guide_vmirror: %s: can not read file %s\n", NAME_CURRENT_COMP, rDownPar);
      exit (1);
    }
    fprintf (stderr, "Pol_guide_vmirror: %s: Reading files is not possible!\n", NAME_CURRENT_COMP);
    exit (1);
  } else if (inputType == 2) {
    // Assemble the parameter array r..ParToFunc from the individual parameters
    // r..ParToFunc is sent to the reflectivity functions

    for (i_index = 0; i_index < 6; i_index++) {
      rParToFunc[i_index] = 0;
      rUpParToFunc[i_index] = 0;
      rDownParToFunc[i_index] = 0;
    }

    rParToFunc[0] = rR0;
    rParToFunc[1] = rQc;
    rParToFunc[2] = ralpha;
    rParToFunc[3] = rmSM;
    rParToFunc[4] = rW;
    rParToFunc[5] = rbeta;

    rUpParToFunc[0] = rUpR0;
    rUpParToFunc[1] = rUpQc;
    rUpParToFunc[2] = rUpalpha;
    rUpParToFunc[3] = rUpmSM;
    rUpParToFunc[4] = rUpW;
    rUpParToFunc[5] = rUpbeta;

    rDownParToFunc[0] = rDownR0;
    rDownParToFunc[1] = rDownQc;
    rDownParToFunc[2] = rDownalpha;
    rDownParToFunc[3] = rDownmSM;
    rDownParToFunc[4] = rDownW;
    rDownParToFunc[5] = rDownbeta;
  }

  if ((xwidth <= 0) || (yheight <= 0) || (length <= 0)) {
    fprintf (stderr,
             "Pol_guide_vmirror: %s: NULL or negative length scale!\n"
             "ERROR      (xwidth,yheight,length). Exiting\n",
             NAME_CURRENT_COMP);
    exit (1);
  }

  if (mcgravitation) {

    localG = rot_apply (ROT_A_CURRENT_COMP, coords_set (0, -GRAVITY, 0));
    fprintf (stdout, "Pol_guide_vmirror: %s: Gravity is on!\n", NAME_CURRENT_COMP);
  } else
    localG = coords_set (0, 0, 0);

  // To be able to handle the situation properly where a component of
  // the gravity is along the z-axis we also define entrance (in) and
  // exit (out) planes
  // The entrance and exit plane are defined by the normal vector
  // (0, 0, 1)
  // and the two points pointIn=(0, 0, 0) and pointOut=(0, 0, length)

  normalInOut = coords_set (0, 0, 1);
  pointIn = coords_set (0, 0, 0);
  pointOut = coords_set (0, 0, length);

  // Top plane (+y dir) can be spanned by (1, 0, 0) & (0, 0, 1)
  // and the point (0, yheight/2, 0)
  // A normal vector is (0, 1, 0)
  normalTop = coords_set (0, 1, 0);
  pointTop = coords_set (0, yheight / 2, 0);

  // Bottom plane (-y dir) can be spanned by (1, 0, 0) & (0, 0, 1)
  // and the point (0, -yheight/2, 0)
  // A normal vector is (0, 1, 0)
  normalBot = coords_set (0, 1, 0);
  pointBot = coords_set (0, -yheight / 2, 0);

  // Left plane (+x dir) can be spanned by (0, 1, 0) & (0, 0, 1)
  // and the point (xwidth/2, 0, 0)
  // A normal vector is (1, 0, 0)
  normalLeft = coords_set (1, 0, 0);
  pointLeft = coords_set (xwidth / 2, 0, 0);

  // Right plane (-x dir) can be spanned by (0, 1, 0) & (0, 0, 1)
  // and the point (-xwidth/2, 0, 0)
  // A normal vector is (1, 0, 0)
  normalRight = coords_set (1, 0, 0);
  pointRight = coords_set (-xwidth / 2, 0, 0);

  mirrorNormals = malloc (2 * nvs * sizeof (Coords));
  mirrorPoints = malloc (2 * nvs * sizeof (Coords));

  /* Split in odd and even nvs cases */
  xwhalf = xwidth / (2 * nvs);
  xwfull = xwidth / nvs;
  norm = 1.0 / sqrt (xwhalf * xwhalf + length * length);

  /* For "all-mirror" solution: */
  /* double dx; */
  /* if (nvs % 2) { */
  /*   /\* Odd number of cavities *\/ */
  /*   dx=0; */
  /* } else { */
  /*   /\* Even number of cavities *\/ */
  /*   dx=xwhalf; */
  /* } */

  /* int counter; */
  /* for (counter=1; counter <= nvs; counter++)  { */
  /*   /\* Put in the mirror point locations *\/ */
  /*   mirrorPoints[2*counter-2] = coords_set(-(dx+counter*xwfull), 0, 0); */
  /*   mirrorPoints[2*counter-1] = coords_set((dx+counter*xwfull), 0, 0); */
  /*   normalMirror1  = coords_set(norm*length, 0, -norm*xwhalf); */

  /* } */
%}

TRACE
%{
  double R;
  int sgn, ivs;
  /* time threshold */
  double tThreshold = 1e-10 / sqrt (vx * vx + vy * vy + vz * vz);
  Coords normalMirror1, pointMirror1;
  Coords normalMirror2, pointMirror2;
  Coords* normalPointer = 0;

  // Pol variables
  double FN, FM, Rup, Rdown, refWeight;

  if (!allow_inside_start || z < 0) {
    /* Propagate neutron to guide entrance. */
    PROP_Z0;
  }

  if (!inside_rectangle (x, y, xwidth, yheight))
    ABSORB;

  if (debug)
    printf ("-........-\n");

  for (;;) {
    double tLeft, tRight, tTop, tBot, tIn, tOut, tMirror1, tMirror2;
    double tUp, tSide, time, endtime;
    double Q; //, dummy1, dummy2, dummy3;
    Coords vVec, xVec;
    int mirrorReflect;

    /* Hal parametrization */
    int N, m, Total;
    double downmirror, upmirror;

    Total = nvs;

    N = floor (fabs (x) / xwhalf);
    /* if N == total, exge case ... */

    m = 1 - (Total - floor (Total / 2.0) * 2);
    if (debug)
      printf ("Total=%i m=%i", Total, m);
    if (debug)
      printf ("neutron is in N=%i and m=%i\n", N, m);

    sgn = x / (fabs (x));

    if (sgn > 0) {
      downmirror = (floor ((N + m) / 2.0) * xwfull - m * xwhalf) * sgn;
      upmirror = ((floor ((N + m) / 2.0) + 1) * xwfull - m * xwhalf) * sgn;
    } else {
      upmirror = (floor ((N + m) / 2.0) * xwfull - m * xwhalf) * sgn;
      downmirror = ((floor ((N + m) / 2.0) + 1) * xwfull - m * xwhalf) * sgn;
    }

    if (debug)
      printf ("Found downmirror=%g , x=%g and upmirror=%g\n", downmirror, x, upmirror);

    normalMirror1 = coords_set (norm * length, 0, norm * xwhalf);
    pointMirror1 = coords_set (upmirror, 0, 0);
    normalMirror2 = coords_set (norm * length, 0, -norm * xwhalf);
    pointMirror2 = coords_set (downmirror, 0, 0);

    if (debug)
      printf ("Normal down=%g ,and up=%g with sign=%i\n", norm * xwhalf, -norm * xwhalf, sgn);

    /* ivs = floor(fabs(x)/(2*xwhalf));

    normalMirror1  = coords_set(norm*length, 0, -norm*xwhalf);
    pointMirror1 = coords_set(sgn*ivs*xwhalf, 0, 0);
    normalMirror2  = coords_set(norm*length, 0, norm*xwhalf);
    pointMirror2 = coords_set(sgn*ivs*xwhalf, 0, 0);*/

    mirrorReflect = 0;
    xVec = coords_set (x, y, z);
    vVec = coords_set (vx, vy, vz);

    solve_2nd_order (&tTop, NULL, 0.5 * coords_sp (normalTop, localG), coords_sp (normalTop, vVec), coords_sp (normalTop, coords_sub (xVec, pointTop)));

    solve_2nd_order (&tBot, NULL, 0.5 * coords_sp (normalBot, localG), coords_sp (normalBot, vVec), coords_sp (normalBot, coords_sub (xVec, pointBot)));

    solve_2nd_order (&tRight, NULL, 0.5 * coords_sp (normalRight, localG), coords_sp (normalRight, vVec), coords_sp (normalRight, coords_sub (xVec, pointRight)));

    solve_2nd_order (&tLeft, NULL, 0.5 * coords_sp (normalLeft, localG), coords_sp (normalLeft, vVec), coords_sp (normalLeft, coords_sub (xVec, pointLeft)));

    solve_2nd_order (&tIn, NULL, 0.5 * coords_sp (normalInOut, localG), coords_sp (normalInOut, vVec), coords_sp (normalInOut, coords_sub (xVec, pointIn)));

    solve_2nd_order (&tOut, NULL, 0.5 * coords_sp (normalInOut, localG), coords_sp (normalInOut, vVec), coords_sp (normalInOut, coords_sub (xVec, pointOut)));

    solve_2nd_order (&tMirror1, NULL, 0.5 * coords_sp (normalMirror1, localG), coords_sp (normalMirror1, vVec),
                     coords_sp (normalMirror1, coords_sub (xVec, pointMirror1)));

    solve_2nd_order (&tMirror2, NULL, 0.5 * coords_sp (normalMirror2, localG), coords_sp (normalMirror2, vVec),
                     coords_sp (normalMirror2, coords_sub (xVec, pointMirror2)));

    /* Choose appropriate reflection time */
    if (tTop > tThreshold && (tTop < tBot || tBot <= tThreshold))
      tUp = tTop;
    else
      tUp = tBot;

    if (tLeft > tThreshold && (tLeft < tRight || tRight <= tThreshold))
      tSide = tLeft;
    else
      tSide = tRight;

    if (tUp > tThreshold && (tUp < tSide || tSide <= tThreshold))
      time = tUp;
    else
      time = tSide;

    if (tMirror1 > tThreshold && tMirror1 < time) {
      time = tMirror1;
      mirrorReflect = 1; // flag to show which reflection function to use
    }
    if (tMirror2 > tThreshold && tMirror2 < time) {
      time = tMirror2;
      mirrorReflect = 2; // flag to show which reflection function to use
    }

    if (time <= tThreshold)
      fprintf (stdout,
               "Pol_guide_vmirror: %s: tTop: %f, tBot:%f, tRight: %f, tLeft: %f\n"
               "tUp: %f, tSide: %f, time: %f\n",
               NAME_CURRENT_COMP, tTop, tBot, tRight, tLeft, tUp, tSide, time);

    /* Has neutron left the guide? */
    if (tOut > tThreshold && (tOut < tIn || tIn <= tThreshold))
      endtime = tOut;
    else
      endtime = tIn;

    if (time > endtime)
      break;

    if (time <= tThreshold) {

      printf ("Time below threshold!\n");
      fprintf (stdout,
               "Pol_guide_vmirror: %s: tTop: %f, tBot:%f, tRight: %f, tLeft: %f\n"
               "tUp: %f, tSide: %f, time: %f\n",
               NAME_CURRENT_COMP, tTop, tBot, tRight, tLeft, tUp, tSide, time);
      break;
    }

    if (debug > 0 && time == tLeft) {

      fprintf (stdout, "\nPol_guide_vmirror: %s: Left side hit: x, v, normal, point, gravity\n", NAME_CURRENT_COMP);
      coords_print (xVec);
      coords_print (vVec);
      coords_print (normalLeft);
      coords_print (pointLeft);
      coords_print (localG);

      fprintf (stdout, "\nA: %f, B: %f, C: %f, tLeft: %f\n", 0.5 * coords_sp (normalLeft, localG), coords_sp (normalLeft, vVec),
               coords_sp (normalLeft, coords_sub (xVec, pointLeft)), tLeft);
    }

    if (debug > 0)
      fprintf (stdout,
               "Pol_guide_vmirror: %s: tTop: %f, tBot:%f, tRight: %f, tLeft: %f\n"
               "tUp: %f, tSide: %f, time: %f\n",
               NAME_CURRENT_COMP, tTop, tBot, tRight, tLeft, tUp, tSide, time);

    if (debug > 0)
      fprintf (stdout, "Pol_guide_vmirror: %s: Start v: (%f, %f, %f)\n", NAME_CURRENT_COMP, vx, vy, vz);

    PROP_DT (time);
    if (mcgravitation)
      vVec = coords_set (vx, vy, vz);
    SCATTER;

    if (time == tTop)
      normalPointer = &normalTop;
    else if (time == tBot)
      normalPointer = &normalBot;
    else if (time == tRight)
      normalPointer = &normalRight;
    else if (time == tLeft)
      normalPointer = &normalLeft;
    else if (time == tMirror1)
      normalPointer = &normalMirror1;
    else if (time == tMirror2)
      normalPointer = &normalMirror2;
    else
      fprintf (stderr, "Pol_guide_vmirror: %s: This should never happen!!!!\n", NAME_CURRENT_COMP);

    Q = 2 * coords_sp (vVec, *normalPointer) * V2K;

    if (!mirrorReflect) {
      // we have hit one of the sides. Always reflect.
      vVec = coords_add (vVec, coords_scale (*normalPointer, -Q * K2V));
      StdReflecFunc (fabs (Q), rParToFunc, &refWeight);
      p *= refWeight;

    } else {
      // we have hit one of the mirrors
      StdDoubleReflecFunc (fabs (Q), rUpParToFunc, &Rup);
      StdDoubleReflecFunc (fabs (Q), rDownParToFunc, &Rdown);
      if (Rup < 0)
        ABSORB;
      if (Rup > 1)
        Rup = 1;
      if (Rdown < 0)
        ABSORB;
      if (Rdown > 1)
        Rdown = 1;

      GetMonoPolFNFM (Rup, Rdown, &FN, &FM);
      GetMonoPolRefProb (FN, FM, sy, &refWeight);
      /* Output of PW discussions with Hal Lee 2024/03/08
         We have now done our QM "measurement", thus
         forcing the spin to assume up/down: */
      sx = 0;
      sz = 0;
      // check that refWeight is meaningfull
      if (refWeight < 0)
        ABSORB;
      if (refWeight > 1)
        refWeight = 1;

      if (rand01 () < refWeight) {
        // reflect: SCATTERED==1 for reflection

        vVec = coords_add (vVec, coords_scale (*normalPointer, -Q * K2V));
        SetMonoPolRefOut (FN, FM, refWeight, &sx, &sy, &sz);
      } else {
        // transmit: SCATTERED==2 for transmission
        SCATTER;
        SetMonoPolTransOut (FN, FM, refWeight, &sx, &sy, &sz);
      }

      if (sx * sx + sy * sy + sz * sz > 1.000001) { // check that polarisation is meaningful
        fprintf (stderr, "Pol_guide_vmirror: %s: polarisation |s|=%g > 1 s=[%g,%g,%g]\n", NAME_CURRENT_COMP, sx * sx + sy * sy + sz * sz, sx, sy, sz);
      }
    }

    if (p == 0) {
      ABSORB;
      break;
    }

    // set new velocity vector
    coords_get (vVec, &vx, &vy, &vz);

    if (debug > 0)
      fprintf (stdout, "Pol_guide_vmirror: %s: End v: (%f, %f, %f)\n", NAME_CURRENT_COMP, vx, vy, vz);
  }
  %}

FINALLY
%{
  free (mirrorNormals);
  free (mirrorPoints);
%}


MCDISPLAY
%{
  int i, j;

  double dx = xwidth / (2 * nvs);

  // draw box
  box (0, 0, length / 2.0, xwidth, yheight, length, 0, 0, 1, 0);

  for (i = -nvs; i <= nvs; i += 2)
    for (j = -1; j <= 1; j += 2) {
      double dx2 = (i + j) * dx;
      if (dx2 < -xwidth / 2.0)
        dx2 = -xwidth / 2.0;
      if (dx2 > xwidth / 2.0)
        dx2 = xwidth / 2.0;
      line (dx2, j * yheight / 2, 0, i * dx, j * yheight / 2, length);
    }
  %}

END
