'''
Author: Vegard G. Jervell
Date: December 2020
Purpose: Parent class containing some general procedures to be used in both Kempers89 and Kempers01
Requires: numpy, ThermoPack
Note: This is a virtual class, and will not do anything exciting if initialized on its own.
'''

import numpy as np
from pycThermopack.pyctp import cubic
from models.kineticgas import KineticGas
from scipy.constants import gas_constant
import warnings

class Kempers(object):
    def __init__(self, comps, eos, x=[0.5, 0.5], temp=300, pres=1e5, phase=1):
        '''
        KempersXX parent class, contains interface for retrieving soret-coefficient for spans of temperatures, pressures or compositions
        and some general initialization procedures that are common for the two KempersXX' models
        :param comps (str): comma separated list of components
        :param x (1darray): list of mole fractions
        :param eos (ThermoPack): Initialized Equation of State object, initialized with components 'comp'
        :param temp (float > 0): Temperature [K]
        :param pres (float > 0): Pressure [Pa]
        :param phase: Phase of mixture, used for calculating dmudn_TP, see thermo.thermopack for phase identifiers
        '''

        self.comps = comps
        self.x = np.array(x)
        self.temp = temp
        self.pres = pres
        self.phase = phase
        self.total_moles = 1 #Dummy value, because some ThermoPack methods require it. Does not effect output.

        eoscomp_inds = [eos.getcompindex(comp) for comp in self.comps.split(',')]
        if -1 in eoscomp_inds:
            warnings.warn('Equation of state and KempersXX must be initialized with same components.\n'
                          "I'm initializing using SRK with "+comps+" now to avoid crashing")
            self.eos = cubic.cubic()
            self.eos.init(self.comps, 'SRK')

        elif any(np.array(eoscomp_inds) != sorted(np.array(eoscomp_inds))):
            eoscomps = ','.join(eos.get_comp_name(i) for i in sorted(eoscomp_inds))
            warnings.warn('Equation of state and KempersXX must be initialized with same components in the same order\n'
                            'but are initialized with ' + eoscomps + ' and ' + self.comps+'.\n'
                            "I'm initializing using SRK with "+comps+" now to avoid crashing")
            self.eos = cubic.cubic()
            self.eos.init(self.comps, 'SRK')

        else:
            self.eos = eos

    #Some standard set-methods (use these! they handle some important stuff!)
    def set_min_temp(self, temp):
        self.min_temp = temp
        self.eos.set_tmin(temp)

    def set_temp(self,temp):
        self.temp = temp

    def set_pres(self, pres):
        self.pres = pres

    def set_mole_fracs(self, x):
        self.x = np.array(x)
        self.reset_alpha_t0()

    def set_eos(self, eos):
        '''
        Change equation of state
        :param eos_key (str): new equation of state key
        '''
        eoscomp_inds = [eos.getcompindex(comp) for comp in self.comps.split(',')]
        if -1 in eoscomp_inds:
            warnings.warn('Equation of state and KempersXX must have the same components.'
                          "I'm just not going to change anything!")
            return 1

        elif any(np.array(eoscomp_inds) != sorted(np.array(eoscomp_inds))):
            eoscomps = ','.join(eos.get_comp_name(i) for i in sorted(eoscomp_inds))
            warnings.warn('Equation of state and KempersXX must have the same components in the same order'
                          'but are given ' + eoscomps + ' and ' + self.comps + '.'
                            "I'm not changing anything!")
            return 1
        else:
            self.eos = eos
            return 0

    def set_comps(self, comps, eos):
        '''
        Change components
        :param comps: Comma separated list of components
        :param eos: Initialized equation of state, with same components as 'comp'
        '''
        #Use the check in self.set_eos() to determine if input is valid, only change if it is
        old_comps = self.comps
        self.comps = comps
        if self.set_eos(eos) == 0:
            pass
        else:
            self.comps = old_comps

    def reset_alpha_t0(self):
        #Overridden in Kempers01
        pass

    def get_soret_cov(self, kin=False):
        # Get soret coefficient at current settings, center of volume frame of reference
        # Overridden in Kempers01 and Kempers89
        pass

    def get_soret_com(self, kin=False):
        # Get soret coefficient at current settings, center of mass frame of reference
        # Overridden in Kempers01 and Kempers89
        pass

    def get_soret_avg(self, kin=False):
        # Get average soret coefficient computed with 'cov' and 'com' frame of reference
        return 0.5 * (self.get_soret_cov(kin=kin) + self.get_soret_com(kin=kin))

    def get_soret_comp(self, x, mode='cov', kin=False):
        '''
        Get soret-coefficients for a range of compositions
        Args:
            x : array-like
                mole fractions of components, if N - 1 components are given, N is calculated implicitly
                row i is composition i, column j is mole fraction of component j
                [[x1, x2, ..., xN], #composition 1
                 [x1, x2, ..., xN]] #composition 2

        return:
            tuple of floats or ndarrays, matching shape of input : soret coefficients at given composition(s)
        '''
        if mode == 'cov':
            get_soret = self.get_soret_cov
        elif mode == 'com':
            get_soret = self.get_soret_com
        elif mode == 'avg':
            get_soret = self.get_soret_avg
        else:
            warnings.warn("mode must be either 'cov', 'com' or 'avg', not "+str(mode)+" defaulting back to 'cov'")
            get_soret = self.get_soret_cov

        old_mole_fracs = [frac for frac in self.x]  # take care of the values from initialization

        x = np.array(x)

        if len(x.shape) == 1:
            if x.shape[0] != len(self.x):
                xN = 1 - x
                x = np.concatenate((np.vstack(x), np.vstack(xN)), axis=1)
            elif x.shape[0] == len(self.x):
                x = [x]

        elif x.shape[1] == len(self.x) - 1:
            xN = 1 - np.sum(x, axis=1)
            x = np.concatenate((x, np.vstack(xN)), axis=1)
        elif x.shape[1] == len(self.x):
            pass
        else:
            raise IndexError('x must contain N-1 or N mole fractions')

        if kin is True:
            soret = np.empty(x.shape, float)
            kin_contrib = np.empty(x.shape, float)
            for i, fracs in enumerate(x):
                self.x = fracs
                self.reset_alpha_t0()
                soret[i], kin_contrib[i] = get_soret(kin=kin)

            self.x = np.array(old_mole_fracs)  # reset values from initialization

            return soret.transpose(), kin_contrib.transpose()
        else:
            soret = np.empty(x.shape, float)
            for i, fracs in enumerate(x):
                self.x = fracs
                self.reset_alpha_t0()
                soret[i] = get_soret()

            self.x = np.array(old_mole_fracs)  # reset values from initialization

            return soret.transpose()

    def get_soret_temp(self, temps, mode='cov', kin=False):
        '''
        Get soret coefficients for a range of temperatures
        Args:
            temps (int, float or array-like) : temperature(s) [K] to get soret-coefficients for

        return:
            tuple of floats or arrays, matching shape of input : Soret-coefficients at given temperature(s)
        '''
        if mode == 'cov':
            get_soret = self.get_soret_cov
        elif mode == 'com':
            get_soret = self.get_soret_com
        elif mode == 'avg':
            get_soret = self.get_soret_avg
        else:
            warnings.warn("mode must be either 'cov', 'com' or 'avg', not "+str(mode)+" defaulting back to 'cov'")
            get_soret = self.get_soret_cov

        old_temp = self.temp

        if type(temps) in (list, np.ndarray):
            if kin is True:
                soret = np.empty((len(temps), len(self.x)))
                kin_contrib = np.empty((len(temps), len(self.x)))
                for i, T in enumerate(temps):
                    self.temp = T
                    soret[i], kin_contrib[i] = get_soret(kin=kin)
            else:
                soret = np.empty((len(temps), len(self.x)))
                for i, T in enumerate(temps):
                    self.temp = T
                    soret[i] = get_soret()

        else:
            self.temp = temps
            if kin is True:
                soret, kin_contrib = get_soret(kin=kin)
            else:
                soret = get_soret()

        self.temp = old_temp

        if kin is True:
            return soret.transpose(), kin_contrib.transpose()

        else:
            return soret.transpose()

    def get_soret_pres(self, pressures, mode='cov'):
        '''
        Get soret coefficients for a range of pressures
            Args:
                pressures (float or array-like) : pressure(s) [Pa] to get soret-coefficients for

            return:
                tuple of floats or arrays, matching shape of input : Soret-coefficients at given pressure(s)
        '''

        if mode == 'cov':
            get_soret = self.get_soret_cov
        elif mode == 'com':
            get_soret = self.get_soret_com
        elif mode == 'avg':
            get_soret = self.get_soret_avg
        else:
            warnings.warn("mode must be either 'cov', 'com' or 'avg', not "+str(mode)+" defaulting back to 'cov'")
            get_soret = self.get_soret_cov

        old_pres = self.pres #take care of initialization-value

        if type(pressures) in (list, np.ndarray):

            soret = np.empty((len(pressures), len(self.x)))
            for i, pres in enumerate(pressures):
                self.pres = pres
                soret[i] = get_soret()

        else:
            self.pres = pressures
            soret = get_soret()

        self.pres = old_pres #reset to initialization-value

        return soret

    def dmudn_TP(self):
        '''
        Calculate chemical potential derivative with respect to number of moles at constant temperature and pressure
        :return: ndarray, dmudn[i,j] = dmu_idn_j
        '''

        v, dvdn = self.eos.specific_volume(self.temp, self.pres, self.x, self.phase, dvdn=True)
        mu, dmudn_TV = self.eos.chemical_potential_tv(self.temp, v * self.total_moles,
                                                      self.x * self.total_moles, dmudn=True)
        pres, dpdn = self.eos.pressure_tv(self.temp, v * self.total_moles, self.x * self.total_moles, dpdn=True)

        return dmudn_TV - np.tensordot(dpdn, dvdn, axes=0)

    def dmudx_TP(self):
        '''
        Calculate chemical potential derivative with respect to mole fraction of components
        at constant temperature and pressure
        :return: ndarray, dmudx[i,j] = dmu_idn_j
        '''
        dmudn = self.dmudn_TP()

        M1 = (np.tensordot(np.ones(len(self.x)), -1 / self.x, axes=0) +
              (np.identity(len(self.x)) * (1 / self.x + 1 / (1 - self.x)))) * self.total_moles

        return np.dot(dmudn, M1)