/*
 *  Copyright (C) 2004-2024 Edward F. Valeev
 *
 *  This file is part of Libint library.
 *
 *  Libint library is free software: you can redistribute it and/or modify
 *  it under the terms of the GNU Lesser General Public License as published by
 *  the Free Software Foundation, either version 3 of the License, or
 *  (at your option) any later version.
 *
 *  Libint library is distributed in the hope that it will be useful,
 *  but WITHOUT ANY WARRANTY; without even the implied warranty of
 *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *  GNU Lesser General Public License for more details.
 *
 *  You should have received a copy of the GNU Lesser General Public License
 *  along with Libint library.  If not, see <http://www.gnu.org/licenses/>.
 *
 */

#include <libint2/boys.h>
#include <libint2/config.h>

#include <type_traits>

#include "catch.hpp"

#ifdef LIBINT_HAS_MPFR
TEST_CASE("Boys reference values", "[core-ints][precision]") {
  using scalar_type = libint2::scalar_type;

  const int mmax = 40;
  std::vector<LIBINT2_REF_REALTYPE> T_ref_values;
  std::vector<double> T_values;
  T_ref_values.push_back(LIBINT2_REF_REALTYPE(0));
  T_values.push_back(0);
  for (int i = 0; i != 12; ++i) {
    using std::pow;
    using std::sqrt;
    T_ref_values.push_back(pow(sqrt(LIBINT2_REF_REALTYPE(10)), i) / 1000);
    T_values.push_back(pow(sqrt(10.), i) / 1000.);
  }

  std::cout << std::setprecision(30);
  if (!std::is_same<LIBINT2_REF_REALTYPE, scalar_type>::value) {
    auto Fm_ref_values = new LIBINT2_REF_REALTYPE[mmax + 1];
    auto Fm_values = new scalar_type[mmax + 1];
    auto fm_evals =
        std::make_tuple(libint2::FmEval_Chebyshev7<scalar_type>::instance(mmax),
                        libint2::FmEval_Taylor<scalar_type, 7>::instance(mmax));
    for (int t = 0; t != T_ref_values.size(); ++t) {
      auto Tref = T_ref_values[t];
      auto T = T_values[t];
      libint2::FmEval_Reference<LIBINT2_REF_REALTYPE>::eval(Fm_ref_values, Tref,
                                                            mmax);
      for (auto engine_type : {0, 1}) {
        if (engine_type == 0) {
          std::get<0>(fm_evals)->eval(Fm_values, T, mmax);
        }
        if (engine_type == 1) {
          std::get<1>(fm_evals)->eval(Fm_values, T, mmax);
        }
        for (int m = 0; m <= mmax; ++m) {
          const LIBINT2_REF_REALTYPE value =
              libint2::sstream_convert<LIBINT2_REF_REALTYPE>(Fm_values[m]);
          const LIBINT2_REF_REALTYPE abs_error = abs(Fm_ref_values[m] - value);
          const LIBINT2_REF_REALTYPE relabs_error =
              abs(abs_error / Fm_ref_values[m]);
          using libint2::sstream_convert;
          printf(
              "m=%d T=%e ref_value=%e value==%e abs_error=%e relabs_error=%e\n",
              m, sstream_convert<double>(Tref),
              sstream_convert<double>(Fm_ref_values[m]), Fm_values[m],
              sstream_convert<double>(abs_error),
              sstream_convert<double>(relabs_error));
          // can only get about 14 digits of precision for extreme cases, but
          // can guarantee absolute precision to epsilon
          REQUIRE(
              sstream_convert<scalar_type>(abs_error) ==
              Approx(0.0).margin(std::numeric_limits<scalar_type>::epsilon()));
          REQUIRE(
              sstream_convert<scalar_type>(relabs_error) ==
              Approx(0.0).margin((engine_type == 0 ? 125 : 1e5) *
                                 std::numeric_limits<scalar_type>::epsilon()));
        }
      }
    }
  }
}

#endif  // LIBINT_HAS_MPFR

TEST_CASE("Slater/Yukawa core integral values", "[core-ints]") {
  using scalar_type = libint2::scalar_type;

  const auto zeta = 1.0;
  const int mmax = 36;
  const auto T = 0.3;
  const auto rho = 0.5;
  const auto one_over_rho = 1.0 / rho;

  auto yukawa = new scalar_type[mmax + 1];
  auto stg = new scalar_type[mmax + 1];
  auto tenno_eval = libint2::TennoGmEval<scalar_type>::instance(mmax);
  // eval at T=0, U
  {
    const std::vector<scalar_type> ref_values{
        0.3443204575812015284561287692691887166007,
        0.2185598474729328238479570769102704277998,
        0.15628803050541343523040858461794591444,
        0.1205302813563695092527987736260077265086,
        0.09771885762707005452746680293044358594349,
        0.08202555839753908595204847246086876491423,
        0.07061341858480468569599627134916394116044,
        0.0619591054276796876202669152433890705893,
        0.05517887615131295955174900498568299584769,
        0.04972742757098352844464478921127984232381,
        0.04525107487757221293120739098993905512744,
        0.04151082283140990378559967865261134542924,
        0.03833956708674360384857601285389554618283,
        0.03561705307086134800560829582022609088212,
        0.03325458437686685006877212773033703134889,
        0.03118533598784300483649122168611816027907,
        0.02935802012158051500495481146405702544609,
        0.02773262799652627099985843395816979927297,
        0.02627749654063442510811193421734676218181,
        0.02496724367844527115107405296878598045688,
        0.02378128673955011533777868163490765901325,
        0.02270276077349883452702840275267656606946,
        0.02171771642725558145495492438327385408735,
        0.02081451667176051954351159735354736480665,
        0.01998337721078039756033649801319291092231,
        0.01921601221155332553803261768601582527603,
        0.01850535826015937121626353551535819197592,
        0.01784535712254255688697702662699348741862,
        0.01723078320837644637040391181356151776459,
        0.01665710536934955175643383200315997427518,
        0.01612037532181394177448469127863672173319,
        0.01561713689965374695596056045589465521058,
        0.01514435174000532696990829906990931299676,
        0.01469933803373126377656853284970284607468,
        0.01427971973864157588729610821956952396993,
        0.01388338422903321724102399847578071093,
        0.0135084467913831066131366575551262916311};
    tenno_eval->eval_yukawa(yukawa, one_over_rho, 0, mmax, zeta);
    for (int m = 0; m <= mmax; ++m) {
      CHECK(std::abs(yukawa[m] - ref_values.at(m)) <= 1.4e-14);
    }
  }
  // eval at T=0, U=10^3
  {
    const std::vector<scalar_type> ref_values{
        0.0004992518684668698760157919116180007421591,
        0.0004987543554200826561387255879995052272495,
        0.0004982578319669375445097648001979091002148,
        0.0004977622951607015686386285148831142243323,
        0.0004972677420663180803047744704190612594879,
        0.0004967741697603490354955508328979528203838,
        0.000496281575330917616069102631084181479419,
        0.0004957899558776511907863158554424694107981,
        0.000495299308511624613374605242062422259045,
        0.000494809630355303855304711361850288521582,
        0.000494320918542489970979870299972521754099,
        0.0004938331702182633930547565241285431218218,
        0.0004933463825389285556194780697165502542594,
        0.0004928605526719588430016244654407219067124,
        0.0004923756777959418619569334178812478129392,
        0.0004918917551005250350365536850807862619847,
        0.0004914087817863615129361402981341659403214,
        0.0004909267550650564036491258209048034102075,
        0.000490445672159113316263469140280896745543,
        0.000489965530301881217257992806107859202924,
        0.0004894863267375015971710826288849169305349,
        0.0004890080587208559455310405169806078820973,
        0.0004885307235175135319537548008618719067855,
        0.0004880543184036794913295829420480039665733,
        0.0004875788406661432110374309368161646296632,
        0.0004871042876022270181399632621111909935993,
        0.0004866306565197351645296882222192077887056,
        0.000486157944736903108011337373846989501614,
        0.0004856861495823470873214956544915964345953,
        0.0004852152683950139891018422206238496747354,
        0.0004847452985241315048576321106934532873623,
        0.0004842762373291585759481869621125940519902,
        0.0004838080821797361246711703965355676310694,
        0.0004833408304556380695173015959532050427031,
        0.0004828744795467226226869102622259407912157,
        0.0004824090268528838679743588105368791206835,
        0.0004819444697840036171408545058389282004521};
    tenno_eval->eval_yukawa(yukawa, 4e3, 0, mmax, zeta);
    for (int m = 0; m <= mmax; ++m) {
      CHECK(std::abs(yukawa[m] - ref_values.at(m)) <= 1.4e-14);
    }
  }
  // eval at T=1023, U=10^-3 => barely inside the interpolation range
  {
    const std::vector<scalar_type> ref_values{
        0.003668769827513450049304087206872065121708,
        5.420435725593595129051675074500856589402e-6,
        1.153413823646514442119413472006192571438e-8,
        3.348561223534355443014370662234180926232e-11,
        1.258394731770944134180812393922368114687e-13,
        5.862788284772907280659914798508381826701e-16,
        3.275046949953268248075302422846380074442e-18,
        2.138229130319993581676975291139033105234e-20,
        1.59963080864078970551171504651294098697e-22,
        1.350018063194399985827690657328827145627e-24,
        1.269319128173106435573238782920347212332e-26,
        1.316016877041745608221092501114030150933e-28,
        1.491801172902970199227750502551650191297e-30,
        1.835690747701984560231572877029035230332e-32,
        2.437048407748493507629424133934226845417e-34,
        3.472216127773543921471144401735013957144e-36,
        5.284756092010242437298698009803141644211e-38,
        8.557741655028969169850990381419876688679e-40,
        1.469100240197536507488975576498808956171e-41,
        2.665095978514634905204746634314972087591e-43,
        5.094455679868537077442957666353063206612e-45,
        1.02348827991355301640372538439559426143e-46,
        2.15600611680743648323113639698619816743e-48,
        4.751953710279440486250336409641155248614e-50,
        1.093709802573290274439941537304650715144e-51,
        2.623989218000821713454650149370513594673e-53,
        6.55142678780823878011119970309663360075e-55,
        1.699659910055898434450559698406626019348e-56,
        4.575382384595824959451759892787212303821e-58,
        1.276328093372303041914176079318438057291e-59,
        3.684988521830781816819424431947389522298e-61,
        1.099900124390772978662598854177196717848e-62,
        3.390391378643954706357776945347132991928e-64,
        1.078178981606721458988211247401755052878e-65,
        3.534007800353040901580513641975783538743e-67,
        1.192874750834218800101921229451658920742e-68,
        4.142951489014496364871359607229095749925e-70};
    tenno_eval->eval_yukawa(yukawa, 4e-3, 1023., mmax, zeta);
    for (int m = 0; m <= mmax; ++m) {
      CHECK(std::abs(yukawa[m] - ref_values.at(m)) <= 1.4e-14);
    }
  }
  // eval at T=1024, U=10^-5 => corner case: exactly at the edge of the T
  // interpolation range
  {
    const std::vector<scalar_type> ref_values{
        0.02262060843208498576572581793629832479017,
        0.00001328061410798521010109505192413038336352,
        1.967492895146354004814437115777202958319e-8,
        4.816418800755732638805847501335095487461e-11,
        1.648158274569973415612169006428770140205e-13,
        7.247586576528941506731998414483203834384e-16,
        3.894356225942859181756065719699405155324e-18,
        2.472711235859874377238394450826232100922e-20,
        1.811448107927246160983370551456683609577e-22,
        1.503884924669673070579676570491468448584e-24,
        1.395382639889630997721475663107461615738e-26,
        1.430959140992753432533796302656136565791e-28,
        1.607170459380522861829724507442253707834e-30,
        1.962017744277102395146883802669026121532e-32,
        2.58680143732424441966257432231829051703e-34,
        3.663142857318927850342038530997839992486e-36,
        5.545048934558662687961060527811409977934e-38,
        8.935251031550317703047640697990162351163e-40,
        1.527074590830714047519742903079645044486e-41,
        2.758962251510875479578666842875315658127e-43,
        5.254032322173135016194484685446458493996e-45,
        1.051859585613288854565102682797805414478e-46,
        2.208545556046672625649543664904429790736e-48,
        4.852863952210522203707939215139889462594e-50,
        1.113715931860383969333615440424338933601e-51,
        2.664699767282678545009174715265411900603e-53,
        6.635835720863297195478836802217619902765e-55,
        1.717307727564889241784513519858224539406e-56,
        4.611975204258451218632194304666152011294e-58,
        1.28362315067474047480721994332524600143e-59,
        3.697982826408717758978012846438010122791e-61,
        1.10146249832633365614554881871389291007e-62,
        3.388324071832241312535223202223854537911e-64,
        1.075406580049278994053231441612245960694e-65,
        3.518208912505666945792432306282271686395e-67,
        1.185344559453374598719071951154101420182e-68,
        4.109383172020121955197425147382932484728e-70};
    tenno_eval->eval_yukawa(yukawa, 4e-5, 1024., mmax, zeta);
    for (int m = 0; m <= mmax; ++m) {
      CHECK(std::abs(yukawa[m] - ref_values.at(m)) <= 1.4e-14);
    }
  }
  // eval at T=1025, U=10^-3 => outside the interpolation range, use upward
  // recursion
  {
    const std::vector<scalar_type> ref_values{
        0.003657951979774989740547145195094265598956,
        5.397434252949152394251250558614863241027e-6,
        1.146741791141338373846245954440640044927e-8,
        3.323510149412937730771453601915206279433e-11,
        1.246734372106011749419154493770132880287e-13,
        5.797712867725216239476429836250847404132e-16,
        3.232600501911678152833123765185345717367e-18,
        2.106504834068139474864677800715105551541e-20,
        1.572882566409972085538380086037251337395e-22,
        1.32489290711137333410367790355006311444e-24,
        1.243294719434053061711100301690665826431e-26,
        1.286544765343287266475797166644585120875e-28,
        1.455570173013776475184661214300662972947e-30,
        1.787637233093320455263984867846121529164e-32,
        2.368649723323044272544610910961239701561e-34,
        3.368213142584729170480317726411959787859e-36,
        5.116504261697132348263354717120146407997e-38,
        8.269184549488893333013988741963279191292e-40,
        1.416803707729027771342157448351896590131e-41,
        2.565225408632380751527810017749240947252e-43,
        4.894007398937007554334293764970281768822e-45,
        9.813041387226526140390810066699092717618e-47,
        2.063119907953795246574008251810024601031e-48,
        4.538373496861722761884789635745614092719e-50,
        1.042517943093130384672682509904418864771e-51,
        2.496299838121981624595079940391777397835e-53,
        6.220477821114131167379688461559099076222e-55,
        1.610656509691089503395329290003466373265e-56,
        4.327342321435717332199546571679113781029e-58,
        1.204783627433044418712454733871756969723e-59,
        3.471647847072113904911161962432548462893e-61,
        1.034202319009197839430845710512974671368e-62,
        3.181667270952239324987390390320295378802e-64,
        1.00983030768632875947349533908742338808e-65,
        3.303525071239174227864582220812866971167e-67,
        1.112903395009952621827077425872708808235e-68,
        3.857668855957630226062428725920906474347e-70};
    tenno_eval->eval_yukawa(yukawa, 4e-3, 1025., mmax, zeta);
    for (int m = 0; m <= mmax; ++m) {
      CHECK(std::abs(yukawa[m] - ref_values.at(m)) <= 1.4e-14);
    }
  }
  // eval at T, U
  {
    const std::vector<scalar_type> ref_values{
        0.2852744664375203210290618028227940433907,
        0.1766838802450110751782118990187381894651,
        0.1241798108180594674947062009353195660062,
        0.09460785608931757764144854062919857885627,
        0.07602763793594107486328701003648790969636,
        0.06339729471844897557359641939962148990216,
        0.0542994319286032335099564068575123773964,
        0.04745281518095524189359321538244253978858,
        0.0421223982686899930783014304605560034047,
        0.03785894177827876359640625649012954247718,
        0.0343734522904477255719108740920007304814,
        0.03147203199327189089943472184054668400725,
        0.02901994575663891698672616184459598410994,
        0.02692075871254491583452498106271569095571,
        0.02510368385605629742004478536683794621981,
        0.02351561642743279158158329563866543191339,
        0.02211595404125828397042195141274910558768,
        0.02087313184539716089772318823594840687453,
        0.01976224658073508220643293392187745668993,
        0.01876339108479556078144660671266238703311,
        0.01786046367674014435996136066315613133189,
        0.01704030191570602245498102430707816576412,
        0.01629204228396873976211937758283397834483,
        0.01560664002096907613913206036131719849835,
        0.01497650431299575372408739207820905930807,
        0.0143952177934052375875674881262403873524,
        0.01385731849157500770525917866441990349708,
        0.01335812860860463316571696337113066752536,
        0.0128936188051866095880306412646495753497,
        0.01246029970420585602764956023056598379362,
        0.01205513445269041525413485925037624498466,
        0.0116754677276722007716703253095000927931,
        0.01131896769053532967081929071844203127929,
        0.01098357821791793884675073781735875627457,
        0.01066747934886561056040824193943638248817,
        0.01036905434654533574674275386775379296019,
        0.01008686211977763752044997872023135079946};
    tenno_eval->eval_yukawa(yukawa, one_over_rho, T, mmax, zeta);
    for (int m = 0; m <= mmax; ++m) {
      CHECK(std::abs(yukawa[m] - ref_values.at(m)) <= 1.4e-14);
    }
  }

  // compare to ints of STG approximated by Gaussians
  {
    tenno_eval->eval_yukawa(yukawa, one_over_rho, T, mmax, zeta);
    tenno_eval->eval_slater(stg, one_over_rho, T, mmax, zeta);
    auto cgtg = new scalar_type[mmax + 1];
    auto cgtg_x_coulomb = new scalar_type[mmax + 1];
    // precise fit of exp(-r12) on [0,10)
    std::vector<std::pair<double, double>> cgtg_params = {
        {0.10535330565471572, 0.08616353459042002},
        {0.22084823136587992, 0.08653979627551414},
        {0.3543431104992702, 0.08803697599356214},
        {0.48305514665749105, 0.09192519612306953},
        {0.6550035700167584, 0.10079776426873248},
        {1.1960917050801643, 0.11666110644901695},
        {2.269278814810891, 0.14081371547404428},
        {5.953990617813977, 0.13410255216448014},
        {18.31911063199608, 0.0772095196191394},
        {66.98443868169818, 0.049343985939540556},
        {367.24137290439205, 0.03090625839896873},
        {5.655142311118115, -0.017659052507938647}};
    auto cgtg_eval = libint2::GaussianGmEval<scalar_type, 0>::instance(mmax);
    auto cgtg_x_coulomb_eval =
        libint2::GaussianGmEval<scalar_type, -1>::instance(mmax);
    cgtg_eval->eval(cgtg, rho, T, mmax, cgtg_params);
    cgtg_x_coulomb_eval->eval(cgtg_x_coulomb, rho, T, mmax, cgtg_params);

    for (int m = 0; m <= mmax; ++m) {
      REQUIRE(std::abs(yukawa[m] - cgtg_x_coulomb[m]) <= 5e-4);
      REQUIRE(std::abs(stg[m] - cgtg[m]) <= 2e-4);
    }
  }
}
