
def compare_stats(modelnums,inits,samplers,Ms,directory,nsamps):
    import matplotlib
    matplotlib.use('Agg')
    from ezcache import ezcache
    import standemo
    import pandas as pd
    import plotnine as p9
    import ezstan
    ez = ezcache(directory)

    def keep(list,num2keep):
        rez = len(list)//num2keep
        print(len(list),num2keep,'rez',rez)
        return list[::rez]

    metric_list = [('elbo_eval','elbo','$\mathrm{ELBO}$'),('err_mu','err_mu','error ($\mu$)'),('err_Sigma','err_Sigma','error ($\Sigma$)')]

    for init in inits:
        df = pd.DataFrame()
        for modelnum in modelnums:
            ndims = ezstan.ezload(modelnum).zlen
            for (metric,shortmetric,axlab) in metric_list:
                # if shortmetric=='elbo':
                #     normalizer = 1.0
                # elif shortmetric=='err_mu':
                #     normalizer = ndims
                # elif shortmetric=='err_Sigma':
                #     normalizer = ndims**2
                normalizer = 1.0
                for M in Ms:
                    try:
                        row = {'modelnum':[modelnum],'M':M}
                        for sampler in samplers:
                            rez  = ez.get(standemo.experiment, directory, modelnum, sampler, M, init, nsamps)
                            rez1 = ez.get(standemo.experiment, directory, modelnum, sampler, 1, init, nsamps)
                            row[shortmetric+'_'+sampler] = rez[metric]/normalizer
                            row[shortmetric+'_naive']    = rez1[metric]/normalizer
                            row['metric']                = shortmetric

                        mydf = pd.DataFrame.from_dict(row)
                        df = df.append(mydf)
                    except Exception:
                        print('skipping')

        for (metric,shortmetric,axlab) in metric_list:
            mydf = df[df.metric==shortmetric]
            metric_iid   = shortmetric+'_iid'
            metric_anti  = shortmetric+'_anti'
            metric_naive = shortmetric+'_naive'
            if shortmetric=='elbo':
                x0 = 'iid - naive'
                y0 = 'antithetic - naive'
                x = metric_iid+'-'+metric_naive
                y = metric_anti+'-'+metric_naive
            else:
                x0 = 'naive - iid'
                y0 = 'naive - antithetic'
                x = metric_naive+'-'+metric_iid
                y = metric_naive+'-'+metric_anti    
            aes = p9.aes(x=x, y=y)
            p = (p9.ggplot(mydf,aes)
                 + p9.geom_abline(intercept=0,slope=1,size=.25,linetype='-',alpha=0.25)
                 + p9.geom_path(p9.aes(group='modelnum'),size=.125,alpha=0.15)
                 + p9.geom_point(p9.aes(color='factor(M)'),alpha=0.5,size=4.5,shape='.',stroke=0)
                 #+ p9.geom_point(p9.aes(color='factor(M)'),alpha=0.75,size=3*.5,shape='$\circ$',stroke=0)
                 + p9.coord_cartesian(xlim=[1e-5,1],ylim=[1e-5,1])
                 + p9.scale_x_log10()
                 + p9.scale_y_log10()
                 + p9.xlab(axlab+', '+x0)
                 + p9.ylab(axlab+', '+y0)
                 + p9.theme_classic()
                 + p9.theme(legend_position=(.78,.27),legend_background=p9.element_rect(fill=(0,0,0,0)))
                 + p9.labs(color='M')
            )
            p9.ggsave(p, 'all_experiments_'+shortmetric+'_' + str(init) + '.pdf', width=4, height=4)

def cross_stats(modelnums,inits,samplers,Ms,directory,nsamps):
    import matplotlib
    matplotlib.use('Agg')
    from ezcache import ezcache
    from numpy import mean
    import standemo
    import pandas as pd
    import plotnine as p9
    ez = ezcache(directory)

    metric_list = [('elbo_eval','elbo','$\mathrm{ELBO}$'),('err_mu','err_mu','error ($\mu$)'),('err_Sigma','err_Sigma','error ($\Sigma$)')]

    # re-organize the data with one row per experiment / M / sampler pair
    for init in inits:
        df = pd.DataFrame()
        for modelnum in modelnums:
            # only take the model if we have ALL the data
            try:
                df_model = pd.DataFrame()
                for sampler in samplers:
                    for M in Ms:
                        row = {'modelnum':[modelnum],'modelnum10':((modelnum*7) % 13),'M':M,'sampler':sampler}
                        for (metric,shortmetric,axlab) in metric_list:
                            rez  = ez.get(standemo.experiment, directory, modelnum, sampler, M, init, nsamps)
                            rez1 = ez.get(standemo.experiment, directory, modelnum, sampler, 1, init, nsamps)
                            row[shortmetric]             = rez[metric]
                            row[shortmetric+'_naive']    = rez1[metric]
                            row['modelnum_sampler']      = str(modelnum)+str(sampler)
                        mydf = pd.DataFrame.from_dict(row)
                        df_model = df_model.append(mydf)

                df_model['elbo_naive']      = mean(df_model.elbo_naive)
                df_model['err_mu_naive']    = mean(df_model.err_mu_naive)
                df_model['err_Sigma_naive'] = mean(df_model.err_Sigma_naive)
    
                # only plot if we have some improvement for everyone over NAIVE (not over iid)
                #assert(np.all(df_model.err_Sigma < df_model.err_Sigma_naive))
                #assert(np.all(df_model.elbo      > df_model.elbo_naive)) 
                df = df.append(df_model)
                print('made it for', modelnum)
                print('num unique',df['modelnum'].nunique())
            except Exception:
                print('skipping')
                print('failed for',modelnum)

        # keep the samplers in a consistent order
        from pandas.api.types import CategoricalDtype
        mycat = CategoricalDtype(categories= ['iid','anti','qmc','qmc-cart','anti-qmc'], ordered=True)
        df['sampling method'] = df['sampler'].astype(str).astype(mycat)

        for sampler in samplers:
            if sampler == 'iid':
                continue
            mydf = df[(df.sampler==sampler) | (df.sampler=='iid')]

            x = 'elbo-elbo_naive'
            y = 'err_Sigma_naive-err_Sigma'
            aes = p9.aes(x=x,y=y)
            print(mydf)
            # color based plots
            p = (p9.ggplot(mydf,aes)
                 + p9.geom_path(p9.aes(group='modelnum'),size=.125,alpha=0.5)
                 + p9.geom_point(p9.aes(shape='factor(modelnum10)',color='sampling method'),alpha=0.65,stroke=0,size=3.0) #,size=2.5
                 + p9.coord_cartesian(xlim=[1e-4,1],ylim=[1e-11,1000])
                 + p9.scale_x_log10()
                 + p9.scale_y_log10()
                 + p9.xlab('$\mathbb{E} \log R$ (improvement over naive)')
                 + p9.ylab('error ($\Sigma$) (improv. over naive)')
                 + p9.facet_grid("~M")
                 + p9.theme_classic()
                 #+ p9.geom_text(p9.aes(label='modelnum'))
                 + p9.theme(legend_position=(.2,.65),legend_background=p9.element_rect(fill=(0,0,0,0)))
                 + p9.scales.scale_size_manual((1.5,2.5,3))
                 + p9.scale_color_brewer('qual',palette=6) # 2 best?
                 + p9.guides(shape=False)
                 #+ p9.scales.scale_shape_manual(['$a$','$i$','$q$'])
                 #+ p9.labs(color='M')
             )
        # Valid names are: ['Blues', 'BuGn', 'BuPu', 'GnBu', 'Greens', 'Greys', 'OrRd', 'Oranges', 'PuBu', 'PuBuGn', 'PuRd', 'Purples', 'RdPu', 'Reds', 'YlGn', 'YlGnBu', 'YlOrBr', 'YlOrRd']
        # p = (p9.ggplot(df,aes)
        #      #+ p9.geom_abline(intercept=0,slope=1,size=.25,linetype='-',alpha=0.25)
        #      + p9.geom_path(p9.aes(group='modelnum'),size=.125,alpha=0.15)
        #      + p9.geom_point(p9.aes(color='sampler'),alpha=0.75,stroke=0,size=2.5) #,size=2.5
        #      #+ p9.geom_point(p9.aes(color='factor(modelnum10)',shape='sampler'),alpha=0.75,stroke=.2,size=2.0) #,size=2.5 # good for letters
        #      #+ p9.geom_point(p9.aes(color='factor(modelnum10)',shape='factor(sampler)',size='factor(M)'),alpha=0.5,stroke=0) #,size=2.5
        #      + p9.coord_cartesian(xlim=[1e-5,1],ylim=[1e-11,1000])
        #      + p9.scale_x_log10()
        #      + p9.scale_y_log10()
        #      #+ p9.xlab(axlab+', '+x0)
        #      #+ p9.ylab(axlab+', '+y0)
        #      + p9.facet_grid("~M")
        #      + p9.theme_classic()
        #      + p9.theme(legend_position=(.1,.6),legend_background=p9.element_rect(fill=(0,0,0,0)))
        #      + p9.scales.scale_size_manual((1.5,2.5,3))
        #      #+ p9.scales.scale_shape_manual(['$a$','$i$','$q$'])
        #      #+ p9.labs(color='M')
        #  )
            p9.ggsave(p, 'all_experiments_crossmetric_' + str(init) + '_' + str(sampler) + '.pdf', width=3*len(Ms), height=3)
    
    
        # for sampler in samplers:
        #     x = 'elbo-elbo_naive'
        #     y = 'err_Sigma_naive-err_Sigma'
        #     aes = p9.aes(x=x,y=y)
        #     print(df)
        #     mydf = df[df.sampler==sampler]
        #     p = (p9.ggplot(mydf,aes)
        #          + p9.geom_abline(intercept=0,slope=1,size=.25,linetype='-',alpha=0.25)
        #          + p9.geom_path(p9.aes(group='modelnum'),size=.125,alpha=0.15)
        #          + p9.geom_point(p9.aes(color='factor(M)'),alpha=0.5,size=4.5,shape='.',stroke=0)
        #          + p9.coord_cartesian(xlim=[1e-5,1],ylim=[1e-11,1000])
        #          + p9.scale_x_log10()
        #          + p9.scale_y_log10()
        #          #+ p9.xlab(axlab+', '+x0)
        #          #+ p9.ylab(axlab+', '+y0)
        #          + p9.theme_classic()
        #          + p9.theme(legend_position=(.78,.27),legend_background=p9.element_rect(fill=(0,0,0,0)))
        #          + p9.labs(color='M')
        #     )
        #     p9.ggsave(p, 'all_experiments_crossmetric_' + sampler + '_' + str(init) + '.pdf', width=4, height=4)


def makecontour(modelnum,inits,samplers,Ms,directory,nsamps,orientation='hor'):
    import matplotlib
    matplotlib.use('Agg')
    from ezcache import ezcache
    import standemo
    import pandas as pd
    import plotnine as p9
    ez = ezcache(directory)

    def keep(list,num2keep):
        rez = len(list)//num2keep
        print(len(list),num2keep,'rez',rez)
        return list[::rez]

    n2plot = 10000*5
    mu,Sigma,pca,Ystan = ez.get(standemo.stan_sampling_info, modelnum)
    Y1stan = keep(Ystan[:,0],n2plot)
    Y2stan = keep(Ystan[:,1],n2plot)

    for init in inits:
        df = pd.DataFrame()

        for sampler in samplers:
            for M in Ms:
                try:
                    rez = ez.get(standemo.experiment, directory, modelnum, sampler, M, init, nsamps)
                    row = {'Y1':Y1stan             , 'Y2':Y2stan             ,'M':M, 'sampler':sampler,'method':'HMC'}
                    mydf = pd.DataFrame.from_dict(row)
                    df = df.append(mydf)
                    Y = rez['y']
                    row = {'Y1':keep(Y[:,0],n2plot), 'Y2':keep(Y[:,1],n2plot),'M':M, 'sampler':sampler,'method':'VI'}
                    mydf = pd.DataFrame.from_dict(row)
                    df = df.append(mydf)
                except Exception:
                    print('skipping')

        #df = df.sort_values(by='method')
        #print(df)
        
        xlo = df.Y1.quantile(.02)
        xhi = df.Y1.quantile(.98)
        ylo = df.Y2.quantile(.02)
        yhi = df.Y2.quantile(.98)

        if orientation=='hor':
            facet_grid = p9.facet_grid('sampler~M')
            width      = 2*len(Ms)
            height     = 2*len(samplers)
        else:
            facet_grid = p9.facet_grid('M~sampler')
            height     = 2*len(Ms)
            width      = 2*len(samplers)            

        aes = p9.aes(x='Y1', y='Y2', color='method')
        p = (p9.ggplot(df,aes)
             + p9.geom_density_2d(levels=4) #7
             + p9.theme_void()
             + p9.facet_grid("sampler~M")
             + p9.coord_cartesian(xlim=[xlo,xhi],ylim=[ylo,yhi])
             + p9.scale_color_manual(values=['#000000', '#8b5757'])
             + p9.theme(legend_position=(.9,.4))
             #+ p9.theme(legend_position=(.9,.5))
        )
        p9.ggsave(p, 'dense'+str(modelnum) + '_' + str(init) + '_' + orientation + '.png', width=width, height=height)
        p9.ggsave(p, 'dense'+str(modelnum) + '_' + str(init) + '_' + orientation + '.pdf', width=width, height=height)
