#!/usr/bin/env ruby

nvar = ARGV[0].to_i
nrep = ARGV[1].to_i

File.open("test_utils_omp.hpp", "w") { |fout|

  fout.printf("double test_simd(int nloop, long long& ops, double& dt, double& gflops) {\n")
  fout.printf("\n")
  fout.printf("  int nthreads;\n")
  fout.printf("#pragma omp parallel\n")
  fout.printf("  {\n")
  fout.printf("    if (omp_get_thread_num() == 0)\n")
  fout.printf("      nthreads = omp_get_num_threads();\n")
  fout.printf("  }\n")
  fout.printf("\n")

  fout.printf("  long long **values;\n")
  fout.printf("  values = (long long **) malloc((size_t) sizeof(long long *)*nthreads);\n")
  fout.printf("  memset(values, 0x0, (size_t) sizeof(long long *)*nthreads);\n")
  fout.printf("  for (int i = 0; i < nthreads; ++i) {\n")
  fout.printf("    values[i] = (long long *) malloc((size_t) sizeof(long long));\n")
  fout.printf("    memset(values[i], 0x0, (size_t) sizeof(long long));\n")
  fout.printf("  }\n")    
  fout.printf("\n")

  fout.printf("  __m256d a[%d];\n", nvar)
  fout.printf("  for (int i = 0; i < %d; ++i) {\n", nvar)
  fout.printf("    a[i] = _mm256_set_pd(0.5, 2.2, 1.6, 3.1);\n")
  fout.printf("  }\n")
  fout.printf("\n")

  fout.printf("  double t0 = etime();\n")
  fout.printf("#pragma omp parallel\n")
  fout.printf("  {\n")
  fout.printf("  int EventSet = PAPI_NULL;\n")
  fout.printf("  PAPI_create_eventset( &EventSet );\n")
  fout.printf("  PAPI_add_event(EventSet, PAPI_VEC_DP);\n")
  fout.printf("  //PAPI_add_event(EventSet, PAPI_DP_OPS);\n")
  fout.printf("  PAPI_start( EventSet );\n")
  fout.printf("#pragma omp for\n")
  fout.printf("    for (int i = 0; i < nloop; ++i) {\n")

  for i in 0...nrep
    fout.printf("      // %d\n", i+1)
    for j in 0...nvar
      fout.printf("      a[%2d] = _mm256_mul_pd(a[%2d], a[%2d]);\n", j, j, j)
    end
    for j in 0...nvar
      fout.printf("      a[%2d] = _mm256_add_pd(a[%2d], a[%2d]);\n", j, j, j)
    end
    fout.printf("\n")
  end

  fout.printf("    }\n")
  fout.printf("    PAPI_stop( EventSet, values[omp_get_thread_num()]);\n")
  fout.printf("  }\n")
  fout.printf("  double t1 = etime();\n")
  fout.printf("  dt = t1 - t0;\n")
  fout.printf("\n")
  fout.printf("  ops = 0;\n")
  fout.printf("  for (int i = 0; i < nthreads; ++i) {\n")
  fout.printf("    ops += values[i][0];\n")
  fout.printf("  }\n")
  fout.printf("  gflops = 2.0*ops/dt*1.0e-9;\n")
  fout.printf("  //gflops = 1.0*ops/dt*1.0e-9;\n")
  fout.printf("\n")
  fout.printf("  __m256d res = _mm256_setzero_pd();\n")
  fout.printf("  for (int i = 0; i < %d; ++i) {\n", nrep)
  fout.printf("    res = _mm256_add_pd(res, a[i]);\n")
  fout.printf("  }\n")
  fout.printf("\n")
  fout.printf("  double *val;\n")
  fout.printf("  if (posix_memalign((void **) &(val), 32, sizeof(double)*2) != 0) {\n")
  fout.printf("    std::cerr << \"memory allocation error.\" << std::endl;\n")
  fout.printf("    std::exit(EXIT_FAILURE);\n")
  fout.printf("  }\n")
  fout.printf("  _mm256_store_pd(val, res);\n")
  fout.printf("\n")
  fout.printf("  return val[0] + val[1];\n")
  fout.printf("}\n")

}

