from mynumpy import *
#from autograd.scipy.stats import *
import scipy
#from dirichlet import *
import dirichlet

def sigmoid(a):
    #return 1.0/(1.0+exp(-a))
    #return .5*(np.tanh(.5*a)+1)
    return autograd.scipy.special.expit(a)

def logsubexp(a,b):
    assert(False) # don't actually need this
    #return log(exp(a)-exp(b))
    c = np.maximum(a,b)
    return log(exp(a-c)-exp(b-c))+c

def fw(z):
    if z.ndim==1:
        # K = len(z)+1
        # x = np.zeros(K)
        # logJ = 0.0
        # prob_left = 1.0
        # for k in range(K-1):
        #     zz = sigmoid(z[k] - log(K-k-1))
        #     x[k] = prob_left * zz
        #     logJ += log(zz) + log(1-zz) + log(prob_left)
        #     prob_left -= x[k]
        # x[K-1] = prob_left
        K = len(z) + 1
        logJ = 0.0
        prob_left = 1.0
        x = []
        logx = []
        for k in range(K-1):
            tmp = z[k] - log(K-k-1)
            zz = sigmoid(tmp)
            x.append(prob_left * zz)
            logJ = logJ + log(zz) + log(1-zz) + log(prob_left)
            #logJ = logJ + -np.logaddexp(-tmp,1.0+0*tmp) + log(1-zz) + log(prob_left)
            prob_left -= x[-1]
        x.append(prob_left)
        x = array(x)
        logx = log(x)
    else:
        K = z.shape[0] + 1
        nz = z.shape[1]
        logJ = zeros(nz)
        prob_left = ones(nz)
        log_prob_left = zeros(nz)
        x = []
        logx = []
        for k in range(K-1):
            tmp = z[k,:] - log(K-k-1)
            zz = sigmoid(z[k,:] - log(K-k-1))
            x.append(prob_left * zz)
            #logJ = logJ + log(zz) + log(1-zz) + log(prob_left)

            #logJ = logJ - np.logaddexp(-tmp,0.0) - np.logaddexp(tmp,0.0) + log(prob_left)
            logJ = logJ - np.logaddexp(-tmp,0.0) - np.logaddexp(tmp,0.0) + log_prob_left
            prob_left = prob_left - x[-1]


            logx.append(log_prob_left - np.logaddexp(-tmp,0.0))

            #logx = log_prob_left - np.logaddexp(-tmp,0.0)
            #log_prob_left = logsubexp(log_prob_left,logx)
            # deriving this took a stupid amount of effort
            log_prob_left = log_prob_left - np.logaddexp(tmp,0.0)

            #assert(np.allclose(logx[-1],log(x[-1])))
            #assert(np.allclose(log_prob_left,log(prob_left)))

        x.append(prob_left)
        logx.append(log_prob_left)
        x = array(x)
        logx = array(logx)
    return x,logx,logJ

# this is not numerically hard enough
def bw(x):
    K = x.shape[0]
    x = x[:-1,:]
    tmp    = x/(1+x-np.cumsum(x,axis=0))
    return scipy.special.logit(tmp) + expand_dims(log(K-arange(1,K)),axis=1)

def test_fw_bw():
    K = 3
    ndata = 10
    z = randn(K-1,ndata)
    assert(np.allclose(z,bw(fw(z)[0])))
    x = rand(K,ndata)
    x /= np.sum(x,axis=0)
    assert(np.allclose(x,fw(bw(x))[0]))

# gives a new function that you can evaluate in the unconstrained space
def transform_dist(logp):
    # this also works vectorized
    def logp2(z):
        x,logx, logJ = fw(z)
        return logp(x) + logJ

    return logp2

def test_vectorization():
    K = 4
    z = randn(K-1,5)
    x,logx,logJ = fw(z)

    x2 = 0*x
    logx2 = 0*x
    logJ2 = 0*logJ
    for i in range(5):
        x2[:,i],logx2[:,i],logJ2[i] = fw(z[:,i])

    assert(np.allclose(x,x2))
    assert(np.allclose(logx,logx2))
    assert(np.allclose(logJ,logJ2))

    # a = rand(K)
    # logp0 = lambda x : dirichlet.logpdf(x,a)
    # logp   = transform_dist(logp0)
    # print('logp(z)')
    # print(logp(z))

def test_grad():
    K = 7
    a = rand(K)

    z = randn(K-1)

    logp0  = lambda t : dirichlet.logpdf(t,a)
    logp = transform_dist(logp0)

    g = grad(logp)
    g_num = 0*z
    eps = 1e-6
    for k in range(K-1):
        z_pos = z+0.0
        z_pos[k] += eps
        z_neg = z+0.0
        z_neg[k] -= eps
        g_num[k] = (logp(z_pos)-logp(z_neg))/(2*eps)
    print('g(z)',g(z),'g_num',g_num)

    K = 7
    nz = 2

    a = rand(K)

    z = randn(K-1,nz)

    logp0  = lambda t : dirichlet.logpdf(t,a)
    logp = transform_dist(logp0)
    def f(t):
        l = logp(t)
        #print('t',t.shape,'l',l.shape)
        return np.sum(l**2)

    g = grad(f)
    g_num = 0*z
    eps = 1e-6
    for k in range(K-1):
        for i in range(nz):
            z_pos = z+0.0
            z_pos[k,i] += eps
            z_neg = z+0.0
            z_neg[k,i] -= eps
            g_num[k,i] = (f(z_pos)-f(z_neg))/(2*eps)
    print('g(z)',g(z),'g_num',g_num)
    assert(np.allclose(g(z),g_num))

def test_transform_integration():
    #z = randn(3)
    #print('z',z,'fw(z)',fw(z),'sum',sum(fw(z)))

    print('testing a distribution on 2 outcomes')

    K = 2
    a = rand(2)
    p = lambda x : dirichlet.pdf(array([x,1-x]),a)

    rez = scipy.integrate.quad(p,0,1,epsabs=1e-5)
    print('does original dist integrate?   ',rez[0])

    theta = rand(2)
    theta = theta / sum(theta)

    z = randn(1)

    logp  = lambda t : dirichlet.logpdf(t,a)
    #print('logp(t)',logp(theta))
    logp2 = transform_dist(logp)
    #print('logp2(z)',logp2(z))
    p2 = lambda x : exp(logp2(array([x])))
    rez = scipy.integrate.quad(p2,-30,30)
    print('does transformed dist integrate?',rez[0])


    print('testing a distribution on 3 outcomes')

    K = 3
    a = rand(3)
    p = lambda x,y : dirichlet.pdf(array([x,y,1-x-y]),a)

    rez = scipy.integrate.dblquad(p,0.0,1.0,lambda x : 0.0, lambda x : 1.0-x, epsabs=1e-5)
    print('does original dist integrate?   ',rez[0])

    theta = rand(K)
    theta = theta / sum(theta)

    z = randn(K-1)

    logp  = lambda t : dirichlet.logpdf(t,a)
    #print('logp(t)',logp(theta))
    logp2 = transform_dist(logp)
    #print('logp2(z)',logp2(z))
    p2 = lambda x,y : exp(logp2(array([x,y])))
    rez = scipy.integrate.dblquad(p2,-10,10,lambda x : -10, lambda x : 10)
    print('does transformed dist integrate?',rez[0])

def test():
    test_fw_bw()
    test_grad()
    test_vectorization()
    #test_transform_integration()

if __name__ == '__main__':
    test()
