from mynumpy import *
import scipy

def logsumexp(A,axis=0):
    damax = np.max(A,axis=axis)
    return log(np.sum(exp(A-damax),axis=axis))+damax

def mvnlogpdf(x, d, var, mu=None, keepdims=False):
    # If x is (d x n) and mu is (d x k), this will return a (n x k) matrix
    # evaluating all unique (x, mu) pairs

    x = x.reshape(d,-1,1)        # d x n x 1
    if mu is None:
        mu = np.zeros((d,1,1))
    mu = mu.reshape(d,1,-1)      # d x 1 x k

    logp = -0.5/var * np.sum((x-mu)*(x-mu), axis=0) - 0.5*d*np.log(2*np.pi*var)

    if not keepdims:
        logp = np.squeeze(logp)

    return logp

def test_mvnlogpdf(d=2,var=2.0,n=5,k=2):
    import scipy
    from scipy.stats import multivariate_normal

    print("single x, single mu: ", end='')
    x = randn(d)
    mu = randn(d)
    logp1 = mvnlogpdf(x, d, var, mu=mu)
    logp2 = multivariate_normal.logpdf(x, mean=mu, cov=var*np.eye(d))
    assert(np.allclose(logp1, logp2))
    print("passed")

    print("many x, single mu: ", end='')
    x = randn(d,n)
    mu = randn(d)
    logp1 = mvnlogpdf(x, d, var, mu=mu)
    logp2 = multivariate_normal.logpdf(x.T, mean=mu, cov=var*np.eye(d))
    assert(np.allclose(logp1, logp2))
    print("passed")

    print("many x, many mu: ", end='')
    x = randn(d, n)
    mu = randn(d, k)
    logp1 = mvnlogpdf(x, d, var, mu=mu)
    logp2 = np.zeros((n,k))
    for i in range(k):
        logp2[:,i] = multivariate_normal.logpdf(x.T, mean=mu[:,i], cov=var*np.eye(d))
    assert(np.allclose(logp1, logp2))
    print("passed")


def log_normalizer(d, prec, h):
    return np.sum(h*h, axis=1, keepdims=True)/(2*prec) + 0.5*d*np.log(2*np.pi/prec)

# Class for Gaussian potentials
class Pot:
    def __init__(self, d, prec, h=None, a=None, b=None):
        self.d = d
        self.prec = np.array(prec).reshape(-1,1)
        m = self.prec.shape[0]
        self.m = m
        self.h = np.zeros((m,d)) if h is None else np.array(h).reshape((m,d))
        self.b = np.zeros((m,1)) if b is None else np.array(b).reshape((m,1))
        self.a = log_normalizer(d, self.prec, self.h) if a is None else np.array(a).reshape((m,1))


    @classmethod
    def density2pot(cls, d, var, mu=None, b=None):
        h = None if mu is None else -mu/var
        return Pot(d, 1/var, h=h, b=b)

    def val(self, x, as_log=True, tot=True, keepdims=False):

        # Coerce x into shape (d, n)
        x = np.array(x).reshape(self.d,-1)

        xtx = np.sum(x*x, axis=0, keepdims=True)                        # shape (1, n)
        val = -0.5*self.prec*xtx - np.dot(self.h, x) - self.a - self.b  # shape (m, n)

        if tot:
            val = logsumexp(val, axis=0)

        if not keepdims:
            val = np.squeeze(val)

        return val if as_log else np.exp(val)

    @classmethod
    def mult(cls, pot1, pot2):

        d = pot1.d
        if pot2.d != d:
            raise ValueError("Potentials must have same d")
        if pot1.m > 1 and pot2.m > 1 and pot1.m != pot2.m:
            raise ValueError("Potentials must have same m or one is a singleton")

        prec = pot1.prec + pot2.prec
        h = pot1.h + pot2.h
        a = log_normalizer(d, prec, h)
        b = pot1.b + pot2.b + pot1.a + pot2.a - a

        return Pot(d, prec, h, a, b)

    @classmethod
    def stack(cls, pot1, pot2):
        d = pot1.d
        if pot2.d != d:
            raise ValueError("Potentials must have same d")

        prec = np.concatenate((pot1.prec, pot2.prec))
        h = np.concatenate((pot1.h, pot2.h))
        a = np.concatenate((pot1.a, pot2.a))
        b = np.concatenate((pot1.b, pot2.b))

        return Pot(d, prec, h, a, b)


def test_pot(d=10, n=50, var=2):

    mu = randn(d)
    pot = Pot.density2pot(d, var, mu=mu)

    x = randn(d, n)
    logp1 = pot.val(x)
    logp2 = mvnlogpdf(x, d, var, mu=mu)
    assert(np.allclose(logp1, logp2))
    print("passed")


# Class for Minka's clutter model
class Clutter:
    def __init__(self, d, n, alpha=0.25, var_prior=100, var_obs=1, var_noise=10):
        self.d = d
        self.n = n
        self.alpha = alpha
        self.var_prior = var_prior
        self.var_obs = var_obs
        self.var_noise = var_noise
        self.posterior = None
        self.generate_data()
        self.ndims = d
        self.compute_zstar()

    def compute_zstar(self):
        starts = 100000
        z0 = -20 + 40*rand(self.ndims,starts)
        obj = self.logp(z0)
        best = np.argmax(obj)
        z0 = z0[:,best]
        f = lambda z : -self.logp(np.expand_dims(z,axis=1))[0]
        obj = value_and_grad(f)
        rez = scipy.optimize.minimize(obj,z0,method='BFGS',jac=True)
        self.zstar = rez['x']
        self.Hstar = -hessian(f)(self.zstar)

    def generate_data(self):
        n, d = self.n, self.d
        x = np.sqrt(self.var_prior) * randn(d)
        obs = x + np.sqrt(self.var_obs) * randn(n,d)
        clutter = np.sqrt(self.var_noise) * randn(n,d)

        is_obs = rand(n,1) < self.alpha
        is_clutter = np.logical_not(is_obs)

        y = is_obs*obs + is_clutter*clutter

        self.x = x
        self.y = y

    def logp(self,x):
        x = x.reshape(self.d,-1)                        # shape d x n
        logp_x = mvnlogpdf(x, self.d, self.var_prior)   # shape n

        logp_y_obs   = np.log(  self.alpha) + mvnlogpdf(x, self.d, self.var_obs, mu=self.y.T, keepdims=True)  # shape n x k
        logp_y_noise = np.log(1-self.alpha) + mvnlogpdf(self.y.T, self.d, self.var_noise)      # shape n (does not depend on x)
        logp_y = np.logaddexp(logp_y_obs, logp_y_noise) # broadcasts

        logp = logp_x + np.sum(logp_y, axis=1)
        return logp

    def __call__(self,x):
        return self.logp(x)

    def do_inference(self):
        posterior = Pot.density2pot(self.d, self.var_prior)
        noise = Pot.density2pot(self.d, self.var_noise, b=-np.log(1-self.alpha))

        # Do inference
        for i in range(self.n):

            obs = Pot.density2pot(self.d,
                                  self.var_obs,
                                  mu=self.y[i,:],
                                  b=-np.log(self.alpha))

            pots1 = Pot.mult(posterior, obs)

            pots2 = posterior
            pots2.b -= noise.val(self.y[i,:], as_log=True)

            posterior = Pot.stack(pots1, pots2)

        # Save potentials as the "posterior"
        self.posterior = posterior
        self.logZ = logsumexp(-self.posterior.b)

        # Compute component weights, means, variance
        self.weights = np.exp(-self.posterior.b - self.logZ)
        self.mus = -self.posterior.h/self.posterior.prec
        self.vars = 1/self.posterior.prec

        assert(np.isclose(self.weights.sum(), 1))


    def loglik(self):
        if self.posterior is None: self.do_inference()
        return self.logZ

    def logposterior(self, x):
        if self.posterior is None: self.do_inference()
        return self.posterior.val(x) - self.logZ

    def posterior_mean(self):
        if self.posterior is None: self.do_inference()
        return np.sum(self.weights * self.mus, axis=0)

    def posterior_cov(self):
        if self.posterior is None: self.do_inference()

        d = self.d
        m = len(self.weights)
        weights = self.weights.reshape(m,1,1)

        # Reshape for correct broadcasting
        mu  = self.mus.reshape(m,d,1)
        muT = self.mus.reshape(m,1,d)
        sigma2 = self.vars.reshape(m,1,1)
        I = np.eye(d).reshape(1,d,d)

        # E[xx^T] = sigma^2 I + mu mu^T
        component_E_xx = sigma2*I + mu*muT

        E_xx = np.sum(weights * component_E_xx, axis=0)
        mu   = np.sum(weights * mu, axis=0)

        Sigma = E_xx - mu * mu.T
        return Sigma

    def samples(self, k):
        if self.posterior is None: self.do_inference()
        m = self.weights.size
        inds = np.random.choice(m, k, p=self.weights.ravel())
        u = randn(k, self.d)
        return np.sqrt(self.vars[inds])*u + self.mus[inds,:]


if __name__ == "__main__":
    test_mvnlogpdf()
    test_pot()
