import numpy as np
import matplotlib.pyplot as plt
import time

def find_pi_meanval(N, timer = False, plot = False):
    func = lambda theta: 0.5*np.sin(theta)
    if timer == True:
        t0 = time.process_time()
        theta_vals = np.pi*np.random.random(N)
        func_vals = func(theta_vals)
        pi_approx = N/sum(func_vals)
        t1 = time.process_time()
        print('Mean value approximation of pi with', N, 'random values took', round(t1 - t0, 4), 'seconds')
    else:
        theta_vals = np.pi * np.random.random(N)
        func_vals = func(theta_vals)
        pi_approx = N / sum(func_vals)

    if plot == False:
        return pi_approx
    else:
        num_plotted_points = (N / 110) + 1000 / 11
        plot_range = range(0, len(theta_vals), int(N / num_plotted_points))
        plt.plot([theta_vals[i] for i in plot_range], [func_vals[i] for i in plot_range],
                 marker='.', linestyle='', markersize=1.5,
                 color='red')

def find_pi_norand(N, timer = False, plot = False):
    func = lambda theta: 0.5*np.sin(theta)
    if timer == True:
        t0 = time.process_time()
        theta_vals = np.linspace(0, np.pi, N)
        func_vals = func(theta_vals)
        pi_approx = N/sum(func_vals)
        t1 = time.process_time()
        print('Mean value approximation of pi with', N, 'evenly distributed values took', round(t1 - t0, 4), 'seconds')
    else:
        theta_vals = np.linspace(0, np.pi, N)
        func_vals = func(theta_vals)
        pi_approx = N / sum(func_vals)

    if plot == False:
        return pi_approx
    else:
        num_plotted_points = (N / 110) + 1000 / 11
        plot_range = range(0, len(theta_vals), int(N/num_plotted_points))
        plt.plot([theta_vals[i] for i in plot_range], [func_vals[i] for i in plot_range],
                 marker = '.', linestyle = '', markersize = 1.5,
                 color = 'red')

def find_pi_randcheck(N, timer = False, plot = False):
    func = lambda theta: 0.5*np.sin(theta)
    if timer == True:
        t0 = time.process_time()
        D_vals = 0.5*np.random.random(N)
        theta_vals = np.pi*np.random.random(N)
        count = sum(D_vals < func(theta_vals))
        pi_approx = 2*N / count
        t1 = time.process_time()
        print('Random count approximation of pi with', N, 'random values took', round(t1 - t0, 4), 'seconds')

    else:
        D_vals = 0.5 * np.random.random(N)
        theta_vals = np.pi * np.random.random(N)
        count = sum(D_vals < func(theta_vals))
        pi_approx = 2 * N / count
    # python can check true/false values for every elemet in an array and sum all the true values
    # this is equivilent to using a for loop and a counter, but reduces runtime slightly.
    if plot == False:
        return pi_approx
    else:
        funcvals = func(theta_vals)

        num_plotted_points = (N/110) + 1000/11

        for i in range(0, len(theta_vals), int(N/num_plotted_points)):
            if D_vals[i] > funcvals[i]:
                plt.plot(theta_vals[i], D_vals[i], marker = '.', markersize = 1,
                         color = 'red')
            else:
                plt.plot(theta_vals[i], D_vals[i], marker='.', markersize = 1.5,
                         color='green')

def plot_integrand(N, finder):
    integrand = lambda theta: 0.5 * np.sin(theta)
    theta_vals = np.linspace(0, np.pi, 100)
    plt.plot(theta_vals, integrand(theta_vals), label = r'$f(\theta) = \frac{1}{2}sin(\theta)$')
    plt.xlabel(r'$\theta$ [rad]')
    plt.ylabel(r'$f(\theta)$')

    finder(N, plot=True)

    plt.title(r'Points used to calculate $\pi$')
    plt.legend()

def plot_convergence(start, stop, finder):
    N_vals = np.arange(start, stop, step=int((stop-start)/100))

    #Only run the timer on the test with the largest N
    pi_vals = [finder(N) for N in N_vals[:-1]]
    pi_vals.append(finder(N_vals[-1], timer = True))
    
    plt.plot(N_vals, pi_vals, label = r'Approximated value of $\pi$')
    plt.plot([N_vals[0],N_vals[-1]], [np.pi,np.pi], linestyle='--', color='orange', label=r'$\pi$')

    plt.xlabel('Number of values')
    plt.ylabel(r'Calculated value of $\pi$')

    plt.legend()

needle_finder_dict = {'Mean value': find_pi_meanval, 'Random counter': find_pi_randcheck,
                      'Non-random mean value' : find_pi_norand}