from scipy.integrate import quad, dblquad
from scipy.optimize import fsolve
import scipy.constants as const
from numpy import pi, sqrt, exp, cos, inf
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

CALLS = 0

mole_weights = np.array([39.95, 4])
sigma_1 = 3.405 * 1e-10 # m
epsilon_1 = 122.1 * const.Boltzmann
sigma_2 = 2.64 * 1e-10 # m
epsilon_2 = 10.9 * const.Boltzmann

sigma_12 = 0.5 * (sigma_1 + sigma_2)
epsilon_12 = sqrt(epsilon_1 * epsilon_2)

m1, m2 = np.array([39.948, 4]) / const.Avogadro

sigma = {1 : sigma_1, 2: sigma_2, 12 : sigma_12}
epsilon = {1 : epsilon_1, 2: epsilon_2, 12 : epsilon_12}

comps = 'AR,HE'
def mie(r, ij):
        return 4 * epsilon[ij] * ((sigma[ij] / r) ** 12 - (sigma[ij] / r) ** 6)

def dmie_dr(r, ij):
    return 4 * epsilon[ij] * (-12 * (sigma[ij] ** 12 / r ** 13) + 6 * (sigma[ij] ** 6/r ** 7))

def reduced_mie(reduced_r, ij):
    return 4 * (epsilon[ij] / const.Boltzmann) * ((1 / reduced_r) ** 12 - (1 / reduced_r) ** 6)

def reduced_dmie_dr(reduced_r, ij):
    return 4 * (epsilon[ij] / const.Boltzmann) * (-12 * (1 / reduced_r ** 13) + 6 * (1 / reduced_r ** 7))

def newton(f, dfdr, r0):
    tol = 1e-3
    step = 2 * tol
    f_val = f(r0)
    while f_val < 0:
        r0 *= 2
        f_val = f(r0)
    i = 0
    while abs(f_val) > tol:
        step = - f_val / dfdr(r0)
        r0 += step
        f_val = f(r0)
        i += 1
        if i > 10:
            break
    return r0

def chi(b, g2, ij):
    mu = (m1 * m2) / (m1 + m2)
    b = b * sigma[ij]
    drdtheta = lambda r: 1 - (mie(r, ij)) / (0.5 * mu * g2) - (b / r)**2
    d2rdtheta2 = lambda r: - (dmie_dr(r, ij) / (0.5 * mu * g2)) + 2 * (b**2 / r**3)
    integrand = lambda r: 1 / (r ** 2 * sqrt(drdtheta(r)))

    R = newton(drdtheta, d2rdtheta2, sigma[ij])

    r_list = np.logspace(np.log10(R), 0, 1000)
    fi = integrand(r_list)
    res = np.trapz(fi, r_list)

    # fig, axs = plt.subplots(2, 1)
    # plt.sca(axs[0])
    # plt.plot(r_list / sigma[ij], drdtheta(r_list))
    # plt.ylabel(r'dr/d$\theta$')
    # plt.sca(axs[1])
    # # plt.plot(r_list, drdtheta(r_list))
    # # plt.plot(r_list, d2rdtheta2(r_list))
    # # plt.show()
    # #
    # # plt.plot(r_list, drdtheta(r_list))
    # # plt.hlines(0, min(r_list), max(r_list), color='black')
    # # plt.scatter(R, drdtheta(R))
    # # print(drdtheta(r_list))
    # # print(R / sigma[ij], drdtheta(R))
    # # plt.show()
    # #
    plt.plot(r_list, integrand(r_list))
    plt.vlines(R, 0, max(integrand(r_list)))
    plt.xlabel(r'r [m]')
    plt.ylabel(r'$\frac{d\chi}{dr}$')
    plt.yscale('log')
    plt.xscale('log')
    plt.title(str(round(b / sigma[ij], 0)) + ', ' + str(g2))
    plt.show()

    #res = quad(integrand, R, 1) # integral fra 1 til inf av (1 / x^2) = 1
    print(R, res)
    return pi - 2 * b * (res + 1)

def W(T, l, r, ij, b_lim=30, g_lim=10):
    integrand = lambda b2, g2: exp(-g2) * g2**(r + 1) * (1 - cos(chi(b2 * sigma[ij], const.Boltzmann * T * g2, ij) ** l)) * b2
    b_vals = np.linspace(4e-6, 8e-6)
    b2_vals = np.linspace(0.9, 50, 100)
    g_vals = np.linspace(1e-10, 1e-9)

    # plt.plot(b2_vals[:20], mie_dev(b2_vals[:20]*sigma[ij], ij))
    # plt.xlabel(r'r [$\sigma_{12}$]')
    # plt.ylabel(r'$U_{LJ}$')
    # plt.show()
    #
    b, g = np.meshgrid(b_vals, g_vals)
    b2, g = np.meshgrid(b2_vals, g_vals)
    integrand_vals = np.zeros_like(b2)
    chi_vals = np.zeros_like(b2)
    for i in range(len(b2)):
        for j in range(len(g)):
            integrand_vals[i,j] = integrand(b2[i,j], g[i,j])
            #chi_vals[i,j] = chi(b2[i,j] * sigma[ij], const.Boltzmann * 300 * g[i,j], ij)
            #print(chi_vals[ij])
    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')
    ax.plot_wireframe(b2, g, integrand_vals)
    #ax.plot_wireframe(b2, g, chi_vals)
    ax.set_xlabel(r'b [$\sigma_{12}$]')
    ax.set_ylabel(r'$g^2$')
    #ax.set_zlabel(r'$\chi$')
    #ax.set_zlim(0,500)
    plt.show()
    return dblquad(integrand, 0, b_lim*sigma[ij], 0, g_lim)[0]

if __name__ == '__main__':
    b_lim = np.linspace(0, 50, 10)
    g_lim = np.linspace(1e-10, 15, 10)

    b_lim, g_lim = np.meshgrid(b_lim, g_lim)

    W_vals = np.empty_like(b_lim)
    chi_vals = np.empty_like(b_lim)
    for i in range(len(b_lim)):
        for j in range(len(g_lim)):
            print(i, j)
            #W_vals[i, j] = W(300, 2, 2, 12, b_lim=b_lim[i,j], g_lim=g_lim[i,j])
            chi_vals[i,j] = chi(b_lim[i,j], g_lim[i,j], 12)

    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')
    ax.plot_wireframe(b_lim, g_lim, W_vals)
    ax.set_xlabel(r'b [$\sigma_{12}$]')
    ax.set_ylabel(r'$g^2$')
    ax.set_zlabel(r'$\chi$')
    plt.show()