from pyctp import saftvrmie
from cpp.hs_eos.carnahanstarling import CarnahanStarling
from pykingas import KineticGas
from integrator import eos_differentials as ed
import numpy as np
from scipy.optimize import root
from scipy.constants import Boltzmann, Avogadro, gas_constant as R
import warnings

class Kempers_HS:

    def __init__(self, comps, sigma=None):
        self.comps = comps
        self.eos = saftvrmie.saftvrmie()
        self.eos.init(comps)
        self.kin = KineticGas(comps, sigma=sigma, BH=True)
    
    def get_Soret(self, T, rho, x):
        ST_list_HS = self.kin.alpha_T0(T, 1 / rho, x, N=4) / T
        nT = 1
        V = nT / rho
        n = nT * x
        eos = self.real_eos()

        p, = eos.pressure_tv(T, V, n)
        u, = eos.internal_energy_tv(T, V, n)
        phase = eos.two_phase_uvflash(x, u, V)[-1]
        if phase == 0:
            return np.array([np.nan for _ in ST_list_HS])
        #phase = eos.guess_phase(T, p, x)
        mu, = eos.chemical_potential_tv(T, V, n)
        dmudx_Tp = ed.dmudx_TP(eos, T, V, n, phase)
        _, S = eos.entropy(T, p, x, phase, dsdn=True)
        _, dvdn = eos.specific_volume(T, p, x, phase, dvdn=True)

        eos = self.HS_eos()
        p_HS, = eos.pressure_tv(T, V, n)
        phase_HS = eos.guess_phase(T, p, x)
        mu_HS, = eos.chemical_potential_tv(T, V, n)
        dmudx_Tp_HS = ed.dmudx_TP(eos, T, V, n, phase)
        _, S_HS = eos.entropy(T, p, x, phase, dsdn=True)
        _, dvdn_HS = eos.specific_volume(T, p_HS, x, phase_HS, dvdn=True)

        def eq_set(ST_list):
            eqs = np.zeros_like(x)
            eqs[-1] = sum(ST_list * (x * (1 - x)))
            #eqs[:-1] = (1 / T) * (mu_HS[:-1] - mu[:-1] - dvdn[:-1] * (mu_HS[-1] - mu[-1]) / dvdn[-1]) + dmudT_HS[:-1] - dmudT[:-1] - (dvdn[:-1] / dvdn[-1]) * (dmudT_HS[-1] - dmudT[-1])\
            real = dvdn[:-1] * (S[-1] + mu[-1] / T) - dvdn[-1] * (S[:-1] + mu[:-1] / T)
            HS = dvdn_HS[:-1] * (S_HS[-1] + mu_HS[-1] / T) - dvdn_HS[-1] * (S_HS[:-1] + mu_HS[:-1] / T)

            for i in range(len(x) - 1):
                for j in range(len(x)):
                    real += (dvdn[i] * dmudx_Tp[-1, j] - dvdn[-1] * dmudx_Tp[i, j]) * x[j] * (1 - x[j]) * ST_list[j]
                    HS += (dvdn_HS[i] * dmudx_Tp_HS[-1, j] - dvdn_HS[-1] * dmudx_Tp_HS[i, j]) * x[j] * (1 - x[j]) * ST_list_HS[j]
                    #eqs[i] += - ((dmudx_Tp[i, j] - (dvdn[i] / dvdn[-1]) * dmudx_Tp[-1, j]) * ST_list[j] - ((dmudx_Tp_HS[i, j] - (dvdn[i] / dvdn[-1]) * dmudx_Tp_HS[-1, j]) * ST_list_HS[j])) * x[j] * (1 - x[j])

            eqs[:-1] = real - HS

            return eqs

        sol = root(eq_set, ST_list_HS)
        if not sol.success:
            msg = 'Kempers eq. set did not converge for\n'\
                    +'T = '+str(T)\
                    +'\nx = '+str(x)\
                    +'\nrho = '+str(rho)
            warnings.warn(msg)
            sol.x = [np.nan for x in sol.x]

        return sol.x

    def HS_eos(self):
        self.eos.model_control_chain(False)
        # self.eos.redefine_critical_parameters(silent=False)
        self.eos.model_control_a3(False)
        # self.eos.redefine_critical_parameters(silent=False)
        self.eos.model_control_a2(False)
        self.eos.redefine_critical_parameters(silent=False)
        self.eos.model_control_a1(False)
        #self.eos.redefine_critical_parameters(silent=False)
        return self.eos

    def real_eos(self):
        self.eos.model_control_a1(True)
        # self.eos.redefine_critical_parameters(silent=False)
        self.eos.model_control_a2(True)
        # self.eos.redefine_critical_parameters(silent=False)
        self.eos.model_control_a3(True)
        # self.eos.redefine_critical_parameters(silent=False)
        self.eos.model_control_chain(True)
        self.eos.redefine_critical_parameters(silent=False)
        return self.eos

    def get_HS_values(self, T, rho, x):
        raise NotImplementedError
        return (dmudx_Tp_HS, dmudT_np_HS)

    def get_real_values(self, T, rho, x):
        raise NotImplementedError
        return (dmudx_Tp, dmudT_np, dvdn)

class Kempers89:
    def __init__(self, comps):
        self.comps = comps
        self.eos = saftvrmie.saftvrmie()
        self.eos.init(comps)
        self.kin = KineticGas(comps)

    def set_HS(self, val): # True => use HS-eos, False => use saftvrmie
            self.eos.model_control_chain(not val)
            self.eos.model_control_a3(not val)
            self.eos.model_control_a2(not val)
            self.eos.model_control_a1(not val)

    def get_Soret(self, T, rho, x):
        ST_list_HS = self.kin.soret(T, 1 / rho, x, N=4)
        nT = 1
        V = nT / rho
        n = nT * x
        eos = self.eos

        p, = eos.pressure_tv(T, V, n)
        phase = eos.guess_phase(T, p, x)
        mu, = eos.chemical_potential_tv(T, V, n)
        dmudx_Tp = ed.dmudx_TP(eos, T, V, n, phase)
        _, S = eos.entropy(T, p, x, phase, dsdn=True)
        _, dvdn = eos.specific_volume(T, p, x, phase, dvdn=True)

        def eq_set(ST_list):
            eqs = np.zeros_like(x)
            eqs[-1] = sum(ST_list * (x * (1 - x)))
            real = dvdn[:-1] * (S[-1] + mu[-1] / T) - dvdn[-1] * (S[:-1] + mu[:-1] / T)

            for i in range(len(x) - 1):
                for j in range(len(x)):
                    real += (dvdn[i] * dmudx_Tp[-1, j] - dvdn[-1] * dmudx_Tp[i, j]) * x[j] * (1 - x[j]) * ST_list[j]

            eqs[:-1] = real

            return eqs

        sol = root(eq_set, ST_list_HS)
        if not sol.success:
            msg = 'Kempers eq. set did not converge for\n'\
                    +'T = '+str(T)\
                    +'\nx = '+str(x)\
                    +'\nrho = '+str(rho)
            warnings.warn(msg)
            sol.x = [np.nan for x in sol.x]

        return sol.x
