from kempers_V2.kempers import Kempers_HS
#from kempers01 import Kempers
from pykingas import KineticGas
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.cm import get_cmap
from matplotlib.colors import Normalize
plt.style.use('default')
from pyctp import saftvrmie, cubic
comps = 'AR,C1'

swap = False # To quickly check that results are independent of the order of the components
x1 = 0.1
x = np.array([x1, 1 - x1])
if swap is True:
    print('Swapping components')
    comps = ','.join(comps.split(',')[::-1])
    x = x[::-1]

def phase_cmap(phase):
    if phase == 0:
        return 'r'
    elif phase == 1:
        return 'b'
    elif phase == 2:
        return 'g'
    else:
        return 'black'

def phase_envelope(x_axis='T'):
    eos = cubic.cubic()#saftvrmie.saftvrmie()
    eos.init(comps, 'SRK')
    Tc, _, pc = eos.critical(x)
    print(Tc)

    phenv_T_list = np.linspace(120, Tc-5, 50)

    p_list = np.array([eos.bubble_pressure(T, x)[0] for T in phenv_T_list])
    V1_list = np.array([eos.specific_volume(T, p, x, 1)[0] for T, p in zip(phenv_T_list, p_list)])
    V2_list = np.array([eos.specific_volume(T, p, x, 2)[0] for T, p in zip(phenv_T_list, p_list)])

    if x_axis != 'p':
        plt.plot(1 / V1_list, phenv_T_list)
        plt.plot(1 / V2_list, phenv_T_list)

    flash_phase_cmap = get_cmap('tab10')
    flash_phase_norm = Normalize(vmin=0, vmax=4)
    flash_comp_cmap = get_cmap('viridis')
    flash_comp_norm = Normalize(vmin=0, vmax=1)

    if x_axis == 'T':
        for rho in rho_list:
            u_list = [eos.internal_energy_tv(T, 1 / rho, x)[0] for rho in rho_list]
            flash_comp = [eos.two_phase_uvflash(x, u, 1 / rho)[4] for u, rho in zip(u_list, rho_list)]
            flash_phase = [eos.two_phase_uvflash(x, u, 1 / rho)[-1] for u, rho in zip(u_list, rho_list)]
            for i in range(len(T_list) - 1):
                plt.plot([rho, rho], T_list[i:i+1], color=flash_phase_cmap(flash_phase_norm(flash_comp[i])))
            #plt.plot(np.ones_like(T_list) * rho, T_list, color=rho_cmap(rho_norm(rho)))

        drho = (max(rho_list) - min(rho_list)) * 0.1
        plt.xlim(min(rho_list) - drho, max(rho_list) + drho)
        plt.ylim(min(T_list), max(T_list))

    elif x_axis == 'rho':
        labels = []
        leg_hands = []
        for T in T_list:
            plt.plot(rho_list, np.ones_like(rho_list) * T, color=T_cmap(T_norm(T)))
            continue
            u_list = [eos.internal_energy_tv(T, 1 / rho, x)[0] for rho in rho_list]
            flash_comp = [eos.two_phase_uvflash(x, u, 1 / rho)[4] for u, rho in zip(u_list, rho_list)] # Molar fraction of gas
            flash_phase = [eos.two_phase_uvflash(x, u, 1 / rho)[-1] for u, rho in zip(u_list, rho_list)] # Phase key
            for i in range(len(flash_comp)):
                if flash_phase[i] == 2:
                    flash_comp[i] = 1
                elif flash_phase[i] == 1:
                    flash_comp[i] = 0
            for i in range(len(rho_list) - 1):
                if flash_phase[i] not in labels:
                    labels.append(flash_phase[i])
                    l, = plt.plot(rho_list[i:i+2], [T, T], color=T_cmap(T_norm(T)), label=flash_phase[i])
                    leg_hands.append(l)
                else:
                    plt.plot(rho_list[i:i + 2], [T, T], color=phase_cmap(flash_phase[i]))

        dT = (max(T_list) - min(T_list)) * 0.1
        plt.xlim(min(rho_list), max(rho_list))
        plt.ylim(min(T_list) - dT, max(T_list) + dT)
        plt.ylabel(r'$T$ [K]')
        plt.xlabel(r'$\rho$ [mol m$^{-3}$]')
        plt.legend(handles=leg_hands)

    elif x_axis == 'p':
        for T in T_list:
            LLE, L1VE, L2VE = eos.get_binary_pxy(T)
            if L1VE[0] is not None:
                plt.plot(L1VE[0], L1VE[2], color=T_cmap(T_norm(T)))
                plt.plot(L1VE[1], L1VE[2], color=T_cmap(T_norm(T)))

            p_list = [eos.pressure_tv(T, 1/rho, x)[0] for rho in rho_list]
            flash_comp = [eos.two_phase_tpflash(T, p, x)[2] for p in p_list]
            flash_phase = [eos.two_phase_tpflash(T, p, x)[-1] for p in p_list]

            if L1VE[0] is not None:
                plt.plot([x[0] for rho in rho_list], p_list, color=T_cmap(T_norm(T)))

    plt.show()

def plot_ST_T():
    rho_list = np.linspace(1, 400, 5)
    T_list = np.linspace(200, 400, 50)

    rho_cmap = get_cmap('cividis')
    rho_norm = Normalize(vmin=min(rho_list), vmax=max(rho_list))

    phase_envelope(x_axis='T')

    for rho in rho_list:
        print(rho)
        real = np.array([model_HS.get_Soret(T, rho, x)[0] * 1e3 for T in T_list])
        real2 = np.array([model_HS.get_Soret(T, rho, x)[1] * 1e3 for T in T_list])
        hs = np.array([(kin.alpha_T0(T, 1 / rho, x, BH=True)[0] / T) * 1e3 for T in T_list])
        plt.plot(T_list, real, color=rho_cmap(rho_norm(rho)), label=round(rho, 1))
        plt.plot(T_list, hs, color=cmap(norm(rho)), linestyle='--')
    plt.plot(T_list, hs, linestyle='--')
    plt.xlabel(r'T [K]')
    plt.ylabel(r'$\Delta_{HS}S_{T,1}$ [mK$^{-1}$]')
    plt.legend(title=r'$\rho$ [mol m$^{-3}$]')
    plt.show()

if __name__ == '__main__':

    model_HS = Kempers_HS(comps)
    #model_01 = Kempers(comps)
    kin = KineticGas(comps)

    rho_list = np.linspace(100, 20000, 50)
    T_list = np.linspace(160, 200, 5)

    T_cmap = get_cmap('cool')
    T_norm = Normalize(vmin=min(T_list), vmax=max(T_list))

    phase_envelope(x_axis='rho')
    exit(0)
    for T in T_list:
        print(T)
        real_HS = np.array([model_HS.get_Soret(T, rho, x)[0] * 1e3 for rho in rho_list])
        #real_01 = np.array([model_01.get_Soret(T, rho, x)[0] * 1e3 for rho in rho_list])
        hs = np.array([(kin.alpha_T0(T, 1 / rho, x, BH=True)[0] / T) * 1e3 for rho in rho_list])
        plt.plot(rho_list, real, color=T_cmap(T_norm(T)), label=round(T, 1))
        plt.plot(rho_list, hs, color=T_cmap(T_norm(T)), linestyle='--')
    plt.xlabel(r'$\rho$ [mol m$^{-3}$]')
    plt.ylabel(r'$S_{T,1}$ [mK$^{-1}$]')
    plt.legend(title=r'T [K]')
    plt.show()

