import numpy as np
import constants as c
from scipy.optimize import fsolve

K_func = lambda H: c.K/H

#z_n+1 = z_n + step
def Z_newton(H):
    K = K_func(H)

    tolerance = K

    z = 0.5
    step = ((1 - z) * (z + np.log(1-z) + K)) / z

    while abs(step) > tolerance:
        z += step
        step = ((1 - z) * (z + np.log(1-z) + K)) / z

    return z

def Z_bisection(H):
    K = K_func(H)
    f = lambda z: z + np.log(1-z) + K
    
    ledge = 0
    redge = 0.999999999
    center = (ledge + redge) /2

    error = (redge - ledge) / 2
    tolerance = K
    
    f_l = f(ledge)
    f_r = f(redge)
    f_c = f(center)
    
    while error > tolerance:
        if f_l * f_c < 0:
            redge, f_r = center, f_c
            center = (ledge + redge) /2
            f_c = f(center)

        elif f_c * f_r < 0:
            ledge, f_l = center, f_c
            center = (ledge + redge) / 2
            f_c = f(center)

        elif f_c == 0:
            return center

        else:
            raise Exception('ledge : '+str(ledge) + ' '+ str(f_l)
                            +'\nredge = '+str(redge)+ ' '+ str(f_r)
                            +'\ncenter = '+ str(center)+ ' '+ str(f_c))

        error = error/2

    return center

def Z_fsolve(H):
    K = K_func(H)
    z_func = lambda z: z + np.log(1-z) + K

    Z = fsolve(z_func, 0.5)[0]

    return Z

#Tar inn y_n, dy/dx og steglengde
#returnerer y_{n+1} etter ett rk4-steg
def rk4_step_func(y, h, dy_dx, solver = None):
    k1 = dy_dx(y)
    k2 = dy_dx(y + h* k1 / 2)
    k3 = dy_dx(y + h* k2 / 2)
    k4 = dy_dx(y + h*k3)

    return y + (h/6) * (k1 + 2*k2 + 2*k3 + k4)

def heuns_step_func(y, h, dy_dx, solver = None):
    next_y_merk = y + h * dy_dx(y)
    return y + (h/2) * (dy_dx(y) + dy_dx(next_y_merk))

def impl_euler_bisec_step_func(y, h, dy_dx, solver = None):
    nullfunc = lambda next_y: y + (h/2)*( dy_dx(y) + dy_dx(next_y) ) - next_y

    redge = y + 0.001
    ledge = y - 0.001
    while nullfunc(ledge) < 0:
        ledge -= 0.01
    center = 0.5*(ledge + redge)

    error = (redge - ledge) / 2
    tolerance = 10**(-5)

    f_l = nullfunc(ledge)
    f_r = nullfunc(redge)
    f_c = nullfunc(center)

    while error > tolerance:
        if f_l * f_c < 0:
            redge, f_r = center, f_c
            center = (ledge + redge) / 2
            f_c = nullfunc(center)

        elif f_c * f_r < 0:
            ledge, f_l = center, f_c
            center = (ledge + redge) / 2
            f_c = nullfunc(center)

        elif f_c == 0:
            return f_c

        error = error/2

    return center

def impl_euler_newton_step_func(H, h, dy_dx, solver = Z_newton):
    nullfunc = lambda next_H: H + (h / 2) * ( dy_dx(H) + dy_dx(next_H) ) - next_H
    nullfunc_deri = lambda next_H: -1 + (h / 2) * ( ( (4*dy_dx.C1* c.K *(1 - solver(next_H))) / next_H ) + dy_dx.C2)

    step = lambda this_H: nullfunc(this_H)/nullfunc_deri(this_H)

    tol = K_func(H)
    next_H = H
    this_step = step(next_H)

    while abs(this_step) > tol:
        next_H = next_H - this_step
        this_step = step(next_H)

    return next_H


Z_solver_dict = {'Newton' : Z_newton, 'Bisection': Z_bisection, 'scipy.fsolve': Z_fsolve}
step_func_dict = {'RK4' : rk4_step_func, 'Heuns' : heuns_step_func, 'Implicit Euler, Newton': impl_euler_newton_step_func,
                  'Implicit Euler, Bisection' : impl_euler_bisec_step_func}