import constants as c
import numpy as np
from numerikk import Z_newton, K_func, rk4_step_func
import matplotlib.pyplot as plt
from scipy.integrate import ode

#dH/dt = C1*H^2 + C2*H + C3
class dH_dt_func():
    def __init__(self, Hv = 1, solver = Z_newton, plot_freq = 10):
        C0 = -(100 * c.M_H / c.M)
        self.C1 = C0 * ((2 * np.power(c.f_H * c.K_H, 2) * c.G_m) / c.p_in)
        self.C2 = C0 * ((c.k_ts * c.A_s * c.rho) / (100 * c.M_H))
        self.C3 = - self.C2 * Hv

        self.solver = solver
        self.call_counter = 0
        self.plot_freq = plot_freq
        self.max_z = 0

        self.fig, self.ax = plt.subplots()

    def __call__(self, H, plot_z = True):
        Z = self.solver(H)

        if plot_z == True:
            self.call_counter += 1
            if self.call_counter == self.plot_freq:
                self.call_counter = 0

                K = K_func(H)
                z_line = np.linspace(0.99*Z - 0.00025, 1.01*Z+ 0.00025, 30)
                lines = self.ax.plot(z_line, z_line + np.log(1-z_line) + K, label = 'H = '+str(round(H, 3)))
                self.ax.plot(Z, Z + np.log(1-Z) + K, marker = 'o', color = lines[-1].get_color())

                if z_line[-1] > self.max_z:
                    self.max_z = z_line[-1]

        return (c.C1 * np.power(H*Z, 2)) + (c.C2 * H) + self.C3

def H_func(H0 = 20, Hv = 1, dt = 1, t_end = 100, Z_solver = Z_newton, step_func = rk4_step_func, plot_freq=None):
    t_list = np.arange(0, t_end, step=dt)
    H_values = np.zeros(len(t_list))
    H_values[0] = H0

    if plot_freq == None:
        dH_dt = dH_dt_func(Hv = Hv, solver = Z_solver, plot_freq=(len(t_list))//3)
    else:
        dH_dt = dH_dt_func(Hv=Hv, solver=Z_solver, plot_freq=plot_freq)

    for i in range(len(H_values)-1):
        H_values[i+1] = step_func(H_values[i], dt, dH_dt, solver=Z_solver)

    dH_dt.ax.remove() #for å gjøre matplotlib fornøyd
    plt.close(dH_dt.fig)
    return H_values, t_list, dH_dt.max_z, dH_dt.ax

def bb_H_func(H0 = 20, Hv = 1, dt = 1, t_end = 100, Z_solver = Z_newton, integrator = c.scipy_integrator_method):
    dH_dt = dH_dt_func(Hv)
    bb_dH_dt = lambda H: dH_dt(H)
    ode_solver = ode(bb_dH_dt).set_integrator(integrator)
    ode_solver.set_initial_value(H0)

    t_list = np.arange(0, t_end, step=dt)
    H_values = np.zeros(len(t_list))

    for i in range(len(H_values)):
        H_values[i] = ode_solver.integrate(t_list[i])

    dH_dt.ax.remove()  # for å gjøre matplotlib fornøyd
    plt.close(dH_dt.fig)
    return H_values, t_list, dH_dt.max_z, dH_dt.ax

