/*******************************************************************************
* Component: MCPL_input_once
*
* %I
* Written by: Gregory S Tucker
* Date: Sep 2024
* Origin: European Spallation Source ERIC
*
* Source-like component that reads neutron state parameters from a MCPL-file one time.
* %D
* Source-like component that reads neutron state parameters from a MCPL-file one time.
*
* MCPL is short for Monte Carlo Particle List, and is a format for sharing events
* between e.g. MCNP(X), Geant4 and McStas.
*
* When used with MPI, the file contents are shared between workers with each accessing
* approximately (#events in the file) / (#MPI nodes)
*
* %BUGS
*
* %P
* INPUT PARAMETERS
*
* filename: [str]       Name of neutron mcpl file to read.
* preload: [ ]          Load particles during INITIALIZE. On GPU preload is forced.
* polarisationuse: [ ]  If !=0 read polarisation vectors from file.
* Emin: [meV]           Lower energy bound. Particles found in the MCPL-file below the limit are skipped.
* Emax: [meV]           Upper energy bound. Particles found in the MCPL-file above the limit are skipped.
* v_smear: [1]          Relative fraction for randomness in replayed particle velocity v *= (1 + v_smear * randpm1())
* pos_smear: [m]        Maximum extent of random position for replayed particles
* dir_smear: [deg]      Maximum deviation of random direction for replayed particles
* always_smear: [ ]     Finite values force particle property smearing for all particles
* verbose: [ ]          Finite values may cause extra output (at some point in the future)
*
* %E
*******************************************************************************/

DEFINE COMPONENT MCPL_input_once

SETTING PARAMETERS (string filename=0, polarisationuse=1, Emin=0, Emax=FLT_MAX, v_smear=0, pos_smear=0, dir_smear=0, int always_smear=0, int preload=0, int verbose=0)

DEPENDENCY "@MCPLFLAGS@"

SHARE
%{
  #include <mcpl.h>
  #include <sys/stat.h>

  typedef struct {
    // Prefixed names to avoid the _particle accessor macros
    double _x, _y, _z;
    double _vx, _vy, _vz;
    double _sx, _sy, _sz;
    // insert [int, void*, int, double, double, double, unsigned long]
    // here in order to have matching memory layout with
    // _struct_particle, e.g., _class_particle
    double _t, _p;
  } mcpl_input_once_particle_t;

  void
  mcpl_input_once_translator (const int use_polarisation, double weight_scale, const mcpl_particle_t* input, mcpl_input_once_particle_t* output) {
    // position in mm -> m
    output->_x = input->position[0] / 100;
    output->_y = input->position[1] / 100;
    output->_z = input->position[2] / 100;
    // ekin in MeV -> meV; then to velocity
    double nrm = sqrt (input->ekin * 1e9 / VS2E);
    output->_vx = input->direction[0] * nrm;
    output->_vy = input->direction[1] * nrm;
    output->_vz = input->direction[2] * nrm;
    // polarization is a direction, and we might ignore it
    output->_sx = (use_polarisation) ? input->polarisation[0] : 0;
    output->_sy = (use_polarisation) ? input->polarisation[1] : 0;
    output->_sz = (use_polarisation) ? input->polarisation[2] : 0;
    // time in msec -> sec
    output->_t = input->time * 1e-3;
    // probability, unitless
    output->_p = input->weight * weight_scale;
  }

  int
  mcplinputonce_file_exist (char* fn) {
    struct stat buffer;
    return (stat (fn, &buffer) == 0);
  }
%}

DECLARE
%{
  mcpl_file_t inputfile;
  uint64_t nparticles;
  uint64_t read_neutrons;
  uint64_t used_neutrons;
  uint64_t emitted_neutrons;
  uint64_t current_index;
  uint64_t maximum_index;
  uint64_t first_particle;
  int times_replayed;
  mcpl_input_once_particle_t* particles;
  double weight_scale;
  char* resolved_filename;
%}

INITIALIZE
%{
  {
    if (!filename || !filename[0]) {
      fprintf (stderr, "ERROR(%s): Requires filename parameter.\n", NAME_CURRENT_COMP);
      exit (-1);
    }
    char* fn_mcpl = mcpl_name_helper (filename, 'M');
    char* fn_mcplgz = mcpl_name_helper (filename, 'G');
    if (mcplinputonce_file_exist (fn_mcpl)) {
      if (mcplinputonce_file_exist (fn_mcplgz)) {
        fprintf (stderr,
                 "ERROR(%s): Can not resolve input file unambiguously"
                 " since both %s and %s exist.\n",
                 NAME_CURRENT_COMP, fn_mcpl, fn_mcplgz);
        exit (-1);
      }
      resolved_filename = fn_mcpl;
      free (fn_mcplgz);
    } else {
      resolved_filename = fn_mcplgz;
      free (fn_mcpl);
    }
  }

  uint64_t particles_per_node;
  uint64_t last_particle;
  if (Emax < Emin) {
    fprintf (stderr, "Error(%s): Nonsensical energy interval: E=[%g,%g]. Aborting.\n", NAME_CURRENT_COMP, Emin, Emax);
    exit (-1);
  }
  inputfile = mcpl_open_file (resolved_filename);
  double mcpl_ray_count = mcpl_hdr_stat_sum (inputfile, "initial_ray_count");
  if (mcpl_ray_count == -2.0) {
    // legacy format without ray count:
    weight_scale = 1.0;
  } else if (!(mcpl_ray_count > 0.0)) {
    fprintf (stderr,
             "ERROR: Input MCPL file has invalid initial_ray_count"
             " (%g). Unable to determine weight scale.\n",
             mcpl_ray_count);
    exit (1);
  } else {
    weight_scale = 1.0 / mcpl_ray_count;
  }
  if (!(nparticles = mcpl_hdr_nparticles (inputfile))) {
    fprintf (stderr, "Error(%s): MCPL-file reports no present particles. Aborting.\n", NAME_CURRENT_COMP);
    exit (-1);
  } else {
    MPI_MASTER (printf ("Message(%s): MCPL file (%s) produced with %s.\n", NAME_CURRENT_COMP, resolved_filename, mcpl_hdr_srcname (inputfile));
                printf ("Message(%s): MCPL file (%s) contains %lu particles.\n", NAME_CURRENT_COMP, resolved_filename, (long unsigned)nparticles););
  }
  first_particle = 0;
  last_particle = nparticles;
  #if defined (USE_MPI)
  // divy up the available particles between nodes
  particles_per_node = last_particle / mpi_node_count;
  // ensuring at least 1 particle per node (e.g., protecting against division by negative node count)
  if (particles_per_node < 1)
    particles_per_node = 1;
  // each node has first index given by how many particles each should do
  first_particle = particles_per_node * mpi_node_rank;
  // the last worker keeps 'nparticles' as its last particle index, to ensure the full range is covered
  if (mpi_node_rank != mpi_node_count - 1)
    last_particle = first_particle + particles_per_node;
  #endif
  read_neutrons = 0;
  used_neutrons = 0;
  #ifdef OPENACC
  preload = 1;
  printf ("OpenACC, preload implicit:\n");
  #endif
  // Move this node's pointer into the file to its first particle:
  // in preparation for pre-loading or first pass through TRACE
  mcpl_seek (inputfile, first_particle);
  // Index will track how many particles this component has accessed
  current_index = 0;
  // Which we want to check against how large it can grow, to know if we've exhausted the available particles
  maximum_index = last_particle - first_particle;
  particles = NULL;
  if (preload) {
    printf ("Preload requested, loading MCPLfile particles (%lu, %lu) in INITIALIZE\n", (long unsigned)first_particle, (long unsigned)last_particle);
    particles = (mcpl_input_once_particle_t*)calloc (last_particle - first_particle, sizeof (mcpl_input_once_particle_t));
    for (uint64_t loop = first_particle; loop < last_particle; loop++) {
      const mcpl_particle_t* particle;
      particle = mcpl_read (inputfile);
      if (particle && particle->pdgcode == 2112) {
        if (particle->ekin > Emin * 1e-9 && particle->ekin < Emax * 1e-9) {
          mcpl_input_once_translator (polarisationuse, weight_scale, particle, particles + used_neutrons++);
        }
        read_neutrons++;
      }
    }
    // keep track of how many particles are available to us in the `particles` array
    maximum_index = used_neutrons;
    printf ("Done reading MCPL file (%lu, %lu), found %lu neutrons (and kept %lu)\n", (long unsigned)first_particle, (long unsigned)last_particle,
            (long unsigned)read_neutrons, (long unsigned)used_neutrons);
    mcpl_close_file (inputfile);
  }
  // keep track of how many times we had to replay the same MCPL input
  times_replayed = 0;
  // and the total number of neutrons emitted by this component
  emitted_neutrons = 0;
  // Determine the total number of read (or to be read) neutrons
  uint64_t total_index = maximum_index;
  #if defined (USE_MPI)
  // add up the read neutrons (or to be read neutrons) from each node
  MPI_Reduce (&maximum_index, &total_index, 1, MPI_UINT64_T, MPI_SUM, 0, MPI_COMM_WORLD);
  // tell all nodes the value, so they can set ncount for themselves
  MPI_Bcast (&total_index, 1, MPI_UINT64_T, 0, MPI_COMM_WORLD);
  #endif
  mcset_ncount (total_index); // will be divided by mpi_node_count before starting the raytrace
%}

TRACE
%{
  mcpl_input_once_particle_t* ptr = NULL;

  if (current_index >= maximum_index) {
    // go back to the start of this worker's particles, rather than accessing out-of-range data
    #pragma acc atomic write
    current_index = 0;
    #ifndef OPENACC
    if (!preload) {
      // mcpl_seek is not available under openACC, and not used anyway
      mcpl_seek (inputfile, first_particle);
    }
    #endif
    times_replayed++;
  }
  #ifndef OPENACC
  // preload can only ever be false if OPENACC is not define (in which case it is the likely code path)
  if (!preload) {
    ptr = (mcpl_input_once_particle_t*)calloc (1, sizeof (mcpl_input_once_particle_t));
    const mcpl_particle_t* particle; // = (mcpl_particle_t *) calloc(sizeof(mcpl_particle_t),1);
    particle = mcpl_read (inputfile);
    current_index++; // track how many particles have been accessed, to ensure we don't exceed our slice of the file
    if (!particle || particle->pdgcode != 2112) {
      ABSORB;
    }
    // keep track of how many neutrons have been read; but only on the first replay
    if (!times_replayed)
      read_neutrons++;
    if (particle->ekin < Emin * 1e-9 || particle->ekin > Emax * 1e-9) {
      ABSORB;
    }
    mcpl_input_once_translator (polarisationuse, weight_scale, particle, ptr);
    // And how many of the read neutrons actually get used
    if (!times_replayed)
      used_neutrons++;
  } else {
    #endif
    ptr = particles + (_particle->_uid % maximum_index);
    #ifdef OPENACC
    #pragma acc atomic
    current_index++; // track how many particles have been accessed, to ensure we don't exceed our slice of the file
                     #else
  }
  #endif
  // copy from the component particle struct to the particle ray struct:
  if (ptr == NULL) {
    fprintf (stderr, "ERROR (%s): component particle struct pointer not set! Crash before out-of-bounds memory access.\n", NAME_CURRENT_COMP);
    exit (-1);
  }
  #pragma acc atomic
  emitted_neutrons++;
  // this could be done via memcpy if we ensure equal memory layout with the ray's struct
  x = ptr->_x;
  y = ptr->_y;
  z = ptr->_z;
  vx = ptr->_vx;
  vy = ptr->_vy;
  vz = ptr->_vz;
  sx = ptr->_sx;
  sy = ptr->_sy;
  sz = ptr->_sz;
  t = ptr->_t;
  p = ptr->_p;

  // done with the mcpl_input_once_particle_t pointer -- free it if not pointing into the particles list
  if (!preload && ptr != NULL)
    free (ptr);

  if (always_smear || times_replayed) {
    // fuzz the input
    if (pos_smear) {
      double tmpx, tmpy, tmpz;
      randvec_target_circle (&tmpx, &tmpy, &tmpz, NULL, 0, 0, 1, 0);
      NORM (tmpx, tmpy, tmpz);
      x += tmpx * pos_smear * rand01 ();
      y += tmpy * pos_smear * rand01 ();
      z += tmpz * pos_smear * rand01 ();
    }
    if (v_smear) {
      double fraction;
      fraction = 1.0 + v_smear * randpm1 ();
      vx *= fraction;
      vy *= fraction;
      vz *= fraction;
    }
    if (dir_smear) {
      double vv, dx, dy, dz;
      vv = sqrt (vx * vx + vy * vy + vz * vz);
      dx = vx / vv;
      dy = vy / vv;
      dz = vz / vv;
      randvec_target_circle (&dx, &dy, &dz, NULL, dx, dy, dz, sin (dir_smear * DEG2RAD));
      NORM (dx, dy, dz);
      vx = dx * vv;
      vy = dy * vv;
      vz = dz * vv;
    }
  }
  SCATTER;
%}

SAVE
%{
  #ifndef OPENACC
  if (!preload)
    mcpl_close_file (inputfile);
  #endif
  if (particles != NULL)
    free (particles);
%}

FINALLY
%{
  if (times_replayed && v_smear == 0 && pos_smear == 0 && dir_smear == 0) {
    printf ("Warning (%s): Forced to replay particle list %d time(s) without smearing\n", NAME_CURRENT_COMP, times_replayed);
  }
  char mpi_nodes_message[256];
  mpi_nodes_message[0] = '\0';
  uint64_t requested_neutrons = (uint64_t)mcget_ncount ();

  #if defined (USE_MPI)
  uint64_t accumulated[4], distributed[4];
  distributed[0] = used_neutrons;
  distributed[1] = read_neutrons;
  distributed[2] = emitted_neutrons;
  distributed[3] = requested_neutrons;
  MPI_Reduce (&distributed, &accumulated, 4, MPI_UINT64_T, MPI_SUM, 0, MPI_COMM_WORLD);
  if (mpi_node_rank == 0) {
    used_neutrons = accumulated[0];
    read_neutrons = accumulated[1];
    emitted_neutrons = accumulated[2];
    requested_neutrons = accumulated[3];
    sprintf (mpi_nodes_message, "he %d MPI node copies of t", mpi_node_count);
    #endif
    if (used_neutrons != read_neutrons) {
      fprintf (stdout, "Message(%s): You have used %llu of %llu neutrons available in the MCPL file.\n", NAME_CURRENT_COMP, used_neutrons, read_neutrons);
    }
    if (requested_neutrons != used_neutrons) {
      char bad_particle_message[512];
      bad_particle_message[0] = '\0';
      if (used_neutrons != nparticles) {
        sprintf (bad_particle_message, " (of which %llu are neutrons within the requested energy interval)", used_neutrons);
      }
      fprintf (stderr,
               "Warning (%s): You requested %llu neutrons from a file which contains %llu particles in general%s."
               " T%shis component emitted %llu neutrons in total. Please examine the recorded intensities carefully.\n",
               NAME_CURRENT_COMP, requested_neutrons, nparticles, bad_particle_message, mpi_nodes_message, emitted_neutrons);
    }
    #if defined (USE_MPI)
  }
  #endif
  free (resolved_filename);
%}

MCDISPLAY
%{
  multiline (5, 0.2, 0.2, 0.0, -0.2, 0.2, 0.0, -0.2, -0.2, 0.0, 0.2, -0.2, 0.0, 0.2, 0.2, 0.0);
  /*M*/
  multiline (5, -0.085, -0.085, 0.0, -0.085, 0.085, 0.0, -0.045, -0.085, 0.0, -0.005, 0.085, 0.0, -0.005, -0.085, 0.0);
  /*I*/
  line (0.045, -0.085, 0, 0.045, 0.085, 0);
  line (0.005, 0.085, 0, 0.085, 0.085, 0);
  line (0.005, -0.085, 0, 0.085, -0.085, 0);
%}

END
