print('Starter opp...')
from pyctp import saftvrmie
from models.kempers01 import Kempers01
import numpy as np
import matplotlib.pyplot as plt
from scipy.constants import Avogadro, Boltzmann
from scipy.optimize import root
from copy import deepcopy
import pandas as pd
print('Nu kjør vi!')

mie_df = pd.read_excel('models/mie.xlsx')
sig_Ar = np.array(mie_df.loc[mie_df['comp'] == 'AR']['sigma'].iloc[0]) * 1e-10
eps_Ar = np.array(mie_df.loc[mie_df['comp'] == 'AR']['epsilon'].iloc[0])

comps = 'XE,KR'

T_red = 0.85
rho_red = 0.81
x = np.array([0.5, 0.5])

eos = saftvrmie.saftvrmie()
eos.init(comps)

params = {'XE' : (3.975, 1.72 * eps_Ar, 131.29), 'AR' : (3.405, eps_Ar, 39.95), 'KR' : (3.633, 1.39 * eps_Ar, 83.80), 'C1' : (3.74, 1.27 * eps_Ar, 16.04)}

c1, c2 = comps.split(',')
sig_1, eps_1_div_k, _ = params[c1]
sig_2, eps_2_div_k, _ = params[c2]

sig_12 = 0.5 * (sig_1 + sig_2) * 1e-10
eps_12_div_k = np.sqrt(eps_1_div_k * eps_2_div_k)

T = T_red * eps_12_div_k
rho = rho_red / (sig_12 ** 3)

pres = eos.pressure_tv(T, Avogadro / rho, x)[0]

model = Kempers01(comps, eos, x=x, temp=T, pres=pres)

print('Predicted soret for', comps, '(COV) at (T, p, rho) (', round(T, 1), round(pres/1e5, 1), round(rho/(Avogadro * 1e3), 1), ') :', model.get_soret(mode='cov') * 1e3)


def integrate_x(delta_T, N, mode='cov'):
    dT_plus = delta_T / (2 * N)
    x_list = np.empty((2 * N + 1, len(x)))
    ST_list = np.empty_like(x_list)
    ST0_list = np.empty_like(x_list)
    T_list = np.empty(2 * N + 1)
    rho_list = np.empty_like(T_list)
    dmu1dx_list = np.empty_like(T_list)
    dmu2dx_list = np.empty_like(T_list)


    x_list[N] = deepcopy(x)
    T_list[N] = T
    rho_list[N] = rho_red
    model.set_temp(T)
    model.set_mole_fracs(deepcopy(x))
    ST_list[N], ST0_list[N] = model.get_soret(mode=mode, kin=True)
    dmudx = model.dmudx_TP()
    dmu1dx_list[N] = dmudx[0, 0]
    dmu2dx_list[N] = dmudx[1, 1]

    for n in range(N):
        if x[0] != x_list[N - n][0]:
            dT_minus = dT_plus * (rho_list[N + n] * (x_list[N + n][0] - x[0])) / (rho_list[N - n] *(x[0] - x_list[N - n][0]))
        else:
            dT_minus = dT_plus

        x_list[N + n + 1] = x_list[N + n] + (- x_list[N + n] * (1 - x_list[N + n]) * ST_list[N + n]) * dT_plus
        x_list[N - n - 1] = x_list[N - n] - (- x_list[N - n] * (1 - x_list[N - n]) * ST_list[N - n]) * dT_minus

        T_list[N + n + 1] = T_list[N + n] + dT_plus
        T_list[N - n - 1] = T_list[N - n] - dT_minus

        phase = eos.guess_phase(T_list[N + n + 1], pres, x_list[N + n + 1])
        r = 1 / eos.specific_volume(T_list[N + n + 1], pres, x_list[N + n + 1], phase)[0]
        rho_list[N + n + 1] = r * (sig_12 ** 3) * Avogadro

        phase = eos.guess_phase(T_list[N + n + 1], pres, x_list[N + n + 1])
        r = 1 / eos.specific_volume(T_list[N - n - 1], pres, x_list[N - n - 1], phase)[0]
        rho_list[N - n - 1] = r * (sig_12**3) * Avogadro

        model.set_mole_fracs(deepcopy(x_list[N + n + 1]))
        model.set_temp(T_list[N + n + 1])
        ST_list[N + n + 1], ST0_list[N + n + 1] = model.get_soret(mode=mode, kin=True)
        dmudx = model.dmudx_TP()
        dmu1dx_list[N + n + 1] = dmudx[0, 0]
        dmu2dx_list[N + n + 1] = dmudx[1, 1]

        model.set_mole_fracs(deepcopy(x_list[N - n - 1]))
        model.set_temp(T_list[N - n - 1])
        ST_list[N - n - 1], ST0_list[N - n - 1] = model.get_soret(mode=mode, kin=True)
        dmudx = model.dmudx_TP()
        dmu1dx_list[N - n - 1] = dmudx[0, 0]
        dmu2dx_list[N - n - 1] = dmudx[1, 1]

    dT = np.diff(T_list)
    drho = np.diff(rho_list)
    n = np.sum(np.vstack(dT[:N]) * np.vstack(rho_list[1:N+1]) * x_list[1:N+1], axis=0) + np.sum(np.vstack(dT[N:]) * np.vstack(rho_list[N:-1]) * x_list[N:-1], axis=0)
    n_tot = np.sum(n)

    x_tot = n / n_tot

    print('Total composition before integration :', x)
    print('Total composition after integration :', x_tot)

    avg_ST = - (x_list[-1] - x_list[0]) / (x * (1 - x) * delta_T)
    avg_T = 0.5 * (max(T_list) + min(T_list))
    mole_avg_T = np.trapz(rho_list * T_list, T_list) / np.trapz(rho_list, T_list)

    x_list = x_list.transpose()
    ST_list = ST_list.transpose()
    ST0_list = ST0_list.transpose()


    c1, c2 = comps.split(',')
    fig, axs = plt.subplots(4, 1, sharex='all', figsize = (8, 8))
    ax1, ax2, ax3, ax4 = axs
    ax1.plot(T_list, x_list[0], color='r')
    ax1.plot(T_list, x_list[1], color='b')
    ax1.set_ylabel('x [-]')

    tw2 = ax2.twinx()
    ax2.plot(T_list, ST_list[0] * 1e3, color='r', label=comps.split(',')[0])
    tw2.plot(T_list, ST_list[1] * 1e3, color='b', label=comps.split(',')[1])
    ax2.plot(T_list, ST0_list[0] * 1e3, color='r', linestyle='--', label='kin.')
    tw2.plot(T_list, ST0_list[1] * 1e3, color='b', linestyle='--')

    ax2.plot(T_list[N], ST_list[0][N] * 1e3, marker='x', color='r', linestyle='', label='pred.')
    tw2.plot(T_list[N], ST_list[1][N] * 1e3, marker='x', color='b', linestyle='')
    ax2.plot(avg_T, avg_ST[0] * 1e3, marker='v', color='r', linestyle='', label='avg.')
    tw2.plot(avg_T, avg_ST[1] * 1e3, marker='v', color='b', linestyle='')
    ax2.set_ylim(15, 70)
    tw2.set_ylim(-70, -15)

    ax2.set_ylabel(r'$S_{T,'+comps.split(',')[0]+'}$ [mK$^{-1}$]', color='r')
    tw2.set_ylabel(r'$S_{T,' + comps.split(',')[1] + '}$ [mK$^{-1}$]', color='b')

    ax3.plot(T_list, rho_list)
    ax3.set_ylabel(r'$\rho^*$ [-]')

    twn4 = ax4.twinx()
    ax4.plot(T_list, dmu1dx_list, color='r')
    twn4.plot(T_list, dmu2dx_list, color='b')
    ax4.set_ylabel(r'$d\mu_{'+c1+'}/ dx_{'+c1+'}$', color='r')
    twn4.set_ylabel(r'$d\mu_{' + c2 + '}/ dx_{' + c2 + '}$', color='b')
    ax4.set_ylim(-10, 10)
    twn4.set_ylim(-10, 10)

    ax4.set_xlabel('T [K]')
    plt.figlegend()
    plt.savefig(str(N))
    plt.savefig('Integrated_BH_'+str(N)+'_dT_'+str(round(dT_plus, 2)).replace('.', '_'))
    plt.close(fig)
    print('Finished run for N = ', N, ', dT = ', dT_plus, sep='')

def integrate_gstart(x0, delta_T):

    if x0 < 0 or x0 > 1:
        print('Nope!')
        return x0*1e6

    dT = 0.75
    Ti = T - 0.5 * delta_T

    N = int((delta_T / dT) + 0.5) #Round to nearest int

    x_list = np.empty(N)
    rho_list = np.empty_like(x_list)

    x_list[0] = x0
    rho_list[0] = 1 / eos.specific_volume(Ti, pres, [x0, 1 - x0], 1)[0]

    for n in range(N - 1):
        model.set_temp(Ti)
        model.set_mole_fracs([x_list[n], 1 - x_list[n]])

        ST = model.get_soret(mode='cov')
        x_list[n + 1] = x_list[n] - x_list[n] * (1 - x_list[n]) * ST[0] * dT

        Ti += dT

        rho_list[n + 1] = 1 / eos.specific_volume(Ti, pres, [x_list[n + 1], 1 - x_list[n + 1]], 1)[0]

    T_list = np.linspace(T - 0.5*delta_T, T + 0.5*delta_T, N)

    x_tot = np.trapz(rho_list * x_list, T_list) / np.trapz(rho_list, T_list)

    print(x0, ':', x_tot - x[0])

    return x_tot

def integrate_oneway(x0, delta_T, mode='cov'):
    dT = 0.75
    N = int((delta_T / dT) + 0.5)  # Round to nearest int
    x_list = np.empty((N, len(x)))
    ST_list = np.empty_like(x_list)
    ST0_list = np.empty_like(x_list)
    T_list = np.empty(N)
    rho_list = np.empty_like(T_list)
    dmu1dx_list = np.empty_like(T_list)
    dmu2dx_list = np.empty_like(T_list)

    x_list[0] = deepcopy(x0)
    T_list[0] = T - 0.5 * delta_T
    rho_list[0] = rho_red
    model.set_temp(T_list[0])
    model.set_mole_fracs(deepcopy(x0))
    ST_list[0], ST0_list[0] = model.get_soret(mode=mode, kin=True)
    dmudx = model.dmudx_TP()
    dmu1dx_list[0] = dmudx[0, 0]
    dmu2dx_list[0] = dmudx[1, 1]

    for n in range(N - 1):

        phase = eos.guess_phase(T, pres, x_list[n])
        print(phase)
        x_list[n + 1] = x_list[n] + (- x_list[n] * (1 - x_list[n]) * ST_list[n]) * dT
        T_list[n + 1] = T_list[n] + dT

        r = 1 / eos.specific_volume(T_list[n + 1], pres, x_list[n + 1], phase)[0]
        rho_list[n + 1] = r * (sig_12 ** 3) * Avogadro

        model.set_mole_fracs(deepcopy(x_list[n + 1]))
        model.set_temp(T_list[n + 1])
        ST_list[n + 1], ST0_list[n + 1] = model.get_soret(mode=mode, kin=True)
        dmudx = model.dmudx_TP()
        dmu1dx_list[n+1] = dmudx[0, 0]
        dmu2dx_list[n+1] = dmudx[1, 1]

    dT = np.diff(T_list)
    drho = np.diff(rho_list)
    x1_tot = np.trapz(rho_list * x_list[:, 0], T_list) / np.trapz(rho_list, T_list)
    x2_tot = np.trapz(rho_list * x_list[:, 1], T_list) / np.trapz(rho_list, T_list)

    x_tot = [x1_tot, x2_tot]

    print('Total composition before integration :', x)
    print('Total composition after integration :', [x1_tot, x2_tot])

    model2 = Kempers01(comps, eos, x=x, temp=T, pres=pres)
    print('Predicted soret for', comps, '(COV) at (T, p, rho) (', round(T, 1), round(pres / 1e5, 1),
          round(rho / (Avogadro * 1e3), 1), ') :', model2.get_soret(mode='cov') * 1e3)

    avg_ST = - (x_list[-1] - x_list[0]) / (x * (1 - x) * delta_T)
    avg_T = 0.5 * (max(T_list) + min(T_list))
    mole_avg_T = np.trapz(rho_list * T_list, T_list) / np.trapz(rho_list, T_list)

    x_list = x_list.transpose()
    ST_list = ST_list.transpose()
    ST0_list = ST0_list.transpose()

    reported = {'XE,KR': 5.1, 'KR,AR': 14.4, 'AR,C1': 9.3, 'XE,AR': 18.6, 'KR,C1': 22.3, 'XE,C1': 23.1}

    fig, axs = plt.subplots(4, 1, sharex='all', figsize=(8, 8))
    ax1, ax2, ax3, ax4 = axs
    ax1.plot(T_list, x_list[0], color='r')
    ax1.plot(T_list, x_list[1], color='b')
    ax1.set_ylabel('x [-]')

    tw2 = ax2.twinx()
    ax2.plot(T_list, ST_list[0] * 1e3, color='r', label=comps.split(',')[0])
    tw2.plot(T_list, ST_list[1] * 1e3, color='b', label=comps.split(',')[1])
    #ax2.plot(T_list, ST0_list[0] * 1e3, color='r', linestyle='--', label='kin.')
    #tw2.plot(T_list, ST0_list[1] * 1e3, color='b', linestyle='--')

    ax2.plot(avg_T, ST_list[0][N//2] * 1e3, marker='x', color='r', linestyle='', label='pred.')
    tw2.plot(avg_T, ST_list[1][N//2] * 1e3, marker='x', color='b', linestyle='')
    ax2.plot(avg_T, avg_ST[0] * 1e3, marker='v', color='r', linestyle='', label='avg.')
    tw2.plot(avg_T, avg_ST[1] * 1e3, marker='v', color='b', linestyle='')
    #ax2.plot(avg_T, reported[comps], marker='o', color='r', linestyle='', label='rep.')
    #tw2.plot(avg_T, -reported[comps], marker='o', color='b', linestyle='', label='rep.')

    ax2.set_ylabel(r'$S_{T,'+comps.split(',')[0]+'}$ [mK$^{-1}$]', color='r')
    tw2.set_ylabel(r'$S_{T,' + comps.split(',')[1] + '}$ [mK$^{-1}$]', color='b')
    ax2.set_ylim(5, 15)
    tw2.set_ylim(-15, -5)

    ax3.plot(T_list, rho_list)
    ax3.set_ylabel(r'$\rho^*$ [-]')

    twn4 = ax4.twinx()
    ax4.plot(T_list, dmu1dx_list, color='r')
    twn4.plot(T_list, dmu2dx_list, color='b')
    ax4.set_ylabel(r'$d\mu_{'+c1+'}/ dx_{'+c1+'}$', color='r')
    twn4.set_ylabel(r'$d\mu_{' + c2 + '}/ dx_{' + c2 + '}$', color='b')

    ax4.set_xlabel('T [K]')
    plt.figlegend()
    plt.savefig('Integrated_1way_'+str(N)+'_DT_'+str(round(delta_T, 2)).replace('.', '_'))
    plt.close(fig)
    print('Finished one-way run for N = ', N, ', ∆T = ', delta_T, sep='')

def solver():

    delta_T = 0.3 * T
    sol = root(lambda x0: x[0] - integrate_gstart(x0, delta_T), x0=[0.53849907])
    print('Success :', sol.success)
    x0 = sol.x[0]
    integrate_oneway([x0, 1 - x0], delta_T)


solver()

