// The two lines below give us M_PI for 3.14159...
#define _USE_MATH_DEFINES
#include <math.h>

#include "Acrobot.hpp"

using namespace Eigen;
using namespace std;

Acrobot::Acrobot(mt19937_64 & generator) {
	// Set parameters as in the RL book
    m1 = m2 = 1;
	l1 = l2 = 1;
	lc1 = lc2 = .5;
	i1 = i2 = 1;
	g = 9.8;
	fmax = 1;
	dt = .2;

	// We actually simulate at time steps of dt/numEulerSteps. The bigger this
	// is the slower (wall-time) the code runs, but the better the simulation
	// of the dynamics.
	numEulerSteps = 10;

	newEpisode(generator);
}

Acrobot::~Acrobot() {}

int Acrobot::getNumActions() const {
    return 3;
}

int Acrobot::getStateDim() const {
    return 4;
}

VectorXd Acrobot::getState() const {
    VectorXd result(4);
	result[0] = theta1;
	result[1] = theta2;
	result[2] = theta1Dot;
	result[3] = theta2Dot;
	// Normalize to be in the range 0,1
	pair<VectorXd,VectorXd> stateRange = getStateRange();
	result = (result - stateRange.first).array() / (stateRange.second - stateRange.first).array();
	// Threshold to really stay in that range
	for (int i = 0; i < 4; i++)
        result[i] = min(1.0, max(0.0, result[i]));
	return result;
}

void Acrobot::setState(const VectorXd & state, mt19937_64 & generator) {
	// The state is normalized outside of the Environment objects, so we need to unnormalize it
	// before directly taking its values.
    pair<VectorXd,VectorXd> stateRange = getStateRange();
    VectorXd unnormalizedState = state.array() * (stateRange.second - stateRange.first).array() + stateRange.first.array();
    theta1 = unnormalizedState[0];
    theta2 = unnormalizedState[1];
    theta1Dot = unnormalizedState[2];
    theta2Dot = unnormalizedState[3];
}

void Acrobot::newEpisode(mt19937_64 & generator) {
    t = theta1 = theta2 = theta1Dot = theta2Dot = 0;
}

bool Acrobot::terminate() const {
	// If the acrobot were an arm, we want to know if the hand is above a specific y-value, which happens to be l1.
	double elbowY, handY;
	elbowY = -l1*cos(theta1);				// Get elbow height
	handY = elbowY - l2*cos(theta1+theta2);	// Get hand height
	return handY > l1;						// Check if high enough
}

double Acrobot::update(int action, mt19937_64 & generator) {
	t += dt;

	double theta1DotDot, theta2DotDot, d1, d2, phi1, phi2, tau;	// These terms are defined in the RL book's definition of the acrobot domain.
	// Get the torque from the integer action
	tau = (action-1.0)*fmax;
	// Split into several forward Eiuler steps
	for (int i = 0; i < numEulerSteps; i++) {
		d1 = m1*lc1*lc1+m2*(l1*l1+lc2*lc2+2*l1*lc2*cos(theta2))+i1+i2;
		d2 = m2*(lc2*lc2+l1*lc2*cos(theta2))+i2;
		phi2 = m2*lc2*g*cos(theta1+theta2-M_PI/2.0);
		phi1 = -m2*l1*lc2*theta2Dot*theta2Dot*sin(theta2)-2*m2*l1*lc2*theta2Dot*theta1Dot*sin(theta2)+(m1*lc1+m2*l1)*g*cos(theta1-M_PI/2.0)+phi2;
		theta2DotDot = (tau+d2/d1*phi1-phi2) / (m2*lc2*lc2+i2-d2*d2/d1);
		theta1DotDot = -1.0/d1 * (d2*theta2DotDot+phi1);
		// Do the actual updates
		theta1Dot += (dt/numEulerSteps)*theta1DotDot;
		theta2Dot += (dt/numEulerSteps)*theta2DotDot;
		theta1 += (dt/numEulerSteps)*theta1Dot;
		theta2 += (dt/numEulerSteps)*theta2Dot;
	}

    // Enforce joint angle derivative constraints
    if (theta1Dot < -4.0*M_PI)
        theta1Dot = -4.0*M_PI;
    else if (theta1Dot > 4*M_PI)
        theta1Dot = 4*M_PI;

    if (theta2Dot < -9.0*M_PI)
        theta2Dot = -9.0*M_PI;
    else if (theta2Dot > 9.0*M_PI)
        theta2Dot = 9.0*M_PI;

	// Make sure the angles stay in a nice range---[-pi,pi]
    theta1 = mod2pi(theta1);
    theta2 = mod2pi(theta2);

	// The reward is always -1
    return -1;
}

double Acrobot::getInitialValue() const {
    return 0;
}

double Acrobot::getGamma() const {
    return 1.0;
}

int Acrobot::getNumSamplesPerState() const {
    return 1; // Deterministic transitions, so one is enough
}

double Acrobot::getPlottableStatistic() const {
	return (double)t;	// Plot the time until the "hand" is high enough.
}

string Acrobot::getPlottableStatisticName() const {     // A string to say what the plottable statistic encodes.
	return "Episode duration";
}

/////////////////////////
///// Private functions
/////////////////////////

pair<VectorXd,VectorXd> Acrobot::getStateRange() {
    pair<VectorXd,VectorXd> result;
    result.first.resize(4);
    result.second.resize(4);
    result.first[0] = -M_PI;
	result.second[0] = M_PI;
	result.first[1] = -M_PI;
	result.second[1] = M_PI;
	result.first[2] = -4.0*M_PI;
	result.second[2] = 4.0*M_PI;
	result.first[3] = -9.0*M_PI;
	result.second[3] = 9.0*M_PI;
	return result;
}

// Actually forces to -pi to pi
double Acrobot::mod2pi(double x) {
	while (x < -M_PI)
		x += 2*M_PI;
	while (x > M_PI)
		x -= 2*M_PI;
	return x;
}

int Acrobot::getNumMCSamplesForPolicyEvaluation() const {
    return 1; // Policy is greedy and deterministic and environment is deterministic
}

int Acrobot::getMaxTForPolicyEvaluation() const {
    return 1000;
}
