import matplotlib.pyplot as plt
import numpy as np
import matplotlib.cm as cm
import constants as c
import time
import c_curves

#Temperaturer i K
def extr_eq(t, h):
    T = c.T0_extrusion + (c.Ti_extrusion - c.T0_extrusion) * np.exp(- (2*h*t)/(c.w_extrusion*c.rhoc_extrusion))
    return T

#Returns a matrix of shape (len(h), len(t)) containing temperature at all combinations of h and t
def T_func_vec(t,h, celcius = False):
    a1 = -(2*h)/(c.w_extrusion * c.rhoc_extrusion)

    if celcius == False:
        return c.T0_extrusion + (c.Ti_extrusion - c.T0_extrusion) * np.exp(np.tensordot(a1, t, axes=0))
    else:
        return c.T0_extrusion + (c.Ti_extrusion - c.T0_extrusion) * np.exp(np.tensordot(a1, t, axes=0)) - 273


def plot_Tt():

    t_list = np.logspace(c.t_min_extrusion,c.t_max_extrusion,c.t_points_extrusion)

    h1_list = np.arange(10,50,step = 5)
    h2_list = np.arange(50, 250, step = 25)
    h3_list = np.arange(250, 1000, step = 100)
    h4_list = np.arange(1000, 5000, step = 500)
    h5_list = np.arange(5000,10000, step = 1000)
    h6_list = np.array([10000])

    h_list = np.array([h1_list, h2_list, h3_list, h4_list, h5_list, h6_list])
    num_lines = np.sum([len(x) for x in h_list])

    t0 = time.process_time()
    all_temps = [T_func_vec(t_list, h_vals, celcius=True) for h_vals in h_list]
    t1 = time.process_time()

    print('Calculating T(h,t) in ', num_lines*len(t_list), ' points (all displayed) took ', round(t1-t0, 4), 's', sep='')

    max_color = 0.7
    color_range = max_color * np.array([i/num_lines for i in range(num_lines)])
    cmap = cm.get_cmap('YlOrRd_r')
    color_list = [cmap(x) for x in color_range]

    fig = plt.figure(figsize=[15,7])
    color_count = 0
    for temp_matrix, h_vals in zip(all_temps, h_list):
        plt.plot(t_list, temp_matrix[0], color = color_list[color_count], label = 'h = '+str(h_vals[0]), linestyle = ':')
        color_count += 1
        for temp_list in temp_matrix[1:]:
            plt.plot(t_list, temp_list, color = color_list[color_count])
            color_count += 1

    plt.title('Temperature as a function of time in an extruded profile for different heat transfer coefficients.\n'
              r'Distance between lines is constant between each pair of dotted lines and equal $\frac{h}{10}$,'
              '\nwhere $h$ is the highest heat transfer coefficient of the two dotted lines.')
    plt.xlabel("Time [s]", fontsize = 14)
    plt.ylabel("Temperature [\u2103]", fontsize = 14)
    plt.legend(loc = 'lower left', fontsize = 14)
    plt.xscale('log')
    plt.grid()

def plot_Xh():
    h_list = np.logspace(np.log10(c.h_min), np.log10(c.h_max),c.h_steps)

    t_max = ( (c.w_extrusion * c.rhoc_extrusion) / (-2*c.h_min) ) * np.log( 1/(c.Ti_extrusion - c.T0_extrusion ))
    t_min = ( (c.w_extrusion * c.rhoc_extrusion) / (-2*c.h_max) ) * np.log( (c.Teq - c.T0_extrusion)/(c.Ti_extrusion - c.T0_extrusion ))

    t_list = np.logspace(np.log10(t_min), np.log10(t_max), 1000)

    temp_list = T_func_vec(t_list, h_list)
    mask = (temp_list > c.Teq) + (temp_list < c.T0_extrusion + 1)

    integrand = 1/ c_curves.t_star_func(temp_list)
    masked_integrand = np.ma.MaskedArray(integrand, mask).filled(0)

    dt = np.diff(t_list) # = np.diff(t_list), bruker dt fra t_list, dette blir en vektor, siden t_list er logaritmisk blir ikke dt konstant
    dt = np.concatenate((dt,np.array([dt[-1]])))

    I = np.dot(masked_integrand, dt)

    fraction_list = 1 - (1 - c.Xc)**(I**c.n)

    fig = plt.figure(figsize=[15,7])
    plt.plot(h_list, fraction_list)
    plt.xscale("log")
    plt.xlabel(r"Heat transfer coefficient, $\left[\frac{W}{m^2 K}\right]$", fontsize = 14)
    plt.ylabel("Fraction of transformation completion [-]", fontsize = 14)
    plt.grid(which='both')