#include <iostream>
#include <fstream>
#include <limits>
#ifdef _MSC_VER
#include <ppl.h>		// For parallel_for (http://msdn.microsoft.com/en-us/library/dd728073.aspx)
using namespace concurrency;
#endif

#include "ValueIteration.hpp"
#include "PolicyEvaluation.hpp"

using namespace Eigen;
using namespace std;

double Q(const MatrixXd & W, const VectorXd & state, int a, Multilinear & d) {
    // Get the weight vector for the state
    pair<VectorXi, VectorXd> weights;
    d.getWeights(state, weights);
    // Try each action and see how good it is
    double result = 0;
    for (int i = 0; i < (int)weights.first.size(); i++)
        result += W(a,weights.first[i]) * weights.second[i];
    return result;
}

int greedyAction(const MatrixXd & W, const VectorXd & state, Multilinear & d) {
    int bestAction = -1;
    double bestActionQValue = 0;
    for (int a = 0; a < (int)W.rows(); a++) {   // W.rows == numActions
        double curQValue = Q(W, state, a, d);
        if ((a == 0) || (curQValue > bestActionQValue)) {
            bestAction = a;
            bestActionQValue = curQValue;
        }
    }
    return bestAction;
}

// This is the regular Bellman operator, with Q-value interpolation.
double maxQ(const MatrixXd & W, const VectorXd & state, Multilinear & d) {
    // Get the weight vector for the state
    pair<VectorXi, VectorXd> weights;
    d.getWeights(state, weights);

    // Try each action and see how good it is
    double result = -1;
    for (int a = 0; a < (int)W.rows(); a++) {   // W.rows() == numActions
        double Qsa = 0;
        for (int i = 0; i < (int)weights.first.size(); i++)
            Qsa += W(a,weights.first[i]) * weights.second[i];
        if ((a == 0) || (Qsa > result))
            result = Qsa;
    }
    return result;
}

// This is the consistent Bellman operator, with Q-value interpolation.
double maxQ_Consistent(const MatrixXd & W, const VectorXd & state, int prevStateIndex, int prevAction, Multilinear & d) {
   // Get the weight vector for the state
    pair<VectorXi, VectorXd> weights;
    d.getWeights(state, weights);

    // Determine the weight assigned to ourselves (i.e. the self-transition "probability")
    double selfWeight = 0.0;
    for (int i = 0; i < (int)weights.first.size(); i++) {
        if (weights.first[i] == prevStateIndex) {
            selfWeight = weights.second[i];
            break;
        }
    }

    // Try each action and see how good it is
    double result = -1, result2 = -1;
    for (int a = 0; a < (int)W.rows(); a++) {   // W.rows() == numActions
        double Qsa = 0;
        for (int i = 0; i < (int)weights.first.size(); i++)
            Qsa += W(a,weights.first[i]) * weights.second[i];
        if ((a == 0) || (Qsa > result2))
            result2 = Qsa;
        Qsa -= W(a,prevStateIndex) * selfWeight;
        Qsa += W(prevAction, prevStateIndex) * selfWeight;
        if ((a == 0) || (Qsa > result))
            result = Qsa;
    }
    return min(result,result2);
}

// Value Iteration.
void runValueIteration(Environment & env,       // One copy of the environment to use
                       const ENVTYPE & type,    // The environment type - needed for threading so that we can make more environment objects---one per thread
                       Multilinear & d,         // The multilinear object to use
                       MaxQVariantType maxQVariant,// Just look at the code to see which is which. Phil thinks setting this to true makes more sense
                       bool valueAveraging,     // If true, averages each iteration instead of fully updating Q
                       mt19937_64 & generator,  // For random-ness fun-times.
                       int numIterations,       // How many iterations to run?
                       const int evaluateFreq,  // How often should the policy be evaluated?
                       int numThreads_VI, // Number of threads for value iteration
                       int numThreads_PE, // Number of threads for policy evaluations
                       const string & fileName) { // File name to dump the result. If nullptr, then it doesn't print anything to a file
    // Get some values once that we may use often
    int numActions = env.getNumActions(), numPoints = d.getNumPoints();
    double gamma = env.getGamma();

    // Change numThreads used by VI to be at most the number of points - one thread per point.
    numThreads_VI = min<int>(numThreads_VI, numPoints);

    // Initialize the points. We use pointers so that we can easily swap W and newW
    MatrixXd * W = new MatrixXd(numActions, numPoints);
    W->setConstant(env.getInitialValue());
    MatrixXd * newW = new MatrixXd(numActions, numPoints);

    // Store the policy performance after each step
    VectorXd policyPerformances(numIterations+1); // +1 because we include initial performance prior to training

    // Get the performance of the initial policy
    policyPerformances[0] = evaluatePolicy(*W, d, env, type, generator, numThreads_PE);
    cout << "Initially: " << env.getPlottableStatisticName() << " = " << policyPerformances[0] << endl;

    // Each thread will need its own generator - create and seed them. Also compute the number of points that will be handled by each thread
    vector<mt19937_64> generators(numThreads_VI);
    uniform_int_distribution<unsigned long> distribution(std::numeric_limits<unsigned long>::lowest(), std::numeric_limits<unsigned long>::max());
    for (int i = 0; i < numThreads_VI; i++)
        generators[i].seed(distribution(generator));
    int pointsPerThread = d.getNumPoints() / numThreads_VI;

    // Get the number of samples to take per state (due to stochasticity in the environment).
    int numSamples = valueAveraging ? 1 : env.getNumSamplesPerState();

    // Run all of the iterations
    for (int iteration = 0; iteration < numIterations; iteration++) {
        // Run the threads that each loop over some points to update
        // Every object from outside this loop that is used inside this loop must be thread safe. Multilinear is thread safe.
        // If you're not familiar with threading, just ignore the loop over threadIndex and pretend the loop over pointIndex goes from 0 to d.getNumPoints()
        
#ifdef _MSC_VER
		parallel_for(0, numThreads_VI, [&](int threadIndex) {
#else
		#pragma omp parallel for num_threads(128) // The num_threads here should be more than numThreads_VI
		for (int threadIndex = 0; threadIndex < numThreads_VI; threadIndex++) {
#endif	
            // Figure out which point indexes this thread will use
            int startPointIndex = threadIndex * pointsPerThread, endPointIndexPlusOne = (threadIndex+1)*pointsPerThread;
            if (threadIndex == numThreads_VI-1) // Correct any rounding errors
                endPointIndexPlusOne = d.getNumPoints();

            // Create the environment object for this thread
            Environment * threadEnv = createEnvironment(type, generators[threadIndex]);

            // Allocate some vectors that are used by this thread
            VectorXd curPoint, newPoint;

            // Loop over the points handled by this thread
            // for (int pointIndex = 0; pointIndex < d.getNumPoints(); pointIndex++) {  // This is the loop we want - we've just split it into bits for each thread to work on.
            for (int pointIndex = startPointIndex; pointIndex < endPointIndexPlusOne; pointIndex++) {
                // Do a Bellman update from this point. We do one update per action
                curPoint = d.getPoint(pointIndex);
                for (int a = 0; a < numActions; a++) {
                    double newValue = 0;
                    for (int i = 0; i < numSamples; i++) {
                        // Place the environment in this state
                        threadEnv->setState(curPoint, generators[threadIndex]);
                        // Update the environment
                        double reward = threadEnv->update(a, generators[threadIndex]);
                        // Get the new state
                        newPoint = threadEnv->getState();
                        // Get whether newState is terminal (will change the Bellman update)
                        bool newStateTerminal = threadEnv->terminate();
                        // Do the Bellman update. First get maxQPrime = max_b Q(s',b)
                        if (maxQVariant == MQV_Ordinary)
                            newValue += reward + gamma*(newStateTerminal ? 0 : maxQ(*W, newPoint, d));
                        else if (maxQVariant == MQV_Consistent)
                            newValue += reward + gamma*(newStateTerminal ? 0 : maxQ_Consistent(*W, newPoint, pointIndex, a, d));
                    }
                    newValue /= (double)numSamples;

                    // If the transitions are stochastic and it takes too long to run the needed number of samples (numSamples) to get good value-iteration updates,
                    // then just run a single sample and move the value *towards* to new value. Here we use a fixed step size of 0.1. This was tuned for bicycle.
                    if (valueAveraging)
                        (*newW)(a,pointIndex) = (*W)(a, pointIndex) * 0.9 + newValue * 0.1;
                    else
                        (*newW)(a,pointIndex) = newValue;
                }
            }

            // Clean up the Environment created by this thread
            delete threadEnv;
#ifdef _MSC_VER
		});	// End parallelized loop
#else
		}	// End parallelized loop
#endif

        // Actually update W
        swap(W, newW);

        // Store the policy's performance
        if ((iteration+1) % evaluateFreq == 0) {
            policyPerformances[iteration+1] = evaluatePolicy(*W, d, env, type, generator, numThreads_PE);
            cout << "Iteration: " << iteration+1 << "\t" << env.getPlottableStatisticName() << " = " << policyPerformances[iteration+1] << endl;
        }
        else
            policyPerformances[iteration+1] = policyPerformances[iteration];
    }

    // Print the performance curve, if the user wants it
    if (fileName.size() > 0) {
        cout << "Printing results to " << fileName << endl;
        ofstream out(fileName);
        out << "Iteration\t" << env.getPlottableStatisticName() << endl;
        for (int i = 0; i < numIterations+1; i++)
            out << i << '\t' << policyPerformances[i] << endl;
        out.close();
    }

    // Clean up memory
    delete W;
    delete newW;

    // Return
    return;
}
