from kempers_V2.numeric_minimizer import NumericSoret
from kempers_V2.kempers import Kempers_HS, Kempers89
from pykingas import KineticGas
import numpy as np
from scipy.optimize import root
from pyctp import saftvrmie
import matplotlib.pyplot as plt
from matplotlib.cm import get_cmap, ScalarMappable
from matplotlib.colors import Normalize, LogNorm
#from mpl_toolkits.mplot3d import axes3d
plt.style.use('default')
from integrator import eos_differentials as ed
from global_params import FLTEPS


comps = 'AR,C1'
comps = 'AR,HE'
kin = KineticGas(comps)
T_list = np.linspace(250, 500)
Vm = 0.3
x1 = 0.5
x = np.array([x1, 1 - x1])
D12_hs = np.empty_like(T_list)
for i, T in enumerate(T_list):
    D12_hs[i] = kin.interdiffusion(T, Vm, x, BH=True, N=1)

plt.plot(T_list, D12_hs * 1e4)
plt.show()

exit(0)

comp1, comp2 = comps.split(',')
eos = saftvrmie.saftvrmie()
eos.init(comps)
kin = KineticGas(comps, BH=True)
kempers = Kempers_HS(comps)
kempers89 = Kempers89(comps)

T = 300
rho = 10
x1 = 0.3
x = np.array([x1, 1 - x1])
nT = 1e3
n = nT * x
V = nT / rho
V_bulb = V / 2

dT = 1e-2
p, = eos.pressure_tv(T, 2 * V_bulb, n)

num = NumericSoret(comps, dT=dT, nT=nT)
num.set_bulb_values(T, rho)
V_bulb = num.V_A

def mid(lst):
    dlst = np.diff(lst)
    return lst[:-1] + dlst, dlst

def set_HS(eos, val=True):
    eos.model_control_chain(not val)
    eos.model_control_a3(not val)
    eos.model_control_a2(not val)
    eos.model_control_a1(not val)
    return eos

def eq_pres_condition(dn, dT):
    nA = (n / 2) - dn
    nB = (n / 2) + dn
    pA, = eos.pressure_tv(T - 0.5 * dT, V_bulb, nA)
    pB, = eos.pressure_tv(T + 0.5 * dT, V_bulb, nB)
    return pA - pB

def eq_pres_T_condition(dT, dn):
    nA = (n / 2) - dn
    nB = (n / 2) + dn
    pA, = eos.pressure_tv(T - 0.5 * dT, V_bulb, nA)
    pB, = eos.pressure_tv(T + 0.5 * dT, V_bulb, nB)
    return pA - pB

def check_dp_kin():
    ST_kin = kin.alpha_T0(T, 1 / rho, x) / T
    res = root(get_dn_rootfun, x0=[0, 0], args=(ST_kin))
    dn = res.x

    res_T = root(eq_pres_T_condition, x0=[0], args=(dn))
    dT = res_T.x[0]
    print('dn from ST_kin :', res)
    print('dp :', eq_pres_condition(dn, dT))
    print('dT : ', dT)

def compare_kin_kemp_num():
    T_list = np.linspace(250, 450, 30)
    ST_num = np.empty_like(T_list)
    ST_kin = np.empty_like(T_list)
    ST_kemp = np.empty_like(T_list)
    for i, T in enumerate(T_list):
        print(i / (len(T_list) - 1))
        ST_num[i] = num.get_Soret(T, rho, x)[0]
        ST_kin[i] = kin.alpha_T0(T, 1/rho, x)[0] / T
        ST_kemp[i] = kempers.get_Soret(T, rho, x)[0]

    plt.plot(T_list, ST_kin * 1e3, label='Kinetic')
    plt.plot(T_list, ST_num * 1e3, label='Numeric')
    plt.plot(T_list, ST_kemp * 1e3, label='Kempers')
    plt.legend()
    plt.show()

def get_drho(ST, dT, dx=None):
    # Checked
    if dx is None:
        dx = get_dx(ST, dT)
    return sum(ed.drhodx_Tp(eos, T, p, x, 1) * dx) + ed.drhodT_px(eos, T, p, n, 1) * dT

def get_dn(ST, dT):
    # Checked
    # dn = nB - nA
    drho = get_drho(ST, dT)
    dnT = get_dnT(ST, dT)
    dx = get_dx(ST, dT)
    return x * dnT + (nT / 2) * dx

def get_Helmholtz_ST(ST, dT):
    dn = get_dn(ST, dT) # Equal pressure implied
    nA, nB, xA, xB, TA, TB = get_n_x_and_T(dn, dT)
    helmholtz_A, = eos.helmholtz_tv(TA, V_bulb, nA)
    helmholtz_B, = eos.helmholtz_tv(TB, V_bulb, nB)
    return helmholtz_A + helmholtz_B

def get_Helmholtz_dn(dn, dT):
    nA, nB, xA, xB, TA, TB = get_n_x_and_T(dn, dT)
    helmholtz_A, = eos.helmholtz_tv(TA, V_bulb, nA)
    helmholtz_B, = eos.helmholtz_tv(TB, V_bulb, nB)
    return helmholtz_A + helmholtz_B

def get_dn2(dn1, dT):
    n1A = (n[0] - dn1) / 2
    n1B = (n[0] + dn1) / 2
    TA = T - dT / 2
    TB = T + dT / 2
    def p_constraint(dn2):
        n2A = (n[1] - dn2[0]) / 2
        n2B = (n[1] + dn2[0]) / 2
        pA, = eos.pressure_tv(TA, V_bulb, [n1A, n2A])
        pB, = eos.pressure_tv(TB, V_bulb, [n1B, n2B])
        return pA - pB

    res = root(p_constraint, x0=[0])
    if res.success is False:
        print('get_dn2 did not converge for dn1 =', dn1)

    return res.x[0]

def get_dn2_diff(dn1, dT):
    global T, V, n
    p, dpdT, dpdn = eos.pressure_tv(T, V, n, dpdt=True, dpdn=True)
    return - (dpdn[0] * dn1 + dpdT * dT) / dpdn[1]

def test_get_dn2():
    # Result:
    # dn2_diff is correct, but error at dT > 0 is substantial.
    # Small errors in dn2 propagate to substantial errors in dp
    # When computing Delta n2, use get_dn2()
    # When computing grad(n2), use get_dn2_diff()
    dn1_list = np.linspace(-n[0], n[0])
    dn2_list = np.empty_like(dn1_list)
    dp_list = np.empty_like(dn1_list)
    dT_list = [0.01, 0.05, 0.1, 0.5, 1]
    T_norm = LogNorm(vmin=min(dT_list), vmax=max(dT_list))
    T_cmap = get_cmap('cool')
    fig, axs = plt.subplots(2, 1, sharex='all')
    for dT in dT_list:
        for i, dn1 in enumerate(dn1_list):
            dn2_list[i] = get_dn2(dn1, dT) - get_dn2_diff(dn1, dT)
            dn = np.array([dn1_list[i], get_dn2(dn1, dT)])
            dp_list[i] = get_dp(dn, dT)
        axs[0].plot(dn1_list, dp_list, color=T_cmap(T_norm(dT)))
        axs[1].plot(dn1_list, dn2_list, color=T_cmap(T_norm(dT)))
        #axs[1].plot(dn1_list, -dn1_list - (sum(n) / (2 * T)) * dT, linestyle='--', color=T_cmap(T_norm(dT)))

    axs[0].set_ylabel(r'$\Delta p$ [Pa]')
    plt.legend()
    axs[1].set_ylabel(r'$\Delta n_2$ [mol]')
    axs[1].set_xlabel(r'$\Delta n_1$ [mol]')
    plt.show()

def get_dnT(ST, dT):
    # Exact, Checked
    drho = get_drho(ST, dT)
    return V_bulb * drho

def get_dp(dn, dT):
    nA = (n - dn)/ 2
    nB = (n + dn)/ 2
    T_A = T - 0.5 * dT
    T_B = T + 0.5 * dT
    pA, = eos.pressure_tv(T_A, V_bulb, nA)
    pB, = eos.pressure_tv(T_B, V_bulb, nB)
    return pA - pB

def get_dx(ST, dT):
    # Checked
    return - x * (1 - x) * ST * dT

def get_dp_2(drho, dT, ST):
    dpdrho = ed.dpdrho_Tx(eos, T, V, n)
    dx = get_dx(ST, dT)
    _, dpdT, dpdn = eos.pressure_tv(T, V, n, dpdt=True, dpdn=True)
    #print('dpdT :', dpdT * dT, nT * sum(dpdn * dx))
    return dpdrho * drho + dpdT * dT + nT * sum(dpdn * dx)

def get_n_x_and_T(dn, dT):
    nA = (n - dn) / 2
    nB = (n + dn) / 2
    xA = nA / sum(nA)
    xB = nB / sum(nB)
    dx = xB - xA
    TA = T - dT / 2
    TB = T + dT / 2
    return (nA, nB, xA, xB, TA, TB)

def get_num_dx(dn):
    nA = (n / 2) - dn
    nB = (n / 2) + dn
    xA = nA / sum(nA)
    xB = nB / sum(nB)
    dx = xB - xA
    return dx

def test_dn():
    global T, dT, nT, rho, x, n, num, p
    T_list = np.linspace(250, 450)
    drho = np.empty_like(T_list)
    num_drho = np.empty_like(T_list)
    dn = np.empty_like(T_list)
    num_dn = np.empty_like(T_list)
    dp = np.empty_like(T_list)
    num_dp = np.empty_like(T_list)

    dT_list = [50, 25, 10, 5, 1, 0.5, 0.1]
    norm = LogNorm(vmin=min(dT_list), vmax=max(dT_list))
    cmap = get_cmap('cool')

    for dTi in range(len(dT_list)):
        dT = dT_list[dTi]
        num.dT = dT
        for i in range(len(T_list)):
            T = T_list[i]
            p, = eos.pressure_tv(T, nT / rho, n)
            ST = num.get_Soret(T, rho, x)
            dn[i] = get_dn(ST, num.dT)[0]
            num_dn[i] = num.get_dn(T, rho, x)[0]
            #dp[i] = get_dp(dn, dT)
            #num_dp[i] = get_dp(num_dn, dT)

        plt.plot(T_list, (dn - num_dn) / num_dn , color=cmap(norm(dT)), label=dT)
        #plt.plot(T_list, dp, color=cmap(norm(dT)), label=dT)
        #plt.plot(T_list, num_dp, color=cmap(norm(dT)), linestyle='--')
    plt.ylabel(r'$\frac{\Delta n_1 - \Delta_{num} n_1}{\Delta_{num} n_1}$ [-]')
    plt.xlabel(r'$T$ [K]')
    plt.legend(title=r'$\Delta T$ [K]')
    plt.show()

def kempers89_condition(ST, dT):
    dn = get_dn(ST, dT)
    nA, nB, xA, xB, TA, TB = get_n_x_and_T(dn, dT)
    mu_A, = eos.chemical_potential_tv(TA, V_bulb, nA)
    mu_B, = eos.chemical_potential_tv(TB, V_bulb, nB)

    pA, = eos.pressure_tv(TA, V_bulb, nA)
    pB, = eos.pressure_tv(TB, V_bulb, nB)
    p = 0.5 * (pA + pB)  # To minimize the effects of numerical error

    dp = pB - pA

    _, vA = eos.specific_volume(TA, p, xA, 1, dvdn=True)
    _, vB = eos.specific_volume(TB, p, xB, 1, dvdn=True)

    eq_set = (mu_A[:-1] / TA) - (mu_B[:-1] / TB) - ((vA[:-1] + vB[:-1]) / (vA[-1] + vB[-1])) * (
                (mu_A[-1] / TA) - (mu_B[-1] / TB))
    return eq_set, dp

def plot_dp():
    global num, T, dT, rho, x
    dT_list = np.linspace(1e-3, 10)
    dp_list = np.empty_like(dT_list)
    for i, dT in enumerate(dT_list):
        num.dT = dT
        ST = num.get_Soret(T, rho, x)
        dn = get_dn(ST, dT)
        dp_list[i] = get_dp(dn, dT)

    plt.plot(dT_list, dp_list)
    plt.ylabel(r'$\Delta p$ [Pa]')
    plt.xlabel(r'$\Delta T$ [K]')
    plt.show()

def test_kempers89_condition():
    global T, dT, nT, rho, x, n, num, p, eos
    T_list = np.linspace(300, 500, 30)
    eq1 = np.empty_like(T_list)
    eq2 = np.empty_like(T_list)
    ST_list = np.empty_like(T_list)
    dp_list = np.empty_like(T_list)

    eos = set_HS(eos)
    num.eos = set_HS(num.eos)
    dT_list = [10, 5, 1, 0.5, 0.1]
    norm = LogNorm(vmin=min(dT_list), vmax=max(dT_list))
    cmap = get_cmap('cool')
    for dTi in range(len(dT_list)):
        dT = dT_list[dTi]
        num.dT = dT
        for i, T in enumerate(T_list):
            ST = kin.alpha_T0(T, 1 / rho, x, BH=True) / T # num.get_Soret(T, rho, x)#
            ST *= 1.5
            ST_list[i] = ST[0]
            eq1[i], dp_list[i] = kempers89_condition(ST, dT)

        plt.plot(ST_list, eq1, color=cmap(norm(dT)), label=dT)
        #plt.plot(T_list, eq2, color='b', label='2')
    plt.ylabel('Eq. (1.8)')
    plt.xlabel(r'$T$ [K]')
    plt.legend(title=r'$\Delta T$ [K]')
    plt.title(r''+comps+r' $x_{'+comp1+'}$ = '+str(round(x[0], 2)))
    plt.show()

def compare_kin_to_num(constraint):
    global T, dT, nT, rho, x, n, num, p, eos
    T_list = np.linspace(300, 500, 30)
    eq1 = np.empty_like(T_list)
    eq2 = np.empty_like(T_list)
    ST_list = np.empty_like(T_list)
    ST_num_list = np.empty_like(T_list)
    dp_list = np.empty_like(T_list)

    eos = set_HS(eos)
    #num.eos = set_HS(num.eos)
    dT_list = [10, 5, 1, 0.5, 0.1]
    norm = LogNorm(vmin=min(dT_list), vmax=max(dT_list))
    cmap = get_cmap('cool')
    for dTi in range(len(dT_list)):
        dT = dT_list[dTi]
        num.dT = dT
        for i, T in enumerate(T_list):
            ST_num_list[i] = num.get_Soret(T, rho, x, constraint=constraint)[0]
            ST_list[i] = kin.soret(T, 1 / rho, x, BH=True)[0]

        plt.plot(T_list, ST_list * 1e3, color=cmap(norm(dT)), label=dT)
        plt.plot(T_list, ST_num_list * 1e3, color=cmap(norm(dT)), linestyle='--')
    plt.ylabel(r'$S_{T,1}$ [mK$^{-1}$]')
    plt.xlabel(r'$T$ [K]')
    plt.legend(title=r'$\Delta T$ [K]')
    plt.title(r''+comps+r' $x_{'+comp1+'}$ = '+str(round(x[0], 2)))
    plt.show()

def compare_bulb_ratios(constraint='p'):
    global T, dT, nT, rho, x, n, p, eos
    T_list = np.linspace(300, 500, 30)
    eq1 = np.empty_like(T_list)
    eq2 = np.empty_like(T_list)
    ST_list = np.empty_like(T_list)
    ST_num_list = np.empty_like(T_list)
    dp_list = np.empty_like(T_list)

    eos = set_HS(eos)
    num = NumericSoret(comps)
    num.eos = set_HS(num.eos)
    dT_list = np.linspace(0.01, 15, 30)
    r_list = [10, 5, 1, 0.5, 0.1]
    norm = LogNorm(vmin=min(r_list), vmax=max(r_list))
    cmap = get_cmap('viridis')
    for r in r_list:
        num.bulb_ratios = r
        for i in range(len(dT_list)):
            num.dT = dT_list[i]
            ST_num_list[i] = num.get_Soret(T, rho, x, constraint=constraint)[0]

        plt.plot(dT_list, ST_num_list * 1e3, color=cmap(norm(r)), label=r)

    for i in range(len(dT_list)):
        ST_list[i] = kin.soret(T, 1 / rho, x)[0]
    plt.plot(dT_list, ST_list * 1e3, linestyle='--', color='r')

    plt.ylabel(r'$S_{T,1}$ [mK$^{-1}$]')
    plt.xlabel(r'$\Delta T$ [K]')
    plt.legend(title=r'$V_A / V_B$ [-]')
    plt.title(r''+comps+r' $x_{'+comp1+'}$ = '+str(round(x[0], 2)))
    plt.show()

def compare_89_kin():
    global T, dT, nT, rho, x, n, num, p, eos
    kempers89.set_HS(True)
    T_list = np.linspace(300, 500, 30)
    eq1 = np.empty_like(T_list)
    eq2 = np.empty_like(T_list)
    ST_list = np.empty_like(T_list)
    ST_kemp_list = np.empty_like(T_list)

    for i, T in enumerate(T_list):
        ST_kemp_list[i] = kempers89.get_Soret(T, rho, x)[0]
        ST_list[i] = kin.soret(T, 1 / rho, x, BH=True)[0] / T

    plt.plot(T_list, ST_list * 1e3, label='Kinetic')
    plt.plot(T_list, ST_kemp_list * 1e3, label='Kempers89 (HS)')
    plt.ylabel(r'$S_{T,1}$ [mK$^{-1}$]')
    plt.xlabel(r'$T$ [K]')
    plt.legend()
    plt.title(r''+comps+r' $x_{'+comp1+'}$ = '+str(round(x[0], 2)))
    plt.show()

def plot_helmholtz_3d(ax):
    global T
    ST = kin.soret(T, 1 / rho, x)[0]
    dn0 = get_dn(ST, dT)
    dn1_list = np.linspace(-0.05, 0.055, 30)
    dn2_list = np.linspace(-0.01, 0.17, 30)
    #dn1_list = np.linspace(-0.01 * n[0], 0.01 * n[0], 30)
    #dn2_list = np.linspace(-0.01 * n[1], 0.01 * n[1], 30)
    dn1_list, dn2_list = np.meshgrid(dn1_list, dn2_list)
    A_list = np.empty_like(dn1_list)
    shp = A_list.shape
    for i1 in range(shp[0]):
        for i2 in range(shp[1]):
            dn = np.array([dn1_list[i1, i2], dn2_list[i1, i2]])
            A_list[i1, i2] = get_Helmholtz_dn(dn, dT)

    ax.plot_wireframe(dn1_list, dn2_list, A_list)

def plot_helmholtz_3d_surf(ax):
    global T
    ST = kin.soret(T, 1 / rho, x)[0]
    dn0 = get_dn(ST, dT)
    dn1_list = np.linspace(-0.05, 0.055, 30)
    dn2_list = np.linspace(-0.01, 0.17, 30)
    #dn1_list = np.linspace(-0.01 * n[0], 0.01 * n[0], 30)
    #dn2_list = np.linspace(-0.01 * n[1], 0.01 * n[1], 30)
    dn1_list, dn2_list = np.meshgrid(dn1_list, dn2_list)
    A_list = np.empty_like(dn1_list)
    dp_list = np.empty_like(dn1_list)
    shp = A_list.shape
    for i1 in range(shp[0]):
        for i2 in range(shp[1]):
            dn = np.array([dn1_list[i1, i2], dn2_list[i1, i2]])
            A_list[i1, i2] = get_Helmholtz_dn(dn, dT)
            dp_list[i1, i2] = get_dp(dn, dT)

    surf = ax.plot_(dn1_list, dn2_list, A_list)

def plot_helmholtz_3d_valid_pressure_curve(ax):
    global T
    ST = kin.soret(T, 1 / rho, x)[0]
    dn0 = get_dn(ST, dT)
    dn1_list = np.linspace(-0.05, 0, 30)
    #dn2_list = np.linspace(0.9 * dn0[1], 1.1 * dn0[1], 30)
    #dn1_list = np.linspace(-0.01 * n[0], 0.01 * n[0], 30)
    dn2_list = np.empty_like(dn1_list)
    A_list = np.empty_like(dn1_list)
    shp = A_list.shape
    for i, dn1 in enumerate(dn1_list):
        dn2 = get_dn2(dn1, dT)
        dn2_list[i] = dn2
        dn = np.array([dn1, dn2])
        A_list[i] = get_Helmholtz_dn(dn, dT)

    ax.plot(dn1_list, dn2_list, A_list, color='r', label=r'$\Delta p = 0$')

def scatter_helmholtz_num_kin_kemp(ax):
    models = [num, kin, kempers89, kempers]
    model_names = ['Numeric', 'Kinetic', 'K-89', 'K-HS']
    for model, name in zip(models, model_names):
        set_HS(eos, False)
        if name == 'Numeric':
            dn = model.get_dn(T, rho, x, constraint=None)
            dp = get_dp(dn, dT)
            (nA, nB, xA, xB, TA, TB) = get_n_x_and_T(dn, dT)
            pA, = eos.pressure_tv(TA, V_bulb, nA)
            vA, = eos.specific_volume(TA, pA, xA, 1)
            pB, = eos.pressure_tv(TB, V_bulb, nB)
            vB, = eos.specific_volume(TB, pB, nB, 1)
            print('Mole conservation :', n, sum(nA), sum(nB))
            print('Numeric Delta V =', vB * sum(nB) - vA * sum(nA))
            print('Numeric Delta p =', pB - pA)
        else:
            ST = model.get_Soret(T, rho, x)
            dn = get_dn(ST, dT)
        set_HS(eos, False)
        A = get_Helmholtz_dn(dn, dT)
        ax.scatter(dn[0], dn[1], A, label=name)

def compare_89_num_3d():
    global T
    T = 300
    fig = plt.figure()
    ax = fig.add_subplot(projection='3d')
    plot_helmholtz_3d(ax)
    plot_helmholtz_3d_valid_pressure_curve(ax)
    scatter_helmholtz_num_kin_kemp(ax)
    ax.set_xlabel(r'$\Delta n_1$')
    ax.set_ylabel(r'$\Delta n_2$')
    plt.legend()
    plt.show()

def compare_89_num_2d():
    global T, rho, dT
    ST = kin.soret(T, 1 / rho, x)[0]
    dn0 = get_dn(ST, dT)
    dn1_list = np.linspace(-4, 0, 30)
    dn2_list = np.empty_like(dn1_list)
    A_list = np.empty_like(dn1_list)
    p_list = np.empty_like(dn1_list)
    dT_styles = ['-', '--', ':']
    dT_list = [1, 0.5, 0.1]
    fig, axs = plt.subplots(2, 2, figsize=(10, 6), sharex='col', gridspec_kw={'width_ratios': (1, 0.1)})
    for dT_idx, dT_val in enumerate(dT_list):
        dT = dT_val
        num.dT = dT
        for i, dn1 in enumerate(dn1_list):
            dn2 = get_dn2(dn1, dT)
            dn2_list[i] = dn2
            dn = np.array([dn1, dn2])
            nA, nB, _, _, TA, TB = get_n_x_and_T(dn, dT)
            pA, = eos.pressure_tv(TA, V_bulb, nA)
            pB, = eos.pressure_tv(TB, V_bulb, nB)

            p_list[i] = pA
            A_list[i] = get_Helmholtz_dn(dn, dT)

        num.set_HS(False)
        kempers89.set_HS(False)
        set_HS(eos, False)

        models = [num, kin, kempers89, kempers]
        model_names = ['Numeric', 'Kinetic', 'K-89', 'K-HS']
        markers = ['o', 'v', 'x', '+']

        p_cmap = get_cmap('cividis')
        p_norm = Normalize(vmin=min(p_list), vmax=max(p_list))
        plt.sca(axs[1, 0])
        for i in range(len(dn1_list) - 1):
            axs[1, 0].plot(dn1_list[i:i + 2], A_list[i:i + 2], color=p_cmap(p_norm(p_list[i])), linestyle=dT_styles[dT_idx])
            axs[0, 0].plot(dn1_list[i:i + 2], p_list[i:i + 2], color=p_cmap(p_norm(p_list[i])), linestyle=dT_styles[dT_idx])
        for model, name, marker in zip(models, model_names, markers):
            if name != 'K-HS':
                set_HS(eos, False)
                ST = model.get_Soret(T, rho, x)
                dn = get_dn(ST, dT)
                set_HS(eos, False)
                A = get_Helmholtz_dn(dn, dT)
                axs[1, 0].scatter(dn[0], A, label=(dT_idx > 0)*'_'+name, marker=marker)
                nA, nB, _, _, TA, TB = get_n_x_and_T(dn, dT)
                pA, = eos.pressure_tv(TA, V_bulb, nA)
                pB, = eos.pressure_tv(TB, V_bulb, nB)
                if abs(pA - pB) > FLTEPS:
                    axs[0, 0].scatter(dn[0], pA, marker=marker, color='r')
                    axs[0, 0].scatter(dn[0], pB, marker=marker, color='g')
                else:
                    axs[0, 0].scatter(dn[0], pA, marker=marker, color='b')

    plt.colorbar(ScalarMappable(norm=p_norm, cmap=p_cmap), cax=axs[1, 1], ax=axs[1, 0])
    axs[1, 0].legend()
    plt.show()

def plot_valid_dn():
    T = 300
    ST = kin.soret(T, 1 / rho, x)[0]
    dn0 = get_dn(ST, dT)
    #dn1_list = np.linspace(0.9 * dn0[0], 1.1 * dn0[0], 30)
    #dn2_list = np.linspace(0.9 * dn0[1], 1.1 * dn0[1], 30)
    dn1_list = np.linspace(-n[0], n[0], 30)
    dn2_list = np.empty_like(dn1_list)
    for i, dn1 in enumerate(dn1_list):
        dn2_list[i] = get_dn2(dn1, dT)

    plt.plot(dn1_list, dn2_list + dn1_list)
    plt.show()

compare_89_num_2d()

#plot_dp()
#compare_kin_to_num('p')
#test_kempers89_condition()