from pycThermopack.pyctp import cubic
from models.kempers89 import Kempers89 as Kempers89
import numpy as np
import matplotlib.pyplot as plt
from scipy.integrate import quad

def moles_to_fracs(moles):
    return moles/sum(moles)

def check_vs_kempers():
    #model = Kempers89(comps='C1,C3', pres=5.6*1e7, temp=346.08, eos_key='SRK', phase=1)

    # x_axis = np.array([34, 35, 42, 49, 51, 58, 63]) * 0.01
    # ST1, ST2 = model.get_soret_comp(x_axis).transpose()
    #
    # alpha1 = ST1 * model.temp / (1 - x_axis)
    #
    # for x, a in zip(x_axis, alpha1):
    #     print(round(x, 2), ' : ', a)

    eos_list = ['SRK', 'PR', 'PT', 'VdW', 'SW']
    comps = 'CYCLOHEX,C3'
    print('Thermal diffusion factors for', comps)
    for eos in eos_list:
        model2 = Kempers89(comps=comps, temp=293, x=[0.55, 0.45], eos_key=eos, phase=1)
        ST1, ST2 = model2.get_soret_cov()
        print(eos, ' : ', ST1 * model2.temp)

def check_flash():
    model = cubic.cubic()
    model.init('C1,C3', 'PR')
    model.set_tmin(50)
    fracs = np.array([34, 35, 42, 49, 51, 58, 63]) * 0.01
    for x in fracs:
        phase = model.guess_phase(346.08, 5.6*1e7, [x,1-x])
        print(model.get_phase_type(phase))

def check_enthalpies(comp):
    model = cubic.cubic()
    model.init(comp, 'SRK')

    x = np.array([0.9,0.1])
    H, dhdn = model.enthalpy(298, 1, x, 2, dhdn=True)
    print(dhdn)
    print(H)
    print(sum(dhdn * x))

def check_dmudx(x):
    x0 = np.array([x,1-x])
    nT = 1
    dn = 1e-3
    dx1 = 1e-10
    n0 = x0 * nT
    n1 = n0 + np.array([dn, 0])

    x1 = x0 + np.array([dx1, -dx1])
    n0_x1 = x1 * nT
    n1_x1 = n0_x1 + np.array([dn, 0])

    dx1, dx2 = x1 - x0

    delta_x0 = x0[1] - x0[0]
    delta_x1 = x1[1] - x1[0]

    delta_x = delta_x1 - delta_x0

    eos = cubic.cubic()
    eos.init('NC24,NC10','SRK')
    model = Kempers89(comps='NC24,NC10', eos_key='SRK', x=x0, total_moles=nT)
    T = model.temp

    v0 = eos.specific_volume(T, 1e5, x0, 1)[0]
    v0_x1 = eos.specific_volume(T, 1e5, x1, 1)[0]
    mu0 = eos.chemical_potential_tv(T, v0, n0)[0]
    A0_x0 = eos.helmholtz_tv(T, v0 * sum(n0), n0)[0]
    A0_x1 = eos.helmholtz_tv(T, v0_x1*sum(n0_x1), n0_x1)[0]

    dmudn = model.dmudn_TP()
    dmu1dx1 = (dmudn[0, 0] * (1/x0[1]) - (dmudn[0,1]*(1/x0[0]))) * nT #* 0.5
    dmu2dx2 = (dmudn[1, 1] * (1 / x0[0]) - (dmudn[1, 0] * (1 / x0[1]))) * nT # * 0.5

    dmudx_TP = model.dmudx_TP()

    #print('x1 dmu1dx1 - x2 dmu2dx2 = ', x0[0] * dmu1dx1 - x0[1] * dmu2dx2)
    #print('x1 dmu1dx1 - x2 dmu2dx2 = ', x0[0] * dmudx_TP[0,0] - x0[1] * dmudx_TP[1,1])

    v1 = eos.specific_volume(T, 1e5, x1, 1)[0]
    v1_x1 = eos.specific_volume(T, 1e5, x1)
    mu1= eos.chemical_potential_tv(T, v1, n1)[0]
    A1_x0 = eos.helmholtz_tv(T, v1, n1)[0]
    A1_x1 = eos.helmholtz_tv(T, v1, n1_x1)[0]

    dAdn_x0 = (A1_x0 - A0_x0)/(n1[0] - n0[0])
    dAdn_x1 = (A1_x1 - A0_x1)/(n1_x1[0] - n0_x1[0])
    dAdndx = (dAdn_x1 - dAdn_x0)/(x1 - x0)
    print('nT : ', sum(n0), sum(n0_x1), sum(n1), sum(n1_x1))
    print('dAdn1(x0) = ', dAdn_x0)
    print('dAdn1(x1) = ', dAdn_x1)
    print('Delta( dAdn )', dAdn_x1 - dAdn_x0)
    print('dx = ', x1 - x0)
    print('dAdn1dx = ', dAdndx)

    dmu1, dmu2 = mu1 - mu0

    print('x1dmu1 + x2dmu2 =', x0[0] * dmu1 + x0[1] * dmu2)

    print('dmu1dx1_TP :', dmu1dx1)
    print('dmu2dx2_TP :', dmu2dx2)
    print('dmudx_? : ', (mu1 - mu0) / (x1 - x0))
    print('dmudx (model): ', dmudx_TP)

    return dmu1dx1, dmu1/dx1

def check_dmudn():
    model = Kempers89(comps='NC24,NC10', eos_key='SRK', x=[0.1, 0.9], total_moles=3)

    dn_i = 1e-8
    dn = np.array([0, dn_i])

    mu0, dmudn_TV = model.get_mu()
    dmudn = model.dmudn_TP()

    print('dmudn_TP :', dmudn)

    model.x = (model.x * model.total_moles + dn) / (model.total_moles + dn_i)
    model.total_moles += dn_i
    mu1, dmudn_TV = model.get_mu()

    print('dmudn1 : ', (mu1 - mu0) / dn[1])

def go_to_bed(now):
    comps = 'NC10,NC24'
    fracs = np.array([0.1,0.9])

    model10 = Kempers89(comps=comps, x= fracs, total_moles=10)
    model1 = Kempers89(comps=comps, x= fracs, total_moles=1)

    ST1_10, ST2_10 = model10.get_soret_binary()
    ST1_1, ST2_1 = model1.get_soret_binary()

    if abs(ST1_1 - ST1_10) < 1e-5 and abs(ST2_10 - ST2_1) < 1e-5:
        print('YES')
    else:
        print('NO')

def dmudx(N, axs, plot_analytical = False, plot_integral = False):
    fracs = [0.05, 0.8, 0.15]
    model = Kempers89(comps='NC24,NC10,NC6', eos_key='SRK', x=fracs, total_moles=5.8)

    dx = [0.75/N, -0.75/N, 0]

    mu1, mu2 = (np.zeros(N) for i in range(2))

    if plot_analytical:
        dmu1dx1 = np.zeros(N)
        dmu1dx2 = np.zeros(N)
        dmu2dx2 = np.zeros(N)
        dmu2dx1 = np.zeros(N)
    x1 = np.zeros(N)

    for i in range(N):
        x1[i] = model.x[0]
        if plot_analytical:
            dmu1dx1[i], dmu1dx2[i] = model.dmudx_TP().flatten()[:2]
        mu1[i], mu2[i], _ = model.get_mu(dmudn=False)
        model.x += dx

    delta_mu1 = np.diff(mu1)
    axs[0].plot(x1[1:], delta_mu1 * N/0.75, linestyle = ':',color = 'b', alpha = 0.5)

    if plot_analytical:
        axs[0].plot(x1, dmu1dx1, linestyle=':', color='g')
        axs[1].plot(x1, mu1, color = 'g', linestyle = '--')

        if plot_integral:
            mu = np.array([mu1[0] for i in range(N)]) + np.array([sum(dmu1dx1[:k] * 0.75/N) for k in range(N)])
            axs[1].plot(x1, mu, color = 'r', alpha = 0.5)

def integrate_dmudx():
    N = 100
    fracs = [0.1, 0.9]

    dx = (fracs[1] - fracs[0]) / N
    dx = [dx, -dx]

    def dmudx_1var(x1):
        model.x = np.array([x1, 1-x1])
        return model.dmudx_TP()[0,0]

    model = Kempers89(comps='NC24,NC10', eos_key='SRK', x=fracs, total_moles=1)

    mu0 = model.get_mu(dmudn=False)
    mu = [mu0 for i in range(N)]
    x1_axis = np.linspace(fracs[0], fracs[1], N)
    delta_mu = [quad(dmudx_1var, 0, x)[0] for x in x1_axis]

    mu_faktisk = np.zeros(N)
    for i in range(N):
        mu_faktisk[i], _ = model.get_mu(dmudn=False)

    plt.plot(x1_axis, mu_faktisk, color = 'b')
    plt.plot(x1_axis, mu + delta_mu, linestyle = ':', color = 'r')

    plt.savefig('dmudx_integrert')

def check_dmudx_helmholtz(x):
    x0 = np.array([x,1-x])
    nT0 = 1
    dn = 1e-10
    dx = 1e-10
    nT1 = nT0 + dn

    n0_x0 = x0 * nT0
    n1_x0 = n0_x0 + np.array([dn, 0])
    x1_x0 = n1_x0 / nT1

    x1 = x0 + np.array([dx, -dx])
    n0_x1 = x1 * nT0
    n1_x1 = n0_x1 + np.array([dn, 0])
    x1_x1 = n1_x1 / nT1

    eos = cubic.cubic()
    eos.init('NC24,NC10','SRK')
    eos.set_tmin(50)
    model = Kempers89(comps='NC24,NC10', eos_key='SRK', x=x0, total_moles=nT0)
    T = model.temp

    v0_x0 = eos.specific_volume(T, 1e5, x0, 1)[0]
    v0_x1 = eos.specific_volume(T, 1e5, x1, 1)[0]
    A0_x0 = eos.helmholtz_tv(T, v0_x0 * nT0, n0_x0)[0]
    A0_x1 = eos.helmholtz_tv(T, v0_x1 * nT0, n0_x1)[0]

    #dmudn = model.dmudn_TP()
    #dmu1dx1 = (dmudn[0, 0] * (1/x0[1]) - (dmudn[0,1]*(1/x0[0]))) * nT0 #* 0.5
    #dmu2dx2 = (dmudn[1, 1] * (1 / x0[0]) - (dmudn[1, 0] * (1 / x0[1]))) * nT0 # * 0.5
    dmudx_TP = model.dmudx_TP()

    mu_x0 = eos.chemical_potential_tv(T, v0_x0 * nT0, n0_x0)[0]
    mu_x1 = eos.chemical_potential_tv(T, v0_x1 * nT0, n0_x1)[0]

    v1_x0 = eos.specific_volume(T, 1e5, x1_x0, 1)[0]
    v1_x1 = eos.specific_volume(T, 1e5, x1_x1, 1)[0]
    A1_x0 = eos.helmholtz_tv(T, v0_x0 * nT0, n1_x0)[0]
    A1_x1 = eos.helmholtz_tv(T, v0_x1 * nT0, n1_x1)[0]

    dAdn_x0 = (A1_x0 - A0_x0)/dn
    dAdn_x1 = (A1_x1 - A0_x1)/dn
    dAdndx = (dAdn_x1 - dAdn_x0)/(x1 - x0)

    print('dAdn1(x0) = ', dAdn_x0, mu_x0[0], 'diff :', dAdn_x0 - mu_x0[0])
    print('dAdn1(x1) = ', dAdn_x1, mu_x1[0], 'diff :', dAdn_x1 - mu_x1[0])
    print('Delta( dAdn )', dAdn_x1 - dAdn_x0, mu_x1[0] - mu_x0[0] , 'diff :', dAdn_x1 - dAdn_x0 -(mu_x1[0] - mu_x0[0]))
    print('dAdn1dx = ', 2 * dAdndx, 2 * (mu_x1[0] - mu_x0[0])/(x1 - x0), 'diff :', dAdndx - (mu_x1[0] - mu_x0[0])/(x1 - x0))

    print('dmudx (model): ', dmudx_TP)
