import numpy as np
import matplotlib.pyplot as plt
import pickle
from sympy.utilities.iterables import multiset_permutations #Har du et problem så har python et bibliotek
import time

def lattice_shape(N): #The function determines the lattice shapes that you will be using in this exercise, given N particles.
    if N == 4:
        return (2,2)
    elif N == 6:
        return (2,3)
    elif N == 8:
        return (2,4)
    elif N == 12:
        return (3,4)
    elif N == 16:
        return (4,4)
    elif N == 20:
        return (5,4)
    elif N == 24:
        return (6,4)

# Siden periodiske grensebetingelser ble understreket sånn trodde jeg det skulle bli mer bruk for det,
# så jeg skrev det like gjerne inn i klasse med en gang
class Lattice:
    def __init__(self, array):
        self.contains = np.array(array)
        self.shape = self.contains.shape

    def __getitem__(self, key): #periodic boundary conditions
        if type(key) == tuple:
            i, j = key
            if i >= self.shape[0] and j >= self.shape[1]:
                return self[i - self.shape[0], j - self.shape[1]]
            elif i >= self.shape[0]:
                return self[i - self.shape[0], j]
            elif j >= self.shape[1]:
                return self[i, j - self.shape[1]]
            else:
                return self.contains[i,j]

        else:
            if key >= self.shape[0]:
                return self[key - self.shape[0]]
            else:
                return self.contains[key]

    def __setitem__(self, key, value):
        if key >= len(self.contains):
            self[key - len(self.contains)] = value
        else:
            self.contains[key] = value

    def __len__(self):
        return len(self.contains)

    def __str__(self):
        return str(self.contains)

def count_AB(lattice): #Funker så lenge type(lattice) == Lattice
    m_AB = 0
    for row in range(len(lattice)):
        for col in range(len(lattice[1])):
            if lattice[row, col] == 'A':
                if lattice[row + 1, col] == 'B':
                    m_AB += 1
                if lattice[row - 1, col] == 'B':
                    m_AB += 1
                if lattice[row, col + 1] == 'B':
                    m_AB += 1
                if lattice[row, col - 1] == 'B':
                    m_AB += 1
    return m_AB

def create_arrays_and_count(N):
    if N%2: #For å hindre teite feil som er vanskelige å spore
        raise ValueError('N must be even!')

    unique_configs = int(np.math.factorial(N) / (np.math.factorial(N // 2) * np.math.factorial(N // 2)))
    m_AB = np.zeros(unique_configs) # aldri append! Det er møkktreigt!
    particles = np.repeat(['A', 'B'], N//2) #lager et sett med like mange A og B, og med lengde N
    shape = lattice_shape(N)
    t0 = time.process_time() #litt tidtaking for lættis, N = 24 tok ≈ 9 min, så dette går åpenbart ikke for makroskopiske systemer
    for i, config in enumerate(multiset_permutations(particles)):
        lattice = Lattice(np.array(config).reshape(shape))
        m_AB[i] = count_AB(lattice) #teller antall interaksjoner
        if i%100000 == 0 and i != 0: #Kjedelig å vente når du ikke vet hvor lenge det er igjen
            t = time.process_time()
            t_rem = ((t - t0) / (i + 1)) * (unique_configs - (i + 1))
            m_rem = t_rem//60
            s_rem = t_rem - m_rem * 60
            frac_fin = round(i / unique_configs, 3)
            print(round(frac_fin*100,1), '% finished.', sep='', end=' ')
            print(int(m_rem), 'min, ', round(s_rem,0), 's remaining')
    return m_AB

for i in [4, 6, 8, 12, 16, 20, 24]:
    print('i =', i)
    m_AB = create_arrays_and_count(i)
    states, degeneracy = np.unique(m_AB, return_counts=True) #Hvis du har et problem har python et bibliotek. C er raskere enn list-comprehentions.
    normalized_mAB = m_AB / max(degeneracy)
    var_m_AB = np.var(normalized_mAB) #den normaliserte variansen synker med økende i
    #Creates a bar chart of density of states: x-axis = microstate, y-axis = degeneracy
    y_pos = np.arange(len(states)) #skjønner ikke helt hva poenget med dette er?
    plt.bar(states, degeneracy)
    #plt.xticks(y_pos, states, fontsize=7, rotation=30) #hvorforrr??
    plt.title(r'$\sigma^2 = $'+str(var_m_AB))
    plt.savefig('density_of_states/N_' + str(i))
    plt.clf()

    #Saves m_AB and i for future use for system size = i_max (in this case 24)
    with open('state_files/m_AB_'+str(i)+'.pkl', 'wb') as f:
        pickle.dump([m_AB,i], f) #lagrer like gjerne alt sånn i tilfelle
