'''
Author: Vegard G. Jervell
Date: December 2020
Purpose: Implementation of L. J. T. M. KempersXX' model for prediction of Soret coefficients in multicomponent
        mixtures, derived in "A comprehensive thermodynamic theory of the Soret effect in a multicomponent gas,
        liquid, or solid", 2001, Journal of Chemical Physics
        doi : http://dx.doi.org/10.1063/1.1398315
Requires : numpy, scipy, ThermoPack, KineticGas
Note : KineticGas is only 2-component compatible, replacing this module with a module capable of predicting
        thermal diffusion coefficients for multicomponent ideal-gas mixtures
'''

import numpy as np
from scipy.constants import gas_constant
from pycThermopack.pyctp import cubic
from pycThermopack.pyctp import cpa
from scipy.optimize import root
from models.alpha_t0_empirical import Alpha_T0_empirical
from models.kineticgas import KineticGas

class Kempers01:
    def __init__(self, comps='NC10,NC24', x=[0.5, 0.5], eos_key='SRK',
                 temp=300, pres=1e5, phase=1, min_temp=10, total_moles=1, cpa_flag=False,
                 alpha_t0_empirical=False, alpha_t0_method='wheights', alpha_t0_N = 7):
        '''
        KempersXX 2001 model, currently only implemented for binary systems
        :param comps (str): comma separated list of components
        :param x (1darray): list of mole fractions
        :param eos_key (str): key indicating what equation of state to use
        :param temp (float > 0): Temperature [K]
        :param pres (float > 0): Pressure [Pa]
        :param phase: Phase of mixture, used for calculating dmudn_TP, see thermo.thermopack for phase identifiers
        :param min_temp (float > 0): minimum temp for thermopack numerical solver
        :param total_moles (float): Total number of moles in system (should not effect results)
        :param cpa_flag (bool) : True if using cpa equation of state
        :param alpha_t0_empirical (bool): True if using empirical alpha_T0 values
        :param alpha_t0_method (str): what method to use for empirical alpha_T0 values (see Alpha_T0_empirical)
        :param alpha_t0_N (int > 0): Order of approximation when using analytical alpha_T0 values
        '''
        self.comps = comps
        self.x = np.array(x)
        self.temp = temp
        self.pres = pres
        self.phase = phase
        self.min_temp = min_temp
        self.total_moles = total_moles

        if cpa_flag == True:
            self.eos = cpa.cpa()
        else:
            self.eos = cubic.cubic()

        self.eos.init(comps, eos_key)
        self.eos.set_tmin(min_temp)

        self.alpha_T0_empirical = alpha_t0_empirical
        if alpha_t0_empirical:
            self.alpha_T0_method = alpha_t0_method
            self.kinetic_gas = Alpha_T0_empirical(comps, method=alpha_t0_method)
        else:
            self.alpha_T0_N = alpha_t0_N
            self.kinetic_gas = KineticGas(comps, self.eos, mole_fracs=x, N=alpha_t0_N)

        if len(comps.split(',')) == 2:
            self.get_soret = self.get_soret_multi_cov
        else:
            self.get_soret = self.get_soret_multi_cov


    def reset_alpha_t0(self):
        #reset alpha_t0_getter with current mole fractions
        if self.alpha_T0_empirical:
            self.kinetic_gas = Alpha_T0_empirical(self.comps, mole_fracs=self.x, method=self.alpha_T0_method)
        else:
            self.kinetic_gas = KineticGas(self.comps, self.eos, mole_fracs=self.x, N=self.alpha_T0_N)

    def set_eos(self, eos_key, cpa_flag=False):
        '''
        Change equation of state
        :param eos_key (str): new equation of state key
        :param cpa_flag (bool): True if using cpa equation of state
        '''
        if cpa_flag is True:
            self.eos = cpa.cpa()
        else:
            self.eos = cubic.cubic()

        self.eos.init(self.comps, eos_key)
        self.eos.set_tmin(self.min_temp)

    def get_soret_binary_cov(self):
        '''
        Get Soret Coefficient at current settings for binary system
        :return: (float) Soret Coefficient
        '''
        _, v = self.eos.specific_volume(self.temp, self.pres, self.x, self.phase, dvdn=True)
        _, h = self.eos.enthalpy(self.temp, self.pres, self.x, self.phase, dhdn=True)
        _, h0 = self.eos.enthalpy(self.temp, 1e-5, self.x, 2, dhdn=True)
        v1, v2 = v
        h1, h2 = h
        h10, h20 = h0
        x1, x2 = self.x
        R = gas_constant

        dmudx = self.dmudx_TP()
        alpha_T0 = self.kinetic_gas.alpha_T0(self.temp)[0]

        alpha_T1 = (v1 * (h2 - h20) - v2 * (h1 - h10))/((v1 * x1 + v2 * x2) * x1 * dmudx[0,0]) \
                   + R * self.temp * alpha_T0 / (x1 * dmudx[0,0])

        return np.array([alpha_T1 /self.temp, - alpha_T1 / self.temp])

    def get_soret_multi_cov(self):
        '''
        Get soret coefficient at current settings for multicomponent system, not implemented
        :return: (ndarray) soret coefficients
        '''

        R = gas_constant
        v, dvdn = self.eos.specific_volume(self.temp, self.pres, self.x, self.phase, dvdn=True)
        h, dhdn = self.eos.enthalpy(self.temp, self.pres, self.x, self.phase, dhdn=True)
        h0, dh0dn = self.eos.enthalpy(self.temp, 1e-5, self.x, 2, dhdn=True)

        dmudx = self.dmudx_TP()
        alpha_T0 = self.kinetic_gas.alpha_T0(self.temp)

        initial_guess = alpha_T0 * self.temp

        N = len(self.x)
        def eq_set(alpha):
            eqs = np.zeros(N)
            for i in range(N-1):
                eqs[i] = ((dhdn[-1] - dh0dn[-1])/dvdn[-1]) - ((dhdn[i] - dh0dn[i])/dvdn[i])\
                         + R * self.temp * ((alpha_T0[i] * (1 - self.x[i]) / dvdn[i])
                                            - (alpha_T0[-1] * (1 - self.x[-1])/dvdn[-1]))\
                         - sum((dmudx[i, j]/dvdn[i] - dmudx[-1, j]/dvdn[-1]) * self.x[j] * (1 - self.x[j]) * alpha[j]
                               for j in range(N - 1))

            eqs[N-1] = sum(self.x * (1 - self.x) * alpha)
            return eqs


        solved = root(eq_set, initial_guess)
        if solved.success is False:
            print('Solution did not converge for composition :', self.x, ', Temperature :', self.temp)
        alpha = solved.x

        soret = alpha / self.temp

        return soret

    def get_soret_comp(self, x):
        '''
        Args:
            x : array-like
                mole fractions of components, if N - 1 components are given, N is calculated implicitly
                rows are compositions, columns are components
                [[x1, x2, ..., xN],
                 [x1, x2, ..., xN]]

        return:
            tuple of floats or ndarrays, matching shape of input : soret coefficient(s) at given composition(s)
        '''
        old_mole_fracs = [frac for frac in self.x]  # take care of the values from initialization

        x = np.array(x)

        if len(x.shape) == 1:
            if x.shape[0] != len(self.x):
                xN = 1 - x
                x = np.concatenate((np.vstack(x), np.vstack(xN)), axis=1)
            elif x.shape[0] == len(self.x):
                x = [x]

        elif x.shape[1] == len(self.x) - 1:
            xN = 1 - np.sum(x, axis=1)
            x = np.concatenate((x, np.vstack(xN)), axis=1)
        elif x.shape[1] == len(self.x):
            pass
        else:
            raise IndexError('x must contain N-1 or N mole fractions')

        # Allocate some memory
        soret = np.empty(x.shape, float)

        for i, fracs in enumerate(x):
            self.x = fracs
            self.reset_alpha_t0()
            soret[i] = self.get_soret()

        self.x = np.array(old_mole_fracs)  # reset values from initialization

        return soret.transpose()

    def get_soret_temp(self, temps):
        '''
        Args:
            temps (int, float or array-like) : temperature(s) [K] to get soret-coefficients for

        return:
            tuple of floats or arrays, matching shape of input : Soret-coefficient(s) at given temperature(s)
        '''
        old_temp = self.temp

        if type(temps) in (list, np.ndarray):
            # Allocate some memory
            soret = np.empty((len(temps), len(self.x)))

            for i, T in enumerate(temps):
                self.temp = T
                soret[i] = self.get_soret()
        else:
            self.temp = temps
            soret = self.get_soret()

        self.temp = old_temp

        return soret.transpose()

    def get_soret_pres(self, pressures):
        '''
            Args:
                pressures (float or array-like) : pressure(s) [Pa] to get soret-coefficients for

            return:
                tuple of floats or arrays, matching shape of input : Soret-coefficient(s) at given pressure(s)
        '''
        old_pres = self.pres

        if type(pressures) in (list, np.ndarray):
            # Allocate some memory
            soret = np.empty((len(pressures), len(self.x)))

            for i, pres in enumerate(pressures):
                self.pres = pres
                soret[i] = self.get_soret()
        else:
            self.pres = pressures
            soret = self.get_soret()

        self.pres = old_pres

        return soret

    def dmudn_TP(self):
        #calculate dmudn at constant temperature and pressure

        v, dvdn = self.eos.specific_volume(self.temp, self.pres, self.x, self.phase, dvdn=True)
        mu, dmudn_TV = self.eos.chemical_potential_tv(self.temp, v * self.total_moles,
                                                      self.x * self.total_moles, dmudn=True)
        pres, dpdn = self.eos.pressure_tv(self.temp, v * self.total_moles, self.x * self.total_moles, dpdn=True)

        return dmudn_TV - np.tensordot(dpdn, dvdn, axes=0)

    def dmudx_TP(self):
        #calculate dmudx at constant temperature and pressure

        dmudn = self.dmudn_TP()

        M1 = (np.tensordot(np.ones(len(self.x)), -1 / self.x, axes=0) +
              (np.identity(len(self.x)) * (1 / self.x + 1 / (1 - self.x)))) * self.total_moles #* (1 / len(self.x))

        return np.dot(dmudn, M1)