from plot import plot
import matplotlib.pyplot as plt
import ipywidgets.widgets as wid
from IPython.display import display

def main(CR = 1, HR = 1, T_iso = 800, t_iso = 100, show = 'All'):


    if show == 'All':
        fig, subs = plt.subplots(2,2, figsize=[25,17])

        plt.sca(subs[0,0])
        plot(CR=CR, T_iso = T_iso, t_iso = t_iso, HR = 'var')
        plt.xlabel(r'Heating rate [K$s^{-1}$]', fontsize=16)
        plt.ylabel(r'Average grain size [$\mu$m]', fontsize=16)
        plt.tick_params(labelsize=15)

        plt.sca(subs[0,1])
        plot(CR='var', T_iso=T_iso, t_iso=t_iso, HR=HR)
        plt.xlabel(r'Cooling rate [K$s^{-1}$]', fontsize=16)
        plt.ylabel(r'Average grain size [$\mu$m]', fontsize=16)
        plt.tick_params(labelsize=15)

        plt.sca(subs[1,0])
        plot(CR=CR, T_iso='var', t_iso=t_iso, HR=HR)
        plt.xlabel('Isothermal temperature [°C]', fontsize=16)
        plt.ylabel(r'Average grain size [$\mu$m]', fontsize=16)
        plt.tick_params(labelsize=15)

        plt.sca(subs[1,1])
        plot(CR=CR, T_iso=T_iso, t_iso='var', HR=HR)
        plt.xlabel('Time at isotherm [s]', fontsize=16)
        plt.ylabel(r'Average grain size [$\mu$m]', fontsize=16)
        plt.tick_params(labelsize=15)

    else:
        fig,subs = plt.subplots(1,1,figsize=[25, 10])
        plt.sca(subs)
        plt.tick_params(labelsize=15)

    if show == 'HR':
        plot(CR=CR, T_iso=T_iso, t_iso=t_iso, HR='var')
        plt.xlabel(r'Heating rate [K$s^{-1}$]', fontsize=16)
        plt.ylabel(r'Average grain size [$\mu$m]', fontsize=16)
    elif show == 'CR':
        plot(CR='var', T_iso=T_iso, t_iso=t_iso, HR=HR)
        plt.xlabel(r'Cooling rate [K$s^{-1}$]', fontsize=16)
        plt.ylabel(r'Average grain size [$\mu$m]', fontsize=16)
    elif show == 'T_iso':
        plot(CR=CR, T_iso='var', t_iso=t_iso, HR=HR)
        plt.xlabel('Isothermal temperature [°C]', fontsize=16)
        plt.ylabel(r'Average grain size [$\mu$m]', fontsize=16)
    elif show == 't_iso':
        plot(CR=CR, T_iso=T_iso, t_iso='var', HR=HR)
        plt.xlabel('Time at isotherm [s]', fontsize=16)
        plt.ylabel(r'Average grain size [$\mu$m]', fontsize=16)

    plt.show()

def run():
    show = wid.Dropdown(options = [('Heating rate', 'HR'), ('Cooling rate', 'CR'),
                                   ('Isothermal temperature', 'T_iso'), ('Time at isotherm', 't_iso'), ('All', 'All')],
                        description = 'Display')
    CR = wid.FloatSlider(min = 0.2, max = 2.5, step = 0.1, continuous_update = False)
    HR = wid.FloatSlider(min=0.2, max=2.5, step=0.1, continuous_update=False)
    T_iso = wid.IntSlider(min = 900, max = 1400, step = 5, continuous_update = False)
    t_iso = wid.IntSlider(min=1, max=500, step=1, continuous_update=False)

    controls = wid.VBox([show, wid.HBox([CR, HR, T_iso, t_iso])])
    output = wid.interactive(main, CR = CR, HR = HR, T_iso = T_iso, t_iso = t_iso, show = show)
    plots = output.children[-1]

    display(wid.VBox([controls, plots]))