#!/usr/bin/env python3
# This file is part of granularhealing.

# granularhealing is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.

# granularhealing is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.

# You should have received a copy of the GNU General Public License
# along with granularhealing. If not, see <https://www.gnu.org/licenses/>.
import json
import os

# import matplotlib as mpl
import numpy as np
import uncertainties as unc
from granularhealing import cfit as ghcfit
from granularhealing import errors as gherrors
from granularhealing import plots as ghplots
from matplotlib import pyplot as plt
from matplotlib.lines import Line2D
from rstevaluation import files as rstfiles
from tqdm import tqdm

REPICK = False


def main():
    # all_data_file_path = ghfiles.ask_for_file(filetype=(("*.json", "*.json"),))
    global REPICK
    all_data_file_path = "Materials/Data/results/000_all_data.json"
    pick_folder, _ = os.path.split(all_data_file_path)
    base_folder, _ = os.path.split(pick_folder)
    data_folder = os.path.join(base_folder, "picked")
    output_folder = os.path.join(base_folder, "results")
    all_data = read_all_data(all_data_file_path)

    pe_all = []
    pe_all_e = []
    pe_all_e_conf = []
    se_all = []
    se_all_e = []
    se_all_e_conf = []
    names = []

    for ii, proj in tqdm(
        enumerate(all_data["name"]), total=len(all_data["name"])
    ):
        all_good = True
        try:
            moduli, normstress = get_moduli(
                all_data["picks"][ii], proj, data_folder
            )
        except ValueError:
            REPICK = True
            moduli, normstress = get_moduli(
                all_data["picks"][ii], proj, data_folder
            )
            REPICK = False
        except FileNotFoundError:
            all_good = False
        except gherrors.GranularHealingError:
            all_good = False
        if all_good:
            curve_fits = plot_all_moduli(
                moduli, normstress, proj, output_folder, show=False
            )
            pe_all.append(curve_fits["peak"]["e1000"])
            se_all.append(curve_fits["static"]["e1000"])
            pe_all_e.append(curve_fits["peak"]["e1000_std"])
            se_all_e.append(curve_fits["static"]["e1000_std"])
            pe_all_e_conf.append(curve_fits["peak"]["e1000_std_conf"])
            se_all_e_conf.append(curve_fits["static"]["e1000_std_conf"])
            # names.append(all_data["txt"][all_data["name"].index(proj)])
            names.append(proj)

    ghplots.create_overview_plot(
        output_folder,
        names,
        pe_all,
        pe_all_e,
        "E-Modulus of Peak (Pa)",
        "000_EModulus_peak",
    )
    ghplots.create_overview_plot(
        output_folder,
        names,
        se_all,
        se_all_e,
        "E-Modulus of Static (Pa)",
        "000_EModulus_static",
    )
    ghplots.create_overview_plot(
        output_folder,
        names,
        pe_all,
        pe_all_e_conf,
        "E-Modulus of Peak (Pa) conf.",
        "000_EModulus_peak_conf",
    )
    ghplots.create_overview_plot(
        output_folder,
        names,
        se_all,
        se_all_e_conf,
        "E-Modulus of Static (Pa) conf.",
        "000_EModulus_static_conf",
    )


def plot_all_moduli(moduli, normstress, proj, output_folder, show=True):
    """
    Creates an overview plot for the material for both peak and static picks
    versus normal stress.
    """

    fig, ax = plt.subplots(ncols=2, sharex=True, sharey=True, figsize=(10, 5))
    curve_fits = dict()
    for ii, rn in enumerate(["peak", "static"]):
        norm = np.array([n.nominal_value for n in normstress[rn]])
        mod = np.array([m.nominal_value for m in moduli[rn]])
        norm_s = np.array([n.std_dev for n in normstress[rn]])
        mod_s = np.array([m.std_dev for m in moduli[rn]])

        popt, perr, pcov = ghcfit.orthogonal_distance_regression(
            norm, norm_s, mod, mod_s, ghcfit.poly1_odr
        )
        xq = np.linspace(0, np.max(norm) * 1.15)
        lcb, ucb = ghcfit.confband(xq, norm, mod, popt, pcov, ghcfit.poly1)
        lpb, upb = ghcfit.predband(xq, norm, mod, popt, ghcfit.poly1)
        e1000_l, _ = ghcfit.predband(1000, norm, mod, popt, ghcfit.poly1)
        e1000 = ghcfit.poly1(1000, *popt)
        e1000_c, _ = ghcfit.confband(1000, norm, mod, popt, pcov, ghcfit.poly1)

        curve_fits[rn] = {
            "popt": popt,
            "perr": perr,
            "pcov": pcov,
            "e1000": e1000 / 1000,
            "e1000_std": (e1000 - e1000_l) / 1000,
            "e1000_std_conf": (e1000 - e1000_c) / 1000,
        }

        ax[ii].errorbar(
            norm, mod, xerr=norm_s, yerr=mod_s, fmt="d", color="C%i" % ii
        )
        ax[ii].plot(
            xq, ghcfit.poly1(xq, *popt), linestyle="-", color="C%i" % ii
        )
        ax[ii].plot(xq, lcb, linestyle=":", color="C%i" % ii)
        ax[ii].plot(xq, lpb, linestyle="-.", color="C%i" % ii)
        ax[ii].plot(xq, ucb, linestyle=":", color="C%i" % ii)
        ax[ii].plot(xq, upb, linestyle="-.", color="C%i" % ii)
        ax[ii].set_xlabel("Normal Stress (Pa)")
        ax[ii].annotate(
            "E-Modulus @ 1 kPa=%skPa\nincreases by\n%s Pa per Pa $\\sigma_N$"
            % (
                ghcfit.sign_str(e1000 / 1000, (e1000 - e1000_l) / 1000),
                ghcfit.sign_str(popt[0], perr[0]),
            ),
            (0.025, 0.975),
            xycoords="axes fraction",
            verticalalignment="top",
            color="C%i" % ii,
        )

    factor = unc.ufloat(
        curve_fits["static"]["e1000"], curve_fits["static"]["e1000_std"]
    ) / unc.ufloat(
        curve_fits["peak"]["e1000"], curve_fits["peak"]["e1000_std"]
    )
    ax[1].annotate(
        "%s-fold increase from peak to static" % ghcfit.sign_str(factor),
        (0.025, 0.75),
        xycoords="axes fraction",
    )

    leg_handles = [
        Line2D([], [], color="C0", label="Peak"),
        Line2D([], [], color="C1", label="Static"),
        Line2D([], [], linestyle=":", color="k", label="Conf. Band (95%)"),
        Line2D([], [], linestyle="-.", color="k", label="Pred. Band (95%)"),
    ]
    ax[0].set_ylabel("Elastic Modulus (Pa)")
    ax[1].legend(handles=leg_handles, loc="upper right", fontsize="small")
    fig.suptitle(proj)
    ax[0].set_xlim(
        0,
    )
    ax[0].set_ylim(0, 500000)
    fig.tight_layout()
    fig.savefig(os.path.join(output_folder, proj + "_EMod.png"))
    fig.savefig(os.path.join(output_folder, proj + "_EMod.pdf"))
    if show:
        plt.show()
    plt.close(fig)
    return curve_fits


def get_moduli(picks: dict, proj: str, data_folder: os.PathLike):
    """
    Gets the moduli for first load and reload phase
    """
    proj_folder = os.path.join(data_folder, proj)
    if not os.path.exists(proj_folder):
        raise FileNotFoundError("Folder for project '%s' not found" % proj)

    file_list = [
        f
        for f in os.listdir(proj_folder)
        if f.endswith((".asc", ".dat", ".tdms"))
    ]
    all_moduli = {
        "peak": [],
        "static": [],
    }
    all_normstress = {
        "peak": [],
        "static": [],
    }
    for fpath, pick in zip(file_list, picks):

        regions = get_load_regions(proj_folder, fpath, pick)
        for region_name in ["peak", "static"]:
            region = regions[region_name]
            strain, stress, slc, normstress = get_data(
                proj_folder, fpath, pick[region_name], regions["offset"]
            )
            slc = slice(region["start"], region["stop"])

            popt, perr, pcov = ghcfit.normal_fit(
                strain[slc], stress[slc], getcov=True
            )
            # plot_stress_strain_fit(strain, stress, slc, popt, pcov, perr)

            all_moduli[region_name].append(unc.ufloat(popt[0], perr[0]))
            all_normstress[region_name].append(
                unc.ufloat(normstress[0], 2 * normstress[1])
            )
    return all_moduli, all_normstress


def get_load_regions(
    proj_folder: os.PathLike, fpath: os.PathLike, picks: dict
):
    """Loads the regions for fitting the elastic moduli"""
    fname, _ = os.path.splitext(fpath)
    json_path = fname + ".json"
    full_path = os.path.join(proj_folder, json_path)
    if os.path.exists(full_path) and not REPICK:
        with open(full_path, "rt") as json_file:
            try:
                regions = json.load(json_file)
            except json.decoder.JSONDecodeError as _:
                regions = pick_load_regions(proj_folder, fpath, picks)
    else:
        regions = pick_load_regions(proj_folder, fpath, picks)
    with open(full_path, "wt") as json_file:
        json.dump(regions, json_file)
    return regions


def pick_load_regions(
    proj_folder: os.PathLike, fpath: os.PathLike, picks: dict
):
    """Gets a pick for peak and reloading"""
    offset = 150
    region_peak = pick_load_region(proj_folder, fpath, picks["peak"], offset)
    region_reload = pick_load_region(
        proj_folder, fpath, picks["static"], offset
    )
    region = {"peak": region_peak, "static": region_reload, "offset": offset}
    return region


def pick_load_region(
    proj_folder: os.PathLike, fpath: os.PathLike, pick: int, offset: int
):
    """
    Allows the user to interactively pick the regions for fitting the elastic
    modulus.
    """
    strain, stress, slc, _ = get_data(proj_folder, fpath, pick, offset)
    fig, ax = plt.subplots()
    ax.plot(strain, stress, "s-")
    ax.plot(strain[slc], stress[slc], "s")
    _, folder = os.path.split(proj_folder)
    ax.set_title(folder)
    manager = plt.get_current_fig_manager()
    manager.window.showMaximized()
    picks = plt.ginput(n=-1, show_clicks=True)
    if len(picks) < 1:
        npicks = [slc.start, slc.stop]
    else:
        npicks = [int(np.argwhere(strain >= p[0])[0][0]) for p in picks]
    pick_dict = {"start": npicks[0], "stop": npicks[1]}
    plt.close(fig)
    return pick_dict


def get_data(
    proj_folder: os.PathLike, fpath: os.PathLike, pick: int, offset: int
):
    """Loads and preprocesses data ready for fitting"""
    data = load_data(proj_folder, fpath)
    check_min_displacement(data, max_disp=0.5)  # checking if data is usable
    pick_start = pick - offset
    if pick_start < 0:
        pick_start = 0
    pick_end = pick + 10
    if pick_end > len(data["displacement"]):
        pick_end = -1
    strain = data["displacement"][pick_start:pick_end] / 40
    stress = data["shearstress"][pick_start:pick_end]
    rng = np.max(stress) - np.min(stress)
    low_end = np.min(stress) + 0.1 * rng
    hig_end = np.min(stress) + 0.9 * rng
    slc = slice(
        int(np.argwhere(stress >= low_end)[0]),
        int(np.argwhere(stress >= hig_end)[0]),
    )
    return (
        strain,
        stress,
        slc,
        (np.mean(data["normalstress"]), np.std(data["normalstress"])),
    )


def check_min_displacement(data, max_disp=0.5):
    """
    Assures that there is not more than `max_disp` displacement between each
    data point. If the displacement is to large, then we have too few data
    points to estimate the Young's Moduli from the curves.
    """

    mean_vel = np.mean(data["velocity"])
    avg_displ = mean_vel / data["time"][1]
    if avg_displ > max_disp:
        raise gherrors.GranularHealingError(
            "Displacement to small. Can't estimate E-Mod."
        )


def plot_stress_strain_fit(strain, stress, slc, popt, pcov, perr, popt2=None):
    """
    Plots the intermediate result for fitting the shear modulus to the loading curve
    """
    fig, ax = plt.subplots()
    ax.plot(strain, stress)
    # ax.plot(strain[slc], stress[slc])

    ax.plot(strain[slc], ghcfit.poly1(strain[slc], *popt))
    # ax.plot(strain[slc], ghcfit.poly1(strain[slc], popt2[0], popt[1]))
    ucf, lcf = ghcfit.confband(
        strain[slc], strain[slc], stress[slc], popt, pcov, ghcfit.poly1
    )
    ax.plot(strain[slc], ucf, color="C1", linestyle=":")
    ax.plot(strain[slc], lcf, color="C1", linestyle=":")

    ax.annotate(
        "Shear modulus: %s Pa" % ghcfit.sign_str(popt[0], perr[0]),
        (0.95, 0.05),
        color="C1",
        xycoords="axes fraction",
        horizontalalignment="right",
    )
    plt.close(fig)


def load_data(proj_folder, fpath):
    """Loads the data using rstevaluations file utility and calculates additional data"""
    load_file = {
        ".asc": rstfiles.readasc,
        ".dat": rstfiles.readdat,
        ".tdms": rstfiles.readtdms,
    }
    _, ext = os.path.splitext(fpath)
    data = load_file[ext](proj_folder, fpath)

    # Corrections
    data["time"] = data["time"] - data["time"][0]
    # data['liddispl'] = -(data['liddispl']-data['liddispl'][0])
    data["velocity"][data["velocity"] < 0] = 0

    # Additional data
    if "shearforce" in data.keys():
        data["shearstress"] = force_2_stress(data["shearforce"])
        data["normalstress"] = data["normalforce"] / 0.022619  # m^2
    data["displacement"] = np.cumsum(data["time"][1] * data["velocity"])
    return data


def read_all_data(fpath: os.PathLike):
    """
    Reads all picked data from the given json file
    """
    with open(fpath, "rt") as json_file:
        data = json.load(json_file)
    return data


def force_2_stress(force):
    """
    Converts force at the sensors to stress using the leverage principle given
    for the testing machine in ASTM Standard D6773
    """
    rs = 0.125  # m
    rm = 0.0776  # m
    ad = 0.022619  # m^2
    stress = rs * force / (rm * ad)  # Pa
    return stress


if __name__ == "__main__":
    main()
