#include "Pendulum.hpp"

/*
Doya gives a reward of -1 if it is over-rotated (has spun too far in one direction).
This indirectly punishes high velocities, but it makes it not an MDP! Also, outside
of his specific agent/application, it's not necessary, so we don't use it. Otherwise,
this implementation is based on code from Kenji Doya's 2000 paper on continuous
time/space RL.

Oh, we also discretize the actions.
*/

using namespace Eigen;
using namespace std;

// Create the pendulum
Pendulum::Pendulum(mt19937_64 & generator) {
	// Set all the constants. See Pendulum.hpp for their descriptions
	dt = 0.01;
	simSteps = 10;
	uMax = 5.0;
	m = l = 1;
	g = 9.8;
	mu = 0.1;	// 0.01 in the paper

	newEpisode(generator);   // Reset state
}

Pendulum::~Pendulum() { }

int Pendulum::getNumActions() const {
    return 5;
}

int Pendulum::getStateDim() const {
    return 2;
}

Eigen::VectorXd Pendulum::getState() const {
	VectorXd result(2);
	result[0] = angleWrap(theta);		// Wrap the angle - the outside agent need not differentiate between pi and 3*pi
	result[1] = omega;
	// Normalize to be in the range 0,1
	pair<VectorXd,VectorXd> stateRange = getStateRange();
	result = (result - stateRange.first).array() / (stateRange.second - stateRange.first).array();
	// Threshold to really stay in that range
	for (int i = 0; i < 2; i++)
        result[i] = min(1.0, max(0.0, result[i]));
	return result;
}

void Pendulum::setState(const Eigen::VectorXd & state, mt19937_64 & generator) {
    pair<VectorXd,VectorXd> stateRange = getStateRange();
    VectorXd unnormalizedState = state.array() * (stateRange.second - stateRange.first).array() + stateRange.first.array();
    theta = unnormalizedState[0];
    omega = unnormalizedState[1];
}

void Pendulum::newEpisode(mt19937_64 & generator) {
    theta = omega = timeUp = 0;
}

bool Pendulum::terminate() const {
    return false;
}

double Pendulum::update(int action, mt19937_64 & generator) {
	// Convert the integer action into a double containing the torque applied to the pendulum
    double u = uMax*((double)action-2.0)/2.0;
    // Run simSteps short steps, each using a forward Euler approximation of the dynamics.
	double thetaDot, omegaDot, subDt = dt/(double)simSteps;
	for (int i = 0; i < simSteps; i++) {
		// First compute the time derivatives
		thetaDot = omega;
		omegaDot = (-mu*omega - m*g*l*sin(theta)+u)/(m*l*l);	// From code, theta = 0 is DOWN. The equation in the paper has a different sign on one term because of theta difference
        // Update by adding the derivatives
		theta += subDt*thetaDot;
		omega += subDt*omegaDot;
	}

	if (fabs(angleWrap2Pi(theta) - M_PI) <= 0.2)
		timeUp += dt;

	// Compute the reward
	return -cos(theta);
}

double Pendulum::getInitialValue() const {
    return 0;
}

double Pendulum::getGamma() const {
    return 0.99;
}

// Returns the minimum and maximum values the state can take
pair<VectorXd,VectorXd> Pendulum::getStateRange() {
    pair<VectorXd,VectorXd> result;
    result.first.resize(2);
    result.second.resize(2);
	result.first[0] = -M_PI;			// Minimum theta
	result.second[0] = M_PI;			// Maximum theta
	result.first[1] = -3.0 * M_PI;		// Minimum omega, -5/4 * pi in the paper
	result.second[1] = 3.0 * M_PI;		// Maximum omega, +5/4 * pi in the paper
	return result;
}

// Forces 'x' to be between -pi and pi
double Pendulum::angleWrap(double x) {
	while (x < -M_PI)
		x += 2*M_PI;
	while (x > M_PI)
		x -= 2*M_PI;
	return x;
}

double Pendulum::angleWrap2Pi(double x) {
	while (x < 0)
		x += 2.0*M_PI;
	while (x > 2.0*M_PI)
		x -= 2.0*M_PI;
	return x;
}

int Pendulum::getNumMCSamplesForPolicyEvaluation() const {
    return 1; // Policy is greedy and deterministic and environment is deterministic
}

int Pendulum::getMaxTForPolicyEvaluation() const {
    return 1000;
}

int Pendulum::getNumSamplesPerState() const {
    return 1;
}

double Pendulum::getPlottableStatistic() const {
	return timeUp;
}

string Pendulum::getPlottableStatisticName() const {
	return "Time near-vertical";
}
