import numpy as np
import scipy.linalg as lin
import scipy.constants as const
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import matplotlib.gridspec as gs
from numba import jit, njit

fac = np.math.factorial

T = 1000
n = 5

sigma1 = 1
sigma2 = 2
sigma = np.array([sigma1, sigma2])

m1 = 3
m2 = 4
m0 = m1 + m2
mole_wheights = np.array([m1, m2])
M = mole_wheights / m0
M1, M2 = M

mole_fracs = np.array([0.2, 0.8])
x1, x2 = mole_fracs

def summation(start, stop, func, args=None):
    if args is not None:
        return sum(func(i, args) for i in range(start, stop + 1))
    else:
        return sum(func(i) for i in range(start, stop + 1))

def delta(i, j):
    if i == j:
        return 1
    else:
        return 0

def w(l, r):
    return 0.25 * (2 - ((1/(l+1)) * (1 + (-1)**l))) * np.math.factorial(r + 1)

def omega(ij, l, r):
    if ij in (1, 2):
        return sigma[ij - 1]**2 * np.sqrt((np.pi * const.Boltzmann * T) / mole_wheights[ij - 1]) * w(l, r)

    elif ij in (12, 21):
        M1, M2 = M
        m0 = np.sum(mole_wheights)
        sigma12 = 0.5 * sum(sigma)
        return 0.5 * sigma12**2 * np.sqrt(2 * np.pi * const.Boltzmann * T / (m0 * M1 * M2)) * w(l, r)
    else:
        raise ValueError('('+str(ij)+', '+str(l)+', '+str(r)+') are non-valid arguments for omega.')

def A(p, q, r, l):
    def inner(i):
        return ((8**i * fac(p + q - 2 * i) * (-1)**l * (-1)**(r + i) * fac(r + 1) * fac(2 * (p + q + 2 - i)) * 4**r) /
                (fac(p - i) * fac(q - i) * fac(l) * fac(i + 1 - l) * fac(r - i) * fac(p + q + 1 - i - r) * fac(2 * r + 2)
                 * fac(p + q + 2 - i) * 4**(p + q + 1))) * ((i + 1 - l) * (p + q + 1 - i - r) - l * (r - i))

    return summation(l - 1, min(p, q, r, p + q + 1 - r), inner)

def A_prime(p, q, r, l):
    F = (M1**2 + M2**2)/(2 * M1 * M2)
    G = (M1 - M2)/M2

    def inner(w, args):
        i, k = args
        return ((8**i * fac(p + q - 2 * i - w) * (-1)**(r + i) * fac(r + 1) * fac(2 * (p + q + 2 - i - w)) * 2**(2 * r) * F**(i - k) * G**w)/
                (fac(p - i - w) * fac(q - i - w) * fac(r - i) * fac(p + q + 1 - i - r - w) * fac(2 * r + 2) * fac(p + q + 2 - i - w)
                * 4**(p + q + 1) * fac(k) * fac(i - k) * fac(w))) * (2**(2 * w - 1) * M1**i * M2**(p + q - i - w)) * 2 * (
                M1 * (p + q + 1 - i - r - w) * delta(k,l) - M2 * (r - i) * delta(k, l - 1) )

    def sum_w(k, i):
        return summation(0, min(p, q, p + q + 1 - r) - i, inner, args=(i, k))

    def sum_k(i):
        return summation(l - 1, min(l, i), sum_w, args=i)

    return summation(l - 1, min(p, q, r, p + q + 1 - r), sum_k)

def A_tripleprime(p, q, r, l):
    if l % 2 != 0:
        return 0

    def inner(i):
        return ((8**i * fac(p + q - (2 * i)) * 2 * (-1)**(r + i) * fac(r + 1) * fac(2 * (p + q + 2 - i)) * 2**(2 * r))/
                (fac(p - i) * fac(q - i) * fac(l) * fac(i + 1 - l) * fac(r - i) * fac(p + q + 1 - i - r) * fac(2 * r + 2)
                * fac(p + q + 2 - i) * 4**(p + q + 1))) * (((i + 1 - l) * (p + q + 1 - i - r)) - l * (r - i))

    return 0.5**(p + q + 1) * summation(l - 1, min(p, q, r, p + q + 1 - r), inner)

def H_ij(p, q, ij):
    global M1, M2

    if ij == 21: #swap indices
        M1, M2 = M2, M1

    def inner(r, l):
        return A(p, q, r, l) * omega(12, l, r)

    def sum_r(l):
        return summation(l, p + q + 2 - l, inner, args=l)

    val =  8 * M2**(p + 0.5) * M1**(q + 0.5) * summation(1, min(p, q) + 1, sum_r)

    if ij == 21: #swap back
        M1, M2 = M2, M1

    return val

def H_i(p ,q, ij):
    global M1, M2

    if ij == 21: #swap indices
        M1, M2 = M2, M1

    def inner(r, l):
        return A_prime(p ,q, r, l) * omega(12, l, r)

    def sum_r(l):
        return summation(l, p + q + 2 - l, inner, args=l)

    val = 8 * summation(1, min(p, q) + 1, sum_r)

    if ij == 21: #swap back
        M1, M2 = M2, M1

    return val

def H_simple(p, q, i):
    def inner(r, l):
        return A_tripleprime(p, q, r, l) * omega(i, l, r)

    def sum_r(l):
        return summation(l, p + q + 2 - l, inner, args=l)

    return 8 * summation(2, min(p,q) + 1, sum_r)

def a(p, q):
    if p == 0 or q == 0:
        if p > 0:
            return M1**0.5 * x1 * x2 * H_i(p, q, 12)
        elif p < 0:
            return - M2**0.5 * x1 * x2 * H_i(-p, q, 21)
        elif q > 0:
            return M1**0.5 * x1 * x2 * H_i(p, q, 12)
        elif q < 0:
            return - M2**0.5 * x1 * x2 * H_i(p, -q, 21)
        else: #p == 0 and q == 0
            return M1 * x1 * x2 * H_i(p, q, 12)

    elif p > 0 and q > 0:
        return x1**2 * (H_simple(p, q, 1)) + x1 * x2 * H_i(p, q, 12)

    elif p > 0 and q < 0:
        return x1 * x2 * H_ij(p, -q, 12)

    elif p < 0 and q > 0:
        return x1 * x2 * H_ij(-p, q, 21)

    else: #p < 0 and q < 0
        return x2**2 * H_simple(-p, -q, 2) + x1 * x2 * H_i(-p, -q, 21)

def get_KT(N):
    global x1, x2, mole_fracs, M, mole_wheights, sigma, sigma1, sigma2, m1, m2, m0, T, n

    pq_range = np.arange(-N , N+1, 1)

    A = np.array([[a(p, q) for p in pq_range] for q in pq_range])

    delta_0 = 3/(2 * n) * np.sqrt(const.Boltzmann * T / np.sum(mole_wheights))
    print(delta_0)
    b = np.zeros(2 * N + 1)
    b[int((len(b) - 1)/2)] = delta_0
    d = lin.solve(A, b)

    d_1, d0, d1 = d[int(((len(d) + 1) / 2) - 2) : int(((len(d) + 1) / 2) + 1)]
    soret =  - (5/(2 * d0)) * ( (x1 * d1/np.sqrt(M1)) + (x2 * d_1 / np.sqrt(M2)) )
    return soret

def plot_test(N, compare=True):
    global sigma, mole_wheights, mole_fracs, M, sigma1, x1, x2, M1, M2
    cmap = cm.get_cmap('viridis')

    x_list = np.linspace(0.001, 0.999, 50)

    s_list = np.empty((10, 50), float)

    sigma_list = [2, 1, 0.5]

    if compare is True:
        m_list = np.array([1, 2, 3, 4, 5, 8, 10])
        fig = plt.figure(figsize=(10, 5))
        grid = gs.GridSpec(ncols=3, nrows=1, figure=fig, wspace=0.3)

        axs = [None for i in range(3)]
        for i in range(3):
            axs[i] = fig.add_subplot(grid[i])

        lim_list = [(0.05, -0.11), (0.05, -0.15), (0, -0.2)]

        plt.sca(axs[0])
        plt.hlines(0, 0, 1, colors='black', alpha=0.5)

    else:
        m_list = np.arange(1, 11, 1)
        fig = plt.figure(figsize = (10,5))
        grid = gs.GridSpec(ncols=3, nrows=1, figure=fig, wspace=0)

        axs = [None for i in range(3)]
        axs[0] = fig.add_subplot(grid[0])

        for i in (1,2):
            axs[i] = fig.add_subplot(grid[i], sharey=axs[0])
            plt.setp(axs[i].get_yticklabels(), visible=False)


    for i in range(3):
        print('Making plot', i+1)
        plt.sca(axs[i])
        sigma1 = sigma_list[i]
        sigma = np.array([sigma1, sigma2])

        for j, m in enumerate(m_list):
            print('m =', m)
            mole_wheights = np.array([1, m])
            M = mole_wheights / np.sum(mole_wheights)
            M1, M2 = M
            for k, x in enumerate(x_list):
                mole_fracs = np.array([x, 1 - x])
                x1, x2 = mole_fracs
                s_list[j, k] = get_KT(N)

        for m, s in zip(m_list, s_list):
            plt.plot(x_list, s, label = m, color = cmap(m / 10))

        if compare is True:
            plt.ylim(lim_list[i][1],lim_list[i][0])

        plt.xlim(0,1)
        plt.title(r'$\frac{\sigma_2}{\sigma_1} = $'+str(round(sigma2/sigma1, 1)))

    plt.legend(title=r'$\frac{m_2}{m_1}$', bbox_to_anchor=[1, 1.025])
    plt.suptitle(r'Calculated $k_T$ values for some theoretical mixtures')
    #plt.savefig('alpha_t0_analytical', dpi=600)
    plt.show()

def plot_test_temp(N):
    global T, x1, x2, mole_fracs

    cmap = cm.get_cmap('viridis')

    T_list = [10,100,1000,10000]
    x_list = np.linspace(0.001, 0.999, 50)
    s_list = np.empty((len(T_list), len(x_list)), float)

    for i, temp in enumerate(T_list):
        print('Plotting for T =',temp)
        T = temp
        for j, x in enumerate(x_list):
            x1 = x
            x2 = 1 - x
            mole_fracs = np.array([x1, x2])
            s_list[i, j] = get_KT(N)

    for s, T in zip(s_list, T_list):
        plt.plot(x_list, s, color = cmap(np.log(T)/np.log(max(T_list))), label=T)

    plt.legend()
    plt.show()

'''
def a_qq(q):
    if q == 0:
        return np.prod(mole_fracs) * (8 * np.prod(M) * omega(12, 1, 1))

    elif q in (-1, 1):
        if q == -1:
            i, j = 1, 0
        else:
            i, j = 0, 1

        return mole_fracs[i]**2 * 4 * omega(i+1, 2, 2) +\
               np.prod(mole_fracs) * (10 * (5 * M[j]**3 + 6 * M[j] * M[i]**2)* omega(12, 1, 1)
                                      - 40 * M[j]**3 * omega(12, 1, 2) + 8 * M[j]**3 * omega(12, 1, 3)
                                      + 16 * M[j]**2 * M[i] * omega(12, 2, 2))
    elif q in (-2, 2):
        if q == -2:
            i, j = 1, 0
        else:
            i, j = 0, 1

        return mole_fracs[i]**2 * ((77/4) * omega(i+1, 2, 2) - 7*omega(i+1, 2, 3) + omega(i+1, 2, 4))\
            + np.prod(mole_fracs) * ((35/8) * (35 * M[j]**5 + 168 * M[j]**3 * M[i]**2 + 40 * M[j]*M[i]**4) * omega(12, 1, 1)
                                     - 49 * (5 * M[j]**5 + 12 * M[j]**3 * M[i]**2) * omega(12, 1, 2)
                                     + (133 * M[j]**5 + 108 * M[j]**3 * M[i]**2) * omega(12, 1, 3) - 28 * M[j]**5 * omega(12, 1, 4)
                                     + 2 * M[j]**5 * omega(12, 1, 5) + 28 * (7 * M[j]**4 * M[i] + 4 * M[j]**2 * M[i]**3) * omega(12, 2, 2)
                                     - 112 * M[j]**4 * M[i] * omega(12, 2, 3) + 16 * M[j]**4 * M[i] * omega(12, 2, 4)
                                     + 16 * M[j]**3 * M[i]**2 * omega(12, 3, 3))

    else:
        raise ValueError('q must be  in (-2, -1, 0, 1, 2)')

def a_pq(p, q):
    M1, M2 = M
    if p == q:
        return a_qq(q)

    elif (p, q) in [(1, -1), (-1, 1)]:
        return np.prod(mole_fracs) * (M1 * M2)**(3/2) * \
                                     (-110 * omega(12, 1, 1)
                                      + 40 * omega(12, 1, 2)
                                      - 8  * omega(12, 1, 3)
                                      + 16 * omega(12, 2, 2))

    elif (p, q) in [(2, -2), (-2, 2)]:
        return np.prod(mole_fracs) * (M1 * M2)**(5/2) * (-(8505/8) * omega(12, 1, 1) + 833 * omega(12, 1, 2) - 241 * omega(12, 1, 3)
                                      + 28 * omega(12, 1, 4) - 2 * omega(12, 1, 5) + 308 * omega(12, 2, 2) - 112 * omega(12, 2, 3)
                                            + 16 * omega(12, 2, 4) - 16 * omega(12, 3, 3))

    elif (p, q) in [(0, -1), (-1, 0), (1, 0), (0, 1)]:
        if (p, q) in [(0,-1), (-1, 0)]:
            i, j = 0, 1
            c = 1
        else: # (p, q) in [(1, 0), (0, 1)]:
            i, j = 1, 0
            c = -1

        return np.prod(mole_fracs) * c * (-20 * M[i]**2 * np.sqrt(M[j]) * omega(12, 1, 1)
                                      + 8 * M[i]**2 * np.sqrt(M[j]) * omega(12, 1, 2))

    elif (p, q) in [(-2, -1), (-1, -2), (1, 2), (2, 1)]:
        if p in (-2, -1) and q in (-2, -1):
            i, j = 1, 0
        else: # p in (1, 2) and q in (1, 2):
            i, j = 0, 1

        return mole_fracs[i] ** 2 * (7 * omega(i+1, 2, 2) - 2 * omega(i+1, 2, 3))\
                + np.prod(mole_fracs) * ((35/2) * (5 * M[j]**4 + 12 * M[j]**2 * M[i]**2) * omega(12, 1, 1)
                                         -21 * (5 * M[j]**4 + 4 * M[j]**2 * M[i]**2) * omega(12, 1, 2) + 38 * M[j]**4 * omega(12, 1, 3)
                                         - 4 * M[j]**4 * omega(12, 1, 4) + 56 * M[j]**3 * M[i] * omega(12, 2, 2)
                                         - 16 * M[j]**3 * M[i] * omega(12, 2, 3))

    elif p in (-2, 0, 2) and q in (-2, 0, 2):
        if p in (-2, 0) and q in (-2, 0):
            i, j = 1, 0
            c = 1
        else: # p in (0, 2) and q in (0, 2)
            i, j = 0, 1
            c = -1

        return np.prod(mole_fracs) * M[j]**3 * M[i]**0.5 * c * (-35 * omega(12, 1, 1) + 28 * omega(12, 1, 2) - 4 * omega(12, 1, 3) )

    elif (p, q) in [(1, -2), (-2, 1), (2, -1), (-1, 2)]:
        if p in (-2, 1) and q in (-2, 1):
            i, j = 1, 0
        else: # p in (2, -1) and q in (2, -1)
            i, j = 0, 1

        return np.prod(mole_fracs) * M[j]**(5/2) * M[i]**(3/2) * \
               (-(595/2) * omega(12, 1, 1) + 189 * omega(12, 1, 2) - 38 * omega(12, 1, 3)
                + 4 * omega(12, 1, 4) + 56 * omega(12, 2, 2) - 16 * omega(12, 2, 3))

    else:
        raise ValueError('(' + str(p) + ', ' + str(q) + ') ar Invalid values for p and q.')

def test_vs_other():

    N = 2
    pq_range = np.arange(-N, N + 1, 1)

    A1 = np.array([[a_pq(p, q) for p in pq_range] for q in pq_range])
    A2 = np.array([[a(p, q) for p in pq_range] for q in pq_range])

    matr_print(A1 - A2)

def test():
    print('H_simple :')
    print(H_simple(1, 1, 1))
    print(4 * omega(1, 2, 2))
    print()

    print('H_i :')
    print(H_i(1, 1, 12))
    print((10 * (5 * M2**3 + 6 * M2 * M1**2)* omega(12, 1, 1)
                                          - 40 * M2**3 * omega(12, 1, 2) + 8 * M2**3 * omega(12, 1, 3)
                                          + 16 * M2**2 * M1 * omega(12, 2, 2)))


    print()
    print('H_ij :')
    print(H_ij(1, 1, 12))
    print((M1 * M2)**(3/2) * (-110 * omega(12, 1, 1) + 40 * omega(12, 1, 2) - 8  * omega(12, 1, 3) + 16 * omega(12, 2, 2)))

def matr_print(matr):
    for line in matr:
        for x in line:
            print(round(x,20), end=' '*(25 - len(str(round(x,20)))))
        print()
'''
