import numba as nb
import numpy as np
import scipy.linalg as lin
import scipy.constants as const
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import matplotlib.gridspec as gs
import pandas as pd
import os
import time
from ctypes import *

fast_gas = cdll.LoadLibrary('kineticgas.cpython-38-x86_64-linux-gnu.so')
A = getattr(fast_gas, 'A')
print(A)
fast_gas.A.restype = c_double
fast_gas.A.argtypes = [c_int, c_int, c_int, c_int]
fast_gas.A_prime.restype = c_double
fast_gas.A_prime.argtypes = [c_double, c_double, c_int, c_int, c_int, c_int]
fast_gas.A_tripleprime.restype = c_double
fast_gas.A_tripleprime.argtypes = [c_int, c_int, c_int, c_int]

print(fast_gas.A(byref(1), byref(2), byref(1), byref(1)))

class Alpha_T0_analytical():

    def __init__(self, comps, eos, mole_fracs=[0.5, 0.5], N=5):
        '''
        Callable object, composition and components are set upon initialization
        __call__(temp): returns alpha_t0 at the given temperature

        :param comps (str): Comma-separated list of components, following Thermopack-convention
        :param eos (thermopack): An initialized equation of state, only used to get mole weights
        :param mole_fracs (array-like): list of mole fractions
        :param N (int > 0): Order of approximation.
                            Be aware that N > 10 can be detrimental to runtime. This should be implemented in C++ or Fortran.
        '''

        # Packing out variables in an n-component compatible way
        complist = comps.split(',')
        self.mole_weights = np.array([10, 15])

        self.m0 = np.sum(self.mole_weights)

        self.M = mole_fracs / np.sum(mole_fracs)

        self.sigmaij = np.array([[1,1.5],[1.5, 2]])
        self.sigma = np.diag(self.sigmaij)

        # This part is only binary-compatible
        self.M1, self.M2 = self.M
        self.M1, self.M2 = 0.1,0.9
        self.M = np.array([0.1, 0.9])
        self.x1, self.x2 = mole_fracs
        self.m1, self.m2 = self.mole_weights
        self.sigma1, self.sigma2 = self.sigma
        self.sigma12 = self.sigmaij[0, 1]

        # Calculate the (temperature independent) soret-coefficient at the ideal gas state
        self.T = 100  # Set T to a dummy value, because intermediate expressions depend on T, even though final result does not.
        self.N = N

        pq_range = np.arange(-N, N + 1, 1)
        t0 = time.process_time()
        A = np.empty((2 * N + 1, 2 * N + 1), float)
        for i, p in enumerate(pq_range):
            for j, q in zip([j for j in range(i, len(pq_range))], pq_range[i:]):
                A[i, j] = self.a(int(p), int(
                    q))  # a(p,q) gives NaN or inf if type(p) == np.int64 or type(q) == np.int64... dont ask why
                A[j, i] = A[i, j]
        print('Computed A_matr in ', time.process_time() - t0)
        for line in A:
            for x in line:
                print(round(x,15),' '*(10 - len(str(round(x,15)))) ,end=' ')
            print()
        delta_0 = (3 / 2) * np.sqrt(const.Boltzmann * self.T / np.sum(self.mole_weights))
        b = np.zeros(2 * N + 1)
        b[int((len(b) - 1) / 2)] = delta_0
        d = lin.solve(A, b)

        d_1, d0, d1 = d[int(((len(d) + 1) / 2) - 2): int(((len(d) + 1) / 2) + 1)]
        self.soret = - (5 / (2 * d0)) * ((self.x1 * d1 / np.sqrt(self.M1)) + (self.x2 * d_1 / np.sqrt(self.M2)))

        self.d_1, self.d0, self.d1 = d_1, d0, d1

    def get_hard_sphere_radius(self, comps):
        '''
        Get hard-sphere diameters, assumed to be equal to Mie-potential sigma parameter. Gets Mie-parameters from the file mie.xlsx.

        :param comps: Comma seperated list of components
        :return: N x N matrix of hard sphere diameters, where sigma_ij = 0.5 * (sigma_i + sigma_j),
                such that the diagonal is the radius of each component, and off-diagonals are the average diameter of
                component i and j.
        '''
        df = pd.read_excel(os.path.dirname(__file__) + '/mie.xlsx')
        sigma_i = np.array([df.loc[df['comp'] == comp]['sigma'].iloc[0] for comp in comps.split(',')])

        sigma_ij = 0.5 * np.sum(np.meshgrid(sigma_i, np.vstack(sigma_i)), axis=0)
        return sigma_ij

    def omega(self, ij, l, r):
        if ij in (1, 2):
            return self.sigma[ij - 1] ** 2 * np.sqrt(
                (np.pi * const.Boltzmann * self.T) / self.mole_weights[ij - 1]) * w(l, r)

        elif ij in (12, 21):
            return 0.5 * self.sigma12 ** 2 * np.sqrt(
                2 * np.pi * const.Boltzmann * self.T / (self.m0 * self.M1 * self.M2)) * w(l, r)
        else:
            raise ValueError('(' + str(ij) + ', ' + str(l) + ', ' + str(r) + ') are non-valid arguments for omega.')

    def H_ij(self, p, q, ij):
        M1, M2 = self.M1, self.M2

        if ij == 21:  # swap indices
            M1, M2 = M2, M1

        def inner(r, l):
            return A(p, q, r, l) * self.omega(12, l, r)

        def sum_r(l):
            return summation(l, p + q + 2 - l, inner, args=l)

        val = 8 * M2 ** (p + 0.5) * M1 ** (q + 0.5) * summation(1, min(p, q) + 1, sum_r)

        return val

    def H_i(self, p, q, ij):

        if ij == 21:  # swap indices
            self.M1, self.M2 = self.M2, self.M1

        val = 0
        for l in range(1, min(p, q) + 1):
            for r in range(l, p + q + 2 - l):
                val += A_prime(p, q, r, l, self.M1, self.M2) * self.omega(12, l, r)
        val *= 8

        if ij == 21:  # swap back
            self.M1, self.M2 = self.M2, self.M1

        return val

    def H_simple(self, p, q, i):
        def inner(r, l):
            return A_tripleprime(p, q, r, l) * self.omega(i, l, r)

        def sum_r(l):
            return summation(l, p + q + 2 - l, inner, args=l)

        return 8 * summation(2, min(p, q) + 1, sum_r)

    def a(self, p, q, id='hey'):
        if p == 0 or q == 0:
            if p > 0:
                return self.M1 ** 0.5 * self.x1 * self.x2 * self.H_i(p, q, 12)
            elif p < 0:
                return - self.M2 ** 0.5 * self.x1 * self.x2 * self.H_i(-p, q, 21)
            elif q > 0:
                return self.M1 ** 0.5 * self.x1 * self.x2 * self.H_i(p, q, 12)
            elif q < 0:
                return - self.M2 ** 0.5 * self.x1 * self.x2 * self.H_i(p, -q, 21)
            else:  # p == 0 and q == 0
                return self.M1 * self.x1 * self.x2 * self.H_i(p, q, 12)

        elif p > 0 and q > 0:
            return self.x1 ** 2 * (self.H_simple(p, q, 1)) + self.x1 * self.x2 * self.H_i(p, q, 12)

        elif p > 0 and q < 0:
            return self.x1 * self.x2 * self.H_ij(p, -q, 12)

        elif p < 0 and q > 0:
            return self.x1 * self.x2 * self.H_ij(-p, q, 21)

        else:  # p < 0 and q < 0
            return self.x2 ** 2 * self.H_simple(-p, -q, 2) + self.x1 * self.x2 * self.H_i(-p, -q, 21)

    def get_alpha_T0(self, T):
        # pq_range = np.arange(-self.N, self.N + 1, 1)
        # A_matr = np.array([[self.a(p, q) for p in pq_range] for q in pq_range])
        # delta_0 = (3 / 2) * np.sqrt(const.Boltzmann * T / np.sum(self.mole_weights))
        # b = np.zeros(2 * self.N + 1)
        # b[int((len(b) - 1) / 2)] = delta_0
        # d = lin.solve(A_matr, b)
        #
        # d_1, d0, d1 = d[int(((len(d) + 1) / 2) - 2): int(((len(d) + 1) / 2) + 1)]
        # soret = - (5 / (2 * d0)) * ((self.x1 * d1 / np.sqrt(self.M1)) + (self.x2 * d_1 / np.sqrt(self.M2)))

        return self.soret * T

    def plot_test(self, N, compare=True, save=False):
        '''
        Plot soret coefficient from kinetic gas theory for N'th order approximation with different m2/m1 and sigma2/sigma1 ratios
        :param N (int): Order of approximation
        :param compare (bool): Use same limits as Tipton, Tompson and Loyalka
        :param save (str): Filename to save figure as (defaults to False)
        '''
        cmap = cm.get_cmap('viridis')

        x_list = np.linspace(0.001, 0.999, 50)

        s_list = np.empty((10, 50), float)

        sigma_list = [2, 1, 0.5]

        if compare is True:
            m_list = np.array([1, 2, 3, 4, 5, 8, 10])
            fig = plt.figure(figsize=(10, 5))
            grid = gs.GridSpec(ncols=3, nrows=1, figure=fig, wspace=0.3)

            axs = [None for i in range(3)]
            for i in range(3):
                axs[i] = fig.add_subplot(grid[i])

            lim_list = [(0.05, -0.11), (0.05, -0.15), (0, -0.2)]

            plt.sca(axs[0])
            plt.hlines(0, 0, 1, colors='black', alpha=0.5)

        else:
            m_list = np.arange(1, 11, 1)
            fig = plt.figure(figsize=(10, 5))
            grid = gs.GridSpec(ncols=3, nrows=1, figure=fig, wspace=0)

            axs = [None for i in range(3)]
            axs[0] = fig.add_subplot(grid[0])

            for i in (1, 2):
                axs[i] = fig.add_subplot(grid[i], sharey=axs[0])
                plt.setp(axs[i].get_yticklabels(), visible=False)

        for i in range(3):
            print('Making plot', i + 1)
            plt.sca(axs[i])
            self.sigma1 = sigma_list[i]
            self.sigma = np.array([self.sigma1, self.sigma2])

            for j, m in enumerate(m_list):
                print('m =', m)
                self.mole_weights = np.array([1, m])
                self.M = self.mole_weights / np.sum(self.mole_weights)
                self.M1, self.M2 = self.M
                for k, x in enumerate(x_list):
                    self.mole_fracs = np.array([x, 1 - x])
                    self.x1, self.x2 = self.mole_fracs
                    s_list[j, k] = self.get_alpha_T0(10) / 10

            for m, s in zip(m_list, s_list):
                plt.plot(x_list, s, label=m, color=cmap(m / 10))

            if compare is True:
                plt.ylim(lim_list[i][1], lim_list[i][0])

            plt.xlim(0, 1)
            plt.title(r'$\frac{\sigma_2}{\sigma_1} = $' + str(round(self.sigma2 / self.sigma1, 1)))

        plt.legend(title=r'$\frac{m_2}{m_1}$', bbox_to_anchor=[1, 1.025])
        plt.suptitle(r'Calculated $k_T$ values for some theoretical mixtures')

        if save:
            plt.savefig(save, dpi=600)
        plt.show()
