import numpy as np, matplotlib.pyplot as plt, matplotlib.animation as ani, time
from numpy import sin, cos, pi, e, exp, log

def m_print(m):
    for row in m:
        for x in row:
            x = round(x,2)
            print(x, end=' '*(5-len(str(x))))
        print()
    print()

def v_print(v):
    for x in v:
        x = round(x,2)
        print(x, end=' '*(5-len(str(x))))
    print()

t0 = time.process_time()
x0,xn = 0,1
tn = 1

n = 100
h = (xn-x0)/n
k = h

x_points = np.linspace(x0,xn,n)

A = np.array([np.zeros(n-2) for i in range(n-2)])
for i in range(len(A)):
    A[i][i] = 2
for i in range(1,len(A)):
    A[i][i-1] = -1
    A[i-1][i] = -1

M = np.identity(n-2) + (k/h**2)*A
M = np.linalg.inv(M)

U = np.array([np.zeros(n) for i in range(int(tn/k))])
U[0] = sin(pi*x_points)
for i in range(len(U)-1):
    next_u = np.dot(M,U[i][1:-1])
    for j in range(1,len(U[i])-1):
        U[i+1][j] = next_u[j-1]

print('imp_tid: ', time.process_time() - t0)

k=0
fig, ax = plt.subplots()

def animate(i):
    global k
    u = U[k]

    k+=1

    ax.clear()
    plt.plot(x_points,u)

    plt.ylim(0,1)
    plt.xlim(0,1)


ani = ani.FuncAnimation(fig, animate, frames=360, interval=50)
plt.show()