from numpy import pi, sin, cos, log, exp
import numpy as np, matplotlib.pyplot as plt
import time

def los_xcosx():
    g = lambda x: cos(x)

    tol = 1e-10
    x = 1
    i = 0
    while abs(x - g(x)) > tol:
        #print(x,' '*(20-len(str(x))), i)
        x = g(x)
        i+=1
    print(x, i)

def newton2D():
    f = lambda x: x[0]**2 + x[1]**2 - 4
    fx = lambda x: 2*x[0]
    fy = lambda x: 2*x[1]

    g= lambda x: x[0]*x[1] - 1
    gx = lambda x: x[1]
    gy = lambda x: x[0]

    det = lambda x: fx(x)*gy(x) - gx(x)*fy(x)

    matrise = lambda x: (1/det(x))*np.array([[gy(x), -fy(x)],
                                                 [-gx(x), fx(x)]])
    funcval = lambda x: np.array([f(x), g(x)])

    next = lambda x: x - np.dot(matrise(x), funcval(x))

    tol = 1e-15

    x0 = np.array([2,0])
    x1 = next(x0)

    diff = sum([abs(a-b) for a,b in zip(x0,x1)])
    while diff > tol:
        x0,x1 = x1, next(x1)
        diff = sum([(a - b)**2 for a, b in zip(x0, x1)])

    print(diff)
    print(x1)

def fikspunkt():
    g = lambda x: exp(1/x)

    x0 = 2
    x1 = g(x0)

    tol = 1e-10

    diff = abs(x0-x1)

    i=1
    while diff > tol:
        x0,x1 = x1, g(x1)
        diff=abs(x0-x1)
        i += 1

    print(i)
    print(diff)
    print(x1)

def gitter(a,b,n):
    ekstremal = np.array([cos(i*pi/n) for i in range(n+1)])
    nullpunkt = np.array([cos(pi*(2*i +1)/(2*n+2)) for i in range(n+1)])

    y = [0 for x in range(n+1)]

    ab_ekstremal = a + ((b-a) * ((ekstremal + 1)/2))


    #plt.scatter(ekstremal, y, label = 'ekspkt')
    #plt.scatter(nullpunkt, y, label='nullpkt')
    plt.scatter(ab_ekstremal, y, label = 'ab_eks', marker = '.')
    plt.legend()
    plt.show()

def eks_euler(x,y,h,f):
    return y + h*f(x,y)

def impl_euler(x,y,h):
    return y/((1-h*2*x))

def rk4(x,y,h,f):
    k1 = f(x,y)
    k2 = f(x+h/2, y + (h*k1)/2)
    k3 = f(x+h/2, y + (h*k2)/2)
    k4 = f(x + h, y + h*k3)

    return y + (h/6)*(k1 + 2*k2 + 2*k3 + k4)

def ode(n, h, analyse = True):
    f = lambda x, y: (log(x) + 1)*exp(x*log(x))
    g = lambda x, y: 2 * x * y
    y = 1
    x = 0.01

    ekspl_euler = np.zeros(n)
    i_euler = np.zeros(n)
    heuns = np.zeros(n)
    rk4_punkter = np.zeros(n)

    x_gitter = np.array([x + h * i for i in range(n)])

    if analyse:
        analytisk = [exp(x ** 2) for x in x_gitter]
        plt.plot(x_gitter,analytisk, label = 'analytisk', marker='x')

    for i in range(n):
        ekspl_euler[i] = y
        y = eks_euler(x,y,h,f)
        x += h

    plt.plot(x_gitter,ekspl_euler, label = 'e_euler')

    x,y = 0.01,1

    for i in range(n):
        i_euler[i] = y
        x += h
        y = impl_euler(x,y,h)

    #plt.plot(x_gitter+h, i_euler, label='impl_euler')

    x,y = 0.01,1

    for i in range(n):
        rk4_punkter[i] = y
        y = rk4(x,y,h,f)
        x += h

    plt.plot(x_gitter,rk4_punkter, label = 'rk4')

    plt.legend()
    plt.show()

