import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.cm as cm
from mpl_toolkits.mplot3d import Axes3D
import numpy as np
import warnings
import os
from scipy.interpolate import CloughTocher2DInterpolator as interpol

class Alpha_T0_empirical:
    def __init__(self, comps, mole_fracs = [0.5, 0.5], method='wheights'):
        '''
        Callable object, composition and components are set upon initialization
        __call__(temp): returns alpha_T0 for component 1 at given composition and temperature, approximated from experimental data

        :param comps: 'comp1,comp2'
        :param mole_fracs: mole fractions of components
        :param method: 'wheights' (default), 'ct': CloughTocher, 'plane': linear interpolation
        '''

        self.comp1, self.comp2 = comps.split(',')
        self.comps = comps

        df = pd.read_excel(os.path.dirname(__file__) + '/alpha_t0.xlsx')

        if self.comp2 in df.loc[df['Comp1'] == self.comp1]['Comp2']:
            pass
        elif self.comp1 in df.loc[df['Comp2'] == self.comp2]['Comp1']:
            self.comp1, self.comp2 = self.comp2, self.comp1
            self.flip = True
        else:
            raise ValueError('Components '+comps+' not found in database.')

        self.data = df.loc[(df['Comp1'] == self.comp1) & (df['Comp2'] == self.comp2)]

        self.c_list = np.array(self.data['c_avg'].tolist())
        self.T_list = np.array(self.data['T_avg'].tolist())
        self.alpha_list = np.array(self.data['Alpha_T_0'].tolist())

        self.cT_ranges = (min(self.c_list), max(self.c_list), min(self.T_list), max(self.T_list))
        self.num_points = len(self.c_list)

        #Calculate span of T and c for the avalible data set
        self.c_span = max(self.c_list) - min(self.c_list)
        self.T_span = max(self.T_list) - min(self.T_list)

        #To avoid deviding by zero
        if self.c_span == 0:
            self.c_span = 1
        if self.T_span == 0:
            self.T_span = 1

        self.c = mole_fracs[0]

        if method == 'wheights':
            self.get_alpha_T0 = self.get_alpha_T0_wheights
        elif method == 'plane':
            self.get_alpha_T0 = self.get_alpha_T0_plane
        elif method == 'ct':
            self.get_alpha_T0 = self.get_alpha_T0_CloughTocher
        else:
            raise ValueError("method must be either 'wheights', 'plane' or 'ct'")

    def get_alpha_T0_wheights(self, T, plot=False):
        '''
        :param comps (str) : comp1,comp2 following Thermopack convention
        :param c (float) : Volumetric concentration at which to get kinetic_gas
        :param T (float) : Temperature (K) at which to get kinetic_gas
        :param plot (bool or '2d' or '3d') : display fitted data
        :return float : kinetic_gas at the specified conditions, approximated from experimental data
                        using the custom weighting scheme
        '''
        self.plot = plot
        c = self.c * 100

        #Check if c,T is inside the range of experimental data, warn if not
        if any(np.array([c < self.cT_ranges[0], c > self.cT_ranges[1], T < self.cT_ranges[2], T > self.cT_ranges[3]])):
            warnings.warn('\nc = '+str(round(c,0))+' T = '+str(round(T,0))+' is outside range '
                            '\nc : ('+str(round(self.cT_ranges[0],0))+', '+str(round(self.cT_ranges[1],0))+'), T : '
                          + str(round(self.cT_ranges[2], 0)) + ', '+str(round(self.cT_ranges[3],0))+')')

        #Distance from required point to each experimental data point
        dist = ((self.c_list - c)/self.c_span)**2 + ((self.T_list - T)/self.T_span)**2

        #Create list of indices sorted by increasing distance from required point
        #Mask all but the first element
        dist_ind = np.array([i for i, _ in sorted(enumerate(dist), key=lambda pair: pair[1])])
        dist_ind_mask = np.array([False for i in dist_ind])
        dist_ind_mask[0] = True

        #Build list of c and T points that box in the required point
        close_c = np.array([self.c_list[dist_ind[0]]])
        close_T = np.array([self.T_list[dist_ind[0]]])
        i = 1
        while not self.check_inside(close_c, close_T, c, T) and i < self.num_points:
            new_c = self.c_list[dist_ind[i]]
            new_T = self.T_list[dist_ind[i]]
            valid_point = self.check_valid(new_c/self.c_span, new_T/self.T_span,
                                           close_c/self.c_span, close_T/self.T_span, c/self.c_span, T/self.T_span)
            while i < self.num_points - 1 and not valid_point:
                i += 1
                new_c = self.c_list[dist_ind[i]]
                new_T = self.T_list[dist_ind[i]]
                if self.check_valid(new_c, new_T, close_c, close_T, c, T):
                    valid_point = True
                    break

            if valid_point:
                dist_ind_mask[i] = True

            close_c = self.c_list[dist_ind[dist_ind_mask]]
            close_T = self.T_list[dist_ind[dist_ind_mask]]
            i += 1

        #Get corresponding kinetic_gas to the experimental points that box in the required point
        close_alpha = self.alpha_list[dist_ind[dist_ind_mask]]

        #Get distance to each of the data points
        r = dist[dist_ind[dist_ind_mask]]

        #Compute wheights for a wheigted average. Closer data points count more.
        tot = sum([np.prod(r[:i]) * np.prod(r[i + 1:]) for i in range(len(r))])
        wheights = [np.prod(r[:i]) * np.prod(r[i + 1:]) / tot for i in range(len(r))]

        #Compute wheigted average of experimental points
        fit_alpha = sum(close_alpha * wheights)

        #Compute R2, warn if high, plot if desired
        R2 = sum((close_c - c) ** 2 + (close_T - T) ** 2 + (close_alpha - fit_alpha) ** 2) / len(close_alpha)
        if plot:
            self.plot_fit(dist, close_c, close_T, close_alpha, c, T, fit_alpha, R2)

        if R2 > 500 or i == self.num_points:
            warnings.warn('\nAlpha_T0 at c = '+str(round(c,1))+' T = '+str(round(T,0))
                            +' for '+self.comps+' may be a bad approximation (R^2 = '+str(round(R2,0))+')')

        return np.array([fit_alpha, -fit_alpha])

    def get_alpha_T0_CloughTocher(self, T, plot=False):
        '''
        :param comps (str) : comp1,comp2 following Thermopack convention
        :param c (float) : Volumetric concentration at which to get kinetic_gas
        :param T (float) : Temperature (K) at which to get kinetic_gas
        :param plot (bool or '2d' or '3d') : display fitted data
        :return float : kinetic_gas at the specified conditions, approximated from experimental data
                        by CloughTocher-interpolation
        '''
        self.plot = plot
        c = self.c * 100

        # Check if c,T is inside the range of experimental data, warn if not
        if any(np.array([c < self.cT_ranges[0], c > self.cT_ranges[1], T < self.cT_ranges[2], T > self.cT_ranges[3]])):
            warnings.warn('\nc = ' + str(round(c, 0)) + ' T = ' + str(round(T, 0)) + ' is outside range '
                                                                                     '\nc : (' + str(
                round(self.cT_ranges[0], 0)) + ', ' + str(round(self.cT_ranges[1], 0)) + '), T : '
                          + str(round(self.cT_ranges[2], 0)) + ', ' + str(round(self.cT_ranges[3], 0)) + ')')

        # Compute interpolation weights
        fit_alpha = interpol(np.array([self.c_list, self.T_list]).transpose(), self.alpha_list)([c,T])

        return np.array([fit_alpha, -fit_alpha])

    def get_alpha_T0_plane(self, T, plot=False):
        '''
        :param comps (str) : comp1,comp2 following Thermopack convention
        :param c (float) : Volumetric concentration at which to get kinetic_gas
        :param T (float) : Temperature (K) at which to get kinetic_gas
        :param plot (bool or '2d' or '3d') : display fitted data
        :return float : kinetic_gas at the specified conditions, approximated from experimental data
                        by fitting a plane to the closest three data points that enclose the desired point
        '''
        self.plot = plot
        c = self.c * 100

        #Check if c,T is inside the range of experimental data, warn if not
        if any(np.array([c < self.cT_ranges[0], c > self.cT_ranges[1], T < self.cT_ranges[2], T > self.cT_ranges[3]])):
            warnings.warn('\nc = '+str(round(c,0))+' T = '+str(round(T,0))+' is outside range '
                            '\nc : ('+str(round(self.cT_ranges[0],0))+', '+str(round(self.cT_ranges[1],0))+'), T : '
                          + str(round(self.cT_ranges[2], 0)) + ', '+str(round(self.cT_ranges[3],0))+')')

        #Distance from required point to each experimental data point
        dist = ((self.c_list - c)/self.c_span)**2 + ((self.T_list - T)/self.T_span)**2

        #Create list of indices sorted by increasing distance from required point
        #Mask all but the first element
        dist_ind = np.array([i for i, _ in sorted(enumerate(dist), key=lambda pair: pair[1])])

        dist_ind_mask = np.array([False for i in dist_ind])

        n = 0
        close_c = np.array([])
        while len(close_c) < 3 and n < self.num_points:
            dist_ind_mask = np.array([False for i in dist_ind])
            dist_ind_mask[n] = True
            close_c = np.array([self.c_list[dist_ind[n]]])
            close_T = np.array([self.T_list[dist_ind[n]]])

            i = SkipCounter(0,n - 1)
            new_c = self.c_list[dist_ind[i]]
            new_T = self.T_list[dist_ind[i]]
            valid_point = self.check_valid(new_c/self.c_span, new_T/self.T_span,
                                           close_c/self.c_span, close_T/self.T_span, c/self.c_span, T/self.T_span)
            while i < self.num_points - 1 and not valid_point:
                i += 1
                new_c = self.c_list[dist_ind[i]]
                new_T = self.T_list[dist_ind[i]]
                if self.check_valid(new_c, new_T, close_c, close_T, c, T):
                    valid_point = True
                    break

            if valid_point:
                dist_ind_mask[i] = True

            close_c = self.c_list[dist_ind[dist_ind_mask]]
            close_T = self.T_list[dist_ind[dist_ind_mask]]

            while i < self.num_points - 1 and n != self.num_points:
                i += 1
                new_c = self.c_list[dist_ind[i]]
                new_T = self.T_list[dist_ind[i]]
                box_c = np.concatenate((close_c, [new_c]))
                box_T = np.concatenate((close_T, [new_T]))
                if self.check_inside(box_c, box_T, c, T):
                    dist_ind_mask[i] = True
                    break

            close_c = self.c_list[dist_ind[dist_ind_mask]]
            close_T = self.T_list[dist_ind[dist_ind_mask]]

            n += 1
        #Get corresponding kinetic_gas to the experimental points that box in the required point
        close_alpha = self.alpha_list[dist_ind[dist_ind_mask]]

        if len(close_c) == 3:
            close_points = np.array([close_c, close_T, close_alpha]).transpose()
            normal_vec = np.cross(close_points[1] - close_points[0], close_points[2] - close_points[0])
            fit_alpha = (np.dot(normal_vec,close_points[0]) - normal_vec[0]*c - normal_vec[1] * T) / normal_vec[2]
        else:
            #Get distance to each of the data points
            r = dist[dist_ind[dist_ind_mask]]

            #Compute wheights for a wheigted average. Closer data points count more.
            tot = sum([np.prod(r[:i]) * np.prod(r[i + 1:]) for i in range(len(r))])
            wheights = [np.prod(r[:i]) * np.prod(r[i + 1:]) / tot for i in range(len(r))]

            #Compute wheigted average of experimental points
            fit_alpha = sum(close_alpha * wheights)

        #Compute R2, warn if high, plot if desired
        R2 = sum((close_c - c) ** 2 + (close_T - T) ** 2 + (close_alpha - fit_alpha) ** 2) / len(close_alpha)
        if plot:
            self.plot_fit(dist, close_c, close_T, close_alpha, c, T, fit_alpha, R2)

        if R2 > 100:
            warnings.warn('\nAlpha_T0 at c = '+str(round(c,1))+' T = '+str(round(T,0))
                            +' for '+self.comps+' may be a bad fit (R^2 = '+str(round(R2,0))+')')
        return np.array([fit_alpha, -fit_alpha])

    def check_inside(self, x, y, c, T):
        '''
        :param x (ndarray) : list of x-points
        :param y (ndarray) : list of y-points
        :param c (float) : x-point
        :param T (float) : y-point

        :return bool : Is the point (c,T) inside the polygon spanned by (x1,y1),(x2,y2),...
        '''
        if len(x) < 3:
            return False

        vecs = np.array([x - c, y - T]).transpose()
        vecs = np.concatenate((vecs, np.vstack([0 for i in vecs])), axis=1)
        indices = [i for i in range(len(vecs))]
        for i in indices[1:]:
            if np.cross(vecs[0], vecs[i])[-1] > 0:
                for j in indices[1:i] + indices[i+1:]:
                    if np.cross(vecs[0], vecs[j])[-1] < 0 and np.cross(vecs[i], vecs[j])[-1] > 0:
                        return True
            else:
                for j in indices[1:i] + indices[i + 1:]:
                    if np.cross(vecs[0], vecs[j])[-1] > 0 and np.cross(vecs[i], vecs[j])[-1] < 0:
                        return True
        return False

    def check_valid(self, new_x, new_y, x_list, y_list, c, T):
        '''
        :param new_x (float) : x-point
        :param new_y (float) : y-point
        :param x_list (ndarray) : list of x-points
        :param y_list (ndarray) : list of y-points
        :param c (float) : x-value at point that is being approximated
        :param T (float) : y-value at point that is being approximated

        :return bool : If the point (new_x,new_y) will improve the approximated value at (c, T)
        '''

        vecs = np.array([c - x_list, T - y_list]).transpose()
        new_point_vecs = np.array([new_x - x_list, new_y - y_list]).transpose()
        for vec, new_point_vec in zip(vecs, new_point_vecs):
            if np.dot(new_point_vec, vec) < 0:
                return False
        return True
    
    def plot_fit(self, dist, close_c, close_T, close_alpha, c, T, fit_alpha, R2):
        '''
            Plots the fitted data, good for debugging and to view errors.
            Run __call__ with plot=True, plot='2d' or plot='3d' to activate
            NB: Does NOT display the plot. plt.show() must be called after function call to display
        '''

        self.cmap = cm.get_cmap('plasma')
        max_alpha = max(self.alpha_list)

        if self.plot == '2d' or self.plot is True:
            plt.scatter(self.data['c_avg'], self.data['T_avg'], color = [self.cmap(self.alpha_list[i]/max_alpha) for i in range(len(dist))])
            plt.scatter(close_c, close_T, color = 'green', marker='x', label = 'Interpolation points')
            plt.scatter(c,T, color = 'black', label = 'Point to approximate')
            plt.xlabel('c [vol%]')
            plt.ylabel('T [K]')
            print('plots/selected/'+str(round(c,2))+'_'+str(round(T,0)))
            plt.savefig('plots/selected/'+str(c)[:2]+'_'+str(T)[:3])
            plt.close()
            #plt.show()

        if self.plot == '3d' or self.plot is True:
            fig = plt.figure()
            ax = fig.add_subplot(111, projection='3d')

            ax.scatter(self.data['c_avg'],self.data['T_avg'],self.data['Alpha_T_0'], color = 'blue', marker='x')
            ax.scatter(close_c, close_T, close_alpha, color = 'red')
            ax.scatter(c,T, fit_alpha, color = 'black')

            ax.set_xlabel('c [vol%]')
            ax.set_ylabel('T [K]')
            ax.set_zlabel(r'$\alpha_T^\circ$ [-]')

            plt.title(r'$R^2 = $'+str(round(R2,0)))
            plt.show()

        if self.plot not in ('2d','3d',True):
            print("Argument 'plot' can be:\n"
                  "'2d': plot 2d selection of data points\n"
                  "'3d': plot 3d fit and selection of data points\n"
                  "True: plot both the above")

    def plot_points(self, ax):
        '''
        scatter experimental data points in 3d
        :param ax: A_matr matplotlib.Axes3D instance on which to plot
        '''
        ax.scatter(self.data['c_avg'], self.data['T_avg'], self.data['Alpha_T_0'], color='red', marker='x', s = 40)

    def plot_mesh(self, dim_1d = False):
        '''
        Plot the data-set selected upon initialization and the interpolation with the method selected upon initialization.
        :param dim_1d: Set to True if data is 1d
        '''
        warnings.filterwarnings('ignore')
        min_c, max_c, min_T, max_T = self.cT_ranges
        c_list_1d = np.linspace(min_c, max_c, 25) * 0.01
        T_list_1d = np.linspace(min_T, max_T, 25)

        c_list, T_list = np.meshgrid(c_list_1d, T_list_1d)
        fig = plt.figure()
        ax = fig.add_subplot(111, projection='3d')
        self.plot_points(ax)

        if dim_1d:
            alpha_vals_1d = np.array([self.get_alpha_T0(c_list_1d[i], T_list_1d[i])[0] for i in range(len(c_list_1d))])
            ax.plot(c_list_1d * 100, T_list_1d, alpha_vals_1d, color='black')

        else:
            alpha_vals = np.array([[self.get_alpha_T0(c_list[j, i], T_list[j, i])[0] for i in range(len(c_list))]
                                   for j in range(len(T_list))])

            if self.get_alpha_T0 == self.get_alpha_T0_CloughTocher:
                print(alpha_vals.shape)
                alpha_vals = alpha_vals.transpose()[0].transpose()

            ax.plot_wireframe(c_list * 100, T_list, alpha_vals, color = 'black', alpha=0.5)


        plt.title(r'Approximated $\alpha_T^\circ$ values for '+self.comps)
        ax.set_xlabel('c [%vol]')
        ax.set_ylabel('T [K]')
        ax.set_zlabel(r'$\alpha_T^\circ$ [–]')

class SkipCounter:
    #Helper-class for get_alpha_t0_plane()
    #increasing iterator that starts at 'val' and skips the value 'skip'

    def __init__(self, val, skip):
        if val == skip:
            self.val = skip + 1
        else:
            self.val = val
        self.skip = skip

    def __lADD__(self, other):
        if self.val + other == self.skip:
            self.val += other + 1
        else:
            self.val += other
        return self

    __add__ = __lADD__

    def __lt__(self, other):
        return self.val < other

    def __index__(self):
        return self.val

def test_wheights():
    '''
        Testing procedure for different wheighted averages based on distance from point to different experimental points
    '''
    R2_list = []
    min_r = []
    miss1_list = []
    miss2_list = []
    miss3_list = []

    for x_p in np.linspace(0.01, 0.1, 5):
        for y_p in np.linspace(0.01, 0.1, 5):
            x = np.array([0.01, 0.1, 0.01, 0.1])
            y = np.array([0.01, 0.1, 0.1, 0.01])

            r = (x - x_p)**2 + (y - y_p)**2
            f = np.exp(-np.sqrt(r))
            w1 = f / sum(f)
            tot = sum([np.prod(r[:i]) * np.prod(r[i+1:]) for i in range(len(r))])
            w2 = [np.prod(r[:i]) * np.prod(r[i+1:])/tot for i in range(len(r))]

            f = x ** 2 + y

            R2_list.append(sum(r)/len(r))
            min_r.append(min(r))
            miss1_list.append(sum(f*w1) - (x_p**2 + y_p))
            miss2_list.append(sum(f*w2) - (x_p**2 + y_p))
            miss3_list.append(sum(f)/len(f) - (x_p**2 + y_p))

    plt.scatter(R2_list, abs(np.array(miss1_list)), color = 'b')
    plt.scatter(R2_list, abs(np.array(miss2_list)), color = 'r')
    #plt.scatter(R2_list, abs(np.array(miss3_list)), color = 'g')
    plt.show()

    plt.scatter(min_r, abs(np.array(miss1_list)), color = 'b')
    plt.scatter(min_r, abs(np.array(miss2_list)), color = 'r')
    #plt.scatter(min_r, abs(np.array(miss3_list)), color = 'g')
    plt.show()

    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')
    ax.scatter(R2_list, min_r, abs(np.array(miss1_list)), color = 'b')
    ax.scatter(R2_list, min_r, abs(np.array(miss2_list)), color = 'r')
    #ax.scatter(R2_list, min_r, abs(np.array(miss3_list)), color = 'g')
    ax.set_xlabel('R2')
    ax.set_ylabel('r_min')

    plt.show()

class DB_Builder:
    # Build compact database of components and the data range for which they have experimental data points in alpha_t0.xlsx
    # initialize with the main database-file
    # run build_db(outfile.txt), with the desired outfile name

    def __init__(self, filename):
        self.filename = filename

    def get_cT_range(self, comp1, comp2):
        df = pd.read_excel(self.filename)
        data = df.loc[(df['Comp1'] == comp1) & (df['Comp2'] == comp2)]

        min_c = min(data['c_avg'])
        max_c = max(data['c_avg'])

        min_T = min(data['T_avg'])
        max_T = max(data['T_avg'])

        return min_c, max_c, min_T, max_T

    def build_db(self, outfile):
        df = pd.read_excel(self.filename)
        comp1 = df['Comp1']
        comp2 = df['Comp2']

        comp_tuples = set((c1, c2) for c1,c2 in zip(comp1,comp2))
        comp_dict = {}
        for c1, c2 in comp_tuples:
            if c1 in comp_dict.keys():
                comp_dict[c1]['Mix'].append(c2)
            else:
                comp_dict[c1] = {'Mix' : [c2]}

        for c1 in comp_dict.keys():
            comp_dict[c1]['min_c'] = np.zeros(len(comp_dict[c1]['Mix']))
            comp_dict[c1]['max_c'] = np.zeros(len(comp_dict[c1]['Mix']))
            comp_dict[c1]['min_T'] = np.zeros(len(comp_dict[c1]['Mix']))
            comp_dict[c1]['max_T'] = np.zeros(len(comp_dict[c1]['Mix']))

            for i, c2 in enumerate(comp_dict[c1]['Mix']):
                min_c, max_c, min_T, max_T = self.get_cT_range(c1,c2)
                comp_dict[c1]['min_c'][i] = min_c
                comp_dict[c1]['max_c'][i] = max_c
                comp_dict[c1]['min_T'][i] = min_T
                comp_dict[c1]['max_T'][i] = max_T


        with open(outfile, 'w') as file:
            file.write('Comp1, Comp2, min_c, max_c, min_T, max_T\n')
            for c1 in comp_dict.keys():
                for i, c2 in enumerate(comp_dict[c1]['Mix']):
                    file.write(c1 + ', ' + c2 + ', '+
                               str(round(comp_dict[c1]['min_c'][i],0))+', '+str(round(comp_dict[c1]['max_c'][i],0))+', '
                               +str(comp_dict[c1]['min_T'][i])+', '+str(comp_dict[c1]['max_T'][i])+'\n')
                file.write('\n')