import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import scipy.optimize as opt
import matplotlib.cm as cm
import matplotlib.ticker as mtick

def plot_moduli():
    E1f = 233000
    E2f = 15000
    Em = 4100

    V_f = np.linspace(0, 1, 100)
    V_m = 1 - V_f

    E1 = V_f * E1f + V_m * Em
    Et = 1/(V_f/E2f + V_m/Em)

    data = pd.read_excel('data/properties.xlsx')
    fracs = np.array(data['fracs'].tolist()) * 1e-2

    fig, ax = plt.subplots(2,1, figsize=(8,6), sharex=True)

    twin = ax[1]
    ax = ax[0]
    ax.plot(V_f, E1, label=r'$E_1$ Theoretical', color='blue')
    ax.scatter(fracs,data['E11'], marker='x', color='black', label=r'$E_{11}$ Load test')
    twin.plot(V_f, Et, label=r'$E_t$ Theoretical', color='red')

    ax.plot(V_f, [E1f for f in V_f], color='black', linestyle='--')
    ax.plot(V_f, [Em for f in V_f], color='black', linestyle='--')

    twin.scatter(fracs, data['E22'], marker='x', color='black', label=r'$E_{22}$ Load test')
    twin.scatter(fracs, data['E33'], marker='+', color='black', label=r'$E_{33}$ Load test')
    twin.plot(V_f, [E2f for f in V_f], color='black', linestyle='--')
    twin.plot(V_f, [Em for f in V_f], color='black', linestyle='--')

    ax.legend()
    twin.legend()
    twin.set_xlabel(r'$V_f$ [–]', fontsize=14)
    ax.set_ylabel(r'$E_1$ [MPa]', fontsize=14)
    twin.set_ylabel(r'$E_t$ [MPa]', fontsize=14)

    plt.suptitle('Composite moduli')
    plt.savefig('plots/theoretical_moduli', dpi=600)
    plt.show()

def mesh_convergence():
    data = pd.read_excel('data/mesh.xlsx')
    sizes = np.array(data['mesh'].to_list())

    fig, axs = plt.subplots(2,1, figsize=(8,6), sharex=True)
    leg1, = axs[0].plot(1/sizes, data['E11'], color='blue', label=r'$E_{11}$')
    axs[0].plot(1/sizes, [data['E11'].to_list()[-1] for s in sizes], color='black', linestyle='--')

    leg2, = axs[1].plot(1/sizes, data['E22'], color='red', label=r'$E_{22}$')
    leg3, = axs[1].plot(1 / sizes, data['E33'], color='green', linestyle=':', label=r'$E_{33}$')
    axs[1].plot(1 / sizes, [data['E22'].to_list()[-1] for s in sizes], color='black', linestyle='--')

    axs[0].set_ylabel(r'$E_1$ [MPa]', fontsize=14)
    axs[1].set_ylabel(r'$E_t$ [MPa]', fontsize=14)

    axs[1].set_xlabel(r'(mesh size)$^{-1}$ [arb. units]', fontsize=14)
    axs[0].legend(handles = [leg1, leg2, leg3], fontsize=14)
    plt.savefig('plots/convergence', dpi=600)
    plt.show()

def halplin_tsai_Et(Vf, ksi):
    E2f = 15000
    Em = 4100
    eta = (E2f - Em) / (E2f +  ksi * Em)

    E2 = Em * (1 + ksi * eta * Vf) / (1 - eta * Vf)

    return E2

def halplin_tsai_G(Vf, ksi):
    G2f = 9000
    Em = 4100
    nu_m = 0.37
    Gm = Em/(2 * (1 + nu_m))
    eta = (G2f - Gm) / (G2f + ksi * Gm)

    G2 = Gm * (1 + ksi * eta * Vf) / (1 - eta * Vf)

    return G2

def fit_halplin_tsai():
    data = pd.read_excel('data/properties.xlsx')
    xdata = np.array(data['fracs'].to_list()) * 1e-2
    ydata = np.array(data['E22'].to_list())
    y2data = np.array(data['E33'].to_list())

    xdata = np.concatenate((xdata, xdata))
    ydata = np.concatenate((ydata, y2data))
    coeff = opt.curve_fit(halplin_tsai_Et, xdata, ydata, p0=[0.5])
    Vf = np.linspace(0, 1, 100)

    Vm = 1 - Vf
    E2f = 15000
    Em = 4100
    Et_model = 1/(Vf/E2f + Vm/Em)

    plt.scatter(xdata, ydata, marker = 'x', color='black', label='Load test')
    plt.plot(Vf, halplin_tsai_Et(Vf, coeff[0]), label=r'Halplin-Tsai, $\xi = $'+str(round(coeff[0][0],3)))
    plt.plot(Vf, Et_model, label=r'$E_t$ from eq. (2.6)')
    plt.hlines([E2f, Em], 0, 1, colors='black', linestyles='--')
    plt.legend(loc='center left')
    plt.ylabel(r'$E_t$ [MPa]')
    plt.xlabel(r'$V_f$ [-]')
    plt.savefig('plots/halplin_tsai', dpi=600)
    plt.show()

def pertubations():
    cmap = cm.get_cmap('plasma')
    fig, ax = plt.subplots(figsize=(9, 6.5))
    plt.sca(ax)

    data = pd.read_excel('data/pertubation.xlsx')
    color_scaler = lambda V: 0.9*V/(0.4) - 0.3/0.4

    Vf_list = np.array([0.3,0.5,0.7])
    p_list = np.linspace(-450, 450, 100)
    E12f_0 = 15000
    E2f = E12f_0 + p_list
    Em = 4100
    for Vf in Vf_list:

        data_E2f = np.array(data['E2F_'+str(int(Vf*100))].to_list())
        data_Et = np.array(data['Et_'+str(int(Vf*100))].to_list())

        data_dE2f = (data_E2f - data_E2f[0])/data_E2f[0]
        data_dEt = (data_Et - data_Et[0])/data_Et[0]

        data_dE2f.sort()
        data_dEt.sort()

        Vm = 1 - Vf
        Et_0 = 1/(Vf/E12f_0 + Vm/Em)
        Et = 1/(Vf/E2f + Vm/Em)

        dEt = (Et - Et_0)/Et_0
        dE2f = p_list/E12f_0

        plt.plot(dE2f, dEt, color=cmap(color_scaler(Vf)), label=Vf)
        plt.plot(data_dE2f, data_dEt, color=cmap(color_scaler(Vf)), linestyle = ':', marker='x')

    plt.grid()
    plt.xlabel(r'$\frac{\Delta E_{2f}}{E_{2f}^\circ}$ [-]', fontsize=15)
    plt.ylabel(r'$\frac{\Delta E_{t}}{E_{t}^\circ}$ [-]', fontsize=15)
    plt.legend(title=r'$V_f$')
    ax.xaxis.set_major_formatter(mtick.PercentFormatter(1, decimals=1))
    ax.yaxis.set_major_formatter(mtick.PercentFormatter(1, decimals=1))
    plt.savefig('plots/pertubations', dpi=600)
    plt.show()

plot_moduli()
mesh_convergence()
fit_halplin_tsai()
pertubations()