from mynumpy import *
from matplotlib import pyplot as plt
import seaborn as sns
from matplotlib import cm

def set_lims_to_samp(zs,buf=1.0):
    x_lo = min(zs[0,:])
    x_hi = max(zs[0,:])
    x_mid = .5*x_lo + .5*x_hi
    x_buf = (x_mid-x_lo)*buf
    y_lo = min(zs[1,:])
    y_hi = max(zs[1,:])
    y_mid = .5*y_lo + .5*y_hi
    y_buf = (y_mid-y_lo)*buf
    #plt.xlim(x_mid-x_buf,x_mid+x_buf)
    #plt.ylim(y_mid-y_buf,y_mid+y_buf)

    #x_mid = median(zs[0,:])
    #y_mid = median(zs[1,:])
    x_lo = np.percentile(zs[0,:], 1)
    x_hi = np.percentile(zs[0,:],99)
    y_lo = np.percentile(zs[1,:], 1)
    y_hi = np.percentile(zs[1,:],99)
    plt.xlim(x_lo,x_hi)
    plt.ylim(y_lo,y_hi)

def mesh_from_samps(samps,x_lo,x_hi,y_lo,y_hi,rez=50):
    if x_lo is None:
        x_lo = min(samps[0,:])
    if x_hi is None:
        x_hi = max(samps[0,:])
    if y_lo is None:
        y_lo = min(samps[1,:])
    if y_hi is None:
        y_hi = max(samps[1,:])
    return meshgrid(arange(x_lo,x_hi,(x_hi-x_lo)/rez),arange(y_lo,y_hi,(y_hi-y_lo)/rez))

def mapdist_over_samps(samps,dist,x_lo=None,x_hi=None,y_lo=None,y_hi=None,rez=100):
    Z1, Z2 = mesh_from_samps(samps,x_lo,x_hi,y_lo,y_hi,rez=rez)
    return compute_dist_on_mesh(dist,Z1,Z2), Z1, Z2
    #Z = np.vstack([Z1.ravel(),Z2.ravel()])
    #P = dist(Z)
    #return P.reshape(Z1.shape), Z1, Z2

def getmesh(size=1,rez=20,xrange=None,yrange=None):
    if xrange is None:
        xrange = np.arange(-size,size+1e-20,(2*size)/rez)
    if yrange is None:
        yrange = np.arange(-size,size+1e-20,(2*size)/rez)
    return meshgrid(xrange,yrange)

def compute_dist_on_mesh(dist,Z1,Z2,vectorized=True):
    Z = np.vstack([Z1.ravel(),Z2.ravel()])
    P = dist(Z)
    return P.reshape(Z1.shape)

def plot_dist_on_mesh(dist,levels=None,mesh_args={},contour_args={},Z1=None,Z2=None,vectorized=True):
    if Z1 is None or Z2 is None:
        Z1,Z2 = getmesh(**mesh_args)
    P = compute_dist_on_mesh(dist,Z1,Z2)
    # matplotlib can't take an integer number of levels as a keyword argument
    if levels is not None:
        return plt.contour(Z1,Z2,P,levels,**contour_args)
    else:
        return plt.contour(Z1,Z2,P,**contour_args)

def compare_samps_to_dist1d(samps,dist):
    range = np.max(samps)-np.min(samps)

    sns.kdeplot(np.array(samps), bw=range/1000, label='sample KDE')

    zs = arange(min(samps),max(samps),(max(samps)-min(samps))/100)
    plt.plot(zs,dist(zs), label='dist eval')
    plt.legend()

def compare_samps_to_dist2d(samps,dist):
    P, Z1, Z2 = mapdist_over_samps(samps,dist)

    nlevels = 10
    maxP  = np.max(P.ravel())
    stepP = maxP/(nlevels)
    levels=arange(0,2*maxP,stepP)

    plt.subplot(1,3,2)
    plt.style.use('seaborn-white')
    plt.contourf(Z1,Z2,P,cmap='Reds',levels=levels)
    plt.colorbar();
    plt.title('dist eval')

    plt.subplot(1,3,1)
    plt.style.use('seaborn-white')
    sns.kdeplot(samps[0,:],samps[1,:], cmap='Reds', levels=levels, shade=True)
    mymin = np.min(Z1.ravel())
    mymax = np.max(Z1.ravel())
    plt.xlim([mymin,mymax])
    mymin = np.min(Z2.ravel())
    mymax = np.max(Z2.ravel())
    plt.ylim([mymin,mymax])
    plt.colorbar()
    plt.title('sample KDE')

    plt.subplot(1,3,3)
    plt.style.use('seaborn-white')
    plt.contour(Z1,Z2,P,cmap='RdGy',levels=levels)
    #plt.colorbar();
    plt.plot(samps[0,:],samps[1,:],'b.',alpha=min(1.0,500/samps.shape[1]),markersize=3) #samps.shape[1]/10

    plt.title('samples and dist eval')
    plt.tight_layout()

def compare_dists_miniplots(p,q,ws,mesh_args={}):
    'compares p(z) (fixed) vs q(z,w) for w in ws'
    Z1,Z2 = getmesh(**mesh_args)
    P = compute_dist_on_mesh(p,Z1=Z1,Z2=Z2)

    num_per_row = 8
    num_rows = int(ceil(len(ws)/num_per_row))
    #print('num_rows',num_rows)

    plt.figure(figsize=(3*num_per_row,3*num_rows))

    assert(num_rows*num_per_row >= len(ws))

    for k in range(len(ws)):
        w = ws[k]

        plt.subplot(num_rows,num_per_row,k+1)

        my_q = lambda z : q(z,w)
        Q = compute_dist_on_mesh(my_q,Z1=Z1,Z2=Z2)

        plt.contour(Z1,Z2,P,cmap=cm.winter)
        plt.contour(Z1,Z2,Q,cmap=cm.spring)
        plt.axis('off')


def plot_dist(dist,col,label,size=100,levels=25,xrange=[-1,1],yrange=[-1,1]):
    alpha  = .65
    
    mesh_args = {'size' : size, 'rez' : 200, 'xrange':xrange, 'yrange' : yrange}
    contour_args = {'alpha':alpha,'cmap':col}

    sns.set(font_scale=1.5)
    sns.set_style('white')
    CS = plot_dist_on_mesh(dist, mesh_args = mesh_args, contour_args=contour_args,levels=levels)
    #print(CS.levels)
    #levels = CS.levels # force next plot to use same levels
    #plot.plot_dist_on_mesh(dist2, levels=levels, mesh_args = mesh_args, contour_args=con_args(col2))
    plt.title(label)
    return CS.levels

def compare_dists(dist1,dist2,col2,label):
    alpha  = .65
    
    mesh_args = {'size' : size, 'rez' : 25, 'xrange':xrange, 'yrange' : yrange}
    
    def con_args(color):
        return {'alpha':alpha,'cmap':color}
        
    levels = 25

    sns.set(font_scale=1.5)
    sns.set_style('white')
    CS = plot.plot_dist_on_mesh(dist1, mesh_args = mesh_args, contour_args=con_args('Greys'),levels=levels)
    print(CS.levels)
    levels = CS.levels # force next plot to use same levels
    plot.plot_dist_on_mesh(dist2, levels=levels, mesh_args = mesh_args, contour_args=con_args(col2))
    plt.title(label)