#!/usr/bin/env python3
# This file is part of granularhealing.

# granularhealing is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.

# granularhealing is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.

# You should have received a copy of the GNU General Public License
# along with granularhealing. If not, see <https://www.gnu.org/licenses/>.
"""
Analysis functions for granular healing.

 - Getting means of data.
 - Curve fitting.
 - Statistics.

"""

import numpy as np
import uncertainties as unc
import uncertainties.unumpy as unp

import granularhealing.groups as ghgroups


def get_vel(exp):
    """
    Calculates average velocity (wrapper for cleanliness)
    """
    return np.mean(exp["data"]["velocity"][:100])


def get_normload(exp):
    """
    Calculates normal load (wrapper for cleanliness)
    """
    mean = np.mean(exp["data"]["normal"][:100])
    try:
        try:
            device = exp["data"]["device"].decode()
        except AttributeError:
            device = exp["data"]["device"]
        if device != "spring-slider":
            return mean / 0.022619
        else:
            return mean
    except KeyError as _:
        return mean / 0.022619


def get_healing_stat_per_mat(shs_data, stat_fnc=np.mean):
    """Gets the given statistical measure from the shs data for each material"""
    group_dict = ghgroups.get_group_dict()
    output = dict()
    for mat in group_dict["mats"]:
        output[mat] = []

    for ii, k in enumerate(shs_data["exp_names"]):
        mat = group_dict["mats"][group_dict["mat_group"][k]]
        output[mat].append(
            unc.ufloat(
                shs_data["healing_rates"][ii],
                shs_data["healing_rates_err"][ii],
            ),
        )
    for k in output.keys():
        output[k] = stat_fnc(output[k])
    return output


def get_fricprops_stat_per_mat(rst_data, stat_fnc=np.mean):
    """Gets the given statistical measure from the rst data for each material"""
    group_dict = ghgroups.get_group_dict()
    output = dict()
    for mat in group_dict["mats"]:
        output[mat] = {
            "mup": [],
            "mur": [],
            "cp": [],
            "cr": [],
        }
    in_data = {
        "mup": unp.uarray(rst_data["mup"], rst_data["mupe"]),
        "mur": unp.uarray(rst_data["mur"], rst_data["mure"]),
        "cp": unp.uarray(rst_data["cp"], rst_data["cpe"]),
        "cr": unp.uarray(rst_data["cr"], rst_data["cre"]),
    }
    for ii, k in enumerate(rst_data["name"]):
        mat = group_dict["mats"][group_dict["mat_group"][k]]
        for pick in ["mup", "mur", "cp", "cr"]:
            output[mat][pick].append(
                in_data[pick][ii],
            )
    for k in output.keys():
        for pick in ["mup", "mur", "cp", "cr"]:
            output[k][pick] = stat_fnc(output[k][pick])
    return output


def get_healing_stat_per_sample(shs_data, stat_fnc=np.mean):
    """Gets the given statistical measure from the shs data for each sample"""
    group_dict = ghgroups.get_group_dict()
    output = dict()
    for mat in group_dict["annot_mat"].values():
        output[mat] = []

    for ii, k in enumerate(shs_data["exp_names"]):
        mat = group_dict["annot_mat"][k]
        output[mat].append(
            unc.ufloat(
                shs_data["healing_rates"][ii],
                shs_data["healing_rates_err"][ii],
            ),
        )
    dickeys = [k for k in output.keys()]
    for k in dickeys:
        if output[k]:
            output[k] = stat_fnc(output[k])
        else:
            del output[k]
    return output


def get_more_fricprops_stat_per_sample(rst_data, stat_fnc=np.mean):
    """Gets the given statistical measure from the rst data for each material"""
    group_dict = ghgroups.get_group_dict()
    output = dict()
    for mat in group_dict["annot_mat"].values():
        output[mat] = {
            "mup": [],
            "mur": [],
            "mud": [],
            "cp": [],
            "cr": [],
            "cd": [],
            "mat": "",
        }
    in_data = {
        "mup": unp.uarray(rst_data["mup"], rst_data["mupe"]),
        "mur": unp.uarray(rst_data["mur"], rst_data["mure"]),
        "mud": unp.uarray(rst_data["mud"], rst_data["mude"]),
        "cp": unp.uarray(rst_data["cp"], rst_data["cpe"]),
        "cr": unp.uarray(rst_data["cr"], rst_data["cre"]),
        "cd": unp.uarray(rst_data["cd"], rst_data["cde"]),
    }
    for ii, k in enumerate(rst_data["name"]):
        mat = group_dict["annot_mat"][k]
        for pick in ["mup", "mur", "mud", "cp", "cr", "cd"]:
            output[mat][pick].append(
                in_data[pick][ii],
            )
        output[mat]["mat"] = group_dict["mats"][group_dict["mat_group"][k]]
    dickeys = [k for k in output.keys()]
    for k in dickeys:
        if output[k]["mat"]:
            for pick in ["mup", "mur", "mud", "cp", "cr", "cd"]:
                output[k][pick] = stat_fnc(output[k][pick])
        else:
            del output[k]
    return output


def get_more_healing_stat_per_sample(shs_data, stat_fnc=np.mean):
    """Gets the given statistical measure from the shs data for each sample plus a bit more"""
    group_dict = ghgroups.get_group_dict()
    output = dict()
    for mat in group_dict["annot_mat"].values():
        output[mat] = {
            "comp_rates": [],
            "healing_rates": [],
            "str_rates": [],
            "rel_mods": [],
            "mat": "",
        }
    in_data = {
        "comp_rates": unp.uarray(
            shs_data["comp_rates"], shs_data["comp_rates_e"]
        ),
        "healing_rates": unp.uarray(
            shs_data["healing_rates"], shs_data["healing_rates_err"]
        ),
        "str_rates": unp.uarray(
            shs_data["str_rates"], shs_data["str_rates_e"]
        ),
        "rel_mods": unp.uarray(shs_data["rel_mods"], shs_data["rel_mods_e"]),
    }
    for ii, k in enumerate(shs_data["exp_names"]):
        mat = group_dict["annot_mat"][k]
        for pick in ["comp_rates", "healing_rates", "str_rates", "rel_mods"]:
            output[mat][pick].append(
                in_data[pick][ii],
            )
        output[mat]["mat"] = group_dict["mats"][group_dict["mat_group"][k]]
    dickeys = [k for k in output.keys()]
    for k in dickeys:
        if output[k]["mat"]:
            for pick in [
                "comp_rates",
                "healing_rates",
                "str_rates",
                "rel_mods",
            ]:
                output[k][pick] = stat_fnc(output[k][pick])
        else:
            del output[k]
    return output
