'''
Author: Vegard G. Jervell
Purpose: Plot simulated data produced by the KempersXX models and demostrate usage of the models
Requires: Matplotlib, Pandas, Numpy, ThermoPack
Notes: Assumes that this file is placed in a direcory named 'plotting', that data files are
        in a directory named 'data' in the project root directory, and that the models are
        in a directory named 'models', also in the project root directory.
'''

from models.kempers01 import Kempers01
from models.kempers89 import Kempers89
from models.modKempers89_ import Mod_Kempers89
from pycThermopack.pyctp import cubic, pcsaft, extended_csp, cpa
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import os

root_path = os.path.dirname(os.path.abspath(__file__)).strip('plotting')
data_path = root_path + 'data/'

#Defining styles globaly for consistency
lines_dict = {'VdW' : '-', 'SRK' : '--', 'PR' : '-.', 'PT' : ':',
              'SW' : (0, (1, 2, 1, 2, 4, 2)), 'PC-SAFT' : '-', 'SPUNG' : '-', 'CPA' : '-'}

marker_dict = {'VdW' : '', 'SRK' : '', 'PR' : '', 'PT' : '', 'SW' : '',
               'PC-SAFT' : 'v', 'SPUNG' : 'o', 'CPA' : 's'}

eos_cmap = cm.get_cmap('viridis')
color_dict = {'VdW' : eos_cmap(0), 'SRK' : eos_cmap(1/6), 'PR' : eos_cmap(2/6),
                 'PT' : eos_cmap(3/6), 'SW' : eos_cmap(4/6),
               'PC-SAFT' :  eos_cmap(5/6), 'SPUNG' :  eos_cmap(6/6), 'CPA' : eos_cmap(6/6)}

class DataPlotter:
    def __init__(self, model, mode='cov', points=50, no_h0=False, pred=True, exp=True):
        self.points = points
        self.mode = mode
        self.model_name = model
        self.pred = pred
        self.exp = exp
        if model == 'K89':
            self.plots_path = root_path + 'plots/kempers89/' + mode + '/'
            self.KempersXX = Kempers89
        elif model == 'M-K89':
            self.plots_path = root_path + 'plots/kempers89/' + 'no_h0' + '/' + mode + '/'
            self.KempersXX = Mod_Kempers89
        elif model == 'K01':
            self.plots_path = root_path + 'plots/kempers01/' + mode + '/'
            self.KempersXX = Kempers01
        else:
            raise ValueError("'model' must be either 'K89', 'M-K89' or 'Kempers01'")

    def etoh_h2o_T(self, eos, eos_name, cpa_flag=False, ylim=None):
        save_path = self.plots_path + 'etoh_h2o/ethanol_water_temp_' + eos_name
        if cpa_flag is True:
            save_path += '_CPA'
            eos_name += ' CPA'
        # Get data
        data = pd.read_csv(data_path + 'ethanol_water.csv', na_values='NaN')
        temp_list = [int(x) for x in data.columns[1:]]
        cons_list = data['c']

        comps = 'ETOH,H2O'
        temp_ax = np.linspace(min(temp_list), max(temp_list), self.points) + 273

        # Because experimental data are given with weight fraction ethanol
        # So must convert to mole fractions to use model
        M_h2o = 18.02
        M_etoh = 46.07

        # set up plot
        fig = plt.figure(figsize=(10, 5))
        ax = plt.subplot(111)
        ax.set_position([0.08, 0.11, 0.75, 0.77])
        cmap = cm.get_cmap('viridis')

        # Matrix containing all experimental data points
        ST_matr = np.array([[x for x in data[str(T)]] for T in temp_list]).transpose()

        # Do actual plotting
        max_c = max(cons_list)
        for ST, c in zip(ST_matr, cons_list):
            x_etoh = (c / M_etoh) / ((c / M_etoh) + ((1 - c) / M_h2o))
            x_h2o = 1 - x_etoh
            model = self.KempersXX(comps, eos, x=[x_etoh, x_h2o])
            plt.plot(temp_ax - 273, model.get_soret_temp(temp_ax, mode=self.mode)[0] * 1e3, color=cmap(c / max_c))
            plt.scatter(temp_list, ST, color=cmap(c / max_c), label=str(c), marker='x')

        plt.xlabel(r'T [$^{\circ}$C]', fontsize=14)
        plt.ylabel(r'$s_T$ [mK$^{-1}$]', fontsize=14)

        if ylim is not None:
            plt.ylim(ylim[0], ylim[1])
            save_path += '_ylim'

        legend = plt.legend(title=r'$\omega_{EtOH}$ [–]', bbox_to_anchor=(1.01, 1.015), loc='upper left')
        plt.setp(legend.get_title(), fontsize=14)

        plt.sca(ax)
        plt.title(eos_name, fontsize=14)

        plt.savefig(save_path, dpi=600)
        print('Saved', save_path + '.png')
        plt.close(fig)

    def etoh_h2o_c(self, eos, eos_name, cpa_flag=False, ylim=None):
        save_path = self.plots_path + 'etoh_h2o/ethanol_water_cons_' + eos_name
        if cpa_flag is True:
            save_path += '_CPA'
            eos_name += ' CPA'
        # Get data
        data = pd.read_csv(data_path + 'ethanol_water.csv', na_values='NaN')
        temp_list = [int(x) for x in data.columns[1:]]
        cons_list = data['c']

        comps = 'ETOH,H2O'

        M_h2o = 18.02
        M_etoh = 46.07

        w_etoh_ax = np.linspace(min(cons_list), max(cons_list), self.points)
        w_h20_ax = 1 - w_etoh_ax

        x_etoh_ax = (w_etoh_ax / M_etoh) / ((w_etoh_ax / M_etoh) + (w_h20_ax / M_h2o))

        # Set up plot
        cmap = cm.get_cmap('cool')
        max_T = max(temp_list)

        fig = plt.figure(figsize=(10, 5))
        ax = plt.subplot(111)
        ax.set_position([0.1, 0.11, 0.75, 0.77])

        for T in temp_list:
            print(T)
            model = self.KempersXX(comps, eos, temp=T + 273)
            if self.pred is True:
                plt.plot(w_etoh_ax, model.get_soret_comp(x_etoh_ax, mode=self.mode)[1] * 1e3, color=cmap((T) / (max_T)))

            if self.exp is True:
                plt.scatter(cons_list, data[str(T)], color=cmap((T) / (max_T)), label=str(T), marker='x')

        plt.xlabel(r'$\omega_{{EtOH}}$ [–]', fontsize=14)
        plt.ylabel(r'$s_T$ [mK$^{-1}$]', fontsize=14)

        if ylim is not None:
            plt.ylim(ylim[0], ylim[1])
            save_path += '_ylim'

        legend = plt.legend(title='T [$^{\circ}$C]', bbox_to_anchor=(1.015, 1.015), loc='upper left', fontsize=14)
        plt.setp(legend.get_title(), fontsize=14)

        plt.title(eos_name)
        plt.savefig(save_path, dpi=600)
        print('Saved', save_path + '.png')
        plt.close(fig)

    def toluene_n_hexane_c(self, eos, eos_name):
        plotname = 'toluene_n_hexane_cons'
        save_path = self.plots_path + 'toluene_hexane/' + plotname + '_' + eos_name
        # Read in data
        data = pd.read_excel(data_path + 'Toluene_n_hexane.xlsx')
        temp_list = [float(x) for x in data.columns[1:]]
        cons_list = data['x_toluene:']

        # Set up model
        comps = 'TOLU,NC6'
        tolu_ax = np.linspace(min(cons_list), max(cons_list), self.points)

        # Set up plot
        cmap = cm.get_cmap('cool')

        max_T = max(temp_list)

        fig = plt.figure()

        # Do plotting
        for i, T in enumerate(temp_list):
            model = self.KempersXX(comps, eos, temp=T + 273)
            soret = model.get_soret_comp(tolu_ax, mode=self.mode)[0] * 1e3
            plt.plot(tolu_ax, soret, color=cmap(T / max_T))
            plt.scatter(cons_list, data[T], color=cmap(T / max_T), label=str(T), marker='x')

        plt.xlabel(r'$x_{{Toluene}}$ [–]', fontsize=14)
        plt.ylabel(r'$s_T$ [mK$^{-1}$]', fontsize=14)

        plt.legend(title='T [$^{\circ}$C]', bbox_to_anchor=(1.015, 1.015), loc='upper left')
        plt.title('Soret cofficient of toluene in n-hexane, computed with\n'
                 +self.model_name+' using '+eos_name+' equation of state')

        plt.savefig(save_path, dpi=600)
        print('Saved ', save_path + '.png')
        plt.close(fig)

    def toluene_n_hexane_T(self, eos, eos_name):
        # Getting data
        plotname = 'toluene_n_hexane_temp'
        save_path = self.plots_path + 'toluene_hexane/' + plotname + '_' + eos_name

        data = pd.read_excel(data_path + 'Toluene_n_hexane.xlsx')

        temp_list = [float(x) for x in data.columns[1:]]
        cons_list = data['x_toluene:'].to_list()

        comps = 'TOLU,NC6'
        temps = np.linspace(min(temp_list), max(temp_list), self.points) + 273

        # Setting up plot
        cmap = cm.get_cmap('viridis')
        max_c = max(cons_list)
        fig = plt.figure()

        for c in cons_list:
            model = self.KempersXX(comps, eos, x=np.array([c, 1 - c]))
            soret = model.get_soret_temp(temps, mode=self.mode)[0] * 1e3
            plt.plot(temps - 273, soret, color=cmap((c) / (max_c)))
            plt.scatter(temp_list, data.loc[data['x_toluene:'] == c].values.flatten()[1:]
                        , color=cmap((c) / (max_c)), label=str(c), marker='x')

        plt.xlabel(r'$T$ [$^{\circ}$C]', fontsize=14)
        plt.ylabel(r'$s_T$ [mK$^{-1}$]', fontsize=14)

        plt.legend(title='$x_{Toluene}$ [-]', bbox_to_anchor=(1.015, 1.015), loc='upper left')
        plt.title('Soret cofficient of toluene in n-hexane computed with\n'+self.model_name
                  + ' using '+eos_name+' equation of state')
        plt.savefig(save_path, dpi=600)
        print('Saved ', save_path + '.png')
        plt.close(fig)

    def data_298(self, filename, comps, compname):
        data = pd.read_excel(data_path + '298_files/' + filename + '_298.xlsx')
        x_data = data['c']
        ST_data = data['ST']

        x_axis = np.linspace(min(x_data), max(x_data), self.points)

        eos_list = ['VdW', 'SRK', 'PR', 'PT', 'SW', 'PC-SAFT', 'SPUNG']

        eos = cubic.cubic()
        eos.init(comps, 'VdW')
        model = self.KempersXX(comps, eos, temp=298, pres=1e5)

        fig, ax = plt.subplots()
        if self.pred is True:
            for i, eos_key in enumerate(eos_list[:-2]):
                eos = cubic.cubic()
                eos.init(comps, eos_key)
                model.set_eos(eos)
                plt.plot(x_axis, model.get_soret_comp(x_axis, mode=self.mode)[0] * 1e3,
                         linestyle=lines_dict[eos_key], label=eos_key, color=color_dict[eos_key])

            eos = pcsaft.pcsaft()
            eos.init(comps)
            model = self.KempersXX(comps, eos, temp=298)
            plt.plot(x_axis, model.get_soret_comp(x_axis, mode=self.mode)[0] * 1e3,
                     linestyle=lines_dict['PC-SAFT'], marker=marker_dict['PC-SAFT'],
                     label=eos_list[-2], color=color_dict['PC-SAFT'], markevery=5)

            eos = extended_csp.ext_csp()
            eos.init(comps, 'SRK', 'Classic', 'vdW', 'NIST_MEOS', 'C3')
            model = self.KempersXX(comps, eos, temp=298)
            plt.plot(x_axis, model.get_soret_comp(x_axis, mode=self.mode)[0] * 1e3,
                     linestyle=lines_dict['SPUNG'], marker=marker_dict['SPUNG'],
                     label=eos_list[-1], color=color_dict['SPUNG'], markevery=5)

        if self.exp is True:
            plt.scatter(x_data, ST_data, marker='x', color='black', label='Experimental')

        plt.xlabel('$x_{' + compname + '}$[-]', fontsize=14)
        ax.set_ylabel('$S_T$ [mK$^{-1}$]', fontsize=14)
        plt.title('Soret coefficient of '+compname+' in '+filename.split('_')[-1]+
                  ' at 298K using '+self.model_name)

        save_path = self.plots_path + '298_files/' + filename + '_298'

        if self.pred and self.exp:
            pass
        elif self.pred:
            save_path += '_pred'
        elif self.exp:
            save_path += '_exp'

        plt.savefig(save_path, dpi=600)
        print('Saved', save_path + '.png')
        plt.close(fig)

        fig = plt.figure()
        for key in eos_list:
            plt.plot(0, 0, linestyle=lines_dict[key], marker=marker_dict[key],
                     color=color_dict[key], label=key)
        plt.scatter(0, 0, color='black', marker='x', label='Experimental')
        plt.legend(ncol=3)
        plt.savefig(self.plots_path + '298_files/legend.png', dpi=600)
        print('Saved', self.plots_path + '298_files/legend.png')
        plt.close(fig)

    def plot_298(self):
        solute_names = ['benzene', 'toluene']
        solvent_names = ['_hexane', '_heptane']
        solvent_codes = [',NC6', ',NC7']

        comp_codes = ['BENZENE', 'TOLU']
        compnames = ['benzene', 'toluene']

        for i in range(2):
            for j in range(2):
                self.data_298(solute_names[i] + solvent_names[j], comp_codes[i] + solvent_codes[j], compnames[i])

    def n_alkanes(self):
        mixtures = ['NC10,NC5', 'NC12,NC6', 'NC12,NC7', 'NC12,NC8']

        eos_list = ['VdW', 'SRK', 'PR', 'PT', 'SW', 'PC-SAFT', 'SPUNG']

        for comps in mixtures:
            comp1, comp2 = comps.split(',')

            save_path = self.plots_path + 'n_alkanes/'
            save_path += comp1 + '_' + comp2

            data = pd.read_excel(data_path + 'n-alkanes/' + comp1 + '_' + comp2 + '_298K.xlsx')
            x1_list = data['c']
            soret_list = data['ST']

            x_axis = np.linspace(min(x1_list), max(x1_list), self.points)

            fig, ax = plt.subplots()
            if self.pred:
                for i, eos_key in enumerate(eos_list[:-2]):
                    print(eos_key, comps)
                    eos = cubic.cubic()
                    eos.init(comps, eos_key)
                    model = self.KempersXX(comps, eos, temp=298)
                    plt.plot(x_axis, model.get_soret_comp(x_axis, mode=self.mode)[1] * 1e3,
                             linestyle=lines_dict[eos_key], label=eos_key,
                             color=color_dict[eos_key])

                eos = pcsaft.pcsaft()
                eos.init(comps)
                model = self.KempersXX(comps, eos, temp=298)
                plt.plot(x_axis, model.get_soret_comp(x_axis, mode=self.mode)[1] * 1e3,
                         linestyle=lines_dict['PC-SAFT'], marker=marker_dict['PC-SAFT'],
                         label=eos_list[-2], color=color_dict['PC-SAFT'], markevery=5)

                eos = extended_csp.ext_csp()
                eos.init(comps, 'SRK', 'Classic', 'vdW', 'NIST_MEOS', 'C3')
                model = self.KempersXX(comps, eos, temp=298)
                plt.plot(x_axis, model.get_soret_comp(x_axis, mode=self.mode)[1] * 1e3,
                         linestyle=lines_dict['SPUNG'], marker=marker_dict['SPUNG'],
                         label=eos_list[-1], color=color_dict['SPUNG'], markevery=5)

            if self.exp:
                plt.scatter(x1_list, soret_list, marker='x', color='black', label='Experimental')

            if self.pred and self.exp:
                pass
            elif self.pred:
                save_path += '_pred'
            elif self.exp:
                save_path += '_exp'
                plt.yticks(fontsize=15)
                plt.xticks(fontsize=15)

            plt.xlabel(r'$x_{' + comp1 + '}$ [-]', fontsize=14)
            plt.ylabel(r'$S_T$ [mK$^{-1}$]', fontsize=14)
            plt.title(comps, fontsize=16)
            plt.savefig(save_path, dpi=600)
            print('Saved', save_path + '.png')
            plt.close(fig)

            fig = plt.figure()
            for key in eos_list:
                plt.plot(0, 0, linestyle=lines_dict[key], marker=marker_dict[key],
                         color=color_dict[key], label=key)
            plt.scatter(0, 0, color='black', marker='x', label='Experimental')
            plt.legend(ncol=3)
            plt.savefig(self.plots_path + 'n_alkanes/legend.png', dpi=600)
            print('Saved', self.plots_path + 'n_alkanes/legend.png')
            plt.close(fig)

    def cold_gases(self):
        comps = 'AR,C1'
        T = 88

        save_path = self.plots_path + 'cold_gases/AR_C1_' + str(T) + 'K'
        data = pd.read_excel(data_path + 'cold_gases/AR_C1_' + str(T) + 'K.xlsx')
        x_list = data['c']
        soret_list = data['ST']

        x_axis = np.linspace(min(x_list), max(x_list), self.points)

        eos_list = ['VdW', 'SRK', 'PR', 'PT', 'SW']

        fig = plt.figure()

        if self.pred is True:
            for i, eos_key in enumerate(eos_list):
                print(eos_key)
                eos = cubic.cubic()
                eos.init(comps, eos_key)
                model = self.KempersXX(comps, eos, temp=T)
                plt.plot(x_axis, model.get_soret_comp(x_axis, mode=self.mode)[1] * 1e3,
                         color=color_dict[eos_key], linestyle=lines_dict[eos_key], label=eos)

            eos = extended_csp.ext_csp()
            eos.init(comps, 'SRK', 'Classic', 'vdW', 'NIST_MEOS', 'C3')
            model = self.KempersXX(comps, eos, temp=298)
            plt.plot(x_axis, model.get_soret_comp(x_axis, mode=self.mode)[1] * 1e3,
                     linestyle=lines_dict['SPUNG'], marker=marker_dict['SPUNG'],
                     label=eos_list[-1], color=color_dict['SPUNG'], markevery=5)

        if self.exp is True:
            plt.scatter(x_list, soret_list, marker='x', color='black', label='Experimental')

        plt.legend()
        plt.xlabel(r'$x_{AR}$ [-]', fontsize=14)
        plt.ylabel(r'$S_T$ [mK$^{-1}$]', fontsize=14)
        plt.title('Soret coefficient of '+comps+' at '+str(T)+'K, computed with '+self.model_name)

        plt.savefig(save_path, dpi=600)
        print('Saved', save_path + '.png')
        plt.close(fig)

        fig = plt.figure()
        for key in eos_list:
            plt.plot(0,0, linestyle=lines_dict[key], marker=marker_dict[key],
                     color=color_dict[key], label=key)
        plt.plot(0, 0, linestyle=lines_dict['SPUNG'], marker=marker_dict['SPUNG'],
                 color=color_dict['SPUNG'], label='SPUNG')
        plt.scatter(0,0, marker='x', color='black', label='Experimental')
        plt.legend(ncol=3)
        plt.savefig(self.plots_path+'cold_gases/legend.png', dpi=600)
        plt.close(fig)

    def propanol_h2o(self):
        save_path = self.plots_path + 'propanol_h2o'

        data = pd.read_excel(data_path + 'Isopropanol.xlsx')
        x_list = data['c']
        soret_list = data['ST']

        x_axis = np.linspace(min(x_list), max(x_list), self.points)

        comps = 'PROP1OL,H2O'

        fig = plt.figure()

        eos_keys = ['VdW', 'SRK', 'PR', 'PT', 'SW']
        for key in eos_keys:
            eos = cubic.cubic()
            eos.init(comps, key)
            model = self.KempersXX(comps, eos)
            plt.plot(x_axis, model.get_soret_comp(x_axis, mode=self.mode)[0] * 1e3,
                     linestyle=lines_dict[key], color=color_dict[key], label=key)

        eos = cpa.cpa()
        eos.init(comps, 'SRK')
        model = self.KempersXX(comps, eos)
        plt.plot(x_axis, model.get_soret_comp(x_axis, mode=self.mode)[0] * 1e3,
                 linestyle=lines_dict['CPA'], color=color_dict['CPA'],
                 marker=marker_dict['CPA'], label='SRK-CPA', markevery=3)

        plt.scatter(x_list, soret_list, marker='x', color='black')

        plt.ylim(-50, 50)
        plt.xlabel(r'$x_{H_2O}$ [-]')
        plt.ylabel(r'$S_T$ [mK$^{-1}$]')

        plt.legend()

        plt.savefig(save_path, dpi=600)
        print('Saved', save_path + '.png')
        plt.close(fig)

        fig = plt.figure()
        for key in eos_keys:
            plt.plot(0,0, linestyle=lines_dict[key], marker=marker_dict[key],
                     color=color_dict[key], label=key)
        plt.scatter(0,0, marker='x', color='black', label='Experimental')
        plt.legend(ncol=3)
        plt.savefig(self.plots_path+'/propanol_legend.png')
        plt.close(fig)