import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import constants as c
from c_curves import t_star_func

if True:
    #Nothing to see here, please move along
    import warnings
    warnings.filterwarnings('ignore')

#Returns T= T(t,y,q) in Kelvin
def T_weld_func(t, y, qvd):
    return c.T0_weld + (qvd / (c.rhoc * np.sqrt(4*np.pi * c.a * t)) ) * np.exp(- y**2/ (4*c.a*t))

#Uses T_weld_vectorized to generate temperatures for all combinations of position, time and welding effect.
#Returns a 4D-array of temperatures, and a 2D-array of positions and 1D-array of times to be used by plot_Tt_weld_V3
def get_Tt_weld_curves():
    qvd_array = np.arange(c.qvd_slider_min, c.qvd_slider_max, step=c.qvd_slider_step) * 1000
    t_line = np.logspace(-3, 2, 100)

    y_vals1 = np.arange(0, 2, step=0.2)
    y_vals2 = np.arange(2, 5, step=0.5)
    y_vals3 = np.arange(5, 10, step=1)
    y_vals4 = np.arange(10, 20, step=2)
    y_vals5 = np.arange(20, 50, step=5)
    y_vals6 = np.arange(50, 100, step=10)
    y_vals7 = np.array([100])

    y_vals = [y_vals1, y_vals2, y_vals3, y_vals4, y_vals5, y_vals6, y_vals7]

    temps = [T_weld_vectorized(t_line, y, qvd_array) for y in y_vals]
    return temps, y_vals, t_line

#Tar inn t, y og qvd som vektorer, returnerer en 3D-array med shape (len(t), len(y), len(qvd))
#Hver rad har konstant qvd, hver kolonne har konstant y, t er dybden i arrayen
def T_weld_vectorized(t, y, qvd):

    a1 = 1/(4*c.a*t)
    a2 = 1/(c.rhoc * np.sqrt(4*np.pi * c.a * t))

    #building matrixes with shape (len(y), len(qvd))
    exp_core = np.tensordot(-y**2, np.ones(len(qvd)), axes=0)
    pre_exp = np.tensordot(np.ones(len(y)), qvd, axes = 0)

    #expanding to shape (len(t), len(y), len(qvd)), while calculating time-dependency
    exp_core = np.tensordot(a1, exp_core, axes = 0)
    pre_exp = np.tensordot(a2, pre_exp, axes = 0)

    return c.T0_weld + pre_exp * np.exp(exp_core)

#Uses T_weld_func to plot temperature vs. time for different positions
#adjusts time-axis and the max distance from weld shown depending on whether T-axis is log or lin (ylog = True/False)
def plot_Tt_weld_V1(qvd, ylog = True):
    qvd = qvd*1000 #Fikser enhetene

    y_vals1 = np.arange(0,2, step = 0.2)
    y_vals2 = np.arange(2, 5, step = 0.5)
    y_vals3 = np.arange(5,10, step = 1)

    if ylog == True:
        t_line = np.logspace(-3, 2, 100)
        y_vals4 = np.arange(10, 20, step=2)
        y_vals5 = np.arange(20,50, step = 5)
        y_vals6 = np.arange(50, 100, step = 10)
        y_vals7 = [100]
        y_vals = [y_vals1, y_vals2, y_vals3, y_vals4, y_vals5, y_vals6, y_vals7]
    else:
        t_line = np.logspace(-3,0, 100)
        y_vals = [y_vals1, y_vals2, y_vals3]

    max_color = 0.7
    color_range = max_color * np.array([i/len(y_vals) for i in range(len(y_vals))])

    cmap = cm.get_cmap('YlOrRd_r')
    color_list = [cmap(x) for x in color_range]

    plt.figure(figsize=[15,7])
    for vals, color in zip(y_vals, color_list):
        label = True
        for y in vals:
            if label == True:
                plt.plot(t_line, T_weld_func(t_line, y, qvd), color = color,
                     label = 'y = '+str(round(y,3))+'mm', linestyle = ':')
                label = False
            else:
                plt.plot(t_line, T_weld_func(t_line, y, qvd), color = color)
    plt.legend(loc = 'upper right', fontsize = 16)
    plt.xscale('log')
    if ylog == True:
        plt.yscale('log')
    plt.xlabel('Time [s]', fontsize = 16)
    plt.ylabel('Temperature [\u2103]', fontsize = 16)
    plt.title('Temperature as a function of time at different distances from weld.\n'
              'Distance increases linearly from one dotted line to the next by ' 
              r'$\Delta y = \frac{h}{10}$,'
              '\nwhere $h$ is the distance at the furthest of the dotted lines')

#Implements T_weld_vectorized to shave some runtime relative to plot_Tt_weld
def plot_Tt_weld_V2(qvd, ylog = True):
    qvd_index = int((qvd - c.qvd_slider_min)/c.qvd_slider_step)
    qvd_array = np.arange(c.qvd_slider_min,c.qvd_slider_max,step=c.qvd_slider_step) * 1000

    y_vals1 = np.arange(0,2, step = 0.2)
    y_vals2 = np.arange(2, 5, step = 0.5)
    y_vals3 = np.arange(5,10, step = 1)

    if ylog == True:
        t_line = np.logspace(-3, 2, 100)
        y_vals4 = np.arange(10, 20, step=2)
        y_vals5 = np.arange(20,50, step = 5)
        y_vals6 = np.arange(50, 100, step = 10)
        y_vals7 = np.array([100])
        y_vals = [y_vals1, y_vals2, y_vals3, y_vals4, y_vals5, y_vals6, y_vals7]
    else:
        t_line = np.logspace(-3,0, 100)
        y_vals = [y_vals1, y_vals2, y_vals3]

    max_color = 0.7
    color_range = max_color * np.array([i/len(y_vals) for i in range(len(y_vals))])

    cmap = cm.get_cmap('YlOrRd_r')
    color_list = [cmap(x) for x in color_range]

    plt.figure(figsize=[15,7])
    for y, color in zip(y_vals, color_list):
        temps = T_weld_vectorized(t_line, y, qvd_array)

        plt.plot(t_line, temps[:,0,qvd_index], color=color,
                 label='y = ' + str(round(y[0], 3)) + 'mm', linestyle=':')
        for temp in np.transpose(temps[:,1:, qvd_index]):
            plt.plot(t_line, temp, color = color)

    plt.legend(loc = 'upper right', fontsize = 16)
    plt.xscale('log')
    if ylog == True:
        plt.yscale('log')
    plt.xlabel('Time [s]', fontsize = 16)
    plt.ylabel('Temperature [\u2103]', fontsize = 16)
    plt.title('Temperature as a function of time at different distances from weld.\n'
              'Distance increases linearly between each pair of dotted lines by ' 
              r'$\Delta y = \frac{h}{10}$,'
              '\nwhere $h$ is the distance at the furthest of the two.')

#Gives the same plot as V1 and V2, gives far smoother slider-response.
def plot_Tt_weld_V3(qvd, Tt_curves, y_vals, t_line, ylog = True):
    qvd_index = int((qvd - c.qvd_slider_min)/c.qvd_slider_step)

    max_color = 0.7
    color_range = max_color * np.array([i / len(y_vals) for i in range(len(y_vals))])

    cmap = cm.get_cmap('YlOrRd_r')
    color_list = [cmap(x) for x in color_range]

    plt.figure(figsize=[15, 7])

    if ylog == True:
        for i in range(len(y_vals)):
            temps = Tt_curves[i]

            plt.plot(t_line, temps[:, 0, qvd_index], color=color_list[i],
                     label='y = ' + str(round(y_vals[i][0], 3)) + 'mm', linestyle=':')
            for temp in np.transpose(temps[:, 1:, qvd_index]):
                plt.plot(t_line, temp, color=color_list[i])
    else:
        Tt_curves = Tt_curves[:3]
        y_vals = y_vals[:3]
        for i in range(len(y_vals)):
            temps = Tt_curves[i]

            plt.plot(t_line[:60], temps[:60, 0, qvd_index], color=color_list[i],
                     label='y = ' + str(round(y_vals[i][0], 3)) + 'mm', linestyle=':')
            for temp in np.transpose(temps[:60, 1:, qvd_index]):
                plt.plot(t_line[:60], temp, color=color_list[i])

    plt.legend(loc='upper right', fontsize=16)
    plt.xscale('log')

    if ylog == True:
        plt.yscale('log')

    plt.xlabel('Time [s]', fontsize=16)
    plt.ylabel('Temperature [\u2103]', fontsize=16)
    plt.title('Temperature as a function of time at different distances from weld.\n'
              'Distance increases linearly between each pair of dotted lines by '
              r'$\Delta y = \frac{h}{10}$, where $h$ is the distance at the furthest of the two.')


#Finner maksimal avstand fra sveisen hvor temperatur overstiger c.Tsol_weld
#Bruker newtons metode med y0 = 0.1 og toleranse tol = 1e-6
#tar in qvd med enhet kJ/mm^2
def find_max_y_weld(qvd):
    qvd = qvd*1000 #fikser enhetene

    a1 = ( qvd * np.exp(-0.5) )/(c.rhoc * np.sqrt(2*np.pi))
    nullfunc = lambda y: c.T0_weld - c.Tsol_weld + a1/y
    nullfunc_deri = lambda y: -a1/(y**2)

    tol = 1e-6
    y_start = 0.1
    y0 = y_start
    step = - nullfunc(y0)/nullfunc_deri(y0)
    while abs(step) > tol:
        y0 = y0 + step

        while y0 <= 0:
            y0 += 0.00001

        step = - nullfunc(y0)/nullfunc_deri(y0)
    return y0

#Finner t = t(y, T, q), tidspunktet når temperaturen i et bestemt punkt er T og til høyre for maks (synkende)
#bruker newtons metode med t0 = 1.5(y+1)^2/2a fordi maksimum er ved y^2/2a
#tar in q med enhet kJ/mm^2
def find_t_weld(qvd, y, T=c.Teq):
    qvd = 1000*qvd #fikser enheten

    nullfunc = lambda t: T_weld_func(t, y, qvd) - T
    nullfunc_deri = lambda t: (((y**2)/(2*c.a*t)) - 1) * (qvd/(2*c.rhoc*np.sqrt(4*np.pi*c.a)))*\
                                np.power(t, -1.5)*np.exp((-y**2)/(4*c.a*t))

    t0 = 1.5*((y+1)**2)/(2*c.a) #starter til høyre for toppunktet ved t = y/sqrt(2a)

    if y > find_max_y_weld(qvd/1000):
        return -1

    step = - nullfunc(t0)/nullfunc_deri(t0)
    tol = 1e-5

    while abs(step) > tol:
        t0 = t0 + step
        step = - nullfunc(t0) / nullfunc_deri(t0)

    return t0

def vec_weld_scheill_integrator(t_start, t_end, y_list, qvd_list):
    qvd_list = qvd_list * 1000 #fikser enheter

    #the last element of the first row in t_end is the highest time-value
    t_list = np.logspace(-1.5,np.log10(t_end[0][-1]),100)

    #t_mask is a 3D-array of same shape as temps (len(t), len(y), len(qvd))
    #t_mask makes sure we only integrate over the temperatures from Teq to T_min_weld
    t_mask = np.array([(t < t_start) + (t > t_end) for t in t_list])
    temps = T_weld_vectorized(t_list, y_list, qvd_list)

    temps = np.ma.masked_array(temps, t_mask)
    integrand = 1/t_star_func(temps)

    dt = np.diff(t_list)
    dt_M = np.tensordot(dt, np.ones(temps[0].shape), axes=0)

    #dt_M has one fewer value than t_list, so we drop the first value
    I = np.sum(dt_M*integrand[1:], axis=0)
    return I

def plot_X_weld():
    qvd_line = np.linspace(0.001,0.2, 50)
    y_vals = np.arange(0, 31, step=2.5)

    #Lager to matriser, en med startverdier for integrasjon og en med sluttverdier
    #Starttid for integrasjon er der T = Teq - 5, for å unngå overflow i t_star_func
    #Den fysikalske forklaringen på dette er at det tar veeeeldig lang tid (10^300 sekunder) før det skjer noe spennende
    #hvis ikke T er mer enn 5 grader under likevektstemperatur
    #Hver rad har konstant y, hver kolonne har konstant qvd
    #t_start[i][k] = t_start(y_vals[i], qvd_line[k])
    t_start = np.array([[find_t_weld(qvd, y, T = c.Teq-5) for qvd in qvd_line] for y in y_vals])
    t_end = np.array([[find_t_weld(qvd, y, T=c.Tmin_weld) for qvd in qvd_line] for y in y_vals])

    I = vec_weld_scheill_integrator(t_start, t_end, y_vals, qvd_line)

    #Finner fraksjoner fra scheil-integralet, maskerte verdier settes til null
    X = 1 - (1 - c.Xc)**(I**c.n)
    X = np.ma.MaskedArray.filled(X,0)

    #Plotteplot
    plt.figure(figsize=[15,7])
    cmap = cm.get_cmap('YlOrRd_r')
    max_col = 0.85
    label_counter = 3
    for i in range(len(X)):
        if label_counter == 3:
            plt.plot(qvd_line, X[i], color = cmap(max_col * i/len(X)),
                     label=str(round(y_vals[i], 0))+'mm', linestyle = '-')
            label_counter = 0
        else:
            plt.plot(qvd_line, X[i], color=cmap(max_col * i / len(X)), linestyle = '--')
            label_counter += 1

    plt.legend(fontsize = 14)
    plt.xlabel(r'$\frac{q_0}{vd}$ $\left[\frac{kJ}{mm^2}\right]$', fontsize = 16)
    plt.ylabel('Fraction transformed [–]', fontsize = 14)
    plt.title('Degree of transformation completion as a function of welding effect per area,\n'
              'for different distances from weld.', fontsize = 14)
