from models.kineticgas import KineticGas
from pycThermopack.pyctp import cubic
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gs
import time
import scipy.optimize as opt
import matplotlib.cm as cm

root_path = '/home/ubuntu/Home/Documents/7_semester/irrev_prosjekt/'
data_path = root_path + 'data/'
plots_path = root_path + 'plots/alpha_T0/'

def fit_func(n, a, b):
    return np.exp(a * (n**b)) - np.exp(a)

def plot_convergence(comps):
    n_list = np.array([i for i in range(1, 11)])
    d_list = np.zeros((len(n_list), 3))
    time_list = np.zeros(len(n_list))

    fig = plt.figure(figsize=(8, 9))
    grid = gs.GridSpec(ncols=1, nrows=4, figure=fig, hspace=0.1)

    axs = [None for i in range(4)]
    axs[0] = fig.add_subplot(grid[0])
    for i in range(1, 4):
        axs[i] = fig.add_subplot(grid[i], sharex=axs[0])

    for i in range(3):
        plt.setp(axs[i].get_xticklabels(), visible=False)

    eos = cubic.cubic()
    eos.init(comps, 'SRK')
    for i, n in enumerate(n_list):
        print('Computing', n, 'order approximation')
        t0 = time.process_time()
        analytical = KineticGas(comps, eos, mole_fracs=[0.2, 0.8], N=n)
        time_list[i] = time.process_time() - t0
        d_list[i] = [analytical.d_1, analytical.d0, analytical.d1]

    d_list = d_list.transpose()
    label_list = [r'$d_{-1}$', r'$d_0$', r'$d_1$']
    for i in range(3):
        plt.sca(axs[i])
        plt.plot(n_list, d_list[i], color='black')
        plt.ylabel(label_list[i], fontsize=14)

    plt.sca(axs[-1])
    plt.plot(n_list, time_list, color='black')
    plt.ylabel('Time [s]', fontsize=14)
    plt.xlabel('Order of approximation', fontsize=14)

    #plt.suptitle(r'Convergence of parameters used to calculate $\alpha_T^0$' + '\n'
    #                r'and required runtime for computation.')

    plt.savefig('convergence2', dpi=600)

def plot_time_fit(comps):
    n_list = np.array([i for i in range(15, 22)])
    time_list = np.zeros(len(n_list))

    eos = cubic.cubic()
    eos.init(comps, 'SRK')
    for i, n in enumerate(n_list):
        print('Computing', n, 'order approximation')
        t0 = time.process_time()
        analytical = KineticGas(comps, eos, mole_fracs=[0.2, 0.8], N=n)
        time_list[i] = time.process_time() - t0

    print('Fitting')
    fit, _ = opt.curve_fit(fit_func, n_list, time_list, p0=(1, 1))
    poly = np.polyfit(n_list, time_list, 5)

    n_line = np.linspace(min(n_list), max(n_list), 50)
    plt.plot(n_line, fit_func(n_line, fit[0], fit[1]), linestyle='--', color='black')
    plt.plot(n_list, time_list, color = 'black')
    plt.plot(n_line, np.polyval(poly, n_line), color = 'red', linestyle=':')

    print('exp :', fit)

    plt.savefig('runtime')

def files_298():
    save_path = plots_path + '298_files/'

def plot_test(comps, N, T=298, save=False):
    '''
    Plot soret coefficient from kinetic gas theory for N'th order approximation with different m2/m1 and sigma2/sigma1 ratios
    :param N (int): Order of approximation
    :param compare (bool): Use same limits as Tipton, Tompson and Loyalka
    :param save (str): Filename to save figure as (defaults to False)
    '''
    comp1, comp2 = comps.split(',')
    save_path = plots_path + comp1 + '_' + comp2
    x_list = np.linspace(0.1, 0.9, 50)

    eos = cubic.cubic()
    eos.init(comps, 'SRK')

    alpha_list = np.zeros(len(x_list))

    for i, x in enumerate(x_list):
        alpha_T0 = KineticGas(comps, eos, mole_fracs=[x, 1 - x], N=N)
        alpha_list[i] = alpha_T0.soret * T

    plt.plot(x_list, alpha_list, label=N)
    print('Plotted for :', comps, N)
    if save:
        plt.legend(title = 'N')
        plt.savefig(save_path, dpi=600)

