from sympy import diff, latex, symbols, plot
import numpy as np
from matplotlib.cm import get_cmap

cmap = get_cmap('plasma')
r = symbols('r')

sig = 5#Fill in a value for sigma

eps_list = np.linspace(2, 10, 5)

eps = eps_list[0]

V = 4*eps * ((sig/r)**12 - (sig/r)**6)
p = plot(V, (r, 0, 3*sig), xlabel='Radial distance', ylabel='Potential', axis_center=(0,0),
          ylim=(-max(eps_list),4*max(eps_list)), xlim=(-0.5, 3*sig), show=False, line_color=cmap(eps/max(eps_list)),
         label=eps, legend=True)

for eps in eps_list[1:]:
    V = 4 * eps * ((sig / r) ** 12 - (sig / r) ** 6)
    p.extend(plot(V, (r, 0, 3*sig), xlabel='Radial distance', ylabel='Potential', axis_center=(0,0),
             ylim=(-eps,4*eps), show=False, line_color=cmap(eps/max(eps_list)), label=eps))

p.show()

eps = 3
sig_list = np.linspace(2, 10, 5)
sig = sig_list[0]

V = 4*eps * ((sig/r)**12 - (sig/r)**6)
p = plot(V, (r, 0, 3*sig), xlabel='Radial distance', ylabel='Potential', axis_center=(0,0),
          ylim=(-max(eps_list),4*max(eps_list)), xlim=(-0.5, 3*max(sig_list)), show=False, line_color=cmap(sig/max(sig_list)),
         label=sig, legend=True)

for sig in sig_list[1:]:
    V = 4 * eps * ((sig / r) ** 12 - (sig / r) ** 6)
    p.extend(plot(V, (r, 0, 3*max(sig_list)), xlabel='Radial distance', ylabel='Potential', axis_center=(0,0),
             ylim=(-eps,4*eps), show=False, line_color=cmap(sig/max(sig_list)), label=sig))

p.show()