import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import numpy as np

def weight(x,y, x_data, y_data, i):
    dist = (x - x_data)**2 + (y - y_data)**2
    tot = np.sum([np.prod(dist[:n]) * np.prod(dist[n+1:]) for n in range(len(x_data))])

    return (np.prod(dist[:i]) * np.prod(dist[i+1:]))/tot

def plot_weights():
    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')

    x_data = np.array([1, 2, 4, 6, 9])
    y_data = np.array([0, 6, 8, 1, 4])
    z_data = np.array([1, 3, 4, 5, 2])

    x_line = np.linspace(0, 10, 50)
    y_line = np.linspace(0, 10, 50)

    X, Y = np.meshgrid(x_line,y_line)

    n = 1
    indexes = np.array([True for x in x_data])
    indexes[n] = False

    Z = np.array([[weight(x, y, x_data, y_data, n) for x in x_line] for y in y_line])
    weights = np.array([weight(x, y, x_data, y_data, n) for x,y in zip(x_data,y_data)])

    # ax.plot_wireframe(X, Y, Z, rcount=20, ccount=20, color = 'black', alpha=0.5, label = r'$w_n(c,T)$')
    # ax.scatter(x_data[indexes], y_data[indexes], weights[indexes], color = 'red', s=60, label = r'$\tilde{P}_{k \neq n}$')
    # ax.scatter(x_data[n], y_data[n], weights[n], s=60, color = 'green', label = r'$\tilde{P}_n$')
    # ax.set_zlim(0,1.1)
    # plt.legend()
    # plt.show()

    interpol = np.array([[[weight(x, y, x_data, y_data, n) * z_data[n]
                          for x in x_line] for y in y_line] for n in range(len(z_data))])
    interpol = np.sum(interpol, axis=0)
    ax.plot_wireframe(X, Y, interpol, color = 'black', alpha = 0.5)
    ax.scatter(x_data, y_data, z_data, color='red', s=60)
    plt.show()


def lagrange(x, x_data, i):
    return (np.prod(x - x_data[:i]) * np.prod(x - x_data[i+1:])) / \
           (np.prod(x_data[i] - x_data[:i]) * np.prod(x_data[i] - x_data[i+1:]))

def plot_1d_lagrange():
    x_data = np.array([0, 1, 2, 5, 6, 10, 11])
    x_line = np.linspace(0,11,50)

    lagrange_weights = np.array([[lagrange(x, x_data, i) for x in x_line]for i in range(len(x_data))])

    fig, ax = plt.subplots()
    plt.scatter(x_data, [0 for x in x_data])
    for w in lagrange_weights:
        plt.plot(x_line, w)

    plt.plot(x_line, np.sum(lagrange_weights, axis=0), linestyle = '--', color = 'red')
    plt.show()

def lagrange_2d(x, y, x_data, y_data, i):

    wx = (np.prod(x - x_data[:i]) * np.prod(x - x_data[i+1:])) / \
         (np.prod(x_data[i] - x_data[:i]) * np.prod(x_data[i] - x_data[i+1:]))
    wy = (np.prod(y - y_data[:i]) * np.prod(y - y_data[i+1:])) / \
         (np.prod(y_data[i] - y_data[:i]) * np.prod(y_data[i] - y_data[i+1:]))

    return wx*wy

def plot_lagrange_2d():
    x_data = np.array([1, 2, 4, 6, 9])
    y_data = np.array([0, 6, 8, 1, 4])
    z_data = np.array([1, 3, 4, 5, 2])

    x_line = np.linspace(0, 10, 50)
    y_line = np.linspace(0, 10, 50)

    w_2d = np.array([[[lagrange_2d(x, y, x_data, y_data, i)*z_data[i]
                                for x in x_line] for y in y_line] for i in range(len(x_data))])

    X, Y = np.meshgrid(x_line, y_line)

    Z = np.sum(w_2d, axis=0)

    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')
    ax.plot_wireframe(X, Y, Z, color='black', alpha=0.5)
    #ax.scatter(x_data,y_data,np.array([1 for x in x_data]), color = 'red', s=60)
    ax.scatter(x_data, y_data, z_data, color='red', s=60)

    plt.show()

plot_weights()
plot_lagrange_2d()