import ipywidgets as wid
from IPython.display import display
import matplotlib.pyplot as plt
import numpy as np
import needle_drop
import mc_erf
import ran_test

import time

ran_vals = np.random.random(20000)
below = np.where(ran_vals < 0.5)[0]
above = np.where(ran_vals >= 0.5)[0]

def main(case = None, needle_finder = None, end_N_needle = 1000, max_N_erf = 10000, erf_b = 1,
         test_random_N = 1000):
    if case == 'Needle drop':
        fig, axs = plt.subplots(1,2,figsize=[15,7])
        needle_finder = needle_drop.needle_finder_dict[needle_finder]

        plt.sca(axs[0])
        needle_drop.plot_convergence(100,end_N_needle, needle_finder)

        plt.sca(axs[1])
        needle_drop.plot_integrand(end_N_needle, needle_finder)

    elif case == 'Error function':
        plt.figure(figsize=[15, 7.5])
        mc_erf.plot(N = max_N_erf, b = erf_b)

    elif case == 'Test np.random.random()':
        fig, axs = plt.subplots(2, 2, figsize = [15,7])

        t0 = time.process_time()
        plt.sca(axs[0,0])
        ran_test.plot_avg(test_random_N, ran_vals)

        plt.sca(axs[0,1])
        ran_test.plot_portion(test_random_N, ran_vals)

        plt.sca(axs[1,1])
        ran_test.plot_hist(test_random_N, ran_vals)

        plt.sca(axs[1,0])
        ran_test.plot_points(test_random_N,ran_vals, below, above)
        print(time.process_time() - t0)

    return case

def run():
    case = wid.Dropdown(options=['Needle drop', 'Error function', 'Test np.random.random()'],
                        description = 'Case')
    out = wid.Output()

    display(case)

    needle_finder = wid.Dropdown(options=needle_drop.needle_finder_dict.keys(), description = 'Method')
    end_N_needle = wid.IntSlider(min=1000, max=100000, step=1000, continuous_update=False, description = 'Max N')
    max_N_erf = wid.IntSlider(min=10000, max=100000, step=1000, continuous_update=False, description = 'Max N')
    test_random_N = wid.IntSlider(min=1000, max=20000, step=500, continuous_update=True, description = 'Max N')
    erf_b = wid.FloatSlider(min = 0.1, max = 2.5, step = 0.1, continuous_update = False, description = 'erf() parameter')

    output = wid.interactive(main, case=case,
                             needle_finder=needle_finder, end_N_needle=end_N_needle,
                             max_N_erf=max_N_erf, erf_b = erf_b,
                             test_random_N = test_random_N)

    def display_case(case):
        case = case['new']
        if case == 'Needle drop':
            case_controls = wid.HBox([needle_finder, end_N_needle])

        elif case == 'Error function':
            case_controls = wid.HBox([max_N_erf,erf_b])

        elif case == 'Test np.random.random()':
            case_controls = test_random_N

        out.clear_output()
        with out:
            display(wid.VBox([case_controls, output.children[-1]]))

    case.observe(display_case, names='value')

    display(out)





