import numpy as np, matplotlib.animation as animation, matplotlib.pyplot as plt, time
a,b = 0,1
t0,t_slutt = 0,1

n = 150
h = (b-a)/n
k = 0.4*h**2

x = np.linspace(a,b,n)
U = np.array([np.zeros(n) for t in range(int((t_slutt-t0)/k))])
U[0] = x*(1-x)

I = np.identity(n-2)

A = np.array([np.zeros(n-2) for i in range(n-2)])
A += 2*I
A[1:,:-1] -= I[1:]
A[:-1,1:] -= I[1:]

M = np.identity(n-2) - k/(h**2)*A

tid = time.process_time()
for i in range(len(U)-1):
    U[i+1][1:-1] = np.dot(M,U[i][1:-1])

print(time.process_time() - tid)
print(len(U))
fig, ax = plt.subplots()
k=0
a = 1

y_list = U
x_list = x

def animate(i):
    global k,a
    if k >= 0:
        try:
            y = y_list[k]
        except IndexError:
            a = -1
            k = len(y_list) - 1
            y = y_list[k]
    else:
        a = 1
        k += a
        y = y_list[k]

    k+=a

    ax.clear()
    plt.plot(x_list,y)

    plt.ylim(0,1)
    plt.xlim(0,1)


ani = animation.FuncAnimation(fig, animate, frames=360, interval=0.01)

plt.show()
