'''
Author: Vegard G. Jervell
Date: December 2020
Purpose: Implementation of the Chapman-Enskog solutions to the Boltzmann equations for a binary system
        as proposed by Tompson, Tipton and Loyalka, in the paper
        "Chapman–Enskog solutions to arbitrary order in Sonine polynomials III:
        Diffusion, thermal diffusion, and thermal conductivity in a binary, rigid-sphere, gas mixture"
        doi : https://doi.org/10.1016/j.euromechflu.2008.12.002

Requires : numpy, scipy, matplotlib, pandas
'''

import numpy as np
import scipy.linalg as lin
import scipy.constants as const
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import matplotlib.gridspec as gs
import pandas as pd
import os

#fac = np.math.factorial

def fac(n):
    if n in (0,1):
        return 1
    else:
        val = 1
        for i in range(2,n+1):
            val *= i
        return val

def summation(start, stop, func, args=None):
    if args is not None:
        return sum(func(i, args) for i in range(start, stop + 1))
    else:
        return sum(func(i) for i in range(start, stop + 1))

def delta(i, j):
    if i == j:
        return 1
    else:
        return 0

def w(l, r):
    return 0.25 * (2 - ((1 / (l + 1)) * (1 + (-1) ** l))) * np.math.factorial(r + 1)

class KineticGas():

    def __init__(self, comps, eos, mole_fracs=[0.5,0.5], N=1):
        '''
        :param comps (str): Comma-separated list of components, following Thermopack-convention
        :param eos (thermopack): An initialized equation of state, only used to get mole weights
        :param mole_fracs (array-like): list of mole fractions
        :param N (int > 0): Order of approximation.
                            Be aware that N > 10 can be detrimental to runtime. This should be implemented in C++ or Fortran.
        '''

        #Packing out variables in an n-component compatible way
        complist = comps.split(',')
        self.mole_weights = np.array([eos.compmoleweight(eos.getcompindex(comp)) for comp in complist])

        self.m0 = np.sum(self.mole_weights)
        self.M = self.mole_weights/self.m0

        self.sigmaij = self.get_hard_sphere_radius(comps)
        self.sigma = np.diag(self.sigmaij)

        #This part is only binary-compatible
        self.M1, self.M2 = self.M
        self.x1, self.x2 = mole_fracs
        self.m1, self.m2 = self.mole_weights
        self.sigma1, self.sigma2 = self.sigma
        self.sigma12 = self.sigmaij[0, 1]

        #Calculate the (temperature independent) soret-coefficient at the ideal gas state
        self.T = 100 #Set T to a dummy value, because intermediate expressions depend on T, even though final result does not.
        self.N = N

        pq_range = np.arange(-N, N + 1, 1)
        self.A_matr = np.empty((2 * N + 1, 2 * N + 1), float)
        for i, p in enumerate(pq_range):
            for j, q in zip([j for j in range(i, len(pq_range))], pq_range[i:]):
                self.A_matr[i, j] = self.a(int(p), int(q)) # a(p,q) gives NaN or inf if type(p) == np.int64 or type(q) == np.int64... dont ask why
                self.A_matr[j, i] = self.A_matr[i, j]

        delta_0 = (3 / 2) * np.sqrt(const.Boltzmann * self.T / np.sum(self.mole_weights))
        b = np.zeros(2 * N + 1)
        b[int((len(b) - 1) / 2)] = delta_0
        d = lin.solve(self.A_matr, b)

        d_1, d0, d1 = d[int(((len(d) + 1) / 2) - 2): int(((len(d) + 1) / 2) + 1)]
        self.soret = - (5 / (2 * d0)) * ((self.x1 * d1 / np.sqrt(self.M1)) + (self.x2 * d_1 / np.sqrt(self.M2)))

        self.d_1, self.d0, self.d1 = d_1, d0, d1

    def interdiffusion(self, T):
        delta_0 = (3 / 2) * np.sqrt(const.Boltzmann * self.T / np.sum(self.mole_weights))
        b = np.zeros(len(self.A_matr))
        b[int((len(b) - 1) / 2)] = delta_0
        d = lin.solve(self.A_matr, b)
        d_1, d0, d1 = d[int(((len(d) + 1) / 2) - 2): int(((len(d) + 1) / 2) + 1)]

        return 0.5 * self.x1 * self.x2 * np.sqrt(2 * const.Boltzmann * T / self.m0) * d0

    def thermal_diffusion(self,T):
        delta_0 = (3 / 2) * np.sqrt(const.Boltzmann * self.T / np.sum(self.mole_weights))
        b = np.zeros(len(self.A_matr))
        b[int((len(b) - 1) / 2)] = delta_0
        d = lin.solve(self.A_matr, b)
        d_1, d0, d1 = d[int(((len(d) + 1) / 2) - 2): int(((len(d) + 1) / 2) + 1)]

        return - (5/4) * self.x1 * self.x2 * np.sqrt(2 * const.Boltzmann * T / self.m0) * \
               ((self.x1 * d1 / np.sqrt(self.M1)) + (self.x2 * d_1/np.sqrt(self.M2)))

    def alpha_T0(self, T):
        delta_0 = (3 / 2) * np.sqrt(const.Boltzmann * self.T / np.sum(self.mole_weights))
        b = np.zeros(len(self.A_matr))
        b[int((len(b) - 1) / 2)] = delta_0
        d = lin.solve(self.A_matr, b)
        d_1, d0, d1 = d[int(((len(d) + 1) / 2) - 2): int(((len(d) + 1) / 2) + 1)]

        alpha_T0 = - (5 / (2 * d0)) * ((self.x1 * d1 / np.sqrt(self.M1)) + (self.x2 * d_1 / np.sqrt(self.M2)))

        return np.array([alpha_T0, -alpha_T0])

    def get_hard_sphere_radius(self, comps):
        '''
        Get hard-sphere diameters, assumed to be equal to Mie-potential sigma parameter. Gets Mie-parameters from the file mie_dev.xlsx.

        :param comps: Comma seperated list of components
        :return: N x N matrix of hard sphere diameters, where sigma_ij = 0.5 * (sigma_i + sigma_j),
                such that the diagonal is the radius of each component, and off-diagonals are the average diameter of
                component i and j.
        '''
        df = pd.read_excel(os.path.dirname(__file__)+'/mie_dev.xlsx')
        sigma_i = np.array([df.loc[df['comp'] == comp]['sigma'].iloc[0] for comp in comps.split(',')])

        sigma_ij = 0.5 * np.sum(np.meshgrid(sigma_i, np.vstack(sigma_i)), axis=0)
        return sigma_ij

    def omega(self, ij, l, r):
        if ij in (1, 2):
            return self.sigma[ij - 1] ** 2 * np.sqrt((np.pi * const.Boltzmann * self.T) / self.mole_weights[ij - 1]) * w(l, r)

        elif ij in (12, 21):
            return 0.5 * self.sigma12 ** 2 * np.sqrt(2 * np.pi * const.Boltzmann * self.T / (self.m0 * self.M1 * self.M2)) * w(l, r)
        else:
            raise ValueError('(' + str(ij) + ', ' + str(l) + ', ' + str(r) + ') are non-valid arguments for omega.')

    def A(self, p, q, r, l):
        def inner(i):
            return ((8 ** i * fac(p + q - 2 * i) * (-1) ** l * (-1) ** (r + i) * fac(r + 1) * fac(
                2 * (p + q + 2 - i)) * 4 ** r) /
                    (fac(p - i) * fac(q - i) * fac(l) * fac(i + 1 - l) * fac(r - i) * fac(p + q + 1 - i - r) * fac(
                        2 * r + 2)
                     * fac(p + q + 2 - i) * 4 ** (p + q + 1))) * ((i + 1 - l) * (p + q + 1 - i - r) - l * (r - i))

        return summation(l - 1, min(p, q, r, p + q + 1 - r), inner)

    def A_prime(self, p, q, r, l):
        F = (self.M1 ** 2 + self.M2 ** 2) / (2 * self.M1 * self.M2)
        G = (self.M1 - self.M2) / self.M2

        def inner(w, args):
            i, k = args
            return ((8 ** i * fac(p + q - 2 * i - w) * (-1) ** (r + i) * fac(r + 1) * fac(
                2 * (p + q + 2 - i - w)) * 2 ** (2 * r) * F ** (i - k) * G ** w) /
                    (fac(p - i - w) * fac(q - i - w) * fac(r - i) * fac(p + q + 1 - i - r - w) * fac(2 * r + 2) * fac(
                        p + q + 2 - i - w)
                     * 4 ** (p + q + 1) * fac(k) * fac(i - k) * fac(w))) * (
                               2 ** (2 * w - 1) * self.M1 ** i * self.M2 ** (p + q - i - w)) * 2 * (
                           self.M1 * (p + q + 1 - i - r - w) * delta(k, l) - self.M2 * (r - i) * delta(k, l - 1))

        def sum_w(k, i):
            return summation(0, min(p, q, p + q + 1 - r) - i, inner, args=(i, k))

        def sum_k(i):
            return summation(l - 1, min(l, i), sum_w, args=i)

        return summation(l - 1, min(p, q, r, p + q + 1 - r), sum_k)

    def A_tripleprime(self, p, q, r, l):
        if l % 2 != 0:
            return 0

        def inner(i):
            return ((8 ** i * fac(p + q - (2 * i)) * 2 * (-1) ** (r + i) * fac(r + 1) * fac(
                2 * (p + q + 2 - i)) * 2 ** (2 * r)) /
                    (fac(p - i) * fac(q - i) * fac(l) * fac(i + 1 - l) * fac(r - i) * fac(p + q + 1 - i - r) * fac(
                        2 * r + 2)
                     * fac(p + q + 2 - i) * 4 ** (p + q + 1))) * (((i + 1 - l) * (p + q + 1 - i - r)) - l * (r - i))

        return 0.5 ** (p + q + 1) * summation(l - 1, min(p, q, r, p + q + 1 - r), inner)

    def H_ij(self, p, q, ij):
        M1, M2 = self.M1, self.M2

        if ij == 21:  # swap indices
            M1, M2 = M2, M1

        def inner(r, l):
            return self.A(p, q, r, l) * self.omega(12, l, r)

        def sum_r(l):
            return summation(l, p + q + 2 - l, inner, args=l)

        val = 8 * M2 ** (p + 0.5) * M1 ** (q + 0.5) * summation(1, min(p, q) + 1, sum_r)

        return val

    def H_i(self, p, q, ij):
        
        if ij == 21:  # swap indices
            self.M1, self.M2 = self.M2, self.M1

        def inner(r, l):
            return self.A_prime(p, q, r, l) * self.omega(12, l, r)

        def sum_r(l):
            return summation(l, p + q + 2 - l, inner, args=l)

        val = 8 * summation(1, min(p, q) + 1, sum_r)

        if ij == 21:  # swap back
            self.M1, self.M2 = self.M2, self.M1

        return val

    def H_simple(self, p, q, i):
        def inner(r, l):
            return self.A_tripleprime(p, q, r, l) * self.omega(i, l, r)

        def sum_r(l):
            return summation(l, p + q + 2 - l, inner, args=l)

        return 8 * summation(2, min(p, q) + 1, sum_r)

    def a(self, p, q):
        if p == 0 or q == 0:
            if p > 0:
                return self.M1 ** 0.5 * self.x1 * self.x2 * self.H_i(p, q, 12)
            elif p < 0:
                return - self.M2 ** 0.5 * self.x1 * self.x2 * self.H_i(-p, q, 21)
            elif q > 0:
                return self.M1 ** 0.5 * self.x1 * self.x2 * self.H_i(p, q, 12)
            elif q < 0:
                return - self.M2 ** 0.5 * self.x1 * self.x2 * self.H_i(p, -q, 21)
            else:  # p == 0 and q == 0
                return self.M1 * self.x1 * self.x2 * self.H_i(p, q, 12)

        elif p > 0 and q > 0:
            return self.x1 ** 2 * (self.H_simple(p, q, 1)) + self.x1 * self.x2 * self.H_i(p, q, 12)

        elif p > 0 and q < 0:
            return self.x1 * self.x2 * self.H_ij(p, -q, 12)

        elif p < 0 and q > 0:
            return self.x1 * self.x2 * self.H_ij(-p, q, 21)

        else:  # p < 0 and q < 0
            return self.x2 ** 2 * self.H_simple(-p, -q, 2) + self.x1 * self.x2 * self.H_i(-p, -q, 21)

    def get_alpha_T0(self, T):
        return np.array([self.soret * T, - self.soret * T])

    def plot_test(self, N, compare=True, save=False):
        '''
        Plot soret coefficient from kinetic_x gas theory for N'th order approximation with different m2/m1 and sigma2/sigma1 ratios
        :param N (int): Order of approximation
        :param compare (bool): Use same limits as Tipton, Tompson and Loyalka
        :param save (str): Filename to save figure as (defaults to False)
        '''
        cmap = cm.get_cmap('viridis')

        x_list = np.linspace(0.001, 0.999, 50)

        s_list = np.zeros((10, 50))

        sigma_list = [2, 1, 0.5]
        self.sigma2 = 1

        if compare is True:
            m_list = np.array([1, 2, 3, 4, 5, 8, 10])
            fig = plt.figure(figsize=(10, 5))
            grid = gs.GridSpec(ncols=4, nrows=1, figure=fig, wspace=0.5, width_ratios=[1,1,1,0])

            axs = [None for i in range(3)]
            for i in range(3):
                axs[i] = fig.add_subplot(grid[i])

            lim_list = [(0.05, -0.11), (0.05, -0.15), (0, -0.2)]

            plt.sca(axs[0])
            plt.hlines(0, 0, 1, colors='black', alpha=0.5)

        else:
            m_list = np.arange(0.35, 0.55, 0.05)
            fig = plt.figure(figsize = (10,5))
            grid = gs.GridSpec(ncols=3, nrows=1, figure=fig, wspace=0)

            axs = [None for i in range(3)]
            axs[0] = fig.add_subplot(grid[0])

            for i in (1,2):
                axs[i] = fig.add_subplot(grid[i], sharey=axs[0])
                plt.setp(axs[i].get_yticklabels(), visible=False)

        for i in range(3):
            print('Making plot', i+1)
            plt.sca(axs[i])
            self.sigma1 = sigma_list[i]
            self.sigma = np.array([self.sigma1, self.sigma2])
            self.sigma12 = 0.5 * (self.sigma1 + self.sigma2)

            for s_line, m in enumerate(m_list):
                print('m =', m)
                self.mole_weights = np.array([1, m])

                self.m0 = sum(self.mole_weights)
                self.m1, self.m2 = self.mole_weights
                self.M = self.mole_weights / np.sum(self.mole_weights)
                self.M1, self.M2 = self.M

                for k, x in enumerate(x_list):
                    self.mole_fracs = np.array([x, 1 - x])
                    self.x1, self.x2 = self.mole_fracs

                    pq_range = np.arange(-N, N + 1, 1)
                    self.A_matr = np.empty((2 * N + 1, 2 * N + 1), float)
                    for ind, p in enumerate(pq_range):
                        for j, q in zip([j for j in range(ind, len(pq_range))], pq_range[ind:]):
                            self.A_matr[ind, j] = self.a(int(p), int(q))
                            self.A_matr[j, ind] = self.A_matr[ind, j]

                    delta_0 = (3 / 2) * np.sqrt(const.Boltzmann * self.T / self.m0)
                    b = np.zeros(2 * N + 1)
                    b[int((len(b) - 1) / 2)] = delta_0
                    d = lin.solve(self.A_matr, b)

                    d_1, d0, d1 = d[int(((len(d) + 1) / 2) - 2) : int(((len(d) + 1) / 2) + 1)]
                    self.soret = - (5 / (2 * d0)) * ((self.x1 * d1 / np.sqrt(self.M1)) + (self.x2 * d_1 / np.sqrt(self.M2)))
                    s_list[s_line, k] = self.soret

            for m, s in zip(m_list, s_list):
                plt.plot(x_list, s, label = round(m,1), color = cmap(m / max(m_list)))

            if compare is True:
                print(lim_list[i])
                plt.ylim(lim_list[i][1],lim_list[i][0])

            plt.xlim(0,1)
            plt.title(r'$\frac{\sigma_2}{\sigma_1} = $'+str(round(self.sigma2/self.sigma1, 2)), fontsize=14)

        legend = plt.legend(title=r'$\frac{m_2}{m_1}$', bbox_to_anchor=[1, 1.025])
        plt.setp(legend.get_title(), fontsize=14)
        #plt.suptitle(r'Calculated $k_T$ values for some theoretical mixtures')

        if save:
            plt.savefig(save, dpi=600)
        #plt.show()
