import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm
from c_curves import t_star_func
import constants as c
import scipy.special as ss
import warnings
import time

# Tar inn x og t som tall
# eller x og t som arrays med samme shape
# eller enten x eller t som array, den andre som tall
def T_func(x, t):
    return c.T0_jominy + (c.Ti_jominy - c.T0_jominy) * ss.erf(x/(2*np.sqrt(c.a_jominy * t)))

# Takes in a 1D array of positions and a 1D array of times.
# Utilizes np.tensordot to generate a 2D array of shape (len(x), len(t)) containing temperatures
# for all combinations of x and t
def T_func_vectorized(x, t, celcius = False):
    a1 = 1/(2*np.sqrt(c.a_jominy * t))

    if celcius == False:
        return c.T0_jominy + (c.Ti_jominy - c.T0_jominy) * ss.erf(np.tensordot(x, a1, axes=0))
    else:
        return c.T0_jominy + (c.Ti_jominy - c.T0_jominy) * ss.erf(np.tensordot(x, a1, axes=0)) - 273

#Plots temperature as a function of time for different positions on a jominy stick
def plot_Tt():
    x_list = np.arange(c.Tt_x_min_jominy, c.Tt_x_max_jominy, step = c.Tt_x_step_jominy)
    t_list = np.logspace(c.t_min_jominy,c.t_max_jominy,c.t_points_jominy)

    t0 = time.process_time()
    temp_curves = T_func_vectorized(x_list,t_list, celcius=True)
    t1 = time.process_time()

    print('Caluculating T(x,t) in ', np.product(temp_curves.shape), ' points (all displayed) took ', round(t1-t0, 4),'s', sep='')

    max_color = 0.7
    color_range = max_color * np.array([i/len(x_list) for i in range(len(x_list))])
    cmap = cm.get_cmap('YlOrRd_r')
    color_list = [cmap(x) for x in color_range]

    fig = plt.figure(figsize=[15,7])
    count = 0
    label_counter = 3
    for temp_curves_values,x in zip(temp_curves, x_list):
        if label_counter == 3:
            plt.plot(t_list, temp_curves_values, color = color_list[count], label = str(round(x,2))+'mm', linestyle = ":")
            count += 1
            label_counter = 0
        else:
            plt.plot(t_list, temp_curves_values, color=color_list[count])
            count += 1
            label_counter += 1

    plt.title('Temperature as a function of time at different position on a Jominy stick, distance between lines is 5mm.', fontsize = 14)
    plt.xlabel("Time [s]", fontsize = 14)
    plt.ylabel("Temperature [\u2103]", fontsize = 14)
    plt.xscale("log")
    plt.legend(loc = 'upper right', fontsize = 14)

###########################################
###########################################
###            Part B below             ###
###########################################
###########################################

# time for a given temperature can be found explicitly as long as T0 < T < T_i
def find_t(x,T):
    if T < c.T0_jominy or T > c.Ti_jominy:
        warnings.warn('Du har prøvd å finne t, for en temperatur utenfor (Ti,T0) i Jominy')
    return ( x/(2*np.sqrt(c.a_jominy)) * ( 1 / ss.erfinv((T-c.T0_jominy)/(c.Ti_jominy - c.T0_jominy)) ) )

# Uses T_func_vectorized to calculate temperature for all combinations of x and t in x_list, t_list
# Evaluates the scheil integral for every x in x_list
# t_start[i] and t_end[i] are the start and end times of integration corresponding to the position x_list[i]
def scheil_integrator(x_list, t_list, t_start, t_end):
    dt = np.diff(t_list)
    dt = np.concatenate((dt, np.array([dt[-1]])))

    temps = T_func_vectorized(x_list, t_list)
    mask = np.array([(t < t_start) + (t > t_end) for t in t_list]).transpose()
    temps = np.ma.masked_array(temps,mask)

    integrand = 1/t_star_func(temps)
    integrand = np.ma.MaskedArray.filled(integrand,0)

    return np.dot(integrand, dt)

def plot_Xx():
    t0 = time.process_time()
    x_line = np.logspace(c.Xx_x_min, c.Xx_x_max, c.Xx_x_points)

    #Start time must be sufficiently far below equilibrium temperature, or else t^* will overflow
    #Physical explanation: As T approaches T_eq, transformation time approaches infinity.
    t_start = find_t(x_line,c.Teq-15)
    t_end = find_t(x_line, c.T0_jominy+1)

    t_list = np.logspace(np.log10(t_start[0]), np.log10(t_end[-1]), c.jominy_scheil_integrand_points)

    I = scheil_integrator(x_line, t_list, t_start, t_end)

    X = 1 - (1 - c.Xc) ** (I ** c.n)
    t1 = time.process_time()

    print('Calculating X(x) in ', len(X) ,' points took ', round(t1-t0, 3), 's', sep='')

    fig = plt.figure(figsize=[15, 7])
    plt.plot(x_line, X)
    plt.grid(which='both')
    plt.xscale('log')
    plt.title('Degree of transformation completion as a function of position on Jominy stick', fontsize = 14)
    plt.xlabel('Distance from cold end [mm]', fontsize = 14)
    plt.ylabel('Fraction of transformation completion [–]', fontsize = 14)