import numpy as np
from copy import deepcopy
from math import log, factorial

GRAVITY = True
DENSITY_DIFFERENCE = 100
SITE_INTERACT_ENERGY = 30
SELF_INTERACT_ENERGY = 0
T = 50

class Pos:
    def __init__(self, row, col):
        self.row = row
        self.col = col
        self.pos = (row, col)

    def __add__(self, other):
        if type(other) == type(self):
            row = self.row + other.row
            col = self.col + other.col
            return Pos(row, col)

        elif type(other) == Transition:
            if self == other.pos_from:
                return Pos(other.pos_to.row, other.pos_to.col)
            else:
                return Pos(self.row, self.col)

    __radd__ = __add__

    def __eq__(self, other):
        if self.row == other.row and self.col == other.col:
            return True
        return False

    def __sub__(self, other):
        if type(other) == type(self):
            row = self.row - other.row
            col = self.col - other.col
            return Pos(row, col)

        elif type(other) == Transition:
            if self == other.pos_from:
                return Pos(other.pos_to.row, other.pos_to.col)
            else:
                return Pos(self.row, self.col)

    __rsub__ = __sub__

    def __ne__(self, other):
        return not self == other

    def __str__(self):
        return str(self.pos)

    def __hash__(self):
        return hash((self.row, self.col))

class Transition:
    def __init__(self, pos_from, move, pos_energies):
        self.pos_from = deepcopy(pos_from)
        self.pos_to = pos_from + move
        self.move = deepcopy(move)

        self.energy = 0
        check_pos = [Pos(-1, 0), Pos(1, 0), Pos(0, -1), Pos(0, 1)]
        for p in check_pos:
            if self.pos_to + p in pos_energies.keys() and p != Pos(0, 0) - self.move:
                self.energy += SELF_INTERACT_ENERGY
            else:
                self.energy += SITE_INTERACT_ENERGY

        if GRAVITY is True:
            self.energy += DENSITY_DIFFERENCE * self.pos_to.row

        self.energy -= pos_energies[self.pos_from]

    def __str__(self):
        return str(self.pos_from) + ' -> ' + str(self.pos_to) + ', E = '+str(round(self.energy, 2))

class Lattice:
    def __init__(self, nrows, ncols, positions):
        self.nrows = nrows
        self.ncols = ncols
        self.strlat = [[' ' for c in range(ncols)] for r in range(nrows)]
        self.numlat = [[0 for c in range(ncols)] for r in range(nrows)]
        self.positions = deepcopy(positions)

        self.npoints = nrows * ncols
        self.N = len(self.positions)

        self.set_image()
        self.pos_energies = self.set_energies()

    def set_image(self):
        self.strlat = [[' ' for c in range(self.ncols)] for r in range(self.nrows)]
        self.numlat = [[0 for c in range(self.ncols)] for r in range(self.nrows)]
        for p in self.positions:
            self.strlat[p.row][p.col] = 'x'
            self.numlat[p.row][p.col] = 1

    def set_energies(self):
        pos_energies = {p : 0 for p in self.positions}
        checks = [Pos(-1, 0), Pos(1, 0), Pos(0, -1), Pos(0, 1)]
        for p in self.positions:
            for c in checks:
                if p + c in self.positions:
                    pos_energies[p] += SELF_INTERACT_ENERGY
                else:
                    pos_energies[p] += SITE_INTERACT_ENERGY
            if GRAVITY is True:
                pos_energies[p] += DENSITY_DIFFERENCE * p.row

        return pos_energies

    def get_local_lattice(self):
        ncols = max([p.col for p in self.positions]) + 1
        nrows = max([p.row for p in self.positions]) + 1

        return Lattice(nrows, ncols, self.positions)

    def get_energy(self):
        energy = 0
        checks = [Pos(-1, 0), Pos(1, 0), Pos(0, -1), Pos(0, 1)]
        for p in self.positions:
            for c in checks:
                if p + c in self.positions:
                    energy += SELF_INTERACT_ENERGY
                else:
                    energy += SITE_INTERACT_ENERGY

            if GRAVITY is True:
                energy += DENSITY_DIFFERENCE * p.row

        return energy

    def get_entropy(self):
        local_lattice = self.get_local_lattice()
        M = local_lattice.npoints
        return log(factorial(M) / (factorial(self.N) * factorial(M - self.N)))

    def valid_move(self, move):
        if move in self.positions:
            return False
        elif move.row < 0 or move.row >= self.nrows:
            return False
        elif move.col < 0 or move.col >= self.ncols:
            return False
        return True

    def get_possible_transitions(self):
        transitions = []

        moves = [Pos(-1, 0), Pos(1, 0), Pos(0, -1), Pos(0, 1)]
        for i, p in enumerate(self.positions):
            for m in moves:
                if self.valid_move(p+m):
                    transitions.append(Transition(p, m, self.pos_energies))

        return transitions

    def perform_transition(self):
        transitions = self.get_possible_transitions()
        energies = np.array([t.energy for t in transitions])
        factors = np.exp(-energies / T)
        probs = factors / sum(factors)

        intervals = [sum(probs[:i]) for i in range(len(probs) + 1)]
        ran = np.random.rand()
        for i in range(len(intervals) - 1):
            if intervals[i] < ran < intervals[i + 1]:
                trans_ind = i
                break

        for i in range(len(self.positions)):
            if self.positions[i] == transitions[trans_ind].pos_from:
                self.positions[i] += transitions[trans_ind]
                break

        self.pos_energies = self.set_energies()
        self.set_image()


    def __str__(self):
        ret_str = '   '
        for i in range(self.ncols):
            ret_str += ' ' + str(i) +' '
        ret_str += '\n'
        ret_str += '   ' + ' - '*self.ncols + '|\n'
        for i, row in enumerate(self.strlat):
            ret_str += str(i)+ ' |'
            for val in row:
                ret_str += ' '+val+' '
            ret_str += '|\n'

        ret_str += '  |' + ' - ' * self.ncols + '|'

        return ret_str
