import numpy as np
from math import factorial, sqrt, exp, log
import matplotlib.pyplot as plt
import pickle

def calculate_Q_exact(states, multiplicity, T):
    return np.sum(multiplicity * np.exp(np.tensordot(1/T, -states, axes=0)), axis=1) #bruker numpy fordi for-løkker i python er teit treige

def calculate_Q_approx(N, T):
    return (factorial(N) / factorial(N//2)**2) * np.exp(-N/T) #likning 6 i øvinga

def partition_function_plot(states, multiplicity, N):
    temp = np.linspace(1,10,100)
    Q_exact_list = calculate_Q_exact(states, multiplicity, temp)
    Q_approx_list = calculate_Q_approx(N, temp)
    Q_ratio = Q_exact_list/Q_approx_list

    plt.plot(temp, Q_exact_list, label=r'$Q_{exact}$')
    plt.plot(temp, Q_approx_list, label=r'$Q_{approx}$')
    plt.legend()
    plt.xlabel('T [-]')
    plt.savefig('Parition_function')
    plt.show()
    plt.close()
    plt.plot(temp, Q_ratio, label=r'$Q_{ratio}$')
    plt.legend()
    plt.xlabel('T [-]')
    plt.xscale('log')
    plt.yscale('log')
    plt.savefig('Parition_function_ratio')
    plt.show()
    plt.close()


with open('m_AB.pkl', 'rb') as f: #Opens up the list of all microstates.
    m_AB, N = pickle.load(f) #m_AB is the list, N is the system size.

states, multiplicity = np.unique(m_AB, return_counts=True)
partition_function_plot(states, multiplicity, N)#Calls the plotting function.