import numpy as np
import matplotlib.pyplot as plt
from matplotlib.cm import get_cmap
from matplotlib.colors import LogNorm

def plot_midline_vert():
    midlines_vert = []
    for v0 in ['1', '01', '001']:
        midlines_vert.append(np.genfromtxt('midline_data/midline_vert_' + v0 + '.csv',
                                          delimiter=',', usecols=0, skip_header=1))
        print(v0)
        print(max(midlines_vert[-1]))
    midlines_vert = np.array(midlines_vert)

    cmap = get_cmap('viridis')
    norm = LogNorm(vmin=0.001, vmax=1)
    axis = np.linspace(0, 0.1, len(midlines_vert[0]))
    for line in midlines_vert:
        plt.plot(axis, line/max(line), color=cmap(norm(max(line))), label=r'$u_{wall} =$ '+str(max(line)))

    ax = plt.gca()
    ax.yaxis.tick_right()
    ax.yaxis.set_label_position('right')

    plt.ylabel(r'$u/u_{max}$ [-]')
    plt.xlabel('y [m]')
    plt.vlines(0.05, -0.35, 1, linestyles='--', colors='black')
    plt.hlines(0, 0, 0.1, linestyles='--', colors='black')
    plt.ylim(-0.35, 1)
    plt.xlim(0, 0.1)
    plt.legend()
    plt.title('x = 0.05 m')

def plot_midline_hor():
    midlines_hor = []
    for v0 in ['1', '01', '001']:
        midlines_hor.append(np.genfromtxt('midline_data/midline_hor_' + v0 + '.csv',
                                          delimiter=',', usecols=1, skip_header=1))
    midlines_hor = np.array(midlines_hor)

    cmap = get_cmap('viridis')
    norm = LogNorm(vmin=0.001, vmax=1)
    v0_list = [1, 0.1, 0.01]
    axis = np.linspace(0, 0.1, len(midlines_hor[0]))
    for line, v0 in zip(midlines_hor, v0_list):
        plt.plot(axis, line/max(abs(line)), color=cmap(norm(v0)), label=r'$u_{wall} = $'+str(v0))

    plt.ylabel(r'$v/\vert v_{min} \vert$')
    plt.xlabel('x [m]')
    plt.vlines(0.05, -1.1, 0.65, linestyles='--', colors='black')
    plt.hlines(0, 0, 0.1, linestyles='--', colors='black')
    plt.xlim(0, 0.1)
    plt.ylim(-1.1, 0.65)
    plt.title('y = 0.05 m')
    #plt.legend()

def plot_mesh():
    data = np.genfromtxt('meshes.txt', skip_header=1).transpose()
    meshes = data[0]
    times = data[1]
    vx, vy, vz = data[2:]

    fig, ax = plt.subplots()
    twx = ax.twinx()

    leg1, = ax.plot(meshes, vy, label=r'$F_y$', color='red')
    leg2, = twx.plot(meshes, times, label='time [s]', color='blue')

    ax.set_ylabel(r'$F_y$ [N m$^3$ kg$^{-1}$]')
    twx.set_ylabel('Computation time [s]')

    plt.legend(loc='upper center', handles=[leg1, leg2])
    ax.set_xlabel('Mesh cells (x,y)')
    plt.show()

def plot_midline():
    fig, axs = plt.subplots(1, 2, figsize=(9, 4))
    plt.sca(axs[0])
    plot_midline_hor()
    plt.sca(axs[1])
    plot_midline_vert()
    plt.show()

def plot_wall_velocity():
    baffle = np.genfromtxt('left_wall/baffle.csv', delimiter=',', skip_header=1, usecols=1)
    corner = np.genfromtxt('left_wall/corner.csv', delimiter=',', skip_header=1, usecols=1)
    plain = np.genfromtxt('left_wall/plain.csv', delimiter=',', skip_header=1, usecols=1)

    cmap = get_cmap('viridis')
    y_ax = np.linspace(0, 0.1, len(baffle))

    plt.plot(y_ax, baffle, label='Baffle', color=cmap(0), linestyle=':')
    plt.plot(y_ax, corner, label='Corner', color=cmap(0.4), linestyle='--')
    plt.plot(y_ax, plain, label='Plain', color=cmap(0.8), linestyle='-')
    plt.legend(title='Geometry')
    plt.xlabel('y [m]')
    plt.ylabel(r'v [ms$^{-1}]$')
    plt.title(r'$x = 0.256$ mm')
    plt.show()

plot_mesh()