def keep(list,num2keep):
    rez = len(list)//num2keep
    print(len(list),num2keep,'rez',rez)
    return list[::rez]

def get_ranges(mu, Sigma):
    buf_x = 5 * sqrt(Sigma[(0, 0)])
    buf_y = 5 * sqrt(Sigma[(1, 1)])
    buf = 3 * sqrt(np.linalg.norm(Sigma, 2))
    buf_x = buf
    buf_y = buf
    xrange = np.arange(mu[0] - buf_x, mu[0] + buf_x, 2 * buf_x / 100)
    yrange = np.arange(mu[1] - buf_y, mu[1] + buf_y, 2 * buf_y / 100)
    return (
     xrange, yrange)

def get_model(modelnum):
    import os, clutter, dirichlet, dill as pickle
    import ezstan
    from mynumpy import np
    logp0 = ezstan.ezload(modelnum)
    def logp(zs):
        if zs.ndim==1:
            return logp0(zs)
        else:
            return np.array([logp0(z) for z in zs.T])
    model = logp
    return logp, logp0.model_name

# this can be saved
def stan_sampling_info(modelnum):
    import ezstan
    from sklearn.decomposition import PCA
    import numpy as np
    totsamps = 100000
    logp =ezstan.ezload(modelnum)
    Z   = logp.sampling(iter=totsamps,n_jobs=1)
    mu = np.mean(Z,axis=0)
    Sigma = np.cov(Z.T)
    pca = PCA(min(2,Z.shape[1]))
    Y = pca.fit_transform(Z)
    return mu,Sigma,pca,Y

def experiment(directory, modelnum, sampler, M, init, nsamps):
    import warnings
    warnings.filterwarnings("ignore", message="numpy.dtype size changed")
    warnings.filterwarnings("ignore", message="numpy.ufunc size changed")
    import os, dill as pickle
    from mynumpy import seed, np
    import infer
    import sample
    from standemo import get_model, stan_sampling_info
    from ezcache import ezcache
    ez = ezcache(directory)

    sampler_map = {'iid':sample.iid, 'qmc':sample.qmc_cartp1, 'qmc-cart':sample.qmc_cart, 'anti':sample.antithetic,'anti-qmc':sample.anti_qmc_cartp1,'latin-cart':sample.latin_cart}

    logp,model_name = get_model(modelnum)

    mu,Sigma,pca,Y = ez.get(stan_sampling_info, modelnum)
    ndims = len(mu)

    print('running experiment for',modelnum,'sampler',sampler,'M',M,'init',init,'ndims',ndims)

    if init==0:
        m = np.zeros(len(mu))
        S = np.eye(len(mu))
    elif init==1:
        #m = mu
        #S = np.eye(len(mu))
        m = 0*mu
        m, S = infer.laplace(logp,m)
    elif init==2:
        m = mu
        if Sigma.shape[0]>1:
            S = np.linalg.cholesky(Sigma)
        else:
            S = sqrt(Sigma)

    seed(0) # always use same seed!
    z0         = sampler_map[sampler](nsamps, M, ndims)
    m, S, fval = infer.fitgauss(lambda m, S: infer.elbo(z0, m, np.tril(S), logp), m, S, maxiter=1000)

    elbo_fit   = infer.elbo(z0, m, S, logp)
    z0         = sampler_map[sampler](nsamps * 10, M, ndims)
    elbo_eval  = infer.elbo(z0, m, S, logp)
    z          = infer.conditional_sample(z0, m, S, logp)
    mu_eval    = np.mean(z, axis=0)
    Sigma_eval = np.cov(z.T)
    err_mu     = np.sum((mu - mu_eval) ** 2)
    err_Sigma  = np.sum((Sigma - Sigma_eval).ravel() ** 2)
    y          = pca.transform(z)
    rez = {'elbo_fit':elbo_fit,  'elbo_eval':elbo_eval,  'err_mu':err_mu,  'err_Sigma':err_Sigma, 'y': y, 'model_name':model_name, 'Sigma_eval':Sigma_eval, 'mu_eval':mu_eval}
    print('rez',rez)
    return rez

def get_modelnums():
    import dill
    with open('model_timings.pk','rb') as fid:
        store = dill.load(fid)

    # want 10000 evals to take 0.5 seconds or less
    cheap = [k for k in store if k[1]*100000 <= 600*60 and k[2]*10000 < 2.0 and (k[0] not in [73,44]) ]
    return [k[0] for k in cheap]

def dostuff():
    import matplotlib
    matplotlib.use('Agg')
    import warnings
    warnings.simplefilter('ignore', FutureWarning)
    warnings.simplefilter('ignore', UserWarning)
    import plot, clutter, dirichlet, sample, pandas as pd, plotnine as p9
    import infer
    import os.path
    import dill as pickle
    import time
    import stanplot

    from ezcache import ezcache
    directory = './ezcache'
    ez = ezcache(directory)

    modelnums = get_modelnums()

    nsamps = 10000*5
    Ms = [1, 2, 4, 6, 8]
    samplers = ['iid','anti','qmc','qmc-cart','anti-qmc'] #,'latin-cart'
    inits = range(3)
    inits = [1]

    # first off generate all the model info (we will re-use it)
    for modelnum in modelnums:
        ez.run(stan_sampling_info, modelnum, mem=10000)
    ez.wait_for_complete()
    time.sleep(10) # super conservative wait time...

    # next, do the inference!
    print('doing inference...')
    for modelnum in modelnums:
        for sampler in samplers:
            for M in sorted(Ms,reverse=True):
                for init in inits:
                    #experiment(directory, modelnum, sampler, M, init, nsamps)
                    ez.run(experiment, directory, modelnum, sampler, M, init, nsamps, mem=10*1000, hours=24, ncores=1)
    ez.wait_for_complete()
    return
    time.sleep(10) # super conservative wait time...
    
    for modelnum in modelnums:
        for orientation in ['hor','ver']:
            ez.run(stanplot.makecontour,modelnum,inits,samplers,[1,2,4,8],directory,nsamps,orientation,mem=5000,force=True)
            #stanplot.makecontour(modelnum,inits,samplers,[1,2,4,8],directory,nsamps,orientation)
    return

    for init in inits:
       stanplot.compare_stats(modelnums,inits,samplers,[2,4,8],directory,nsamps)
       stanplot.cross_stats(modelnums,inits,samplers,[2,4,8],directory,nsamps)
    return

    print('making plots...')
    # now, make the plots!
    for modelnum in modelnums:
        df = pd.DataFrame()
        for init in inits:
            for sampler in samplers:
                for M in Ms:
                    try:
                        rez = ez.get(experiment, directory, modelnum, sampler, M, init, nsamps)
                        #print(rez)
                        rez['M']       = M + 0.0
                        rez['sampler'] = sampler
                        rez['init']    = init
                        row = pd.DataFrame(columns=list(rez.keys()))
                        row.loc[0] = list(rez.values())
                        df = df.append(row)
                    except Exception as exc:
                        print('skipping...',exc)

        if len(df.index)==0:
            continue
        # keep the samplers in a consistent order
        from pandas.api.types import CategoricalDtype
        mycat = CategoricalDtype(categories= ['iid','anti','qmc','qmc-cart','anti-qmc'], ordered=True)
        df['sampling method'] = df['sampler'].astype(str).astype(mycat)

        print(df)
        for init in inits:
            try:
                mydf = df[df.init==init]
                model_name = list(mydf.model_name)[0] # should all be same!
                aes = p9.aes(x='M', color='sampling method', group='sampling method')
                p = (p9.ggplot(mydf,aes)
                     + p9.geom_path(p9.aes(y='elbo_eval'), alpha=0.75, linetype = "solid")
                     + p9.scale_x_log10()
                     + p9.theme_classic()
                     + p9.labs(y='$\mathbb{E}\ \log R$')
                     + p9.theme(legend_position=(.78,.35),legend_background=p9.element_rect(fill=(0,0,0,0)))
                     + p9.scale_color_brewer('qual',palette=2)
                     + p9.guides(color=False) # no color in ELBO plot
                     + p9.ggtitle(model_name)
                )
                p9.ggsave(p, 'stan'+str(modelnum) + '_' + str(init) + '_elbos' + '.png', width=4, height=2)
                p9.ggsave(p, 'stan'+str(modelnum) + '_' + str(init) + '_elbos' + '.pdf', width=4, height=2)

                for (metric,ylabel) in (('err_mu','error ($\mu$)'),('err_Sigma','error ($\Sigma$)')):
                    p = (p9.ggplot(mydf,aes)
                         + p9.geom_path(p9.aes(y=metric), alpha=0.75, linetype = "solid")
                         + p9.expand_limits(y=0) # forces 0 to be included?
                         + p9.expand_limits(y=.01) # forces 0 to be included?
                         + p9.scale_x_log10()
                         + p9.theme_classic()
                         + p9.labs(y=ylabel)
                         + p9.theme(legend_position=(.78,.6),legend_background=p9.element_rect(fill=(0,0,0,0)))
                         + p9.scale_color_brewer('qual',palette=2)
                         + p9.ggtitle(model_name)
                     )
                    p9.ggsave(p, 'stan'+str(modelnum) + '_' + str(init) + '_' + metric + '.png', width=4, height=2)
                    p9.ggsave(p, 'stan'+str(modelnum) + '_' + str(init) + '_' + metric + '.pdf', width=4, height=2)

            except Exception as exc:
                print('skipping...',exc)

    print('making comparison scatter plots...')
    #for modelnum in modelnums:
    for modelnum in modelnums:
        n2plot = 10000
        mu,Sigma,pca,Ystan = ez.get(stan_sampling_info, modelnum)
        Y1stan = keep(Ystan[:,0],n2plot)
        Y2stan = keep(Ystan[:,1],n2plot)

        for init in inits:
            df = pd.DataFrame()

            for sampler in samplers:
                for M in Ms:
                    try:
                        rez = ez.get(experiment, directory, modelnum, sampler, M, init, nsamps)
                        Y = rez['y']
                        row = {'Y1':keep(Y[:,0],n2plot), 'Y2':keep(Y[:,1],n2plot),'Y1stan':Y1stan,'Y2stan':Y2stan,'M':M, 'sampler':sampler}
                        mydf = pd.DataFrame.from_dict(row)
                        df = df.append(mydf)
                    except Exception:
                        print('skipping')
            
            aes      = p9.aes(x='Y1', y='Y2')
            aes_stan = p9.aes(x='Y1stan', y='Y2stan')
            p = (p9.ggplot(df,aes)
                 + p9.geom_point(aes_stan,alpha=0.05,shape='.',size=.25,color='red')
                 + p9.geom_point(alpha=0.05,shape='.',size=.25,color='blue')
                 + p9.theme_classic()
                 + p9.facet_grid('sampler~M')
             )
            p9.ggsave(p, 'comp'+str(modelnum) + '_' + str(init) + '.png', width=3*len(Ms), height=3*len(samplers))

    print('making scatter plots...')
    for modelnum in modelnums:
        mu,Sigma,pca,Y = ez.get(stan_sampling_info, modelnum)
        for init in inits:
            n2plot = 10000
            df = pd.DataFrame()
            row = {'Y1':keep(Y[:,0],n2plot), 'Y2':keep(Y[:,1],n2plot), 'M':max(Ms), 'sampler':'stan'}
            mydf = pd.DataFrame.from_dict(row)
            df = df.append(mydf)

            for sampler in samplers:
                for M in Ms:
                    try:
                        rez = ez.get(experiment, directory, modelnum, sampler, M, init, nsamps)
                        Y = rez['y']
                        row = {'Y1':keep(Y[:,0],n2plot), 'Y2':keep(Y[:,1],n2plot), 'M':M, 'sampler':sampler}
                        mydf = pd.DataFrame.from_dict(row)
                        df = df.append(mydf)
                    except Exception:
                        print('skipping')
            
            aes = p9.aes(x='Y1', y='Y2')
            p = (p9.ggplot(df,aes)
                 + p9.geom_point(alpha=0.1,shape='.',size=.25)
                 + p9.theme_classic()
                 + p9.facet_grid('M~sampler')
             )
            p9.ggsave(p, 'samps'+str(modelnum) + '_' + str(init) + '.png', height=4*len(Ms), width=4*(len(samplers)+1))                

    for modelnum in modelnums:
        n2plot = 10000*5
        mu,Sigma,pca,Ystan = ez.get(stan_sampling_info, modelnum)
        Y1stan = keep(Ystan[:,0],n2plot)
        Y2stan = keep(Ystan[:,1],n2plot)

        for init in inits:
            df = pd.DataFrame()

            for sampler in samplers:
                for M in Ms:
                    try:
                        rez = ez.get(experiment, directory, modelnum, sampler, M, init, nsamps)
                        Y = rez['y']
                        row = {'Y1':keep(Y[:,0],n2plot), 'Y2':keep(Y[:,1],n2plot),'M':M, 'sampler':sampler,'stan':False}
                        mydf = pd.DataFrame.from_dict(row)
                        df = df.append(mydf)
                        row = {'Y1':Y1stan             , 'Y2':Y2stan             ,'M':M, 'sampler':sampler,'stan':True}
                        mydf = pd.DataFrame.from_dict(row)
                        df = df.append(mydf)
                    except Exception:
                        print('skipping')
            
            aes = p9.aes(x='Y1', y='Y2', color='stan')
            p = (p9.ggplot(df,aes)
                 + p9.geom_density_2d()
                 + p9.theme_classic()
                 + p9.facet_grid('sampler~M')
             )
            p9.ggsave(p, 'dense'+str(modelnum) + '_' + str(init) + '.png', width=2*len(Ms), height=2*(len(samplers)))                

if __name__ == '__main__':
    dostuff()
