#include "PolicyEvaluation.hpp"
#include "ValueIteration.hpp"       // For greedyAction(...)
#include "Bicycle.hpp"
#include <iostream>

using namespace std;
using namespace Eigen;

double evaluatePolicy(const Eigen::MatrixXd & W,
                      Multilinear & d,
                      Environment & env,
                      const ENVTYPE & type,
                      mt19937_64 & generator,
                      int numThreads) {
    // Get and select some constants
    int numTrials = env.getNumMCSamplesForPolicyEvaluation(), maxT = env.getMaxTForPolicyEvaluation();

    // Update numThreads - can't be more than the total number of trials, which would mean one thread per trial
    numThreads = min<int>(numThreads, numTrials);

    // Each thread needs its own RNG. Seed them all from the one that is provided. Also, compute the number of trials that will be done per thread
    vector<mt19937_64> generators(numThreads);
    uniform_int_distribution<unsigned long> distribution(std::numeric_limits<unsigned long>::lowest(), std::numeric_limits<unsigned long>::max());
    for (int i = 0; i < numThreads; i++)
        generators[i].seed(distribution(generator));
    VectorXd results = VectorXd::Zero(numThreads);
    int trialsPerThread = numTrials / numThreads;

    // Run the threads
    #pragma omp parallel for num_threads(128) // The num_threads here should be more than numThreads
    for (int threadIndex = 0; threadIndex < numThreads; threadIndex++) {
        // Figure out which trial numbers this thread will do
        int startTrial = threadIndex * trialsPerThread, endTrialPlusOne = (threadIndex+1)*trialsPerThread;
        if (threadIndex == numThreads-1) // Correct any rounding errors
            endTrialPlusOne = numTrials;

        // Create the environment object for this thread
        Environment * threadEnv = createEnvironment(type, generators[threadIndex]);

        // Loop over trials for this thread
        for (int trial = startTrial; trial < endTrialPlusOne; trial++) {
            // Run the episode
            threadEnv->newEpisode(generators[threadIndex]);
            for (int t = 0; (t < maxT) && (!threadEnv->terminate()); t++)
                threadEnv->update(greedyAction(W, threadEnv->getState(), d), generators[threadIndex]);
            // Now that the episode is over, get the plottable statistic from this episode
            results[threadIndex] += threadEnv->getPlottableStatistic();
        }

        // Clean up memory
        delete threadEnv;
    }
    // Return the mean performance over all of the threads. Do not average sooner since the last thread might have more trials than the others.
    return results.sum() / (double)numTrials;
}
