from math import factorial as f
from scipy.special import binom as bin
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt
import numpy as np

def part(m, n):
    return (f(m)**2 * f(2*n) * f(2*m - 2*n))/(f(2*m) * (f(n)*f(m - n))**2)

def part2(m, n):
    if n > m:
        return np.nan

    return np.log(bin(m, n)**2) / np.log(bin(2*m, 2*n))

def part3(m, n):
    if n > m:
        return np.nan

    return bin(m, n)**2 / bin(2*m, 2*n)

m_list = np.arange(10, 200, 10)
n_list = np.arange(10, 200, 10)

fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')

z = np.array([[part2(m, n) for n in n_list] for m in m_list])
z2 = np.array([[part3(m, n) for n in n_list] for m in m_list])

n_list, m_list = np.meshgrid(n_list, m_list)

ax.plot_wireframe(m_list, n_list, z)
ax.plot_wireframe(m_list, n_list, z2, color='red')
plt.xlabel('m')
plt.ylabel('n')
plt.show()