#include "CartPole.hpp"

using namespace Eigen;
using namespace std;

CartPole::CartPole(mt19937_64 & generator) {
    // Set all the constants (Based on Andy's old paper: Neuronlike Adaptive Elements That Can Solve Difficult Learning Control Problems)
	dt = 0.02;
	simSteps = 1;
	uMax = 10.0;
	l = 0.5;
	g = 9.8;	    // This is the correct value. Andy's paper has -9.8 (typo).
	m = 0.1;
	mc = 1;
	muc = 0.0;
	mup = 0.0;
	xMax = 3.0;

	newEpisode(generator);   // Reset the state
}

CartPole::~CartPole() {}

int CartPole::getNumActions() const {
    return 2;
}

int CartPole::getStateDim() const {
    return 4;
}

Eigen::VectorXd CartPole::getState() const {
    VectorXd result(4);
    result[0] = x;				// x
	result[1] = v;				// x_dot
	result[2] = theta;			// theta
	result[3] = omega;			// theta_dot
	// Normalize to be in the range 0,1
	pair<VectorXd,VectorXd> stateRange = getStateRange();
	result = (result - stateRange.first).array() / (stateRange.second - stateRange.first).array();
	// Threshold to really stay in that range
	for (int i = 0; i < 4; i++)
        result[i] = min(1.0, max(0.0, result[i]));
	return result;
}

void CartPole::setState(const Eigen::VectorXd & state, mt19937_64 & generator) {
    pair<VectorXd,VectorXd> stateRange = getStateRange();
    VectorXd unnormalizedState = state.array() * (stateRange.second - stateRange.first).array() + stateRange.first.array();
    x = unnormalizedState[0];
    v = unnormalizedState[1];
    theta = unnormalizedState[2];
    omega = unnormalizedState[3];
}

void CartPole::newEpisode(mt19937_64 & generator) {
    theta = 0.05;
	t = x = v = omega = theta = 0;
}

bool CartPole::terminate() const {
    // Check if the pole fell or if the cart hit the end
    return ((fabs(theta) > M_PI/15.0) || (fabs(x) > 2.4));
}

double CartPole::update(int action, mt19937_64 & generator) {
    double F = action*uMax + (action-1)*uMax;
    // Run simSteps short steps, each using a forward Euler approximation of the dynamics.
	double omegaDot, vDot, subDt = dt/(double)simSteps;
	for (int i = 0; i < simSteps; i++) {
		// Update once with timestep subDt
		// First compute the time derivatives
	    omegaDot = (double)( (g*sin(theta) + cos(theta)*(muc*sign(v) - F - m*l*omega*omega*sin(theta))/(m+mc) - mup*omega/(m*l))  /  (l*(4.0/3.0 - m/(m+mc)*cos(theta)*cos(theta))) );
		vDot = (F+m*l*(omega*omega*sin(theta) - omegaDot*cos(theta)) - muc*sign(v)) / (m+mc);
		// Update by adding the derivatives
		theta += subDt*omega;
		omega += subDt*omegaDot;
		x += subDt*v;
		v += subDt*vDot;
		// Keep in a nice range (-pi to pi so we can do the check below)
		theta = WrapPosNegPI(theta);
	}
    t += dt; // Increment the time counter
    if (terminate())
		return -1;
	else
		return 1;
}

double CartPole::getInitialValue() const {
    return 0;
}

double CartPole::getGamma() const {
    return 0.99;
}

pair<VectorXd,VectorXd> CartPole::getStateRange() {
	pair<VectorXd,VectorXd> result;
	result.first.resize(4);
	result.second.resize(4);

	result.first[0] = -2.4;
	result.second[0] = 2.4;

	result.first[1] = -4;
	result.second[1] = 4;

	result.first[2] = (-12*M_PI/180.0);
	result.second[2] = (12*M_PI/180.0);

	result.first[3] = -4;
	result.second[3] = 4;

	return result;
}

double CartPole::WrapPosNegPI(const double & theta) {
    while (x < -M_PI)
		x += 2*M_PI;
	while (x > M_PI)
		x -= 2*M_PI;
	return x;
}

double CartPole::sign(const double & x) {
    if (x < 0)
        return -1;
    if (x > 0)
        return 1;
    return 0;
}

int CartPole::getNumMCSamplesForPolicyEvaluation() const {
    return 1; // Policy is greedy and deterministic and environment is deterministic
}

int CartPole::getMaxTForPolicyEvaluation() const {
    return 1000;
}

int CartPole::getNumSamplesPerState() const {
    return 1;
}

double CartPole::getPlottableStatistic() const {
	return t;
}
string CartPole::getPlottableStatisticName() const {
	return "Episode duration";
}
