import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm
from matplotlib import colors as mpl_c

cmap = cm.get_cmap('inferno')

def plot_c_rel_av_DT():

    S_T_axis = np.logspace(-4,-2, 10)

    D_T_axis = np.linspace(-30,-10, 100)

    DC = lambda S_T, D_T: 2 * (np.exp(-S_T * D_T) - 1)/(1+np.exp(-S_T * D_T))

    fig = plt.figure(figsize=(10,5))
    ax = plt.subplot(111)

    plots = [None for x in S_T_axis]

    for i, ST in enumerate(S_T_axis):
        plots[i], = plt.plot(-D_T_axis, DC(ST, D_T_axis), color = cmap(-np.log(ST)/np.log(S_T_axis[0]) + 1),
                 label = str(round(np.log10(ST),1)))

    box = ax.get_position()
    ax.set_position([box.x0, box.y0, box.width*0.85, box.height])

    plt.yscale('log')
    plt.ylabel(r'$\frac{\Delta c}{\langle c \rangle}}$', fontsize = 16)
    plt.xlabel(r'$-\Delta T$ [K]')

    plt.legend(handles = plots, title = r'$\log{S_T}$', bbox_to_anchor=(1.025, 1), loc = 'upper left')
    plt.title('Steady state relative concentration gradient as a function of temperature gradient\n'
              'for different orders of magnitude of the Soret coefficient')

def plot_alt(style = 'log'):

    color_scale = 1.2

    D_T_axis = -np.logspace(np.log10(400), np.log10(1), 7)
    C_bar_list = np.linspace(5,10,4)

    if style == 'lin':
        ST_list = np.linspace(-0.05,0.05, 100)
    else:
        ST_list = np.logspace(-3,-1,100)
    DC_func = lambda C_bar, ST, DT: 2*C_bar * (1 - np.exp(ST*DT))/(1 + np.exp(ST*DT))

    styles = ['-', '--', '-.', ':']

    fig, ax = plt.subplots(1,3, figsize = (10,5))

    plt.sca(ax[0])
    for i, C_bar in enumerate(C_bar_list):
        for j, DT in enumerate(D_T_axis):
            plt.plot(ST_list, DC_func(C_bar, ST_list, DT), color = cmap(np.log10(-DT)/(color_scale*np.log10(-D_T_axis[0]))),
                                       linestyle = styles[i])

    #Generating legend
    plt.sca(ax[1])
    for i, C_bar in enumerate(C_bar_list):
        plt.plot(C_bar_list, linestyle = styles[i], color = 'black', label = str(round(C_bar, 2)))
    plt.sca(ax[2])
    for DT in D_T_axis:
        plt.plot(D_T_axis, color = cmap(np.log10(-DT)/(color_scale*np.log10(-D_T_axis[0]))), label = str(int(-DT)))

    #Placing legend
    box0 = ax[0].get_position()
    box2 = ax[2].get_position()

    ax[0].set_position([box0.x0, box0.y0, box2.x0 + 0.4*box2.width, box0.height])
    box0 = ax[0].get_position()

    ax[1].set_position([box0.x0 + box0.width, box0.y0 + box0.height + 0.015, 0,0])
    ax[1].set_axis_off()
    ax[2].set_position([box0.x0 + box0.width ,box0.y0 + box0.height - 0.25, 0,0])
    ax[2].set_axis_off()

    #Displaying plot
    plt.sca(ax[0])

    if style == 'lin':
        plt.xlim(0, 0.05)
        plt.ylim(0,15)
    else:
        plt.xscale('log')
        plt.xlim(0.001, 0.1)
        plt.ylim(0, 5)
    plt.xlabel(r'$S_T$ [K$^{-1}$]')
    plt.ylabel(r'$\Delta C$ [mol/L]')
    ax[1].legend(title = r'$\langle c \rangle$ [mol/L]',
                 bbox_to_anchor = (1,1), loc = 'upper left')
    ax[2].legend(title = r'$-\Delta T$ [K]', bbox_to_anchor = (1,1), loc = 'upper left')
    #plt.suptitle('Steady state concentration gradient as a function of Soret-coefficient\n'
    #             'for different temperature gradients and total concentrations')

def plot_DC_rel_av_ST():
    D_T_axis = -np.logspace(2, 1, 50)
    colors = np.linspace(min(D_T_axis),max(D_T_axis),200)

    DC_rel = np.logspace(-5,np.log10(2.2),100)
    S_T = lambda DC_r, DT: (1/DT) * np.log((2 - DC_r)/(2 + DC_r))

    fig, axs = plt.subplots(1,2, figsize = (10,5))

    ST_list = np.linspace(0, 0.05, 100)
    DC_rel_func = lambda ST, DT: 2 * (1 - np.exp(ST*DT))/(1 + np.exp(ST*DT))

    plt.sca(axs[0])
    for DT in D_T_axis:
        plt.plot(ST_list, DC_rel_func(ST_list, DT), color = cmap(np.log10(-DT)/(1.2*np.log10(-D_T_axis[0]))), label = r'$\Delta T = $'+str(int(DT))+' K')

    #Generate colorbar
    plt.sca(axs[1])
    for c in colors:
        plt.hlines(c,0,1 ,color = cmap(np.log10(-c)/(1.2*np.log10(-colors[0]))))
    plt.ylim(min(colors), max(colors))
    plt.xlim(0,1)
    axs[1].get_xaxis().set_visible(False)
    axs[1].yaxis.tick_right()

    #Place colorbar
    box0 = axs[0].get_position()
    axs[0].set_position([box0.x0, box0.y0, 1.8*box0.width, box0.height])
    axs[1].set_position([box0.x0 + 1.9*box0.width, box0.y0, 0.05, box0.height])
    axs[1].set_title(r'$\Delta T$ [K]')

    plt.sca(axs[0])
    plt.xlim(0,0.05)
    plt.ylim(0,2)
    #plt.xscale('log')
    plt.ylabel(r'$\frac{\Delta c}{\langle c \rangle}$ [-]', fontsize = 16)
    plt.xlabel(r'$S_T$ [K$^{-1}$]')
    #plt.suptitle('Steady state relative concentration gradient versus Soret coefficient\n'
    #             'for different temperature gradients')

def flux_krefter():
    dcdx_ax = np.linspace(0,5,3)
    ST_ax = np.linspace(-0.05,0.05, 5)
    avg_c = 1

    dTdx_ax = np.linspace(0,20,100)

    JD12 = lambda dc,ST,dT: dc + ST*dT

    styles = ['-.', '--', '-', ':', (0, (3, 2, 1, 2, 1, 2))]

    fig, ax = plt.subplots(1, 3, figsize=(10, 5))

    plt.sca(ax[0])
    for i, ST in enumerate(ST_ax):
        for dcdx in dcdx_ax:
            plt.plot(dTdx_ax, JD12(dcdx, ST, dTdx_ax),
                     color = cmap(dcdx/(1.2*dcdx_ax[-1])),
                     linestyle = styles[i])

    # Generating legend
    plt.sca(ax[1])
    for i, ST in enumerate(ST_ax):
        plt.plot(ST, linestyle=styles[i], color='black', label=str(round(ST, 2)))
    plt.sca(ax[2])
    for dcdx in dcdx_ax:
        plt.plot(dcdx_ax, color=cmap(dcdx/(1.2*dcdx_ax[-1])), label=str(round(dcdx,1)))

    # Placing legend
    box0 = ax[0].get_position()
    box2 = ax[2].get_position()

    ax[0].set_position([box0.x0, box0.y0, box2.x0 + 0.4 * box2.width, box0.height])
    box0 = ax[0].get_position()

    ax[1].set_position([box0.x0 + box0.width, box0.y0 + box0.height + 0.015, 0, 0])
    ax[1].set_axis_off()
    ax[2].set_position([box0.x0 + box0.width, box0.y0 + box0.height - 0.3, 0, 0])
    ax[2].set_axis_off()

    ax[1].legend(title=r'$c_1S_T$',
                 bbox_to_anchor=(1, 1), loc='upper left')
    ax[2].legend(title=r'$\frac{dc}{dx}$ [mol m$^{-4}$]', bbox_to_anchor=(1, 1), loc='upper left')

    plt.sca(ax[0])
    plt.ylabel(r'$\frac{J_1}{D_{12}}$ [mol m$^{-4}$]', fontsize = 12)
    plt.xlabel(r'$\frac{dT}{dx}$ [Km$^{-1}$]')
    plt.suptitle('Mass flux as a function of temperature gradient for different concentration gradients\n'
                 'and concentration-Soret coefficient products')

# flux_krefter()
# plt.savefig('fig1')
plot_DC_rel_av_ST()
plt.savefig('fig2_lin')
plt.show()
# plot_c_rel_av_DT()
# plt.savefig('fig3')
plot_alt(style = 'log')
plt.savefig('fig4_lin')
plt.show()