import numpy as np
import matplotlib.pyplot as plt
from matplotlib.cm import get_cmap

cmap = get_cmap('viridis')

a = 2
M = 1
C = 2
C1 = C

k_invar = (1/a) * np.arccos((2 + np.sqrt(5))/8)

w1_fun = lambda x, k: np.sqrt((C1 + C2)/(2 * M) + np.sqrt( (C1 + C2)**2 - 8 * C1 * C2 * (1 - np.cos(k*a)) )/ (2 * M))
w2_fun = lambda x, k: np.sqrt((C1 + C2)/(2 * M) - np.sqrt( (C1 + C2)**2 - 8 * C1 * C2 * (1 - np.cos(k*a)) )/ (2 * M))

x1_list = np.linspace(1, 7 + 4 * np.sqrt(3), 5)
x2_list = 1/x1_list
x2_list.sort()
x_list = np.concatenate((x2_list, x1_list))
plt.figure(figsize=(10,5))
for i, x in enumerate(x_list):
    C2 = x * C
    if x - 1 > 0:
        x -= 1e-15
        C2 = x * C

    k_lim = (1/a) * np.arccos(1 - (C1 + C2)**2 / (8 * C1 * C2))
    print(k_lim)
    k_line = np.linspace(k_lim, - k_lim, 500)

    w1 = w1_fun(x, k_line)
    w2 = w2_fun(x, k_line)

    plt.plot(k_line, w1, label=str(round(x,2)), color=cmap(i/len(x_list)))
    plt.plot(k_line, w2, color=cmap(i/len(x_list)), linestyle='--')
    plt.xlabel('k')
    plt.ylabel(r'$\omega$')

plt.legend(title = r'x = $C_2$/$C_1$')
plt.vlines([-np.pi/a, np.pi/a], 0, max(w1_fun(x1_list[-1], k_line)), color='black')

x_list = np.linspace(7 - 4 * np.sqrt(3), 7 + 4 * np.sqrt(3), 500)

C2 = x_list * C

k_lim = (1/a) * np.arccos(1 - (C1 + C2)**2 / (8 * C1 * C2))
w1 = w1_fun(x_list, k_lim)

plt.plot(k_lim, w1, label=r'w1 = w2', color = 'red')


#plt.vlines([k_invar, - k_invar], 0, w2_fun(x, k_invar), color='black')
plt.savefig('oving5_3')
plt.show()