from pyctp import saftvrmie
from scipy.optimize import minimize
import numpy as np
import warnings

class NumericSoret:
    def __init__(self, comps, dT=0.1, nT=1e3, bulb_ratio=1):
        '''
        :param comps:
        :param dT:
        :param nT:
        :param bulb_ratio: r = V_A / V_B
        '''
        self.eos = saftvrmie.saftvrmie()
        self.eos.init(comps)
        self.dT = dT # Can safely be changed without using set-function
        self.nT = nT # Can safely be changed without using set-function
        self.bulb_ratio = bulb_ratio # Can safely be changed without using set-function

    def set_HS(self, val):
        self.eos.model_control_chain(not val)
        self.eos.model_control_a3(not val)
        self.eos.model_control_a2(not val)
        self.eos.model_control_a1(not val)

    def set_bulb_values(self, T, rho):
        self.V = self.nT / rho
        self.T_A = T - 0.5 * self.dT
        self.T_B = T + 0.5 * self.dT
        self.V_A = self.V * self.bulb_ratio / (self.bulb_ratio + 1)
        self.V_B = self.V / (self.bulb_ratio + 1)
    
    def get_dn(self, T, rho, x, constraint='p'):
        '''
        :param T:
        :param rho:
        :param x:
        :param constraint: Can be 'p' for equal pressure, 'V' for equal volume, 'V_bulb' for both volumes equal to V_bulb
        :return:
        '''

        self.set_bulb_values(T, rho)
        n = self.nT * x

        def Helmholtz(dn):
            nA = (n - dn) / 2
            nB = (n + dn) / 2

            helmholtz_A, = self.eos.helmholtz_tv(self.T_A, self.V_A, nA)
            helmholtz_B, = self.eos.helmholtz_tv(self.T_B, self.V_B, nB)

            return helmholtz_A + helmholtz_B

        def Helmholtz_jacobian(dn):
            nA = (n - dn) / 2
            nB = (n + dn) / 2

            mu_A, = self.eos.chemical_potential_tv(self.T_A, self.V_A, nA)
            mu_B, = self.eos.chemical_potential_tv(self.T_B, self.V_B, nB)

            return - mu_A + mu_B

        def Helmholtz_hessian(dn):
            nA = (n - dn) / 2
            nB = (n + dn) / 2

            _, dmuAdn = self.eos.chemical_potential_tv(self.T_A, self.V_A, nA, dmudn=True)
            _, dmuBdn = self.eos.chemical_potential_tv(self.T_B, self.V_B, nB, dmudn=True)

            return - dmuAdn + dmuBdn

        def equal_pressure_condition(dn):
            nA = (n - dn) / 2
            nB = (n + dn) / 2
            pA, = self.eos.pressure_tv(self.T_A, self.V_A, nA)
            pB, = self.eos.pressure_tv(self.T_B, self.V_B, nB)
            return pA - pB

        def equal_volume_condition(dn):
            nA = (n - dn) / 2
            nB = (n + dn) / 2
            _, vA = self.eos.specific_volume(self.T_A, p, xA, 1, dvdn=True)
            _, vB = self.eos.specific_volume(self.T_B, p, xB, 1, dvdn=True)
            return sum(vA * nA - vB * nB)

        def bulb_volume_condition_A(dn):
            nA = (n - dn) / 2
            _, vA = self.eos.specific_volume(self.T_A, p, xA, 1, dvdn=True)
            return sum(vA * nA) - self.V_A

        def bulb_volume_condition_B(dn):
            nB = (n + dn) / 2
            _, vB = self.eos.specific_volume(self.T_B, p, xB, 1, dvdn=True)
            return sum(vB * nB) - self.V_B

        if constraint == 'p':
            constraints = {'type': 'eq', 'fun': equal_pressure_condition}
        elif constraint == 'V':
            constraints = {'type': 'eq', 'fun': equal_volume_condition}
        elif constraint == 'V_bulb':
            constraints = ({'type': 'eq', 'fun': equal_volume_condition_A},
                           {'type': 'eq', 'fun': equal_volume_condition_B})
        elif constraint is None:
            constraints = None
        else:
            raise KeyError("Constraint must be either 'p', 'V', 'V_bulb' or None but was "+str(constraint))

        init_guess = n * (1 - self.bulb_ratio) / (1 + self.bulb_ratio)
        sol = minimize(Helmholtz, jac=Helmholtz_jacobian, x0=init_guess,
                       bounds=tuple((-ni, ni) for ni in n), constraints=constraints)

        dn = sol.x
        if sol.success is False:
            warnings.warn('Numeric minimizer did not converge for (T, rho, x) = '+str(T)+', '+str(rho)+', '+str(x))
        return dn

    def get_dx(self, T, rho, x, constraint='p'):
        n = self.nT * x
        dn = self.get_dn(T, rho, x, constraint=constraint)
        nA = (n - dn) / 2
        nB = (n + dn) / 2
        xA = nA / sum(nA)
        xB = nB / sum(nB)
        return xB - xA
    
    def get_Soret(self, T, rho, x, constraint='p'):
        dx = self.get_dx(T, rho, x, constraint=constraint)
        return - dx / (x * (1 - x) * self.dT)

    def get_pressure(self, T, rho, x, constraint='p'):
        dn = self.get_dn(T, rho, x, constraint=constraint)
        n = self.nT * x
        dn = self.get_dn(T, rho, x)
        nA = (n - dn) / 2
        nB = (n + dn) / 2
        xA = nA / sum(nA)
        xB = nB / sum(nB)
        return 0.5 * (self.eos.pressure_tv(self.T_A, self.V_A, nA)[0] + self.eos.pressure_tv(self.T_B, self.V_B, nB)[0])

    def kempers89_condition(self, T, rho, x):
        self.set_bulb_values(T, rho)
        n = self.nT * x
        dn = self.get_dn(T, rho, x)
        nA = (n - dn) / 2
        nB = (n + dn) / 2
        xA = nA / sum(nA)
        xB = nB / sum(nB)

        mu_A, = self.eos.chemical_potential_tv(self.T_A, self.V_A, nA)
        mu_B, = self.eos.chemical_potential_tv(self.T_B, self.V_B, nB)

        pA, = self.eos.pressure_tv(self.T_A, self.V_A, nA)
        pB, = self.eos.pressure_tv(self.T_B, self.V_B, nB)
        p = 0.5 * (pA + pB) # To minimize the effects of numerical error

        _, vA = self.eos.specific_volume(self.T_A, p, xA, 1, dvdn=True)
        _, vB = self.eos.specific_volume(self.T_B, p, xB, 1, dvdn=True)

        eq_set = (mu_A[:-1] / self.T_A) - (mu_B[:-1] / self.T_B) - ((vA[:-1] + vB[:-1]) / (vA[-1] + vB[-1])) * ((mu_A[-1] / self.T_A) - (mu_B[-1] / self.T_B))
        return eq_set
        