import os, sys, platform

sys.path.append(os.path.dirname(os.getcwd()) + '/soret_model/cpp/release')
sys.path.append(os.path.dirname(os.getcwd()) + '/soret_model')
sys.path.append(os.path.dirname(os.getcwd()) + '/soret_model/cpp')

import matplotlib.pyplot as plt
from matplotlib import cm
import pandas as pd
import numpy as np
from datetime import datetime

ROOT_PATH = os.path.dirname(os.path.abspath(__file__))
DATA_PATH = ROOT_PATH + '/data/benchmark'

# Defining styles globaly for consistency
lines_dict = {'VdW': '-', 'SRK': '--', 'PR': '-.', 'PT': ':',
              'SW': (0, (1, 2, 1, 2, 4, 2)), 'PC-SAFT': '-', 'SPUNG': '-', 'CPA': '-', 'SAFT-VR-MIE' : '-'}

marker_dict = {'VdW': '', 'SRK': '', 'PR': '', 'PT': '', 'SW': '',
               'PC-SAFT': 'v', 'SPUNG': 'o', 'CPA': 's', 'SAFT-VR-MIE' : '2'}

eos_cmap = cm.get_cmap('viridis')
color_dict = {'VdW': eos_cmap(0), 'SRK': eos_cmap(1 / 6), 'PR': eos_cmap(2 / 6),
              'PT': eos_cmap(3 / 6), 'SW': eos_cmap(4 / 6),
              'PC-SAFT': eos_cmap(5 / 6), 'SPUNG': eos_cmap(6 / 6), 'CPA': eos_cmap(6 / 6), 'SAFT-VR-MIE' : 'black'}


class BenchmarkPlotter:
    def __init__(self, mode='cov', version=None, pred=True, exp=True):
        self.pred = pred
        self.exp = exp
        self.model_name = 'Kempers01, KineticGas'

        # Make output dir for this benchmark run
        if version is None:
            i = 0
            while os.path.isdir(ROOT_PATH + '/plots/benchmark/V' + str(i)):
                i += 1

            self.outdir = ROOT_PATH + '/plots/benchmark/V' + str(i)
            self.outpath = self.outdir + '/' + mode
            self.indir = ROOT_PATH + '/output/benchmark/V' + str(i)
            self.inpath = self.indir + '/' + mode
            if not os.path.isdir(self.inpath):
                raise FileNotFoundError('No data matching this plot run! (V'+str(i)+')')

            os.mkdir(self.outdir)
            os.mkdir(self.outpath)

            # Writing meta file to output dir
            with open(self.outdir + '/meta.txt', 'w') as ofile:
                tags = ['Date']
                vals = [datetime.now().strftime("%d/%m/%Y %H:%M:%S")]
                for tag, val in zip(tags, vals):
                    ofile.write(tag + ' : ' + str(val) + '\n')
                with open(self.indir + '/meta.txt', 'r') as ifile:
                    ofile.write('\nMeta file in input dir :\n')
                    for line in ifile.readlines():
                        ofile.write(line)

                ofile.write('\nContains plots for :\n')

        else:
            self.outdir = ROOT_PATH + '/plots/benchmark/V' + str(version)
            self.outpath = self.outdir + '/'+mode
            self.indir = ROOT_PATH + '/output/benchmark/V' + str(version)
            self.inpath = self.indir + '/' + mode
            if not os.path.isdir(self.inpath):
                raise FileNotFoundError('No data matching this plot run! (V' + str(version) + ')')

            if os.path.isdir(self.outpath):
                print('This BenchmarkPlotter instance may overwrite files in', self.outpath)
                verify = input("'Y' to verify")
                if verify != 'Y':
                    exit(0)
            else:
                os.mkdir(self.outpath)

            # Writing meta file to output dir
            with open(self.outdir + '/meta.txt', 'a') as ofile:
                ofile.write('\n#########################\n')
                tags = ['Date']
                vals = [datetime.now().strftime("%d/%m/%Y %H:%M:%S")]
                for tag, val in zip(tags, vals):
                    ofile.write(tag + ' : ' + str(val) + '\n')
                with open(self.indir + '/meta.txt', 'r') as ifile:
                    ofile.write('\nMeta file in input dir :\n')
                    for line in ifile.readlines():
                        ofile.write(line)

                ofile.write('\nContains plots for :\n')

    def update_meta(self, filename):
        with open(self.outpath + '/meta.txt', 'a') as file:
            file.write(filename + '\n')

    def data_298(self, filename, comps, compname):
        exp_data = pd.read_excel(DATA_PATH + '/298_files/' + filename + '_298.xlsx')
        x_data = exp_data['c']
        ST_data = exp_data['ST']

        pred_data = pd.read_csv(self.inpath + '/298_files/' + filename + '.csv')
        x_axis = pred_data['x_'+comps.split(',')[0]]

        eos_list = ['VdW', 'SRK', 'PR', 'PT', 'SW', 'PC-SAFT', 'SPUNG']

        fig, ax = plt.subplots()
        if self.pred:
            for i, eos_key in enumerate(eos_list[:-2]):
                soret = pred_data[eos_key]
                plt.plot(x_axis, soret,
                         linestyle=lines_dict[eos_key], label=eos_key, color=color_dict[eos_key])


            soret = pred_data['PC-SAFT']
            plt.plot(x_axis, soret * 1e3,
                     linestyle=lines_dict['PC-SAFT'], marker=marker_dict['PC-SAFT'],
                     label=eos_list[-2], color=color_dict['PC-SAFT'], markevery=5)

            soret = pred_data['SPUNG']
            plt.plot(x_axis, soret * 1e3,
                     linestyle=lines_dict['SPUNG'], marker=marker_dict['SPUNG'],
                     label=eos_list[-1], color=color_dict['SPUNG'], markevery=5)

        if self.exp:
            plt.scatter(x_data, ST_data, marker='x', color='black', label='Experimental')

        plt.xlabel('$x_{' + compname + '}$[-]', fontsize=14)
        ax.set_ylabel('$S_T$ [mK$^{-1}$]', fontsize=14)
        plt.title('Soret coefficient of ' + compname + ' in ' + filename.split('_')[-1] +
                  ' at 298K using ' + self.model_name)

        save_path = self.outpath + '/298_files/' + filename + '_298'

        fig.legend()
        plt.savefig(save_path, dpi=600)
        print('Saved', save_path + '.png')
        self.update_meta(filename)
        plt.close(fig)

        #fig = plt.figure()
        #for key in eos_list:
        #    plt.plot(0, 0, linestyle=lines_dict[key], marker=marker_dict[key],
        #             color=color_dict[key], label=key)
        #plt.scatter(0, 0, color='black', marker='x', label='Experimental')
        #plt.legend(ncol=3)
        #plt.savefig(self.outpath + '298_files/legend.png', dpi=600)
        #print('Saved', self.outpath + '298_files/legend.png')
        #plt.close(fig)

    def plot_298(self):
        os.mkdir(self.outpath+'/298_files')

        solute_names = ['benzene', 'toluene']
        solvent_names = ['_hexane', '_heptane']
        solvent_codes = [',NC6', ',NC7']

        comp_codes = ['BENZENE', 'TOLU']
        compnames = ['benzene', 'toluene']

        for i in range(2):
            for j in range(2):
                self.data_298(solute_names[i] + solvent_names[j], comp_codes[i] + solvent_codes[j], compnames[i])

    def n_alkanes(self):
        os.mkdir(self.outpath + '/n_alkanes')
        mixtures = ['NC10,NC5', 'NC12,NC6', 'NC12,NC7', 'NC12,NC8']

        eos_list = ['VdW', 'SRK', 'PR', 'PT', 'SW', 'PC-SAFT', 'SPUNG']

        for comps in mixtures:
            comp1, comp2 = comps.split(',')

            save_path = self.outpath + '/n_alkanes/'
            save_path += comp1 + '_' + comp2

            exp_data = pd.read_excel(DATA_PATH + '/n-alkanes/' + comp1 + '_' + comp2 + '_298K.xlsx')
            x1_list = exp_data['c']
            soret_list = exp_data['ST']

            pred_data = pd.read_csv(self.inpath + '/n_alkanes/' + comp1 + '_' + comp2 + '.csv')
            x_axis = pred_data['x_'+comp2]

            fig, ax = plt.subplots()
            if self.pred:
                for i, eos_key in enumerate(eos_list[:-2]):
                    soret = pred_data[eos_key]
                    plt.plot(x_axis, soret * 1e3,
                             linestyle=lines_dict[eos_key], label=eos_key,
                             color=color_dict[eos_key])

                soret = pred_data['PC-SAFT']
                plt.plot(x_axis, soret * 1e3,
                         linestyle=lines_dict['PC-SAFT'], marker=marker_dict['PC-SAFT'],
                         label=eos_list[-2], color=color_dict['PC-SAFT'], markevery=5)

                soret = pred_data['SPUNG']
                plt.plot(x_axis, soret * 1e3,
                         linestyle=lines_dict['SPUNG'], marker=marker_dict['SPUNG'],
                         label=eos_list[-1], color=color_dict['SPUNG'], markevery=5)

            if self.exp:
                plt.scatter(x1_list, soret_list, marker='x', color='black', label='Experimental')

            if self.pred and self.exp:
                pass
            elif self.pred:
                save_path += '_pred'
            elif self.exp:
                save_path += '_exp'
                plt.yticks(fontsize=15)
                plt.xticks(fontsize=15)

            plt.xlabel(r'$x_{' + comp1 + '}$ [-]', fontsize=14)
            plt.ylabel(r'$S_T$ [mK$^{-1}$]', fontsize=14)
            plt.title(comps, fontsize=16)
            fig.legend()
            plt.savefig(save_path, dpi=600)
            print('Saved', save_path + '.png')
            self.update_meta(comp1+'_'+comp2)
            plt.close(fig)

            #fig = plt.figure()
            #for key in eos_list:
            #    plt.plot(0, 0, linestyle=lines_dict[key], marker=marker_dict[key],
            #             color=color_dict[key], label=key)
            #plt.scatter(0, 0, color='black', marker='x', label='Experimental')
            #plt.legend(ncol=3)
            #plt.savefig(self.plots_path + 'n_alkanes/legend.png', dpi=600)
            #print('Saved', self.plots_path + 'n_alkanes/legend.png')
            #plt.close(fig)

    def cold_gases(self):
        comps = 'C1,AR'
        T = 88

        save_path = self.outpath + '/AR_C1'
        exp_data = pd.read_excel(DATA_PATH + '/cold_gases/AR_C1_' + str(T) + 'K.xlsx')
        x_list = exp_data['c']
        soret_list = exp_data['ST']

        pred_data = pd.read_csv(self.inpath + '/AR_C1.csv')
        x_axis = pred_data['x_Ar']

        eos_list = ['VdW', 'SRK', 'PR', 'PT', 'SW', 'SPUNG', 'SAFT-VR-MIE']

        fig = plt.figure()

        if self.pred is True:
            for i, eos_key in enumerate(eos_list):
                soret = pred_data[eos_key]
                plt.plot(x_axis, soret * 1e3,
                         color=color_dict[eos_key], linestyle=lines_dict[eos_key], label=eos_key,
                         marker=marker_dict[eos_key], markevery=5)

        if self.exp is True:
            plt.scatter(x_list, soret_list, marker='x', color='black', label='Experimental')

        fig.legend()
        plt.xlabel(r'$x_{AR}$ [-]', fontsize=14)
        plt.ylabel(r'$S_T$ [mK$^{-1}$]', fontsize=14)
        plt.title('Soret coefficient of ' + comps + ' at ' + str(T) + 'K, computed with ' + self.model_name)

        plt.savefig(save_path, dpi=600)
        print('Saved', save_path + '.png')
        self.update_meta('AR_C1')
        plt.close(fig)

        #fig = plt.figure()
        #for key in eos_list:
        #    plt.plot(0, 0, linestyle=lines_dict[key], marker=marker_dict[key],
        #             color=color_dict[key], label=key)
        #plt.plot(0, 0, linestyle=lines_dict['SPUNG'], marker=marker_dict['SPUNG'],
        #         color=color_dict['SPUNG'], label='SPUNG')
        #plt.scatter(0, 0, marker='x', color='black', label='Experimental')
        #plt.legend(ncol=3)
        #plt.savefig(self.plots_path + 'cold_gases/legend.png', dpi=600)
        #plt.close(fig)

    def etoh_h2o_T(self, eos_name, ylim=None):
        if not os.path.isdir(self.outpath + '/ethanol_water'):
            os.mkdir(self.outpath + '/ethanol_water')

        save_path = self.outpath + '/ethanol_water/temp_' + eos_name

        # Get data
        exp_data = pd.read_csv(DATA_PATH + '/ethanol_water.csv', na_values='NaN')
        temp_list = [int(x) for x in exp_data.columns[1:]]
        cons_list = exp_data['c']

        pred_data = pd.read_csv(self.inpath + '/ethanol_water/temp_'+eos_name+'.csv')
        temp_ax = pred_data['Temp']
        # set up plot
        fig = plt.figure(figsize=(10, 5))
        ax = plt.subplot(111)
        ax.set_position([0.08, 0.11, 0.75, 0.77])
        cmap = cm.get_cmap('viridis')

        # Matrix containing all experimental data points
        ST_matr = np.array([[x for x in exp_data[str(T)]] for T in temp_list]).transpose()

        # Plotting
        max_c = max(cons_list)
        for ST, c in zip(ST_matr, cons_list):
            soret = pred_data[str(c)]
            plt.plot(temp_ax - 273, soret * 1e3, color=cmap(c / max_c))
            plt.scatter(temp_list, ST, color=cmap(c / max_c), label=str(c), marker='x')

        plt.xlabel(r'T [$^{\circ}$C]', fontsize=14)
        plt.ylabel(r'$s_T$ [mK$^{-1}$]', fontsize=14)

        if ylim is not None:
            plt.ylim(ylim[0], ylim[1])
            save_path += '_ylim'

        legend = plt.legend(title=r'$\omega_{EtOH}$ [–]', bbox_to_anchor=(1.01, 1.015), loc='upper left')
        plt.setp(legend.get_title(), fontsize=14)

        plt.sca(ax)
        plt.title(eos_name, fontsize=14)

        plt.savefig(save_path, dpi=600)
        print('Saved', save_path + '.png')
        self.update_meta('EtOH_H2O_T_'+eos_name)
        plt.close(fig)

    def etoh_h2o_c(self, eos_name, ylim=None):
        if not os.path.isdir(self.outpath + '/ethanol_water'):
            os.mkdir(self.outpath + '/ethanol_water')

        save_path = self.outpath + '/ethanol_water/cons_' + eos_name

        # Get data
        exp_data = pd.read_csv(DATA_PATH + '/ethanol_water.csv', na_values='NaN')
        temp_list = [int(x) for x in exp_data.columns[1:]]
        cons_list = exp_data['c']

        pred_data = pd.read_csv(self.inpath + '/ethanol_water/cons_'+eos_name+'.csv')
        w_etoh_ax = pred_data['w_EtOH']

        # Set up plot
        cmap = cm.get_cmap('cool')
        max_T = max(temp_list)

        fig = plt.figure(figsize=(10, 5))
        ax = plt.subplot(111)
        ax.set_position([0.1, 0.11, 0.75, 0.77])

        for T in temp_list:
            soret = pred_data[str(T)]
            if self.pred is True:
                plt.plot(w_etoh_ax, soret*1e3,
                         color=cmap((T) / (max_T)))

            if self.exp is True:
                plt.scatter(cons_list, exp_data[str(T)], color=cmap((T) / (max_T)), label=str(T), marker='x')

        plt.xlabel(r'$\omega_{{EtOH}}$ [–]', fontsize=14)
        plt.ylabel(r'$s_T$ [mK$^{-1}$]', fontsize=14)

        if ylim is not None:
            plt.ylim(ylim[0], ylim[1])
            save_path += '_ylim'

        legend = plt.legend(title='T [$^{\circ}$C]', bbox_to_anchor=(1.015, 1.015), loc='upper left', fontsize=14)
        plt.setp(legend.get_title(), fontsize=14)

        plt.title(eos_name)
        plt.savefig(save_path, dpi=600)
        print('Saved', save_path + '.png')
        self.update_meta('EtOH_H2O_T_'+eos_name)
        plt.close(fig)


if __name__ == "__main__":
    plotter = BenchmarkPlotter(mode='com', version=3)
    plotter.plot_298()
    plotter.n_alkanes()
    plotter.cold_gases()
    plotter.etoh_h2o_T('CPA')
    plotter.etoh_h2o_c('CPA')
    plotter.etoh_h2o_T('PC-SAFT')
    plotter.etoh_h2o_c('PC-SAFT')

