import numpy as np
import matplotlib.pyplot as plt
from matplotlib.cm import get_cmap

cmap = get_cmap('viridis')

a = 2
M = 1
C = 2
C1 = C

k_invar = (1/a) * np.arccos((2 + np.sqrt(5))/8)

w1_fun = lambda x, k: np.sqrt((C1 + C2)/(2 * M) + np.sqrt( (C1 + C2)**2 - 8 * C1 * C2 * (1 - np.cos(k*a)) )/ (2 * M))
w2_fun = lambda x, k: np.sqrt((C1 + C2)/(2 * M) - np.sqrt( (C1 + C2)**2 - 8 * C1 * C2 * (1 - np.cos(k*a)) )/ (2 * M))

x1_list = np.linspace(1, np.sqrt(7 + 4 * np.sqrt(3)), 5)
x2_list = 1/x1_list
x2_list.sort()
x_list = np.concatenate((x2_list, x1_list))
plt.figure(figsize=(10,5))
for i, x in enumerate(x_list[:len(x_list)//2]):
    C2 = x * C
    C1 = C/x
    if x - 1 > 0:
        x -= 1e-15
        C2 = x * C
        C1 = C/x

    k_lim = (1/a) * np.arccos(1 - (C1 + C2)**2 / (8 * C1 * C2))
    k_line = np.linspace(k_lim, - k_lim, 500)

    w1 = w1_fun(x, k_line)
    w2 = w2_fun(x, k_line)

    plt.plot(k_line, w1, label=str(round(x,2)), color=cmap(i/len(x_list)), alpha=0.3)
    plt.plot(k_line, w2, color=cmap(i/len(x_list)), alpha=0.3 )
    plt.xlabel('k')
    plt.ylabel(r'$\omega$')

for i, x in enumerate(x_list[len(x_list)//2:]):
    C2 = x * C
    C1 = C/x
    if x - 1 > 0:
        x -= 1e-15
        C2 = x * C
        C1 = C/x

    k_lim = (1/a) * np.arccos(1 - (C1 + C2)**2 / (8 * C1 * C2))
    k_line = np.linspace(k_lim, - k_lim, 500)

    w1 = w1_fun(x, k_line)
    w2 = w2_fun(x, k_line)

    plt.plot(k_line, w1, label=str(round(x,2)), color=cmap((i + len(x_list)//2)/len(x_list)), linestyle=':')
    plt.plot(k_line, w2, color=cmap((i + len(x_list)//2)/len(x_list)), linestyle=':')
    plt.xlabel('k')
    plt.ylabel(r'$\omega$')

plt.legend(title = r'x = $C_2$/$C_1$')
plt.vlines([-np.pi/a, np.pi/a], 0, max(w1_fun(x1_list[-1], k_line)), color='black')
plt.title('C2 = xC, C1 = C/x')
#plt.vlines([k_invar, - k_invar], 0, w2_fun(x, k_invar), color='black')
plt.savefig('oving5_2')
plt.show()