from mynumpy import *
import warnings
import simplex_transform
#import conjugate_inference
import scipy.stats
from autograd.scipy.special import gammaln

def dirichlet_logpdf(z,logz,alpha):
    #print('computing gammaln')
    logB = sum(gammaln(alpha))-gammaln(sum(alpha))
    #return np.sum(expand_dims(alpha-1,axis=1)*log(z),axis=0) - logB
    return np.sum(expand_dims(alpha-1,axis=1)*logz,axis=0) - logB

class dirichlet_multinomial:
    def __init__(self,data,alpha):
        K = len(alpha)
        assert(np.all(alpha>0))
        self.K = K
        self.ndims = K-1
        self.data  = data
        self.alpha = alpha
        self.dataset = 'dirichlet_multinomial'
        self.prior_sigma = ''

        if data is not None:
            y  = np.bincount(data,minlength=K)
            assert(len(y)==K)
            self.alpha_post = alpha + y
        else:
            self.alpha_post = alpha

    def logp(self,z):
        if z.ndim==1:
            return self.logp(z.reshape(len(z),1))[0]
        assert(z.ndim==2)
        t,logt, logJ = simplex_transform.fw(z)
        return dirichlet_logpdf(t,logt,self.alpha_post) + logJ

    def unconstrain(self,theta):
        assert(theta.ndim==2)
        assert(theta.shape[0]==self.K)
        return simplex_transform.bw(theta)

    def constrain(self,z):
        assert(z.ndim==2)
        assert(z.shape[0]==self.K-1)
        return simplex_transform.fw(z)[0]

    def rvs(self,nsamps):
        thetas = scipy.stats.dirichlet.rvs(self.alpha_post,size=nsamps).T
        return self.unconstrain(thetas)


def getD(K,ndata):
    import numpy as np
    import scipy.stats
    #from dirichlet_multinomial import dirichlet_multinomial


    # random prior
    while True:
        b = 1 # larger b - easiser distribution
        alpha = scipy.stats.gamma.rvs(b,scale=b,size=K)

        #print('alphas',np.min(scipy.stats.dirichlet.var(alpha)), np.max(scipy.stats.dirichlet.var(alpha)))
        if np.min(scipy.stats.dirichlet.var(alpha)) < 1e-5:
            print('repeating... (alpha)')
            continue
            

        # sample theta from prior (true params)
        theta = scipy.stats.dirichlet.rvs(alpha).ravel()
        # sample dataset from theta
        data  = np.random.choice(K,size=ndata,p=theta)

        logp = dirichlet_multinomial(data,alpha)

        X=logp.rvs(10000000)
        mu = np.mean(X,axis=1)
        Sigma = np.cov(X)

        #print('mu',mu)
        #print('Sigma',Sigma)
        if np.any(isnan(X)) or np.any(np.isinf(X)):
            print('repeating (X)')
            continue
        return logp,mu,Sigma