#!/usr/bin/env python3
# This file is part of granularhealing.

# granularhealing is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.

# granularhealing is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.

# You should have received a copy of the GNU General Public License
# along with granularhealing. If not, see <https://www.gnu.org/licenses/>.
import json
import os

import granularhealing.angles_unc as ghangl_unc
import granularhealing.cfit as ghfit
import granularhealing.groups as ghgroups
import granularhealing.plots as ghplots
import matplotlib.pyplot as plt
import numpy as np
import uncertainties as unc
from tqdm import tqdm
from uncertainties import unumpy as unp


def main():
    """
    Main Function
    """
    all_data = get_all_data()
    shs_data = get_heal_data()
    save_folder = r"\Materials\angle_forces"
    heal_data = get_healing_stat_per_sample(shs_data)
    rst_data = get_fricprops_stat_per_sample(all_data)
    heal_mats = [k for k in heal_data.keys()]
    rst_mats = [k for k in rst_data.keys()]
    for h in heal_mats:
        if h not in rst_mats:
            del heal_data[h]
    for r in rst_mats:
        if r not in heal_mats:
            del rst_data[r]
    plot_opt_angles(heal_data, rst_data, save_folder)


def plot_opt_angles(heal_medians, rst_medians, save_folder):
    matcol = {
        "quartz sand": "C0",
        "corundum sand": "C1",
        "feldspar sand": "C2",
        "garnet sand": "C3",
        "glass beads": "C4",
        "zircon sand": "C5",
        "foam glass": "C6",
    }
    mean_rhos = {
        "quartz sand": unc.ufloat(1424, 145),
        "corundum sand": unc.ufloat(1797, 218),
        "feldspar sand": unc.ufloat(1160, 72),
        "garnet sand": unc.ufloat(2087, 70),
        "glass beads": unc.ufloat(1372, 91),
        "zircon sand": unc.ufloat(2569, 70),
        "foam glass": unc.ufloat(377, 163),
    }

    time = np.logspace(0, 6)
    for mat in tqdm(heal_medians.keys()):
        fig, axes = plt.subplots(ncols=2, sharex=True, figsize=(10, 6))
        react_mu = rst_medians[mat]["mur"] * time ** heal_medians[mat]
        static_mu = rst_medians[mat]["mup"] * np.ones_like(react_mu)
        opt_angle_react_mu = ghangl_unc.opt_angle(react_mu)
        lock_angle_react_mu = ghangl_unc.lock_angle(react_mu)
        opt_angle_static_mu = ghangl_unc.opt_angle(static_mu)
        angle_pre = np.ones_like(react_mu) * (
            90 - ghangl_unc.opt_angle(rst_medians[mat]["mup"])
        )

        angle_diff = angle_pre - lock_angle_react_mu
        opt_angle_diff = angle_pre - opt_angle_react_mu

        popt, perr = ghfit.normal_fit(
            np.log10(time),
            unp.nominal_values(lock_angle_react_mu),
        )

        axes[0].set_title(ghfit.sign_str(popt[0], perr[0]))

        add_erb(
            axes[0],
            time,
            angle_pre,
            lbl="Angle of preexisting fault",
            col=matcol[rst_medians[mat]["mat"]],
            mark="--",
        )
        add_erb(
            axes[0],
            time,
            opt_angle_static_mu,
            lbl="Angle of new fault",
            col=matcol[rst_medians[mat]["mat"]],
        )

        add_erb(
            axes[0],
            time,
            opt_angle_react_mu,
            lbl="Optimal angle for reactivation",
            col="k",
            mark="--",
        )

        add_erb(
            axes[0],
            time,
            angle_diff,
            lbl="misorientation",
            col="k",
        )
        add_erb(
            axes[0],
            time,
            opt_angle_diff,
            lbl="difference optimal angles",
            col="k",
            mark=":",
        )

        axes[0].fill_between(
            time,
            unp.nominal_values(lock_angle_react_mu),
            90 * np.ones_like(time),
            color="k",
            alpha=0.25,
            label="Lockup region for reactivated faults",
            zorder=-1,
        )

        force_pre = force_mulugeta(
            react_mu, angle_pre, 1, mean_rhos[rst_medians[mat]["mat"]]
        )
        force_new = force_mulugeta(
            static_mu,
            opt_angle_static_mu,
            1,
            mean_rhos[rst_medians[mat]["mat"]],
        )
        force_ratio = force_pre / force_new

        add_erb(
            axes[1],
            time,
            force_ratio,
            col=matcol[rst_medians[mat]["mat"]],
            # alpha=a,
            mark="--",
            lbl="Reactivated fault",
        )
        axes[1].annotate(
            "%s" % str(force_ratio[0]),
            (time[0], unp.nominal_values(force_ratio[0])),
        )
        axes[1].annotate(
            "%s" % str(force_ratio[int(len(time) / 3)]),
            (
                time[int(len(time) / 3)],
                unp.nominal_values(force_ratio[int(len(time) / 3)]),
            ),
        )
        axes[1].annotate(
            "%s" % str(force_ratio[2 * int(len(time) / 3)]),
            (
                time[2 * int(len(time) / 3)],
                unp.nominal_values(force_ratio[2 * int(len(time) / 3)]),
            ),
        )

        axes[0].set_xscale("log")
        for ax in axes:
            ax.set_xlabel("Time (s)")
            ax.set_xlim(10**0, 10**6)
        axes[0].set_ylabel("Angle (°)")
        axes[0].set_ylim(0, 90)

        axes[1].set_ylabel("Shear Force per unit area ($\\frac{N}{m}$)")
        rng = np.max(np.abs(axes[1].get_ylim()))
        axes[1].axhline(0, color="k")
        axes[1].fill_between(
            time,
            0 * time,
            -rng * np.ones_like(time),
            color="k",
            alpha=0.1,
            zorder=-1,
        )
        axes[1].annotate(
            "Lockup Region",
            (0.5, 0.05),
            xycoords="axes fraction",
            horizontalalignment="center",
        )
        axes[1].set_ylim(-rng, rng)
        fig.suptitle(mat)
        ghplots.numerate_axes(fig)
        fig.tight_layout()
        fig.savefig(os.path.join(save_folder, mat + "_angle_forces"))
        fig.savefig(os.path.join(save_folder, mat + "_angle_forces.pdf"))
        plt.close(fig)


def pm_log(y):
    out = np.zeros_like(y)
    for ii in range(len(y)):
        if y[ii] < 0:
            out[ii] = -unc.ufloat(
                unp.nominal_values(unp.log10(-y[ii])),
                unp.std_devs(unp.log10(-y[ii])),
            )
        else:
            out[ii] = unc.ufloat(
                unp.nominal_values(unp.log10(y[ii])),
                unp.std_devs(unp.log10(y[ii])),
            )
    return out


def add_erb(ax, x, y, col="", mark="-", lbl="", alpha=1):
    ax.plot(
        x,
        unp.nominal_values(y),
        label=lbl,
        color=col,
        linestyle=mark,
        alpha=alpha,
    )
    ax.fill_between(
        x,
        unp.nominal_values(y) + unp.std_devs(y),
        unp.nominal_values(y) - unp.std_devs(y),
        color=col,
        alpha=alpha * 0.25,
    )


def force_mulugeta(mu, angle_in, h, rho):
    angle = ghangl_unc.deg2rad(angle_in)
    a = (rho * 9.81 * h**2 * ghangl_unc.cot(angle)) / 2
    b = (mu + unp.tan(angle)) / (1 - mu * unp.tan(angle))
    result = a * b
    for ii in range(len(result)):
        if ii > 0 and result[ii - 1] > 0 and result[ii] < 0:
            result[ii] = np.nan
        # elif result[ii] > 1000:
        #     result[ii] = np.nan
    return result


def get_all_data():
    data_path = r"\Materials\Data\results\000_all_data.json"
    with open(data_path, "rt") as json_file:
        all_data = json.load(json_file)
    return all_data


def get_heal_data():
    data_path = r"\Materials\Data\SHS-Data\all_shs_data.json"
    with open(data_path, "rt") as json_file:
        all_data = json.load(json_file)
    return all_data


def get_healing_stat_per_sample(shs_data, stat_fnc=np.mean):
    """Gets the given statistical measure from the shs data for each material"""
    group_dict = ghgroups.get_group_dict()
    output = dict()
    for mat in group_dict["annot_mat"].values():
        output[mat] = []

    for ii, k in enumerate(shs_data["exp_names"]):
        mat = group_dict["annot_mat"][k]
        output[mat].append(
            unc.ufloat(
                shs_data["healing_rates"][ii],
                shs_data["healing_rates_err"][ii],
            ),
        )
    dickeys = [k for k in output.keys()]
    for k in dickeys:
        if output[k]:
            output[k] = stat_fnc(output[k])
        else:
            del output[k]
    return output


def get_fricprops_stat_per_sample(rst_data, stat_fnc=np.mean):
    """Gets the given statistical measure from the rst data for each material"""
    group_dict = ghgroups.get_group_dict()
    output = dict()
    for mat in group_dict["annot_mat"].values():
        output[mat] = {
            "mup": [],
            "mur": [],
            "cp": [],
            "cr": [],
            "mat": "",
        }
    in_data = {
        "mup": unp.uarray(rst_data["mup"], rst_data["mupe"]),
        "mur": unp.uarray(rst_data["mur"], rst_data["mure"]),
        "cp": unp.uarray(rst_data["cp"], rst_data["cpe"]),
        "cr": unp.uarray(rst_data["cr"], rst_data["cre"]),
    }
    for ii, k in enumerate(rst_data["name"]):
        mat = group_dict["annot_mat"][k]
        for pick in ["mup", "mur", "cp", "cr"]:
            output[mat][pick].append(
                in_data[pick][ii],
            )
        output[mat]["mat"] = group_dict["mats"][group_dict["mat_group"][k]]
    dickeys = [k for k in output.keys()]
    for k in dickeys:
        if output[k]["mat"]:
            for pick in ["mup", "mur", "cp", "cr"]:
                output[k][pick] = stat_fnc(output[k][pick])
        else:
            del output[k]
    return output


if __name__ == "__main__":
    main()
