/*  kinematics.cpp
 *
 *  Copyright (C) 2010-2012 Andreas von Manteuffel
 *  Copyright (C) 2010-2012 Cedric Studerus
 *
 *  This file is part of the package Reduze 2.
 *  It is distributed under the GNU General Public License version 3
 *  (see the file GPL-3.0.txt or http://www.gnu.org/licenses/gpl-3.0.txt).
 */

#include "kinematics.h"
#include "propagator.h" // ScalarProduct
#include "functions.h"
#include "ginacutils.h"
#include "yamlutils.h"
#include "files.h"
#include "globalsymbols.h"
#include <cctype>

using namespace GiNaC;
using namespace std;

namespace Reduze {

// register types
namespace {
YAMLProxy<Kinematics> dummy1;
}

Kinematics::Kinematics() :
		symbol_to_replace_by_one_("undefined"), id_(0) {
	loop_momentum_prefix_ = "k";
}

Kinematics::~Kinematics() {
}

Kinematics* Kinematics::create_derived_kinematics(const string& name) {
	LOG("checking for derived kinematics '" << name << "'");
	string basename, suffix;
	int shift;
	if (!search_dimension_shift_suffix(name, basename, suffix, shift))
		return 0;
    Kinematics* base = Files::instance()->kinematics(basename);
    if (!base)
    	return 0;
	LOG("creating kinematics from '" << basename << "' with dim shift " << shift);
    Kinematics* kin = new Kinematics(*base);
    kin->name_ = name;
    kin->dimension_ = base->dimension_ + shift;
    return kin;
}

namespace {
const char* diminc = "diminc";
const char* dimdec = "dimdec";
}

// check for shifted dimension suffix: d2, d1000 or similar
bool Kinematics::search_dimension_shift_suffix(const string& name,
		 string& basename,  string& suffix, int& shift) {
	size_t i, xlen;
	int sign = 0;
	if ((i = name.find(diminc)) != string::npos) {
		sign = 1;
		xlen = strlen(diminc);
	} else if ((i = name.find(dimdec)) != string::npos) {
		xlen = strlen(dimdec);
		sign = -1;
	} else {
		return false;
	}
	basename = name.substr(0, i);
	shift = atoi(name.substr(i + xlen, name.length() - i - xlen).c_str());
	shift = (sign < 0 ? -shift : shift);
	suffix = name.substr(i, name.size() - i);
	return true;
}

string Kinematics::dimension_shifted_label(const string& base, int shift) {
	string newbase, suffix;
	int extrashift;
	if (search_dimension_shift_suffix(base, newbase, suffix, extrashift))
		return dimension_shifted_label(newbase, shift + extrashift);
	if (shift > 0)
		return base + diminc + to_string(shift);
	else if (shift < 0)
		return base + dimdec + to_string(-shift);
	else
		return base;
}

const GiNaC::ex& Kinematics::dimension() const {
	return dimension_;
}

GiNaC::lst Kinematics::get_external_momenta_and_kinematics() const {
	return add_lst(external_momenta(), kinematic_invariants());
}

const std::string& Kinematics::loop_momentum_prefix() const {
	return loop_momentum_prefix_;
}

int Kinematics::find_mass_dimension(const GiNaC::ex& expr) const {
	try {
		ex tmp = expr.expand();
		if (is_a<numeric>(tmp)) {
			return 0;
		} else if (is_a<symbol>(tmp)) {
			symbol invariant = ex_to<symbol>(tmp);
			map<GiNaC::symbol, int>::const_iterator i = mass_dimension_.find(
					invariant);
			if (i == mass_dimension_.end())
				throw runtime_error("failed");
			return i->second;
		} else if (is_a<mul>(tmp)) {
			int d = 0;
			for (size_t i = 0; i < tmp.nops(); ++i)
				d += find_mass_dimension(tmp.op(i));
			return d;
		} else if (is_a<power>(tmp)) {
			if (!is_a<numeric>(tmp.op(1))
					|| !ex_to<numeric>(tmp.op(1)).is_integer())
				throw runtime_error("failed");
			return find_mass_dimension(tmp.op(0))
					* ex_to<numeric>(tmp.op(1)).to_int();
		} else if (is_a<add>(tmp)) {
			VERIFY(tmp.nops() > 0);
			int d = find_mass_dimension(tmp.op(0));
			for (size_t i = 1; i < tmp.nops(); ++i)
				if (find_mass_dimension(tmp.op(i)) != d)
					throw runtime_error("failed");
			return d;
		} else {
			throw runtime_error("failed");
		}
	} catch (exception& e) {
		stringstream ss;
		ss << "can't determine mass dimension of " << expr;
		throw runtime_error(ss.str());
	}
}

void Kinematics::read(const YAML::Node& node) {
	verify_yaml_spec(node);
	string s1, s2;
	int i;

	if (node.FindValue("name")) {
		node["name"] >> name_;
	}

	map<string, symbol>::const_iterator shared;

	const YAML::Node& inmom_node = node["incoming_momenta"];
	for (YAML::Iterator it = inmom_node.begin(); it != inmom_node.end(); ++it) {
		(*it) >> s1;
		symbol momentum;
		shared = shared_symbols_.find(s1);
		if (shared != shared_symbols_.end()) {
			momentum = shared->second;
		} else {
			momentum = symbol(s1);
			shared_symbols_.insert(make_pair(s1, momentum));
		}
		incoming_momenta_.append(momentum);
		external_momenta_.append(momentum);
	}
	assert_unique_symbols_and_valid_names(incoming_momenta_);

	const YAML::Node& outmom_node = node["outgoing_momenta"];
	for (YAML::Iterator it = outmom_node.begin(); it != outmom_node.end();
			++it) {
		(*it) >> s1;
		symbol momentum;
		shared = shared_symbols_.find(s1);
		if (shared != shared_symbols_.end()) {
			momentum = shared->second;
		} else {
			momentum = symbol(s1);
			shared_symbols_.insert(make_pair(s1, momentum));
		}
		outgoing_momenta_.append(momentum);
		external_momenta_.append(momentum);
	}
	assert_unique_symbols_and_valid_names(outgoing_momenta_);
	assert_unique_symbols_and_valid_names(external_momenta_);

	if (external_momenta_.nops() > 0) {
		node["momentum_conservation"][0] >> s1;
		node["momentum_conservation"][1] >> s2;
		ex lhs(s1, external_momenta_);
		ex rhs(s2, external_momenta_);
		if (!is_a<symbol>(lhs)) {
			stringstream ss;
			ss << lhs << " is not one of " << external_momenta_;
			throw runtime_error(ss.str());
		}
		if (lhs.has(rhs))
			throw runtime_error("Momentum on left hand side appears"
					" on right hand side in momentum conservation.");
		rule_momentum_conservation_[lhs] = rhs;
		lst::const_iterator p;
		for (p = external_momenta_.begin(); p != external_momenta_.end(); ++p)
			if (*p != lhs)
				independent_external_momenta_.append(*p);

		// check momentum conservation
		ex sum = 0;
		for (p = incoming_momenta_.begin(); p != incoming_momenta_.end(); ++p)
			sum += *p;
		for (p = outgoing_momenta_.begin(); p != outgoing_momenta_.end(); ++p)
			sum -= *p;
		sum = sum.subs(rule_momentum_conservation_);
		if (!sum.expand().is_zero())
			WARNING("Found violation of momentum conservation for external momenta.");
	}

	if (node.FindValue("loop_momentum_prefix") != 0) {
		node["loop_momentum_prefix"] >> loop_momentum_prefix_;
	}

	const YAML::Node& inv_node = node["kinematic_invariants"];
	for (YAML::Iterator it = inv_node.begin(); it != inv_node.end(); ++it) {
		(*it)[0] >> s1;
		(*it)[1] >> i;
		symbol invariant;
		shared = shared_symbols_.find(s1);
		if (shared != shared_symbols_.end()) {
			invariant = shared->second;
		} else {
			invariant = symbol(s1);
			shared_symbols_.insert(make_pair(s1, invariant));
		}
		kinematic_invariants_.append(invariant);
		mass_dimension_[invariant] = i;
	}
	try {
		const YAML::Node& spr_node = node["scalarproduct_rules"];
		lst eqns;
		for (YAML::Iterator it = spr_node.begin(); it != spr_node.end(); ++it) {
			//  it -> [[p1+p2, p1+p2], s]
			const YAML::Node& sp = (*it)[0];
			if (sp.size() != 2)
				throw runtime_error("Scalar product is no list of 2 elements");
			sp[0] >> s1;
			sp[1] >> s2;
			ex pl(s1, external_momenta_);
			ex pr(s2, external_momenta_);
			(*it)[1] >> s1;
			ex inv(s1, kinematic_invariants_);
			const exmap& r = rule_momentum_conservation_;
			eqns.append(ScalarProduct(pl, pr).subs(r).expand() == inv);
		}
		rules_sp_to_invariants_ = //
				lsolve_pattern(eqns, ScalarProduct(wild(1), wild(2)));
		//k.rules_invariants_to_sp_ = //
		//		lsolve_unique(eqns, k.kinematic_invariants_);
		// does not work for nonlinear relations, e.g. m instead of m^2
		for (exmap::const_iterator mit = rules_sp_to_invariants_.begin();
				mit != rules_sp_to_invariants_.end(); ++mit) {
			momenta_to_invariants_.add(mit->first.op(0), mit->first.op(1),
					mit->second);
			LOGX(
					"  added rule: " << mit->first.op(0) << " * " << mit->first.op(1) << " --> " << mit->second);
		}
	} catch (exception& e) {
		throw runtime_error(string("scalarproduct_rules: ") + e.what());
	}

	if (node.FindValue("dimension") != 0) {
		node["dimension"] >> s1;
#ifdef NEW_GINAC
        lst syms({Files::instance()->globalsymbols()->d()});
#else
        lst syms(Files::instance()->globalsymbols()->d());
#endif
		dimension_ = ex(s1, syms);
	} else {
		dimension_ = Files::instance()->globalsymbols()->d();
	}

	if (node.FindValue("symbol_to_replace_by_one") != 0) {
		symbol sym;
		node["symbol_to_replace_by_one"] >> s1;
		if (s1 == "undefined") {
			symbol_to_replace_by_one_ = GiNaC::symbol("undefined");
		} else {
			ex expr(s1, kinematic_invariants_);
			if (!is_a<symbol>(expr))
				ABORT(
						"in symbol_to_replace_by_one: " << expr << " is not a plain symbol");
			symbol_to_replace_by_one_ = ex_to<symbol>(expr);
		}
	}

	all_symbols_ = add_lst(kinematic_invariants_, external_momenta_);
	all_symbols_.append(Files::instance()->globalsymbols()->d());
	assert_unique_symbols_and_valid_names(all_symbols_);
	kinematic_invariants_and_dimension_ = kinematic_invariants_;
	kinematic_invariants_and_dimension_.append(Files::instance()->globalsymbols()->d());

	invariants_dimension_feynman_rules_globals_ =
			kinematic_invariants_and_dimension_;
	invariants_dimension_feynman_rules_globals_ = add_lst(
			invariants_dimension_feynman_rules_globals_,
			Files::instance()->globalsymbols()->coupling_constants());
	invariants_dimension_feynman_rules_globals_.append(
			Files::instance()->globalsymbols()->Nc());

	if (node.FindValue("generate_crossings") != 0)
		ERROR("depreciated member \"generate_crossings\"");
}

void Kinematics::print(YAML::Emitter& os) const {
	using namespace YAML;
	os << BeginMap;
	os << Key << "name" << Value << name_;
	os << Key << "incoming_momenta" << Flow << Value << incoming_momenta_;
	os << Key << "outgoing_momenta" << Flow << Value << outgoing_momenta_;
	os << Key << "momentum_conservation" << Value << Flow
			<< rule_momentum_conservation_;
	os << Key << "kinematic_invariants" << Value;
	os << BeginSeq;
	for (lst::const_iterator it = kinematic_invariants_.begin();
			it != kinematic_invariants_.end(); ++it) {
		ASSERT(is_a<symbol>(*it));
		ASSERT(
				mass_dimension_.find(ex_to<symbol>(*it)) != mass_dimension_.end());
		os << Flow << BeginSeq << *it
				<< mass_dimension_.find(ex_to<symbol>(*it))->second << EndSeq;
	}
	os << EndSeq;
	os << Key << "scalarproduct_rules" << Value << BeginSeq;
	GiNaC::exmap::const_iterator r;
	for (r = rules_sp_to_invariants_.begin();
			r != rules_sp_to_invariants_.end(); ++r)
		os << Flow << BeginSeq << BeginSeq << r->first.op(0) << r->first.op(1)
				<< EndSeq << r->second << EndSeq;
	os << EndSeq;
	os << Key << "dimension" << Value << dimension_;
	os << Key << "symbol_to_replace_by_one" << Value
			<< symbol_to_replace_by_one_;
	os << Key << "loop_momentum_prefix" << Value << loop_momentum_prefix_;
	os << EndMap;
}

GiNaC::exmap Kinematics::find_rules_invariants_to_sp() const {
	const exmap& sp2inv = rules_sp_to_invariants_;
	// find d.o.f. for scalar products (some invariants might be masses
	// on which the scalar products don't depend)
	lst rhs;
	for (exmap::const_iterator e = sp2inv.begin(); e != sp2inv.end(); ++e) {
		// this code could be used if we wanted the method to check correct
		// wrong mass dimensions of scalar product rules:
		// (inactive to allow numerical rules or intentional "misuses")
		/*
		 if (!e->second.is_zero()) {
		 try {
		 int d = find_mass_dimension(e->second);
		 if (d != 2)
		 throw runtime_error(
		 "invalid mass dimension " + to_string(d));
		 } catch (exception& exc) {
		 WARNING(
		 "encountered potential problem with mass dimension of rule:\n"
		 + to_string(e->first) + " -> "
		 + to_string(e->second) + "\n" + exc.what());
		 }
		 }
		 */
		rhs.append(e->second);
	}
	exset indep; // independent d.o.f. for scalar products
	gather_objects<symbol>(rhs, indep);
	// hack to treat mass dim 1 symbols
	exmap elim_dimone;
	exmap elim_helpers;
	lst indets;
	for (exset::const_iterator i = indep.begin(); i != indep.end(); ++i) {
		try {
			if (!is_a<symbol>(*i))
				throw string("can't find mass dimension of ") + to_string(*i);
			map<symbol, int, ex_is_less>::const_iterator d = //
					mass_dimension_.find(ex_to<symbol>(*i));
			if (d == mass_dimension_.end())
				throw string("can't find mass dimension of ") + to_string(*i);
			if (d->second == 2) {
				indets.append(d->first);
			} else if (d->second == 1) {
				symbol squaredsymbol(d->first.get_name() + "squared");
				elim_helpers[squaredsymbol] = pow(d->first, 2);
				elim_dimone[pow(d->first, 2)] = squaredsymbol;
				indets.append(squaredsymbol);
			} else {
				throw "unsupported mass dimension " + to_string(d->second)
						+ " of invariant " + to_string(d->first);
			}
		} catch (string& s) {
			throw runtime_error("can't invert scalar product rules: " + s);
		}
	}
	exmap rulesraw;
	ex eqns = to_equations(sp2inv).subs(elim_dimone);
	ex sols = lsolve_overdetermined(eqns, indets);
	rulesraw = lsolve_result_to_rules(sols, indets);
	if (rulesraw.size() != indets.nops())
		throw runtime_error("can't invert scalar products rules: num mismatch");
	exmap rules;
	for (exmap::const_iterator s = rulesraw.begin(); s != rulesraw.end(); ++s) {
		ex inv = s->first.subs(elim_helpers);
		rules[inv] = s->second.subs(elim_helpers);
		if (!(inv - s->first).is_zero()) // cover mt^(-2) case
			rules[pow(inv, -1)] = pow(s->second.subs(elim_helpers), -1);
	}
	//LOGX("found inverted rules: " << rules);
	return rules;
}

int Kinematics::id() const {
	return id_;
}

void Kinematics::set_id(int id) {
	id_ = id;
}

bool Kinematics::operator<(const Kinematics& other) const {
	if (id_ != other.id_)
		return id_ < other.id_;
	return name_ < other.name_;
}

bool Kinematics::operator==(const Kinematics& other) const {
	return !((*this) < other || other < (*this));
}

bool Kinematics::operator!=(const Kinematics& other) const {
	return !(*this == other);
}

void Kinematics::print_info() const {
    LOGX("Kinematics '" << name_ << "':");
	LOGX("Rules scalar products to invariants:\n  " << rules_sp_to_invariants());
	try {
		LOGX("Rules invariants to scalar products:\n  " << find_rules_invariants_to_sp());
	} catch (exception& e) {
		WARNING(string("failed to invert scalar product rules,") //
				+ " some features might not be available");
	}
    LOGX("");
}

} // namespace Reduze


