# To run this script, run
# `python studentized_range_mpmath_ref.py`
# in the "scipy/stats/tests/" directory

# This script generates a JSON file "./data/studentized_range_mpmath_ref.json"
# that is used to compare the accuracy of `studentized_range` functions against
# precise (20 DOP) results generated using `mpmath`.

# Equations in this file have been taken from
# https://en.wikipedia.org/wiki/Studentized_range_distribution
# and have been checked against the following reference:
# Lund, R. E., and J. R. Lund. "Algorithm AS 190: Probabilities and
# Upper Quantiles for the Studentized Range." Journal of the Royal
# Statistical Society. Series C (Applied Statistics), vol. 32, no. 2,
# 1983, pp. 204-210. JSTOR, www.jstor.org/stable/2347300. Accessed 18
# Feb. 2021.

# Note: I would have prefered to use pickle rather than JSON, but -
# due to security concerns - decided against it.
import itertools
from collections import namedtuple
import json
import time

import os
from multiprocessing import Pool, cpu_count

from mpmath import gamma, pi, sqrt, quad, inf, mpf, mp
from mpmath import npdf as phi
from mpmath import ncdf as Phi

results_filepath = "data/studentized_range_mpmath_ref.json"
num_pools = max(cpu_count() - 1, 1)

MPResult = namedtuple("MPResult", ["src_case", "mp_result"])

CdfCase = namedtuple("CdfCase",
                     ["q", "k", "v", "expected_atol", "expected_rtol"])

MomentCase = namedtuple("MomentCase",
                        ["m", "k", "v", "expected_atol", "expected_rtol"])

# Load previously generated JSON results, or init a new dict if none exist
if os.path.isfile(results_filepath):
    res_dict = json.load(open(results_filepath, mode="r"))
else:
    res_dict = dict()

# Frame out data structure. Store data with the function type as a top level
# key to allow future expansion
res_dict["COMMENT"] = ("!!!!!! THIS FILE WAS AUTOGENERATED BY RUNNING "
                       "`python studentized_range_mpmath_ref.py` !!!!!!")
res_dict.setdefault("cdf_data", [])
res_dict.setdefault("pdf_data", [])
res_dict.setdefault("moment_data", [])

general_atol, general_rtol = 1e-11, 1e-11

mp.dps = 24

cp_q = [0.1, 1, 4, 10]
cp_k = [3, 10, 20]
cp_nu = [3, 10, 20, 50, 100, 120]

cdf_pdf_cases = [
    CdfCase(*case,
            general_atol,
            general_rtol)
    for case in
    itertools.product(cp_q, cp_k, cp_nu)
]

mom_atol, mom_rtol = 1e-9, 1e-9
# These are EXTREMELY slow - Multiple days each in worst case.
moment_cases = [
    MomentCase(i, 3, 10, mom_atol, mom_rtol)
    for i in range(5)
]


def write_data():
    """Writes the current res_dict to the target JSON file"""
    with open(results_filepath, mode="w") as f:
        json.dump(res_dict, f, indent=2)


def to_dict(named_tuple):
    """Converts a namedtuple to a dict"""
    return dict(named_tuple._asdict())


def mp_res_to_dict(mp_result):
    """Formats an MPResult namedtuple into a dict for JSON dumping"""
    return {
        "src_case": to_dict(mp_result.src_case),

        # np assert can't handle mpf, so take the accuracy hit here.
        "mp_result": float(mp_result.mp_result)
    }


def cdf_mp(q, k, nu):
    """Straightforward implementation of studentized range CDF"""
    q, k, nu = mpf(q), mpf(k), mpf(nu)

    def inner(s, z):
        return phi(z) * (Phi(z + q * s) - Phi(z)) ** (k - 1)

    def outer(s, z):
        return s ** (nu - 1) * phi(sqrt(nu) * s) * inner(s, z)

    def whole(s, z):
        return (sqrt(2 * pi) * k * nu ** (nu / 2)
                / (gamma(nu / 2) * 2 ** (nu / 2 - 1)) * outer(s, z))

    res = quad(whole, [0, inf], [-inf, inf],
               method="gauss-legendre", maxdegree=10)
    return res


def pdf_mp(q, k, nu):
    """Straightforward implementation of studentized range PDF"""
    q, k, nu = mpf(q), mpf(k), mpf(nu)

    def inner(s, z):
        return phi(z + q * s) * phi(z) * (Phi(z + q * s) - Phi(z)) ** (k - 2)

    def outer(s, z):
        return s ** nu * phi(sqrt(nu) * s) * inner(s, z)

    def whole(s, z):
        return (sqrt(2 * pi) * k * (k - 1) * nu ** (nu / 2)
                / (gamma(nu / 2) * 2 ** (nu / 2 - 1)) * outer(s, z))

    res = quad(whole, [0, inf], [-inf, inf],
               method="gauss-legendre", maxdegree=10)
    return res


def moment_mp(m, k, nu):
    """Implementation of the studentized range moment"""
    m, k, nu = mpf(m), mpf(k), mpf(nu)

    def inner(q, s, z):
        return phi(z + q * s) * phi(z) * (Phi(z + q * s) - Phi(z)) ** (k - 2)

    def outer(q, s, z):
        return s ** nu * phi(sqrt(nu) * s) * inner(q, s, z)

    def pdf(q, s, z):
        return (sqrt(2 * pi) * k * (k - 1) * nu ** (nu / 2)
                / (gamma(nu / 2) * 2 ** (nu / 2 - 1)) * outer(q, s, z))

    def whole(q, s, z):
        return q ** m * pdf(q, s, z)

    res = quad(whole, [0, inf], [0, inf], [-inf, inf],
               method="gauss-legendre", maxdegree=10)
    return res


def result_exists(set_key, case):
    """Searches the results dict for a result in the set that matches a case.
    Returns True if such a case exists."""
    if set_key not in res_dict:
        raise ValueError(f"{set_key} not present in data structure!")

    case_dict = to_dict(case)
    existing_res = list(filter(
        lambda res: res["src_case"] == case_dict,  # dict comparison
        res_dict[set_key]))

    return len(existing_res) > 0


def run(case, run_lambda, set_key, index=0, total_cases=0):
    """Runs the single passed case, returning an mp dictionary and index"""
    t_start = time.perf_counter()

    res = run_lambda(case)

    print(f"Finished {index + 1}/{total_cases} in batch. "
          f"(Took {time.perf_counter() - t_start}s)")

    return index, set_key, mp_res_to_dict(MPResult(case, res))


def write_result(res):
    """A callback for completed jobs. Inserts and writes a calculated result
     to file."""
    index, set_key, result_dict = res
    res_dict[set_key].insert(index, result_dict)
    write_data()


def run_cases(cases, run_lambda, set_key):
    """Runs an array of cases and writes to file"""
    # Generate jobs to run from cases that do not have a result in
    # the previously loaded JSON.
    job_arg = [(case, run_lambda, set_key, index, len(cases))
               for index, case in enumerate(cases)
               if not result_exists(set_key, case)]

    print(f"{len(cases) - len(job_arg)}/{len(cases)} cases won't be "
          f"calculated because their results already exist.")

    jobs = []
    pool = Pool(num_pools)

    # Run all using multiprocess
    for case in job_arg:
        jobs.append(pool.apply_async(run, args=case, callback=write_result))

    pool.close()
    pool.join()


def run_pdf(case):
    return pdf_mp(case.q, case.k, case.v)


def run_cdf(case):
    return cdf_mp(case.q, case.k, case.v)


def run_moment(case):
    return moment_mp(case.m, case.k, case.v)


def main():
    t_start = time.perf_counter()

    total_cases = 2 * len(cdf_pdf_cases) + len(moment_cases)
    print(f"Processing {total_cases} test cases")

    print(f"Running 1st batch ({len(cdf_pdf_cases)} PDF cases). "
          f"These take about 30s each.")
    run_cases(cdf_pdf_cases, run_pdf, "pdf_data")

    print(f"Running 2nd batch ({len(cdf_pdf_cases)} CDF cases). "
          f"These take about 30s each.")
    run_cases(cdf_pdf_cases, run_cdf, "cdf_data")

    print(f"Running 3rd batch ({len(moment_cases)} moment cases). "
          f"These take about anywhere from a few hours to days each.")
    run_cases(moment_cases, run_moment, "moment_data")

    print(f"Test data generated in {time.perf_counter() - t_start}s")


if __name__ == "__main__":
    main()
