import matplotlib.pyplot as plt
import numpy as np
from oving4.lattice import Pos, Lattice, Transition
import random
import time
from matplotlib.animation import FuncAnimation

rows = 20
cols = 20
N = 100
T = 50 #when using interaction energies
num_steps = 3000
dump_interval = 10

def initialize_left():
    positions = [None for i in range(N)]
    cols_to_fill = N // rows
    last_col_rows = N % rows

    if cols_to_fill == 0:
        i = 0
        for r in range(N):
            positions[i] = Pos(r, 0)

    else:
        i = 0
        for c in range(cols_to_fill):
            for r in range(rows):
                positions[i] = Pos(r, c)
                i += 1

        for r in range(last_col_rows):
            positions[i] = Pos(r, cols_to_fill)
            i += 1

    return positions

def initialize_bot():
    positions = [None for i in range(N)]
    rows_to_fill = N // (cols-6)
    last_row_cols = N % (cols-6)

    if rows_to_fill == 0:
        i = 0
        for c in range(3, N+3):
            positions[i] = Pos(rows-1, c)
            i += 1

    else:
        i = 0
        for r in range(rows_to_fill):
            for c in range(3, cols-3):
                positions[i] = Pos(rows - r - 1, c)
                i += 1

        for c in range(3, last_row_cols+3):
            positions[i] = Pos(rows - rows_to_fill, c)
            i += 1

    return positions

def valid_move(move, positions):
    if move in positions:
        return False
    elif move.row < 0 or move.row >= rows:
        return False
    elif move.col < 0 or move.col >= cols:
        return False
    return True

def possible_transitions(positions):
    transitions = []

    moves = [Pos(-1, 0), Pos(1, 0), Pos(0, -1), Pos(0, 1)]
    for i, p in enumerate(positions):
        for m in moves:
            if valid_move(p+m, positions):
                transitions.append(Transition(p, m, positions))

    return transitions

def perform_transition_energy(positions, transitions):
    energies = np.array([t.energy for t in transitions])
    factors = np.exp(-energies / T)
    probs = factors/sum(factors)

    intervals = [sum(probs[:i]) for i in range(len(probs)+1)]
    ran = np.random.rand()
    for i in range(len(intervals)-1):
        if intervals[i] < ran < intervals[i+1]:
            trans_ind = i
            break

    for i in range(len(positions)):
        positions[i] += transitions[trans_ind]
    return positions

def perform_transition(positions, transitions):
    trans_ind = random.randint(0, len(transitions)-1)
    for i in range(len(positions)):
        positions[i] += transitions[trans_ind]
    return positions

def save_image(lattice, TrNum): #TrNum keeps track of the step number in the loop.
    imgplot = plt.imshow(lattice.numlat, cmap='binary')
    plt.show()
    #plt.savefig('figs/Lattice' + str(TrNum) + '.png')

pos = initialize_bot()
t_list = [t for t in range(num_steps)]
entropy = np.zeros(len(t_list))
internal_energy = np.zeros(len(t_list))
lattice = Lattice(rows, cols, pos)
images = np.zeros((len(t_list)//dump_interval, rows, cols))
t0 = time.process_time()
img = 0
for i in range(num_steps):
    entropy[i] = lattice.get_entropy()
    internal_energy[i] = lattice.get_energy()

    if i%100 == 0 and i != 0:
        print('Time remaining :', round((time.process_time() - t0)*((num_steps - i)/i), 0))


    if i%dump_interval == 0:
        images[img] = lattice.numlat
        img += 1
    #    print(i)
    #    save_image(lattice, i)

    lattice.perform_transition()


#The code below creates the plots the local entropy as a function of time.
#plt.clf()
#fig, axs = plt.subplots(3, 1, sharex='all')
#axs[0].plot(t_list, entropy)
#axs[1].plot(t_list, internal_energy)
#axs[2].plot(t_list, internal_energy - T * entropy)
#plt.show()

np.save('img_r'+str(rows)+'_c'+str(cols)+'_N'+str(N)+'_bot', images)

fig, ax = plt.subplots()

def update(frame):
    image = plt.imshow(images[frame], cmap='binary')
    return image

ani = FuncAnimation(fig, update, interval=1)
plt.show()