from models.kempers01 import Kempers01
import pandas as pd
import numpy as np
from numpy import sqrt
import matplotlib.pyplot as plt
from scipy.constants import Boltzmann, Avogadro
from scipy.optimize import fsolve
import file_handling as fh
import os
from datetime import datetime

df = pd.read_csv('data/mie_simulation.csv')
mie_df = pd.read_excel('models/mie.xlsx')


df_m = df.loc[df['m1'] != 1.0]
df_s = df.loc[df['sigma1'] != 1.00]
df_e = df.loc[df['epsilon1'] != 1.00]

ROOT_PATH = os.path.dirname(os.path.abspath(__file__))
DATA_PATH = ROOT_PATH + '/data/mie_data'

class MiePlotter:
    def __init__(self, mode='cov', eos_name='SAFT-VR-MIE', version=None):
        verify = input(
            'Initialising MiePlotter with version : ' + str(version) + '. "Y" to continiue.')
        if verify != 'Y':
            exit(0)

        self.eos_name = eos_name
        self.mode = mode

        # Make output dir for this benchmark run
        if version is None:
            i = 0
            while os.path.isdir(ROOT_PATH + '/plots/mie/V' + str(i)):
                i += 1

            self.outdir = ROOT_PATH + '/plots/mie/V' + str(i)
            self.outpath = self.outdir + '/' + mode
            self.indir = ROOT_PATH + '/output/mie/V' + str(i)
            self.inpath = self.indir + '/' + mode
            if not os.path.isdir(self.inpath):
                raise FileNotFoundError('No data matching this plot run! (V' + str(i) + ')')

            os.mkdir(self.outdir)
            os.mkdir(self.outpath)

            # Writing meta file to output dir
            with open(self.outdir + '/meta.txt', 'w') as ofile:
                tags = ['Date']
                vals = [datetime.now().strftime("%d/%m/%Y %H:%M:%S")]
                for tag, val in zip(tags, vals):
                    ofile.write(tag + ' : ' + str(val) + '\n')
                with open(self.indir + '/meta.txt', 'r') as ifile:
                    ofile.write('\nMeta file in input dir :\n')
                    for line in ifile.readlines():
                        ofile.write(line)

                ofile.write('\nContains plots for :\n')

        else:
            self.outdir = ROOT_PATH + '/plots/mie/V' + str(version)
            self.outpath = self.outdir + '/' + mode
            self.indir = ROOT_PATH + '/output/mie/V' + str(version)
            self.inpath = self.indir + '/' + mode
            if not os.path.isdir(self.inpath):
                raise FileNotFoundError('No data matching this plot run! (V' + str(version) + ')')

            if os.path.isdir(self.outpath):
                print('This BenchmarkPlotter instance may overwrite files in', self.outpath)
                verify = input("'Y' to verify")
                if verify != 'Y':
                    exit(0)
            else:
                os.mkdir(self.outpath)

            # Writing meta file to output dir
            with open(self.outdir + '/meta.txt', 'a') as ofile:
                ofile.write('\n#########################\n')
                tags = ['Date']
                vals = [datetime.now().strftime("%d/%m/%Y %H:%M:%S")]
                for tag, val in zip(tags, vals):
                    ofile.write(tag + ' : ' + str(val) + '\n')
                with open(self.indir + '/meta.txt', 'r') as ifile:
                    ofile.write('\nMeta file in input dir :\n')
                    for line in ifile.readlines():
                        ofile.write(line)

                ofile.write('\nContains plots for :\n')

    def update_meta(self, filename):
        with open(self.outpath + '/meta.txt', 'a') as file:
            file.write(filename + '\n')

    def plot_sigma(self):

        df = pd.read_csv(self.inpath + '/' + 'sigma'+'_'+self.eos_name+'.csv')

        s_list = df['s1/s2']

        fig, axs = plt.subplots(2, 1, figsize=(10, 8), sharex='all')
        plt.sca(axs[0])
        plt.plot(s_list, np.array(df['p_reported'].tolist()) * 1e2, label='reported')
        plt.plot(s_list, [p / 1e5 for p in df['p_calc']], color='black', linestyle='--', label='calculated')
        plt.ylabel('p [bar]')  # 'Reported p [bar]')
        plt.yscale('log')
        plt.legend()

        plt.sca(axs[1])
        plt.plot(s_list, df['vm_reported'])
        plt.plot(s_list, df['vm_calc'], color='black', linestyle='--')
        plt.xlabel(r'$\frac{\sigma_1}{\sigma_2}$')
        plt.ylabel(r'$v_m$ [m$^3$/mol]')

        plt.suptitle(r'Varying $\frac{\sigma_1}{\sigma_2}$')
        plt.tight_layout()
        plt.savefig(self.outpath+'/pressure_and_volume_vs_sigma')
        self.update_meta('pressure_and_volume_vs_sigma')

        plt.clf()

        fig, axs = plt.subplots(1, 2, figsize=(10, 8))
        plt.sca(axs[0])
        plt.plot(df['p_calc'], np.array(df_s['p'].tolist()) * 1e2)
        plt.plot(df['p_calc'], [p /1e5 for p in df['p_calc']], color='black', linestyle='--')
        plt.xlabel(r'Calc. p [bar]')
        plt.ylabel('Reported p [bar]')
        plt.yscale('log')
        plt.xscale('log')

        plt.sca(axs[1])
        plt.plot(df['vm_calc'], df['vm_reported'])
        plt.plot(df['vm_calc'], df['vm_calc'], color='black', linestyle='--')
        plt.xlabel(r'Calc. $v_m$ [m$^3$/mol]')
        plt.ylabel(r'Reported $v_m$ [m$^3$/mol]')

        plt.suptitle(r'Varying $\frac{\sigma_1}{\sigma_2}$')
        plt.savefig(self.outpath+'/pressure_and_volume_sigma_compare')
        self.update_meta('pressure_and_volume_sigma_compare')

        print('Saved p and v plots')
        plt.clf()

        plt.plot(s_list, [ST * 1e3 for ST in df['ST_p_reported']], label=self.mode + r', $p_{rep.}$', color='r')
        plt.plot(s_list, [ST * 1e3 for ST in df['ST_p_calc']], label=self.mode + r', $p_{calc.}$', color='b')
        plt.scatter(df_s['sigma1'], df_s['ST'], marker='x', color='black')
        plt.xlabel(r'$\frac{\sigma_1}{\sigma_2}$')
        plt.ylabel(r'$S_T$ [mK$^{-1}$]')
        plt.legend()
        plt.ylim(-10, 1)
        plt.savefig(self.outpath + '/Sigma_mie')
        plt.clf()
        self.update_meta('Sigma_mie')

        print('Saved sigma plots to :', self.outpath)

    def plot_epsilon(self):

        df = pd.read_csv(self.inpath+'/epsilon_'+self.eos_name+'.csv')

        e_list = df['e1/e2']


        fig, axs = plt.subplots(3, 1, figsize=(10, 8), sharex='all')
        plt.sca(axs[0])
        plt.plot(e_list, np.array(df['p_reported'].tolist()) * 1e2, label='reported p')
        plt.plot(e_list, [p / 1e5 for p in df['p_calc']], color='black', linestyle='--', label='computed p')
        plt.xlabel(r'$\frac{\epsilon_1}{\epsilon_2}$')  # 'Calc. p [bar]')
        plt.ylabel('p [bar]')  # 'Reported p [bar]')
        plt.legend()

        plt.sca(axs[1])
        plt.plot(e_list, df['vm_reported'])
        plt.plot(e_list, df['vm_calc'], color='black', linestyle='--')
        plt.xlabel(r'$\frac{\epsilon_1}{\epsilon_2}$')  # plt.xlabel(r'Calc. $v_m$ [m$^3$/mol]')
        plt.ylabel(r'$v_m$ [m$^3$/mol]')  # plt.ylabel(r'Reported $v_m$ [m$^3$/mol]')

        plt.sca(axs[2])
        plt.plot(e_list, df['T_reported'])
        plt.plot(e_list, df['T_calc'], color='black', linestyle='--')
        plt.xlabel(r'$\frac{\epsilon_1}{\epsilon_2}$')  # plt.xlabel(r'Calc. $v_m$ [m$^3$/mol]')
        plt.ylabel(r'T [K]')  # plt.ylabel(r'Reported $v_m$ [m$^3$/mol]')

        plt.suptitle(r'Varying $\frac{\epsilon_1}{\epsilon_2}$')

        plt.savefig(self.outpath+'/pressure_and_volume_vs_eps')
        self.update_meta('pressure_and_volume_vs_eps')
        plt.clf()

        fig, axs = plt.subplots(2, 1, figsize=(8, 5))
        ax = axs[0]
        twn = ax.twiny()
        p1, = ax.plot(df['e1/e2'], df['p_reported'] * 1e2, label='rep.', color='b')
        twn.plot(df['T_reported'], df['p_reported'] * 1e2, linestyle='--', color='b')

        p2, = ax.plot(df['e1/e2'], df['p_calc'] / 1e5, label='calc.', color='r')

        plt.legend(handles=[p1, p2])

        ax.set_xticks([])
        twn.set_xlabel(r'T [K] (rep.)')
        twn.set_xticks(df['T_reported'])
        twn.set_xticklabels([str(round(T, 0))[:-2] for T in df['T_reported']])

        plt.sca(ax)
        plt.ylabel('p [bar]')

        ax = axs[1]
        twn = ax.twiny()
        p3, = ax.plot(df['e1/e2'], df['vm_reported'], label='rep.', color='b')
        twn.plot(df['T_reported'], df['vm_reported'], linestyle='--',
                 color='b')

        p4, = ax.plot(df['e1/e2'], df['vm_calc'], label='calc.', color='r')

        plt.legend(handles=[p3, p4])

        ax.set_xlabel(r'$\epsilon_1 / \epsilon_2$')
        twn.set_xticks([])

        plt.sca(ax)
        plt.ylabel(r'$v_m$ [m$^3$/mol]')

        plt.savefig(self.outpath+'/pT-epsilon')
        self.update_meta('pT-epsilon')
        plt.clf()

        fig, axs = plt.subplots(1, 3, figsize=(10, 8))
        plt.sca(axs[0])
        plt.plot(df['p_calc'] / 1e5, np.array(df['p_reported'].tolist()) * 1e2)
        plt.plot(df['p_calc'] / 1e5, df['p_calc'] / 1e5, color='black', linestyle='--')
        plt.xlabel(r'Calc. p [bar]')
        plt.ylabel('Reported p [bar]')

        plt.sca(axs[1])
        plt.plot(df['vm_calc'], df['vm_reported'])
        plt.plot(df['vm_calc'], df['vm_calc'], color='black', linestyle='--')
        plt.xlabel(r'Calc. $v_m$ [m$^3$/mol]')
        plt.ylabel(r'Reported $v_m$ [m$^3$/mol]')

        plt.sca(axs[2])
        plt.plot(df['T_calc'], df['T_reported'])
        plt.plot(df['T_calc'], df['T_calc'], color='black', linestyle='--')
        plt.xlabel(r'Calc. T [K]')
        plt.ylabel(r'Reported T[K]')

        plt.suptitle(r'Varying $\frac{\epsilon_1}{\epsilon_2}$')

        plt.savefig(self.outpath + '/pressure_and_volume_eps_compare')
        self.update_meta('pressure_and_volume_eps_compare')

        print('Saved p and v plots')
        plt.clf()

        plt.plot(e_list, df['ST_p_reported'] * 1e3, label=self.mode + r', $p_{rep.}$', color='r')
        plt.plot(e_list, df['ST_p_calc'] * 1e3, label=self.mode + r', $p_{calc.}$', color='b')
        plt.scatter(df_e['epsilon1'], df_e['ST'], marker='x', color='black')
        plt.xlabel(r'$\frac{\epsilon_1}{\epsilon_2}$')
        plt.ylabel(r'$S_T$ [mK$^{-1}$]')
        plt.legend()
        plt.ylim(-25, 50)
        plt.savefig(self.outpath + '/epsilon_mie_test')
        self.update_meta('epsilon_mie_test')
        plt.clf()

        print('Saved epsilon plots to : ', self.outpath)

    def plot_mass(self):

        df = pd.read_csv(self.inpath+'/mass_'+self.eos_name+'.csv')
        fig, ax = plt.subplots()

        p1, = ax.plot(df['m1/m2'], df['ST_p_reported']*1e3, color='b', label=self.mode + 'p reported')
        p2, = ax.plot(df['m1/m2'], df['Kinetic']*1e3, color='r', label='Kin.')
        p3, = ax.plot(df['m1/m2'], df['ST_p_calc'], color='g', label=self.mode + 'p calc.')
        ax.scatter(df['m1/m2'], df_m['ST'])

        plt.xlabel(r'$\frac{m_1}{m_2}$')
        plt.legend(handles=[p1, p2, p3])
        ax.set_ylabel(r'$S_T$ [mK$^{-1}$]')
        plt.savefig(self.outpath + '/mass_mie_test')
        plt.clf()

        print('Saved mass_mie_test to : ', self.outpath + '/mass_mie_test')

    def deviation(self):

        in_df = pd.read_csv(self.inpath+'/deviation_'+self.eos_name+'.csv')

        colors = np.empty_like(df['m1'], dtype=str)
        for i in range(len(df)):
            e = df['epsilon1'][i]
            s = df['sigma1'][i]
            m = df['m1'][i]
            if abs(e - 1) > 0.001:
                colors[i] = 'b'
            elif abs(s - 1) > 0.001:
                colors[i] = 'r'
            elif abs(m - 1) > 0.001:
                colors[i] = 'g'
        markers = ['x', 'v']

        fig, ax = plt.subplots(1, 3, figsize=(10,5))
        plt.sca(ax[0])

        plt.scatter(df['p'] * 1e2,  in_df['ST_reported'] - in_df['ST_p_reported'] * 1e3, marker='x', label='p rep', color=colors)
        plt.scatter(in_df['p_calc'] * 1e-5, in_df['ST_reported'] - in_df['ST_p_calc'] * 1e3, marker='v', label='p calc.', color=colors)
        plt.xlabel('p [bar]')
        plt.ylabel(r'$S_{T,sim} - S_{T,pred}$ [mK$^{-1}$]')
        plt.ylim(-20, 50)
        plt.grid()

        plt.sca(ax[1])
        labels = ['p rep', 'p calc']
        for i, l in enumerate(labels):
            plt.scatter([1], [1], marker=markers[i], color='black', label=l)
        plt.sca(ax[2])
        cols = ['r', 'g', 'b']
        labels = [r'$\sigma$', 'm', r'$\epsilon$']
        for c, l in zip(cols, labels):
            plt.scatter([1], [1], color=c, label=l)

        box0 = ax[0].get_position()
        box2 = ax[2].get_position()

        ax[0].set_position([box0.x0, box0.y0, box2.x0 + 0.4 * box2.width, box0.height])
        box0 = ax[0].get_position()

        ax[1].set_position([box0.x0 + box0.width, box0.y0 + box0.height + 0.015, 0, 0])
        ax[1].set_axis_off()
        ax[2].set_position([box0.x0 + box0.width, box0.y0 + box0.height - 0.1, 0, 0])
        ax[2].set_axis_off()

        ax[1].legend(bbox_to_anchor = (1,1), loc = 'upper left')
        ax[2].legend(bbox_to_anchor = (1,1), loc = 'upper left')

        plt.savefig(self.outpath+'/deviation_abs')
        self.update_meta('deviation_abs')
        plt.clf()

        fig, ax = plt.subplots(1, 3, figsize=(10,5))
        plt.sca(ax[0])

        plt.scatter(df['p'] * 1e2 - in_df['p_calc'] * 1e-5, (in_df['ST_reported'] - in_df['ST_p_reported'] * 1e3) / np.abs(in_df['ST_reported']), marker='x', color=colors)
        plt.scatter(df['p'] * 1e2 - in_df['p_calc'] * 1e-5, (in_df['ST_reported'] - in_df['ST_p_calc'] * 1e3) / np.abs(in_df['ST_reported']), marker='v', color=colors)
        plt.grid()
        plt.ylim(-2, 2)
        plt.ylabel(r'$(S_{T,sim} - S_{T,pred}) / |S_{T,sim}|$ [-]')
        plt.xlabel(r'$p_{sim} - p_{pred}$ [bar]')

        plt.sca(ax[1])
        labels = ['p rep', 'p calc']
        for i, l in enumerate(labels):
            plt.scatter([1], [1], marker=markers[i], color='black', label=l)
        plt.sca(ax[2])
        cols = ['r', 'g', 'b']
        labels = [r'$\sigma$', 'm', r'$\epsilon$']
        for c, l in zip(cols, labels):
            plt.scatter([1], [1], color=c, label=l)

        box0 = ax[0].get_position()
        box2 = ax[2].get_position()

        ax[0].set_position([box0.x0, box0.y0, box2.x0 + 0.4 * box2.width, box0.height])
        box0 = ax[0].get_position()

        ax[1].set_position([box0.x0 + box0.width, box0.y0 + box0.height + 0.015, 0, 0])
        ax[1].set_axis_off()
        ax[2].set_position([box0.x0 + box0.width, box0.y0 + box0.height - 0.1, 0, 0])
        ax[2].set_axis_off()

        ax[1].legend(bbox_to_anchor = (1,1), loc = 'upper left')
        ax[2].legend(bbox_to_anchor = (1,1), loc = 'upper left')

        plt.savefig(self.outpath + '/deviation_rel')
        self.update_meta('deviation_rel')

        print('Saved deviation plots to : ', self.outpath)

if __name__ == '__main__':
    plotter = MiePlotter(eos_name='SAFT-VR-MIE', mode='com', version=6)
    plotter.plot_sigma()
    plotter.plot_epsilon()
    plotter.plot_mass()
    plotter.deviation()