import warnings
warnings.filterwarnings("ignore", message="numpy.dtype size changed")
warnings.filterwarnings("ignore", message="numpy.ufunc size changed")

from mynumpy import *
import util

def mvnlogpdf(z,u,L): # autograd's mvn doesn't backprop correctly
    if len(z.shape)==1:
        z2 = z-u
    else:
        z2 = z-expand_dims(u,axis=1)
    C = L @ L.T
    y = solve(L,z2)
    return -.5*np.sum(y*y,axis=0) - .5*logdet(2*pi*C)

def gaussent(s):
    return .5*log(2*pi*e*s**2)
def logq(z,m,s):
    return mvnlogpdf(z,m,s)
def q(z,m,s):
    return exp(logq(z,m,s))

def log_ave_naive(a,b):
    return log(.5*exp(a)+.5*exp(b))
def log_ave(a,b):
    return np.logaddexp(a,b) + log(.5)
assert(np.allclose(log_ave(.1,.2),log_ave_naive(.1,.2)))

def map(z0,m,S):
    nsamps,M,ndims = z0.shape
    z = z0 @ S.T + m
    # for k in range(5):
    #     z_old  = z0[k,0,:]
    #     z_new  = z [k,0,:]
    #     z_new2 = S @ z_old + m
    #     assert(np.allclose(z_new2, z_new))
    return z

def logR1(z,m,S,logp):
    nsamps,M,ndims = z.shape
    z_flat = z.reshape([nsamps*M,ndims])
    logp_flat = logp(z_flat.T)
    logq_flat = mvnlogpdf(z_flat.T,m,S)
    logp = logp_flat.reshape([nsamps,M])
    logq = logq_flat.reshape([nsamps,M])
    return logp-logq

def elbo(z0,m,S,logp):
    nsamps,M,ndims = z0.shape
    z    = map(z0,m,S)
    logR = logR1(z,m,S,logp)
    logR = logsumexp(logR.T,axis=0) - log(M)
    return np.mean(logR)

def conditional_sample(z0,m,S,logp):
    nsamps,M,ndims = z0.shape
    z = map(z0,m,S)
    logR = logR1(z,m,S,logp).T
    pi = exp(logR - logsumexp(logR,axis=0))
    which = util.sample_each_column(pi)
    z_out = zeros((nsamps,ndims))
    for n,i in enumerate(which):
        z_out[n,:] = z[n,i,:] 
    return z_out

# def elbo(z0,m,S,logp):
#     nsamps,M,ndims = z0.shape
#     z_flat = z0.reshape([nsamps*M,ndims]) @ S.T + m
#     logp_flat = logp(z_flat)
#     logq_flat = mvnlogpdf(z_flat.T,m,S)
#     logp = logp_flat.reshape([nsamps,M])
#     logq = logq_flat.reshape([nsamps,M])
#     #logR0 = log(np.mean(exp(logp)/exp(logq),axis=1))
#     logR  = logsumexp((logp-logq).T,axis=0) - log(M)
#     return np.mean(logR)

# # what is the effective antithetic density?
# def q_anti(z,m,s,logp):
#     z1 = z
#     z2 = expand_dims(2*m,axis=1)-z
#     assert(np.allclose(logq(z,m,s),logq(z2,m,s)))
#     p1 = exp(logp(z1))
#     p2 = exp(logp(z2))
#     return 2*q(z,m,s)*p1/(p1+p2)

# # function to compute elbo with a given mean and std
# def elbo_naive(u,m,s,logp):
#     z = expand_dims(m,axis=1) + s @ u
#     return mean(logp(z) - logq(z,m,s))

# def elbo_anti(u,m,s,logp):
#     z1 = expand_dims(m,axis=1) + s @ u
#     z2 = expand_dims(m,axis=1) - s @ u
#     return mean(log_ave(logp(z1)-logq(z1,m,s),logp(z2)-logq(z2,m,s)))

def fitgauss(f,m0,s0,maxiter=100):
    import time
    w0,unflatten = flatten([m0,s0])

    iters = [0]
    t0 = time.time()
    def f2(w):
        m,s = unflatten(w)
        if util.good_iter_to_print(iters[0]):
            print('iters',iters[0],'time',time.time()-t0,'f',f(getval(m),getval(s)))
        iters[0] += 1
        return -f(m,s)
    obj = value_and_grad(f2)
    options = {'maxiter':maxiter}

    rez=scipy.optimize.minimize(obj,w0,jac=True,method='BFGS',options=options)
    #print(rez)
    #assert(rez.success)
    print('not checking success...')
    m,s = unflatten(rez.x)
    return m,s,-rez.fun

def fitgauss_stoch(f,m0,s0):
    import time
    w0,unflatten = flatten([m0,s0])

    def f2(w):
        m,s = unflatten(w)
        return f(m,s)
    obj = value_and_grad(f2)
    
    w = w0 + 0.0
    t0 = time.time()
    g_ave = 0*w0
    f_ave = 0.0
    f2_ave = 0.0
    iters = 1000*10
    for iter in range(iters):
        ff,g = obj(w)
        alpha_f  = max(.001, 1/(1+iter))
        alpha_g  = max(.1  , 1/(1+iter))
        alpha_f2 = max(.001, 1/(1+iter))
        f_ave  = alpha_f *ff    + (1-alpha_f) *f_ave
        f2_ave = alpha_f2*ff**2 + (1-alpha_f2)*f2_ave
        g_ave  = alpha_g *g     + (1-alpha_g) *g_ave
        if iter < iters//2:
            step = .1
        else:
            step = .01
        w += step*g_ave #/ sqrt(f2_ave+1e-10)

        if util.good_iter_to_print(iter):
            print('iter',iter,'time',time.time()-t0,'f',f_ave)

    m,s = unflatten(w)
    return m,s,f_ave

def laplace(f,z0):
    def f2(z):
        #print('f',f(z))
        return -f(z)
    obj = value_and_grad(f2)

    try:
        rez=scipy.optimize.minimize(obj,z0,jac=True,method='BFGS')
        #print(rez)
        assert(rez.success)
        z = rez.x
    except AssertionError:
        options = {'maxiter':1e6}
        rez=scipy.optimize.minimize(obj,z0,jac=True,method='Nelder-Mead',options=options)
        #print(rez)
        assert(rez.success)
        z = rez.x

    try:
        H = np.squeeze(autograd.hessian(f2)(z))
    except:
        g = autograd.grad(f2)
        H = zeros((len(z),len(z)))
        for i in range(len(z)):
            z_pos = z + 0.0
            z_neg = z + 0.0
            eps = 1e-5
            z_pos[i] += eps
            z_neg[i] -= eps
            H[i,:] = (g(z_pos)-g(z_neg))/(2*eps)
            H = .5*H + .5*H.T
    try:
        C = inv(cholesky(H))
    except np.linalg.LinAlgError:
        C = eye(len(z))
    return z,C

def maximize_elbo(elbo,u,logp,m0=None,s0=None):
    if m0 is None:
        m0 = zeros(len(u))
    if s0 is None:
        s0 = eye(len(u))
    def f(m,s):
        #print('m',m,'s',s)
        return elbo(u,m,s,logp)
    #print('f',f(0,1))
    return fitgauss(f,m0,s0)
