print('Starting')
n_imports = 6
print('-'*(n_imports + 1))
import pandas as pd
print('#', end='')
import numpy as np
print('#', end='')
import os
print('#', end='')
from datetime import datetime
print('#', end='')
from scipy.constants import Avogadro, Boltzmann
print('#', end='')
import matplotlib.pyplot as plt
from matplotlib.cm import get_cmap
from matplotlib.colors import Normalize

def get_file_path(name, comps, e, s):

    file_path = os.path.dirname(os.path.abspath(__file__)) + '/'
    file_path += 'output/saft_validation/'+name

    if e is None:
        e = 1
    if s is None:
        s = 1

    if e > 1 or s > 1:
        file_path += comps.split(',')[0]+'_e_'+str(e).replace('.','_')+'_s_'+str(s).replace('.','_')+'.csv'
    else:
        file_path += comps.replace(',', '_') + '.csv'

    return file_path

def check_overwrite(out_path):
    if os.path.isfile(out_path):
        print('This operation will overwrite ', out_path)
        confirm = input('Y to confirm')
        if confirm != 'Y':
            exit(0)

def redefine_critical_parameters(eos, start_params, e_in=None, s_in=None, level=0):
    if e_in is None:
        e = 1
    else:
        e = e_in

    if s_in is None:
        s = 1
    else:
        s = s_in

    try:
        eos.set_pure_fluid_param(1, start_params[0], start_params[1] * s, start_params[2] * e, start_params[3], start_params[4])
        eos.redefine_critical_parameters(silent=False)
        verify = eos.critical([1])
    except:
        new_s = 0.5 * (1 + s)
        new_e = 0.5 * (1 + e)
        try:
            redefine_critical_parameters(eos, start_params, e_in=new_e, s_in=new_s, level=level+1)
            eos.set_pure_fluid_param(1, start_params[0], start_params[1] * s, start_params[2] * e, start_params[3], start_params[4])
            eos.redefine_critical_parameters(silent=False)
            verify = eos.critical([1])
        except:
            raise ValueError('Could not redefine')

def phase_envelope(comps, e=None, s=None):
    global p_c
    #comps = 'AR'

    composition = [1]

    out_path = get_file_path('phase_envelope', comps, 1, 1)
    check_overwrite(out_path)

    eos = saftvrmie.saftvrmie()
    eos.init(comps)
    m, sigma, eps_div_k, la, lr = eos.get_pure_fluid_param(1)
    eps = eps_div_k * Boltzmann

    # Må omdefinere mange ganger hvis s >> 1 eller e >> 1 for at den skal klare å regne ut nytt kritisk punkt
    redefine_critical_parameters(eos, [m, sigma, eps_div_k, la, lr], e_in=e, s_in=s)

    if e is None:
        e = 1
    if s is None:
        s = 1

    sigma *= s
    eps_div_k *= e
    eps *= e

    try:
        T_c, _, p_c = eos.critical([1])
    except:
        pass
    T, p = eos.get_envelope_twophase(0.01 * p_c, [1], minimum_temperature=0.01 * T_c)

    rho = np.zeros_like(T)
    phase = np.zeros_like(T, dtype=int)
    for i in range(len(rho)):
        phase[i] = eos.guess_phase(T[i], p[i], [1])
        Vm = eos.specific_volume(T[i], p[i], [1], phase[i])[0]
        rho[i] = Avogadro / Vm

    red_T = T / eps_div_k
    red_p = p * (sigma**3) / eps
    red_rho = rho * (sigma**3)

    out_path = get_file_path('phase_envelope_', comps, e, s)

    data = {'red_rho' : pd.Series(red_rho),
            'red_p' : pd.Series(red_p),
            'red_T' : pd.Series(red_T),
            'pred_phase' : pd.Series(phase)}

    out_df = pd.DataFrame(data)
    out_df.to_csv(out_path)

def test_critical(comps):
    #comps = 'KR'
    composition = [1]

    out_path = get_file_path('critical_params_', comps, 1, 1)
    outdir = os.path.dirname(out_path)
    check_overwrite(out_path)

    eos = saftvrmie.saftvrmie()
    eos.init(comps)

    df = pd.read_excel('data/NIST_SAFT.xlsx', sheet_name='SAT-EOS')  # T*, r* (vapor), r* (liquid), p*

    m, sig, eps_div_k, la, lr = eos.get_pure_fluid_param(1)
    eps = eps_div_k * Boltzmann

    s_list = np.array([1, 1.1, 1.4, 1.6, 1.9, 2.0])
    e_list = np.array([1, 1.2, 1.5, 1.7, 2.0, 2.2])

    markers = ['o', 's', 'v', '1', '2', '|']
    cmap = get_cmap('viridis')
    cmap_s = get_cmap('rainbow')
    norm = Normalize(vmin=min(e_list), vmax=max(e_list))
    norm_s = Normalize(vmin=min(s_list), vmax=max(s_list))

    hands = [None] * (len(s_list) + len(e_list))

    T_c_list = np.full((len(e_list), len(s_list)), np.nan)
    rho_c_list = np.full_like(T_c_list, np.nan)
    p_c_list = np.full_like(T_c_list, np.nan)

    fig, axs = plt.subplots(2, 1, sharex='all')
    ax1, ax2 = axs
    for si, s in enumerate(s_list):
        for ei, e in enumerate(e_list):
            sigma = sig * s
            eps_i_div_k = eps_div_k * e
            eps_i = eps_i_div_k * Boltzmann
            eos.set_pure_fluid_param(1, m, sigma, eps_i_div_k, la, lr)
            eos.redefine_critical_parameters(silent=False)


            try:
                T_c, V_c, p_c = eos.critical([1])
            except:
                print('Critical calculation failed for e = ', e, ', s = ', s, sep='')
                T_c, V_c, p_c = np.nan, np.nan, np.nan

            rho_c = Avogadro / V_c

            red_Tc = T_c / eps_i_div_k
            red_pc = p_c * (sigma ** 3) / eps_i
            red_rho_c = rho_c * (sigma ** 3)

            T_c_list[si, ei] = red_Tc
            rho_c_list[si, ei] = red_rho_c
            p_c_list[si, ei] = red_pc

            pe, = ax1.plot(red_rho_c, red_pc, marker=markers[si], color=cmap(norm(e)), label=str(e), linestyle='')
            ps, = ax2.plot(red_rho_c, red_Tc, marker=markers[si], color=cmap(norm(e)), label=str(s), linestyle='')

            if si == 0:
                hands[ei] = pe
            if ei == 0:
                hands[len(e_list) + si] = ps

    Tc_rep = 1.312
    pc_rep = 0.1279
    rho_c_rep = 0.316
    #ax1.plot(rho_c_rep, pc_rep, marker='x', color='black')
    #ax2.plot(rho_c_rep, Tc_rep, marker='x', color='black')
    ax1.set_ylabel(r'$p^*_c$')
    ax2.set_xlabel(r'$\rho^*_c$')
    ax2.set_ylabel(r'$T^*_c$')
    plt.tight_layout()
    plt.figlegend(handles=hands, ncol=2, title=r'$\epsilon / \epsilon_{'+comps+'}$      $\sigma / \sigma_{'+comps+'}$')
    plt.suptitle(comps)
    plt.savefig('Testing_critical_'+comps)

    fig, axs = plt.subplots(2, 2, sharex='col', sharey='row')
    for ei, e in enumerate(e_list):
        axs[0, 0].plot(s_list, p_c_list[:, ei] - pc_rep, color=cmap(norm(e)), label=e)
        axs[1, 0].plot(s_list, T_c_list[:, ei] - Tc_rep, color=cmap(norm(e)), label=e)

    axs[0, 0].legend(title=r'$\epsilon / \epsilon_{'+comps+'}$')
    axs[0, 0].set_ylabel(r'$\Delta p_c^*$')
    axs[1, 0].set_ylabel(r'$\Delta T_c^*$')
    axs[1, 0].set_xlabel(r'$\sigma / \sigma_{'+comps+'}$')

    for si, s in enumerate(s_list):
        axs[0, 1].plot(e_list, p_c_list[si] - pc_rep, color=cmap_s(norm_s(s)), label=s)
        axs[1, 1].plot(e_list, T_c_list[si] - Tc_rep, color=cmap_s(norm_s(s)), label=s)

    axs[0, 1].legend(title=r'$\sigma / \sigma_{'+comps+'}$')
    axs[1, 1].set_xlabel(r'$\epsilon / \epsilon_{'+comps+'}$')
    plt.suptitle(comps)
    #axs[0, 0].set_ylim(-0.05, 0.175)
    #axs[1, 0].set_ylim(-0.02, 1.75)
    plt.savefig('Critical_deviation_'+comps)

def binary_mixture(e=None, s=None):
    comps = 'AR,KR'

    if s is None:
        s = 1
    if e is None:
        e = 1

    out_path = get_file_path('binary_phase_envelope_', 'AR', e, s)

    eos = saftvrmie.saftvrmie()
    eos.init(comps)
    m, sig_Ar, eps_Ar_div_k, la, lr = eos.get_pure_fluid_param(1)

    eos.set_pure_fluid_param(1, 1, sig_Ar, eps_Ar_div_k, 6, 12)
    eos.set_pure_fluid_param(2, 1, s * sig_Ar, e * eps_Ar_div_k, 6, 12)
    eos.redefine_critical_parameters(silent=False)

    eps_div_k = eps_Ar_div_k
    eps = eps_div_k * Boltzmann
    sig = sig_Ar

    red_T = 1
    T = red_T * eps_div_k

    LLE, LL1V, LL2V = eos.get_binary_pxy(T)

    x1_list = np.linspace(0.19, 1)
    pc_list = np.array([eos.critical([x, 1-x])[2] for x in x1_list])

    df = pd.DataFrame({'bubble_line' : pd.Series(LL1V[0]),
                      'dew_line' : pd.Series(LL1V[1]),
                      'pressure' : pd.Series(LL1V[2]),
                       'crit_comp' : pd.Series(x1_list),
                       'crit_pres' : pd.Series(pc_list)})

    df.to_csv(out_path)
    print('Saved to ', out_path)

def rho_p(comps, s=None, e=None):
    #comps = 'LJF'
    composition = [1]
    if abs(sum(composition) - 1) > 1e-10:
        raise ValueError('Composition must sum to 1')
    elif len(composition) != len(comps.split(',')):
        raise ValueError('Composition and comps must be same length!')

    if s is None:
        s = 1
    if e is None:
        e = 1

    N_T_vals = 5

    '''
    confirm = input('Running saft-test rho-p with\n'
                    + 'comps = ' + comps + ', x = ' + str(composition) + ', e = ' + str(e) + ', s = ' + str(s)
                    + ', N_T_vals = ' + str(N_T_vals) +
                    '\n(Y) to confirm.')

    if confirm != 'Y':
        exit(0)
    '''

    out_path = get_file_path('rho_p_', comps, e, s)
    outdir = os.path.dirname(out_path)
    #check_overwrite(out_path)

    eos = saftvrmie.saftvrmie()
    eos.init(comps)


    if len(comps.split(',')) > 1:
        if e > 1 or s > 1 :
            params1 = eos.get_pure_fluid_param(1)
            _, sigma1, eps1_div_k, _, _ = params1

            eos.set_pure_fluid_param(2, params1[0], s * params1[1], e * params1[2], params1[3], params1[4])
            eos.redefine_critical_parameters()

            params2 = eos.get_pure_fluid_param(2)
            _, sigma2, eps2_div_k, _, _ = params2

            sigma = 0.5 * (sigma1 + sigma2)
            eps_div_k = np.sqrt(eps1_div_k * eps2_div_k)

        else:
            params1 = eos.get_pure_fluid_param(1)
            _, sigma1, eps1_div_k, _, _ = params1

            params2 = eos.get_pure_fluid_param(2)
            _, sigma2, eps2_div_k, _, _ = params2

            sigma = 0.5 * (sigma1 + sigma2)
            eps_div_k = np.sqrt(eps1_div_k * eps2_div_k)

    elif e > 1 or s > 1:
        params1 = eos.get_pure_fluid_param(1)
        _, sigma1, eps1_div_k, _, _ = params1

        eos.set_pure_fluid_param(1, params1[0], s * params1[1], e * params1[2], params1[3], params1[4])
        eos.redefine_critical_parameters(silent=False)

        sigma = sigma1 * s
        eps_div_k = eps1_div_k * e

    else:
        params = eos.get_pure_fluid_param(1)
        _, sigma, eps_div_k, _, _ = params


    print('Finished setting up EOS')

    df = pd.read_excel('data/VLE_data.xlsx', sheet_name='pvT', skiprows=5)
    liq_df = df[(df['region'] == 'liquid') & (df[' (0 = outlier; \n1 = confirmed)'] == 1)]
    T_val_counts = liq_df['T'].value_counts()
    T_vals = T_val_counts.axes[0]

    print('Finished fetching data')
    print('Lowest number of data points being used is :', T_val_counts[T_vals[N_T_vals]])

    using_T_vals = [T_vals[i] for i in range(N_T_vals)]

    out_dict = {}
    for i in range(N_T_vals):
        red_T = T_vals[i]
        red_rho = liq_df[liq_df['T'] == red_T]['ρ']
        red_p = liq_df[liq_df['T'] == red_T]['p']
        p_data = red_p * (eps_div_k * Boltzmann) / (sigma**3)

        T = red_T * eps_div_k
        rho = red_rho / (sigma**3)
        V = 1
        N = rho * V / Avogadro

        pred_p = np.array([eos.pressure_tv(T, V, [xi * Ni for xi in composition])[0] for Ni in N])

        sorted_inds = np.argsort(red_rho.tolist())
        sorted_p = pred_p[sorted_inds]
        sorted_p_data = np.array(p_data.tolist())[sorted_inds]
        sorted_rho = np.sort(rho.tolist())

        sorted_red_rho = (sigma ** 3) * sorted_rho
        sorted_red_p_data = sorted_p_data * (sigma**3)/(eps_div_k * Boltzmann)
        sorted_red_p = (sigma**3) * sorted_p / (eps_div_k * Boltzmann)

        out_dict[str(red_T)+',rho'] = sorted_red_rho
        out_dict[str(red_T)+',p_data'] = sorted_red_p_data
        out_dict[str(red_T)+',p_calc'] = sorted_red_p

    out_df = pd.DataFrame(dict([(k,pd.Series(v)) for k, v in out_dict.items()]))

    out_df.to_csv(out_path)
    print('Saved run to :', out_path)

    with open(outdir+'/meta.txt', 'a') as file:
        file.write(out_path+' time : '+datetime.now().strftime("%d/%m/%Y %H:%M:%S")+
                   '\ncomps = '+str(comps)+', x = '+str(composition) +
                   ', e = '+str(e) + ', s = '+str(s)+ ', N_T_vals = '+str(N_T_vals)+
                   ', minimum data points = ' + str(T_val_counts[T_vals[N_T_vals]])+
                   '\n\n')

    print('Updated meta file')

if __name__ == '__main__':
    from pyctp import saftvrmie
    print('#|')
    print('-' * (n_imports + 1))

    binary_mixture(e=0.5)

    exit(0)
    e_list = [1, 1.3, 1.6, 1.9, 2.2]
    p_c = 1e5
    for e in e_list:
        print(e)
        phase_envelope('AR', e=e)
