from models.kineticgas_old import KineticGas
from pycThermopack.pyctp import cubic
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gs
import time
import scipy.optimize as opt
import matplotlib.cm as cm

root_path = '/home/ubuntu/Home/Documents/7_semester/irrev_prosjekt/'
data_path = root_path + 'data/'
plots_path = root_path + 'plots/alpha_T0/'

def fit_func(n, a, b):
    return np.exp(a * (n**b)) - np.exp(a)

def plot_convergence(comps):
    n_list = np.array([i for i in range(1, 11)])
    d_list = np.zeros((len(n_list), 3))
    time_list = np.zeros(len(n_list))

    fig = plt.figure(figsize=(8, 9))
    grid = gs.GridSpec(ncols=1, nrows=4, figure=fig, hspace=0.1)

    axs = [None for i in range(4)]
    axs[0] = fig.add_subplot(grid[0])
    for i in range(1, 4):
        axs[i] = fig.add_subplot(grid[i], sharex=axs[0])

    for i in range(3):
        plt.setp(axs[i].get_xticklabels(), visible=False)

    eos = cubic.cubic()
    eos.init(comps, 'SRK')
    for i, n in enumerate(n_list):
        print('Computing', n, 'order approximation')
        t0 = time.process_time()
        analytical = KineticGas(comps, eos, mole_fracs=[0.2, 0.8], N=n)
        time_list[i] = time.process_time() - t0
        d_list[i] = [analytical.d_1, analytical.d0, analytical.d1]

    d_list = d_list.transpose()
    label_list = [r'$d_{-1}$', r'$d_0$', r'$d_1$']
    for i in range(3):
        plt.sca(axs[i])
        plt.plot(n_list, d_list[i], color='black')
        plt.ylabel(label_list[i], fontsize=14)

    plt.sca(axs[-1])
    plt.plot(n_list, time_list, color='black')
    plt.ylabel('Time [s]', fontsize=14)
    plt.xlabel('Order of approximation', fontsize=14)

    #plt.suptitle(r'Convergence of parameters used to calculate $\alpha_T^0$' + '\n'
    #                r'and required runtime for computation.')

    plt.savefig('convergence2', dpi=600)

def plot_time_fit(comps):
    n_list = np.array([i for i in range(15, 22)])
    time_list = np.zeros(len(n_list))

    eos = cubic.cubic()
    eos.init(comps, 'SRK')
    for i, n in enumerate(n_list):
        print('Computing', n, 'order approximation')
        t0 = time.process_time()
        analytical = KineticGas(comps, eos, mole_fracs=[0.2, 0.8], N=n)
        time_list[i] = time.process_time() - t0

    print('Fitting')
    fit, _ = opt.curve_fit(fit_func, n_list, time_list, p0=(1, 1))
    poly = np.polyfit(n_list, time_list, 5)

    n_line = np.linspace(min(n_list), max(n_list), 50)
    plt.plot(n_line, fit_func(n_line, fit[0], fit[1]), linestyle='--', color='black')
    plt.plot(n_list, time_list, color = 'black')
    plt.plot(n_line, np.polyval(poly, n_line), color = 'red', linestyle=':')

    print('exp :', fit)

    plt.savefig('runtime')

def files_298():
    save_path = plots_path + '298_files/'

def plot_test(comps, N, T=298, save=False):
    '''
    Plot soret coefficient from kinetic_x gas theory for N'th order approximation with different m2/m1 and sigma2/sigma1 ratios
    :param N (int): Order of approximation
    :param compare (bool): Use same limits as Tipton, Tompson and Loyalka
    :param save (str): Filename to save figure as (defaults to False)
    '''
    comp1, comp2 = comps.split(',')
    save_path = plots_path + comp1 + '_' + comp2
    x_list = np.linspace(0.1, 0.9, 50)

    eos = cubic.cubic()
    eos.init(comps, 'SRK')

    alpha_list = np.zeros(len(x_list))

    for i, x in enumerate(x_list):
        alpha_T0 = KineticGas(comps, eos, mole_fracs=[x, 1 - x], N=N)
        alpha_list[i] = alpha_T0.soret * T

    plt.plot(x_list, alpha_list, label=N)
    print('Plotted for :', comps, N)
    if save:
        plt.legend(title = 'N')
        plt.savefig(save_path, dpi=600)

