import os, sys, platform
sys.path.append(os.path.dirname(os.getcwd())+'/soret_model')
sys.path.append(os.path.dirname(os.getcwd())+'/soret_model/cpp')


if (platform.system() == 'Linux'):
    sys.path.append(os.path.dirname(os.getcwd()) + '/soret_model/cpp/release_ubuntu')
    from release_ubuntu.KineticGas import cpp_KineticGas, cpp_tests

else:
    sys.path.append(os.path.dirname(os.getcwd()) + '/soret_model/cpp/release_mac')
    from release_mac.KineticGas import cpp_KineticGas, cpp_tests


from kinetic_gas_benchmark import KineticGas
import time
from scipy.linalg import solve
import numpy as np
FLTEPS = 1e-6

def test_kingas():
    print('Testing KinGas')
    mole_weights = [18.0, 32.0]
    sigmaij = [[3.1, 4.05], 
                [4.05, 5]]
    mole_fracs = [0.3, 0.7]
    N = 3

    k_cpp = cpp_KineticGas(mole_weights, sigmaij, mole_fracs, 300, 1e5, N)
    k_py = KineticGas(mole_weights, sigmaij, mole_fracs, 300, 1e5, N)

    p, q, r, l = 1, 1, 1, 1
    dA = k_cpp.A(p, q, r, l) - k_py.A(p, q, r, l)
    dA_prime =  k_cpp.A_prime(p, q, r, l) - k_py.A_prime(p, q, r, l)
    dA_tripleprime = k_cpp.A_trippleprime(p, q, r, l) - k_py.A_tripleprime(p, q, r, l)

    if abs(dA) > FLTEPS:
        return 1
    if abs(dA_prime) > FLTEPS:
        return 2
    if abs(dA_tripleprime) > FLTEPS:
        return 3

    pq_range = np.arange(-N, N+1, 1)
    for p in pq_range:
        for q in pq_range:
            H_12_cpp = k_cpp.H_ij(p, q, 12)
            H_12_py = k_py.H_ij(p, q, 12)
            dH_12 = H_12_cpp - H_12_py
            if abs(dH_12) > FLTEPS:
                print(H_12_cpp, H_12_py)
                print('Failed on (p, q) =', p, q)
                return 4
            dH_21 = k_cpp.H_ij(p, q, 21) - k_py.H_ij(p, q, 21)
            if abs(dH_21) > FLTEPS:
                print('Failed on (p, q) =', p, q)
                return 5
            
            dH_1 = k_cpp.H_i(p, q, 12) - k_py.H_i(p, q, 12)
            if abs(dH_1) > FLTEPS:
                print('Failed on (p, q) =', p, q)
                return 6
            dH_2 = k_cpp.H_i(p, q, 21) - k_py.H_i(p, q, 21)
            if abs(dH_2) > FLTEPS:
                print('Failed on (p, q) =', p, q)
                return 7
            
            dH_simple_1 = k_cpp.H_simple(p, q, 1) - k_py.H_simple(p, q, 1)
            if dH_simple_1 > FLTEPS:
                print('Failed on (p, q) =', p, q)
                return 8
            dH_simple_2 = k_cpp.H_simple(p, q, 2) - k_py.H_simple(p, q, 2) 
            if dH_simple_2 > FLTEPS:
                print('Failed on (p, q) =', p, q)
                return 9

    print("KineticGas passed testing")
    return 0

def time_kingas():
    print()
    print('#' * 50)
    print('\nTiming KineticGas')
    print()

    mole_weights = [18.0, 32.0]
    sigmaij = [[3.1, 4.05], 
                [4.05, 5]]
    np_sigmaij = np.array(sigmaij)
    mole_fracs = [0.3, 0.7]
    T = 300
    p = 1e5
    N = 12

    x1, x2 = mole_fracs

    m0 = sum(mole_weights)
    M1, M2 = [m/m0 for m in mole_weights]

    t0 = time.process_time()
    k_py = KineticGas(mole_weights, np_sigmaij, mole_fracs, T, p, N)
    t1 = time.process_time()
    py_time = t1 - t0
    print('Python :', k_py.soret, py_time)
    t2 = time.process_time()
    k_cpp = cpp_KineticGas(mole_weights, sigmaij, mole_fracs, T, p, N)
    A = k_cpp.A_matrix
    A = np.array(A)
    delta = k_cpp.delta_vector
    t3 = time.process_time()  
    
    d = solve(A, delta)
    d_1, d0, d1 = d[N-1], d[N], d[N + 1]
    soret = - (5 / (2 * d0)) * ((x1 * d1 / np.sqrt(M1)) + (x2 * d_1 / np.sqrt(M2)))
    t4 = time.process_time()
    cpp_time = t3 - t2
    print('C++ :', soret, cpp_time)
    print('scipy :', t4 - t3)

    print("\nC++ was", round(py_time / cpp_time, 2), "times faster than Python.")
    print('\n', '#' * 50, '\n')
    '''
    for i in range(len(A)):
        for j in range(len(A)):
            eps = abs(A[i, j] - k_py.A_matr[j, i])
            if eps > FLTEPS:
                print()
                print('#'*50)
                print()
                for i in range(len(A)):
                    for j in range(len(A)):
                        print(round(abs(A[i, j] - k_py.A_matr[j, i]), 10), end = ' '*(15 - len(str(round(abs(A[i, j] - k_py.A_matr[j, i]), 10)))))
                print()
                return 1
    
    if abs(soret - k_py.soret) > FLTEPS:
        print(soret, k_py.soret)
        return 2
    '''
    return 0
    

if __name__ == "__main__":
    print('RUNNING TESTS')
    tests = [cpp_tests,
            test_kingas]

    r = [0 for t in tests]
    for i, test in enumerate(tests):
        r[i] = test()

    if sum(r) == 0:
        print('ALL TESTS OK')
    else:
        print('TEST FAILED WITH EXIT CODE', r)
        exit(sum(r))
    
    time_kingas()
