import numpy as np, matplotlib.pyplot as plt

def trapes(a,b,func = lambda x : np.exp(x)):
    y0 = func(a)
    y1 = func(b)

    return (b-a)*(y0 + y1)/2

def flytt_gitter(x_points,a,b):
    return a + (b-a)*((x_points+1)/2)

def lagrange_polynom(x, x_points):
    polynomer = [1 for point in x_points]
    for i in range(len(polynomer)):
        for k in range(len(x_points)):
            if k != i:
                polynomer[i] *= (x - x_points[k])/(x_points[i]-x_points[k])

def gen_chevbyesky_ekstremalgitter(n, a = 0, b = 0):
    x_points = np.zeros(n)

    n = n-1
    for i in range(n+1):
        x_points[i] = np.cos(np.pi*i/n)

    if a or b:
        x_points = flytt_gitter(x_points,a,b)

    return x_points

def gauss_lobatto(n, a = 0, b = 0 ):
    p1 = 1
    p2 = np.sqrt((5/11) + (2/11)*np.sqrt(5/3))
    p3 = np.sqrt((5/11) - (2/11)*np.sqrt(5/3))
    if n == 7:
        x_points = [-p1, -p2, -p3, 0, p3, p2, p1]
        if a and b:
            x_points = flytt_gitter(x_points, a = a, b = b)
        return x_points

def gen_clenshaw_curtis_vekter(n):
    vekter = np.zeros(n)

    vekter[0] = 1/((n-1)**2)
    vekter[-1] = 1 / ((n - 1) ** 2)

    for i in range(1, n-1):
        verdi_av_sum = 0
        for j in range(1,int((n-2)/2)):
            verdi_av_sum += (2/(4 * (j**2) - 1)) * np.cos((2*j*i*np.pi)/(n-1))

        vekter[i] = (2/(n-1))*(1-verdi_av_sum)

    return vekter

def clenshaw_curtis_kvadratur(a, b, n, func=lambda x: np.exp(x), tol=0):
    if n % 2:
        print('n må være partall!')
        return 0

    if tol == 0:
        tol = 1e-5

    endring = tol + 1
    quad = 0

    while endring > tol:
        n += 2

        x_points = gen_chevbyesky_ekstremalgitter(n)
        vekter = gen_clenshaw_curtis_vekter(n)

        ny_quad = 0
        for i in range(n):
            ny_quad += func(x_points[i]) * vekter[i]

        endring = abs(quad - ny_quad)
        quad = ny_quad

    print('n :', n)
    print('Kvadratur :', quad)
    print('Siste endring :', endring)

def oppg1():
    tol = 1e-10
    feil = tol+1
    n = 1
    integral = np.exp(1)-np.exp(-1)

    while feil > tol:
        n+=1
        x_points = np.linspace(-1,1,n)
        quad = 0
        for i in range(0,len(x_points)-1):
            quad += trapes(x_points[i], x_points[i+1])

        feil = abs(quad-integral)

    print('n :', n)
    print('Kvadratur :', quad)
    print('Integral :', integral)

def oppg2():
    clenshaw_curtis_kvadratur(-1,1,2, func= lambda x : np.exp(x))

def oppg3():
    clenshaw_curtis_kvadratur(-1,1,2, func = lambda x : np.exp(x) * np.sqrt(1 - (x**2)), tol = 1e-8)

def oppg4():
    a = float(input('a = '))
    b = float(input('b = '))
