import matplotlib.pyplot as plt
from differensial import H_func, bb_H_func
from numerikk import Z_solver_dict, step_func_dict
import ipywidgets as wid
from IPython.display import display
import numpy as np

def main(H0 = 20, Hv = 1, dt = 1, t_end = 1000, Z_solver = 'Newton', step_func = 'RK4'):
    Z_solver = Z_solver_dict[Z_solver]



    if 'Implicit Euler, Bisection' in step_func:
        if Hv == 0:
            print('The bisection root finder has a hard time in implicit euler if Hv = 0\n'
                  'so im putting Hv = 1 to avoid negative numbers in log-function when finding Z.\n'
                  'Use Implicit Euler with Newton if you want Hv = 0.')
            Hv = 1
        elif Hv > H0:
            print('Implicit Euler with Bisection root finder does not function if Hv > H0.\n'
                  'So im putting Hv = H0 - 1\n'
                  'Use Implicit Euler with Newton if you want Hv > 0.')
            Hv = H0 - 1
        plot_freq = (t_end//np.sqrt(dt)) * 5
    elif step_func == 'Implicit Euler, Newton':
        plot_freq = (t_end // np.power(dt,0.8))

    else:
        plot_freq = None

    step_func = step_func_dict[step_func]
    H_vals, t_vals, max_z, z_ax = H_func(H0=H0, Hv=Hv, dt=dt, t_end=t_end, Z_solver = Z_solver, step_func = step_func, plot_freq=plot_freq)

    fig, axs = plt.subplots(1, 2, figsize=[20, 7])
    z_lines = [line for line in z_ax.get_lines()]
    for line in z_lines:
        x, y = line.get_data()
        axs[1].plot(x,y, color = line.get_color(), marker = line.get_marker(), label=line.get_label())

    axs[1].plot([0,max_z], [0,0], color='Black')
    axs[1].set_xlabel('Z [–]')
    axs[1].set_ylabel(r'$f(Z) = Z-ln{(1-Z)} + K(H)$')
    axs[1].set_title(r'Markers display numerical solution of $f(Z) = 0$'+
                     '\nTolerance is dynamic, and depends on H and the chosen solver')

    axs[1].legend()

    axs[0].plot(t_vals, H_vals, label = '[%H]')
    axs[0].plot(t_vals, [Hv for t in t_vals], linestyle = '--', color = 'orange', label = r'H_v')
    axs[0].grid(linestyle = '-')
    axs[0].legend()
    axs[0].set_xlabel('t [s]')
    axs[0].set_ylabel('%H [p.p.m]')

    plt.show()

def run():
    H0 = wid.IntSlider(min = 1, max = 20, step = 1, continuous_update = False)
    Hv = wid.IntSlider(min = 0, max = 10, step = 1, continuous_update = False)
    dt = wid.IntSlider(min = 1, max = 100, step = 1, continuous_update = False)
    Z_solver = wid.Dropdown(options=Z_solver_dict.keys(), description = 'Solver for Z')
    step_func = wid.Dropdown(options=step_func_dict.keys(), description = 'ODE-solver')

    controls = wid.VBox([wid.HBox([Z_solver, step_func]), wid.HBox([H0, Hv, dt])])
    output = wid.interactive(main, H0 = H0, Hv = Hv, dt = dt, Z_solver = Z_solver, step_func = step_func)

    display(wid.VBox([controls, output.children[-1]]))
