import numpy as np
import constants as c
from analytical import D_func

dx_explicit = lambda D, dt: np.sqrt(2 * dt * D)

def get_A(n,k):
    A = np.array([np.zeros(n) for i in range(n)])

    # print('Bygger (', len(A),'x',len(A), ') matrise ...', sep='')
    diag = k * np.array([-1, 2, -1])

    A[0][0:3] = diag
    for i in range(1, n - 1):
        A[i][i - 1: i + 2] = diag

    #A[-1, -1] = 1

    return A

def eksplisitt(T, t_stop, dt):
    D = D_func(T)
    dx = dx_explicit(D, dt)
    x_line = np.arange(-dx, c.L / 2, step=dx)
    n = len(x_line)
    k = - ( dt*D/dx**2)

    A = get_A(n,k)

    I = np.identity(n)
    #I[-1,-1] = 0

    M = A + I

    #print('Gjør klar (',n, 'x', int(t_stop/ dt),') matrise ...', sep='')
    C = np.array([np.zeros(n) for i in range(int(t_stop/dt))])
    C[0] = c.C0
    C[0,-1] = c.Cs

    #print('Regner', int(t_stop/c.dt_eksp) ,'tidssteg ...')
    for i in range(len(C)-1):
        C[i+1] = np.dot(M,C[i])

    #fjerner første rad i hver linje (den brukes bare for at dC/dx = 0 i x = 0)
    return [line[1:] for line in C], x_line[1:]

def implisitt(T, t_stop, dt):
    D = D_func(T)

    x_points = np.arange(-c.dx_imp, c.L/2+c.dx_imp, step=c.dx_imp)
    n = len(x_points)
    k = ( dt*D) / (c.dx_imp ** 2)

    #print('Bygger (',n, 'x',n,') matrise ...')
    A = get_A(n,k)

    M = np.identity(n) + A
    M = np.linalg.inv(M)

    #print('Gjør klar (', n, 'x', int(t_stop/ dt), ') matrise ...')
    C = np.array([np.zeros(n) for i in range(int(t_stop /  dt))])
    C[0] = c.C0
    C[0,-1] = c.Cs

    #print('Regner', int(t_stop /  dt), 'tidssteg ...')
    for i in range(len(C) - 1):
        C[i+1] = np.dot(M, C[i])

    return [line[1:] for line in C], x_points[1:]

def crank_nicholson(T, t_stop, dt):
    D = D_func(T)

    x_points = np.arange(-c.dx_crank, c.L / 2 + c.dx_crank, step=c.dx_crank)
    n = len(x_points)
    k = (dt * D) / (c.dx_imp ** 2)

    # print('Bygger (',n, 'x',n,') matrise ...')
    A = get_A(n, k)

    I = np.identity(n)

    M1 = 2*I + A
    M2 = 2*I - A

    M = np.dot(np.linalg.inv(M1),M2)

    # print('Gjør klar (', n, 'x', int(t_stop/ dt), ') matrise ...')
    C = np.array([np.zeros(n) for i in range(int(t_stop / dt))])
    C[0] = c.C0
    C[0, -1] = c.Cs

    # print('Regner', int(t_stop /  dt), 'tidssteg ...')
    for i in range(len(C) - 1):
        C[i + 1] = np.dot(M, C[i])

    return [line[1:] for line in C], x_points[1:]



