# -*- coding: utf-8 -*-
import numpy as np

import scipy.stats as ss
import scipy.optimize as so
from scipy.special import gammaln, psi
from scipy import linalg

#import pandas as pd

import multiprocessing as mp
from itertools import repeat

import h5py as h5
import tempfile
import logging
import time


def main():
    log_level = logging.DEBUG # default logging level
    logging.basicConfig(level=log_level,
        format='%(levelname)s:%(module)s:%(message)s')
    
    (y, x, theta) = gen1()
    
    pool = None
    #pool = mp.Pool(processes=2)
    
    (phi, q) = varlap_opt(y, K=2, seed=10241978, pool=pool)

    save_model('model.hdf5', y, phi, q)

def save_model(fname, y, phi, q):
    
    f = h5.File(fname, 'w')
    
    f.create_dataset('y', data=y)
    
    f.create_dataset('alpha', data=phi['alpha'])
    f.create_dataset('lam', data=phi['lam'])
    f.create_dataset('s2', data=phi['s2'])
    
    f.create_dataset('x', data=q['x'])
    f.create_dataset('sig2x', data=q['sig2x'])
    f.create_dataset('gam', data=q['gam'])
    
    f.create_dataset('EqTheta', data=EqTheta(q['gam']))
    
    f.close()

def load_model(fname):
    
    f = h5.File(fname, 'r')
    
    y = f['y'][...]
    
    phi = {}    
    phi['alpha'] = f['alpha'][...]
    phi['lam'] = f['lam'][...]
    phi['s2'] = f['s2'][...]

    q = {}
    q['x'] = f['x'][...]
    q['sig2x'] = f['sig2x'][...]
    q['gam'] = f['gam'][...]
    
    f.close()

    return (y, phi, q)

def gen1(seed=None):
    if seed is not None: np.random.seed(seed=seed)
    
    s2 = pow(.25,2)
    (J, N, K) = (15, 75, 3)

    x = np.zeros([J, K]);
    x[0:5,  0] = +1
    x[5:10, 0] = -1
    x[0:5,  1] = -1
    x[5:10, 1] = +1
    x[10:15,2] = +1

    theta = np.zeros([K, N])
    theta[0, 0:25] = 0.9
    theta[1, 0:25] = 1 - theta[0, 0:25]
    theta[0, 25:50] = 0.1
    theta[1, 25:50] = 1 - theta[0, 25:50]
    theta[2, 50:75] = 0.9
    theta[0, 50:75] = 1 - theta[2, 50:75]

    y = np.zeros([J,N])
    for i in xrange(N):
        y[:,i] = ss.norm.rvs(loc=np.dot(x, theta[:,i]), scale=np.sqrt(s2))
    
    return (y, x, theta)    

def f_x_j(x_j, Etheta_hat, Etheta2_hat, y_j, s2, lam):
    """ Return the value of the exponential kernel involving x_j. """
    
    N = y_j.shape[0]
    f = 0
    for i in xrange(N):
        Etheta_i = Etheta_hat[:,i]
        Etheta2_i = Etheta2_hat[:,:,i]
        f += np.dot(x_j, Etheta_i)*(y_j[i]/s2) - \
             np.trace(np.dot(np.outer(x_j, x_j), Etheta2_i))/(2*s2)
    f -= np.sum(np.abs(x_j))/lam
    
    return f

def neg_f_x_j(x_j, Etheta_hat, Etheta2_hat, y_j, s2, lam):
    return -f_x_j(x_j, Etheta_hat, Etheta2_hat, y_j, s2, lam)

def del_f_x_j(x_j, Etheta_hat, Etheta2_hat, y_j, s2, lam):
    """ Return the gradient of the exponential kernel involving x_j. """
    
    N = y_j.shape[0]
    delf = 0
    for i in xrange(N):
        delf += Etheta_hat[:,i]*(y_j[i]/s2) - np.dot(Etheta2_hat[:,:,i], x_j)/(s2)
    delf -= np.sign(x_j)/lam
    
    return delf

def del2_f_x_j(x_j, Etheta_hat, Etheta2_hat, y_j, s2, lam):
    """ Return the hessian of the exponential kernel involving x_j. """
    K = x_j.shape[0]
    N = y_j.shape[0]
    
    del2f = np.zeros([K, K])
    for i in xrange(N):
        del2f += -Etheta2_hat[:,:,i]/s2 
    #del2f += -np.diag([1/lam] * K) # dirac delta function can safely be ignored for this optimization
    
    return del2f

def EqTheta(gam):
    """ Return Kx1 first moment of variational distribution q(\theta). """
    if np.ndim(gam) == 1: return gam/np.sum(gam)
    
    (K,N) = gam.shape
    
    gam0 = np.sum(gam,0)
    
    EqTheta = np.zeros([K,N])
    for i in xrange(N):
        EqTheta[:,i] = gam[:,i]/gam0[i]
    return EqTheta


def EqTheta2(gam):
    """ Return KxK second moment matrix of variational distribution q(\theta). """
    gam0 = np.sum(gam,0)
    
    if np.ndim(gam) == 1: return (np.outer(gam, gam) + np.diag(gam)) / (gam0 * (1 + gam0))
    
    (K,N) = gam.shape
    
    EqTheta2 = np.zeros([K,K,N])
    for i in xrange(N):
        EqTheta2[:,:,i] = np.outer(gam[:,i], gam[:,i]) + np.diag(gam[:,i])
        EqTheta2[:,:,i] /= gam0[i] * (1 + gam0[i])
    
    return EqTheta2

def EqlogTheta(gam):
    """ Return Kx1 expected value of log(\theta) under variational distribution q(\theta). """
    gam0 = np.sum(gam,0)
    
    if gam.ndim == 1: return psi(gam) - psi( np.sum(gam) )
    
    (K,N) = gam.shape
    
    EqlogTheta = np.zeros([K,N])
    for i in xrange(N):
        EqlogTheta[:,i] = psi(gam[:,i]) - psi(gam0[i])
    
    return EqlogTheta
    
def EqX2(x, sig2x):
    (J,K) = x.shape
    
    EqX2 = np.zeros([K,K,J])
    for j in xrange(J):
        EqX2[:,:,j] = sig2x[:,:,j] + np.outer(x[j,:], x[j,:])
    return EqX2

def EqabsX(x, sig2x):
    (J,K) = x.shape
    
    EqabsX = np.zeros([J,K])
    for j in xrange(J):
        for k in xrange(K):
            EqabsX[j,k] = np.sqrt(sig2x[k,k,j])*np.sqrt(2.0/np.pi)*np.exp(-0.5*pow(x[j,k],2)/sig2x[k,k,j])
            EqabsX[j,k] += x[j,k]*(1-2*ss.norm.cdf(-x[j,k]/np.sqrt(sig2x[k,k,j]),loc=0,scale=1))
    return EqabsX

def varlap_elbo(y, x, sig2x, gam, alpha, lam, s2):
    (J,N) = y.shape
    K = x.shape[1]
    
    alpha0 = np.sum(alpha)
    gam0 = np.sum(gam,0)
    
    x2 = EqX2(x, sig2x)
    absX = EqabsX(x, sig2x)
    
    theta = EqTheta(gam)
    theta2 = EqTheta2(gam)
    logTheta = EqlogTheta(gam)
    
    EqlogPy = -0.5*N*J*np.log(2*np.pi*s2)       
    for i in xrange(N):
        for j in xrange(J):
            EqlogPy += -(0.5/s2)*(pow(y[j,i],2) -2*y[j,i]*np.dot(x[j,:],theta[:,i]) +np.trace(np.dot(x2[:,:,j], theta2[:,:,i])))
    
    EqlogPx = -J*K*np.log(2*lam)
    for j in xrange(J):
        for k in xrange(K):
            EqlogPx += -absX[j,k]/lam
    
    EqlogPtheta = N*gammaln(alpha0) - N*np.sum(gammaln(alpha))
    for i in xrange(N):
        EqlogPtheta += np.dot(alpha-1, logTheta[:,i])
    
    EqlogQx = 0.0
    for j in xrange(J):
        EqlogQx += -0.5*np.log( linalg.det(sig2x[:,:,j]) )
    
    EqlogQtheta = 0.0
    for i in xrange(N):
        EqlogQtheta += gammaln(gam0[i]) - np.sum(gammaln(gam[:,i])) + np.dot(gam[:,i]-1, logTheta[:,i])
    
    #logging.debug( (EqlogPy, EqlogPx, EqlogPtheta, -EqlogQx, -EqlogQtheta) )
    return EqlogPy + EqlogPx + EqlogPtheta - EqlogQx -EqlogQtheta

def neg_elbo_alpha(log_alpha, y, x, sig2x, gam, lam, s2):
    alpha = np.exp(log_alpha)
    return -varlap_elbo(y, x, sig2x, gam, alpha, lam, s2)

def varlap_elbo_i(y_i, x, sig2x, gam_i, alpha, lam, s2):
    J = y_i.shape[0]
    K = x.shape[1]
    
    alpha0 = np.sum(alpha)
    gam0 = np.sum(gam_i)
    
    x2 = EqX2(x, sig2x)
    absX = EqabsX(x, sig2x)
    
    theta = EqTheta(gam_i)
    theta2 = EqTheta2(gam_i)
    logTheta = EqlogTheta(gam_i)
    
    EqlogPy = -0.5*J*np.log(2*np.pi*s2)       
    for j in xrange(J):
        EqlogPy += -(0.5/s2)*(pow(y_i[j],2) -2*y_i[j]*np.dot(x[j,:],theta) +np.trace(np.dot(x2[:,:,j], theta2)))
    
    EqlogPx = -J*K*np.log(2*lam)
    for j in xrange(J):
        for k in xrange(K):
            EqlogPx += -absX[j,k]/lam
    
    EqlogPtheta = gammaln(alpha0) - np.sum(gammaln(alpha)) + np.dot(alpha-1, logTheta)
    
    EqlogQx = 0.0
    for j in xrange(J):
        EqlogQx += -0.5*np.log( linalg.det(sig2x[:,:,j]) )
    
    EqlogQtheta = gammaln(gam0) - np.sum(gammaln(gam_i)) + np.dot(gam_i-1, logTheta)
    
    #logging.debug( (EqlogPy, EqlogPx, EqlogPtheta, -EqlogQx, -EqlogQtheta) )
    return EqlogPy + EqlogPx + EqlogPtheta - EqlogQx -EqlogQtheta

def neg_varlap_elbo_i(loggam_i, y_i, x, sig2x, alpha, lam, s2):
    gam_i = np.exp(loggam_i)
    return -varlap_elbo_i(y_i, x, sig2x, gam_i, alpha, lam, s2)

def opt_gam_i(args):
    gam_i, y_i, x, sig2x, alpha, lam, s2 = args
    K = gam_i.shape[0]

    bnds = [(-7, 7)] * K # limit gamma to [0.001, 1000]

    res = so.minimize(neg_varlap_elbo_i, np.log(gam_i), \
        args=(y_i, x, sig2x, alpha, lam, s2), bounds=bnds, \
        method='L-BFGS-B')
    

    #res = so.minimize(neg_varlap_elbo_i, np.log(gam_i), \
    #    args=(y_i, x, sig2x, alpha, lam, s2), method='Nelder-Mead')
    #if res.success == False:
    #    res = so.minimize(neg_varlap_elbo_i, np.log(gam_i), \
    #        args=(y_i, x, sig2x, alpha, lam, s2), method='BFGS')
    if res.success == False or np.any ( np.isnan(res.x) ):
	logging.warning("Could not optimize gamma or gamma is NaN.")
        gam_i = np.random.uniform(low=0.1, high=100, size=(K,1))       
        return gam_i[:,0]	

    gam_i = np.exp(res.x)
    return gam_i
        
    
def opt_gam(y, x, sig2x, gam, alpha, lam, s2, pool=None):
    """ Return the optimized Dirichlet variational parameter. """
    (J,N) = np.shape(y)
    
    st = time.time() # time.clock() does not work in multi-threading
    if pool is not None:
        args = zip( gam.T, y.T, repeat(x, N), repeat(sig2x, N), repeat(alpha, N), repeat(lam, N), repeat(s2, N) )
        b = pool.map(opt_gam_i, args)
        gam = np.array(b).T
    else:
        for i in xrange(N):
            args = (gam[:,i], y[:,i], x, sig2x, alpha, lam, s2)
            gam[:,i] = opt_gam_i(args)
        
    logging.debug('Gamma update elapsed time is %0.3f sec for %d samples.' % (time.time() - st, N))
    
    return gam

def opt_x_j(args):
    x_j, y_j, theta, theta2, lam, s2, var = args
    K = x_j.shape[0]
    
    res = so.minimize(neg_f_x_j, x_j, args=(theta, theta2, y_j, s2, lam), method='Nelder-Mead')
    if res.success == False:
        res = so.minimize(neg_f_x_j, x_j, args=(theta, theta2, y_j, s2, lam), method='BFGS')
    if res.success == False:
        logging.warning('Could not optimize objective for feature.' )
        x_j = np.sqrt(var)*np.random.randn(1,K)        
        return x_j[0,:] 

    return res.x

def opt_x(y, x, sig2x, gam, alpha, lam, s2, pool=None):
    (J,N) = np.shape(y)
    
    theta = EqTheta(gam)
    theta2 = EqTheta2(gam)
    #logTheta = EqlogTheta(gam)
    
    #pass variance of y to optimizing function
    #y_flat = y.flatten
    #var = np.var(y_flat)
    y_flat = y.flatten()
    avg = sum(y_flat)/len(y_flat)        
    var = sum((avg-value) ** 2 for value in y_flat)/len(y_flat)
    
    st = time.time()
    if pool is not None:
        args = zip( x, y, repeat(theta,J), repeat(theta2,J), repeat(lam,J), repeat(s2,J), repeat(var,J) )
        b = pool.map(opt_x_j, args)
        x = np.array(b)
    else:
        for j in xrange(J):
            args = ( x[j,:], y[j,:], theta, theta2, lam, s2, var )
            b = opt_x_j(args)
            x[j,:] = np.array(b)
    
    # Compute the covariance for x
    # sig2x is the same for all j since d2f = sum_i Etheta2_i / s2
    for j in xrange(J):
        del2fxj = del2_f_x_j(x[j,:], theta, theta2, y[j,:], s2, lam)
        sig2x[:,:,j] = -linalg.inv(del2fxj)
            
    logging.debug('x update elapsed time is %0.3f sec for %d features.' % (time.time() - st, J))
    
    return (x, sig2x)

def opt_sigma2(y, x, sig2x, gam):
    """ Maximize the Expected complete log-likelihood with respect to lambda """
    
    (J, N) = y.shape
    
    x2 = EqX2(x, sig2x)
    
    theta = EqTheta(gam)
    theta2 = EqTheta2(gam)
    
    s2 = 0.0
    for i in xrange(N):
        for j in xrange(J):
            s2 += pow(y[j,i],2) -2*y[j,i]*np.dot(x[j,:],theta[:,i]) +np.trace(np.dot(x2[:,:,j], theta2[:,:,i]))
    s2 *= (1.0/(N*J))

    if s2 < 1e-6:
        s2 = 1e-6
        logging.warning("s2 optimized to less than 1e-6. s2 set to 1e-6.")
    return s2

def opt_lambda(x, sig2x):
    """ Maximize ELBO w.r.t. lambda """
    (K, J) = x.shape
    return (1.0/(K*J)) * np.sum( np.sum( EqabsX(x, sig2x) ) )

def varlap_opt(y, K, phi=None, q=None, seed=None, pool=None):
    if seed is not None: np.random.seed(seed=seed)
    
    h5file = tempfile.NamedTemporaryFile(suffix='.hdf5')
    logging.info('Storing model updates in %s' % h5file.name)
    
    MAXVARITER = 10
    NORMTOL = 0.1
    MAXITER = 20
    ELBOTOLPCT = 0.1
    
    (J, N) = y.shape
    
    # Initialize model parameters
    if phi is None:
        alpha = np.array( [10] * K )
        lam = 1.0
        s2 = pow(0.25, 2)
    else:
        alpha = phi['alpha']
        lam = phi['lam']
        s2 = phi['s2']
    
    # Initialize variational parameters
    if q is None:
        y_flat = y.flatten()
        avg = sum(y_flat)/len(y_flat)        
        var = sum((avg-value) ** 2 for value in y_flat)/len(y_flat)
        #var = np.var(y_flat)
        x = np.sqrt(var) * np.random.randn(J,K)
        gam = np.random.uniform(low=0.1, high=100, size=(K,N))
        sig2x = np.zeros([K,K,J])
        for j in xrange(J): sig2x[:,:,j] = np.identity(K)
    else:
        gam = q['gam']
        x = q['x']
        sig2x = q['sig2x']
    
    phi = {'alpha':alpha, 'lam':lam, 's2':s2}
    q = {'x':x, 'sig2x':sig2x, 'gam':gam}
    #save_model('initial_value.hdf5', y, phi, q)    
    
    # Initial ELBO
    elbo = [varlap_elbo(y, x, sig2x, gam, alpha, lam, s2)]
    logging.info("Initial ELBO: %0.2f" % elbo[-1])
    
    moditer, delta_elbo_pct = (0, np.inf)
    while moditer < MAXITER and delta_elbo_pct > ELBOTOLPCT:
        # E-step: Update variational distrbution
        variter = 0
        var_elbo = [ elbo[-1] ]
	delta_varelbo_pct = np.inf
        (delta_norm_x, delta_norm_gam) = (np.inf, np.inf)
        while variter < MAXVARITER \
                and (delta_norm_x > NORMTOL or delta_norm_gam > NORMTOL) \
		and delta_varelbo_pct > ELBOTOLPCT:
                
            # Store the previous parameter values
            (gam_prev, x_prev) = (np.copy(gam), np.copy(x))
        
            # Update the variational distribution
            (x, sig2x) = opt_x(y, x, sig2x, gam, alpha, lam, s2, pool=pool)
            gam = opt_gam(y, x, sig2x, gam, alpha, lam, s2, pool=pool)

            # Test for convergence
            var_elbo.append(varlap_elbo(y, x, sig2x, gam, alpha, lam, s2))
            delta_varelbo_pct = 100*(var_elbo[-1] - var_elbo[-2])/abs(var_elbo[-2])
            logging.info("Variational Step ELBO: %0.2f; Percent Change: %0.3f%%" % (var_elbo[-1], delta_varelbo_pct))
        
            delta_norm_gam = linalg.norm(gam - gam_prev)
            delta_norm_x = linalg.norm(x - x_prev)
            logging.debug("||x - x_prev|| = %0.2f; ||gam - gam_prev|| = %0.2f" \
                        % (delta_norm_x, delta_norm_gam))
                    
            variter += 1 # increment the variational iteration count
        
        # M-step: Update model parameters
        s2 = opt_sigma2(y, x, sig2x, gam)
        
        lam = opt_lambda(x, sig2x)
        
        res = so.minimize(neg_elbo_alpha, np.log(alpha), args=(y, x, sig2x, gam, lam, s2), method='Nelder-Mead')
        alpha = np.exp( res.x )
        
        #ibic = bic(y, x, sig2x, gam, alpha, lam, s2)
        iL2 = withinL2(gam)
        elbo.append(varlap_elbo(y, x, sig2x, gam, alpha, lam, s2))
        delta_elbo_pct = 100*(elbo[-1] - elbo[-2])/abs(elbo[-2])
        
        moditer += 1
        
        # Display results for debugging
        logging.info("Iteration %d of %d." % (moditer, MAXITER))
        logging.info("ELBO: %0.2f; L2: %0.2f; Percent Change: %0.2f%%" \
                    % (elbo[-1], iL2, delta_elbo_pct))
        
        logging.info("s2 = %0.2e" % s2)
        logging.info("lam = %0.2f" % lam)
        logging.info(alpha)

        # Store the model for viewing
        phi = {'alpha':alpha, 'lam':lam, 's2':s2}
        q = {'x':x, 'sig2x':sig2x, 'gam':gam}
        save_model(h5file.name, y, phi, q)
        
    phi = {'alpha':alpha, 'lam':lam, 's2':s2}
    q = {'x':x, 'sig2x':sig2x, 'gam':gam}
    
    return (phi, q)
def withinL2(gam):
    (K,N) = gam.shape
    
    theta = EqTheta(gam)
    u = np.array([1.0/K]*K)
    s = 0.0
    for i in xrange(N):
        s += linalg.norm(theta[:,i]-u, ord=2)

    return s
def bic(y, phi, q):
    """ Return the Bayesian Information Criterion for the model."""
    
    (J,N) = y.shape
    K = q['x'].shape[1]
    df = K*J + K*N + K +1 + 1 # df(x) + df(alpha) + df(lam) + df(s2)
    return -2 * varlap_elbo(y, q['x'], q['sig2x'], q['gam'], 
                    phi['alpha'], phi['lam'], phi['s2']) + df*np.log(N)
    
    
if __name__ == "__main__":
    main()
