#!/usr/bin/env python3
# This file is part of granularhealing.

# granularhealing is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.

# granularhealing is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.

# You should have received a copy of the GNU General Public License
# along with granularhealing. If not, see <https://www.gnu.org/licenses/>.

"""
Functions for creating publication ready figures. The naming convention of the
functions follows the order in the publication.
"""

import json
import os

import animation
import numpy as np
from tqdm import tqdm

import granularhealing.analysis as ghanalysis
import granularhealing.config as ghconfig
import granularhealing.files as ghfiles
import granularhealing.plots as ghplots


@animation.wait("bar", color="green")
def figure1_shs_timeseries(base_folder):
    """
    Plots a time series of a slide-hold-slide test
    """
    data_folder = r"Data\SHS_Data\19_Glasperlen70-110GFZ_1kPa"
    exp_folder = os.path.join(base_folder, data_folder)
    file_list = [
        os.path.join(exp_folder, f)
        for f in os.listdir(exp_folder)
        if f.endswith(".tdms")
    ]

    exp = ghfiles.open_tdms(file_list[0])
    time = exp["data"]["time"]
    fric = exp["data"]["shear"] / exp["data"]["normal"]
    slc = slice(135000, 160000)
    slc_fric = np.concatenate(
        (fric[14500:70000], fric[85000:139000], fric[157000:-1])
    )

    fig = ghplots.plot_example_shs(time, fric, slc, slc_fric)
    fig.tight_layout()
    fig.savefig(os.path.join(base_folder, r"Output\Figure1_SHS_TimeSeries"))
    fig.savefig(
        os.path.join(base_folder, r"Output\Figure1_SHS_TimeSeries.pdf")
    )


def figure4_5_healing_rates(base_folder):
    """Automatically Pick SHS-Files and create all overview plots"""

    selected_folder = os.path.join(base_folder, r"Data\SHS_Data")
    output_folder = os.path.join(base_folder, "Output")
    folder_list = [
        os.path.join(selected_folder, f)
        for f in os.listdir(selected_folder)
        if (
            os.path.isdir(os.path.join(selected_folder, f))
            and f.endswith("kPa")
        )
    ]

    # Options
    repick = False
    create_shs_simple = True
    create_overviews = True
    create_hold_compaction = True
    get_reloading_modulus = True

    folder_list.sort()
    healing_rates = []
    healing_rates_err = []
    stable_fri = []
    stable_fri_e = []
    str_rates = []
    str_rates_e = []
    comp_rates = []
    comp_rates_e = []
    rel_mod = []
    rel_mod_e = []

    for exp_folder_path in tqdm(folder_list):
        exp_name = os.path.split(exp_folder_path)[-1]
        cfg_file_path = os.path.join(exp_folder_path, exp_name + ".ini")
        cfg = ghconfig.get_config(cfg_file_path)
        if repick:
            cfg.set("main", "overwrite", "True")
        else:
            cfg.set("main", "overwrite", "False")

        if create_shs_simple:
            popt, perr, stbl, stbl_e = ghplots.plot_shs_simple(
                exp_folder_path, cfg
            )
            if create_overviews:
                healing_rates.append(popt)
                healing_rates_err.append(perr)
                stable_fri.append(stbl)
                stable_fri_e.append(stbl_e)

        if create_hold_compaction:
            sr, se, cr, ce = ghplots.plot_hold_compact(exp_folder_path, cfg)
            if create_overviews:
                str_rates.append(sr)
                str_rates_e.append(se)
                comp_rates.append(cr)
                comp_rates_e.append(ce)

        if get_reloading_modulus:
            rm, rme = ghplots.plot_reloading_modulus(exp_folder_path, cfg)
            if create_overviews:
                rel_mod.append(rm)
                rel_mod_e.append(rme)

    if create_overviews:
        exp_names = [os.path.split(efp)[-1] for efp in folder_list]
        if healing_rates:
            ghplots.create_overview_plot(
                output_folder,
                exp_names,
                healing_rates,
                healing_rates_err,
                xlabel=r"Healing rates $b$ (per decade)",
                fname="Figure4a_AllHealing",
            )
        # if stable_fri:
        #     ghplots.create_overview_plot(
        #         output_folder,
        #         exp_names,
        #         stable_fri,
        #         stable_fri_e,
        #         xlabel=r"Apparent frictional sliding strength $\mu=\frac{\tau}{\sigma_N}$ ()",
        #         fname="AllStable",
        #     )
        # if str_rates:
        #     ghplots.create_overview_plot(
        #         output_folder,
        #         exp_names,
        #         str_rates,
        #         str_rates_e,
        #         xlabel=r"Stressing rate during hold (per decade)",
        #         fname="AllStressing",
        #     )
        if comp_rates:
            ghplots.create_overview_plot(
                output_folder,
                exp_names,
                comp_rates,
                comp_rates_e,
                xlabel=r"Compaction rate during hold (per decade)",
                fname="Figure4b_AllCompaction",
            )

        if comp_rates and healing_rates:
            ghplots.correlate_plot(
                output_folder,
                exp_names,
                comp_rates,
                comp_rates_e,
                healing_rates,
                healing_rates_err,
                xlabel=r"Compaction rate during hold (per decade)",
                ylabel=r"Healing rates $b$ (per decade)",
                fname="Figure5_CompactionVsHealing",
            )

        # if rel_mod:
        #     ghplots.create_overview_plot(
        #         output_folder,
        #         exp_names,
        #         rel_mod,
        #         rel_mod_e,
        #         xlabel=r"Shear modulus increase (Pa per decade)",
        #         fname="AllShearMod",
        #     )

        # if comp_rates and rel_mod:
        #     ghplots.correlate_plot(
        #         output_folder,
        #         exp_names,
        #         comp_rates,
        #         comp_rates_e,
        #         rel_mod,
        #         rel_mod_e,
        #         xlabel=r"Compaction rate during hold (per decade)",
        #         ylabel=r"Shear modulus increase (Pa per decade)",
        #         fname="CompactionVsShearMod",
        #     )
        # if healing_rates and rel_mod:
        #     ghplots.correlate_plot(
        #         output_folder,
        #         exp_names,
        #         healing_rates,
        #         healing_rates_err,
        #         rel_mod,
        #         rel_mod_e,
        #         xlabel=r"Healing rates $b$ (per decade)",
        #         ylabel=r"Shear modulus increase (Pa per decade)",
        #         fname="HealingVsShearMod",
        #     )

        all_shs_data = {
            "exp_names": exp_names,
            "healing_rates": healing_rates,
            "healing_rates_err": healing_rates_err,
            "stable_fri": stable_fri,
            "stable_fri_e": stable_fri_e,
            "str_rates": str_rates,
            "str_rates_e": str_rates_e,
            "comp_rates": comp_rates,
            "comp_rates_e": comp_rates_e,
            "rel_mods": rel_mod,
            "rel_mods_e": rel_mod_e,
        }

        with open(
            os.path.join(selected_folder, "all_shs_data.json"), "wt"
        ) as jsonfile:
            json.dump(all_shs_data, jsonfile, indent=4, sort_keys=True)


@animation.wait("bar", color="green")
def figure6_forces(base_folder):
    """
    Compares forces needed for reactivation
    """
    rst_data = ghfiles.get_all_data(base_folder)
    shs_data = ghfiles.get_heal_data(base_folder)
    save_folder = os.path.join(base_folder, "Output")

    heal_medians = ghanalysis.get_healing_stat_per_mat(shs_data)
    rst_medians = ghanalysis.get_fricprops_stat_per_mat(rst_data)
    ghplots.plot_force_mulugeta(heal_medians, rst_medians, save_folder)


@animation.wait("bar", color="green")
def figure6_angles(base_folder):
    """Compares optimal angles for measured healing rates"""
    rst_data = ghfiles.get_all_data(base_folder)
    shs_data = ghfiles.get_heal_data(base_folder)
    save_folder = os.path.join(base_folder, "Output")

    heal_medians = ghanalysis.get_healing_stat_per_mat(shs_data)
    rst_medians = ghanalysis.get_fricprops_stat_per_mat(rst_data)
    ghplots.plot_opt_angles(heal_medians, rst_medians, save_folder)


@animation.wait("bar", color="green")
def figure3_grain_char(base_folder):
    """Overview plot of all quality scores"""
    csv_path = os.path.join(base_folder, r"Data", "quality_scores.csv")
    save_folder = os.path.join(base_folder, "Output")
    quality_scores = ghfiles.get_quality_score(csv_path)
    ghplots.plot_quality_score(quality_scores, save_folder)


@animation.wait("bar", color="green")
def figure7_qualiscore_friction(base_folder):
    """Compares all quality scores with various friction measures."""
    csv_path = os.path.join(base_folder, r"Data", "quality_scores.csv")
    save_folder = os.path.join(base_folder, "Output")

    quality_scores = ghfiles.get_more_quality_score(csv_path)
    heal_data = ghanalysis.get_more_healing_stat_per_sample(
        ghfiles.get_heal_data(base_folder)
    )
    rst_data = ghanalysis.get_more_fricprops_stat_per_sample(
        ghfiles.get_all_data(base_folder)
    )

    # Clean all non-existent quality scores from data
    qual_mats = quality_scores["annot_mat"]
    heal_mats = [k for k in heal_data.keys()]
    for h in heal_mats:
        if h not in qual_mats:
            del heal_data[h]
    rst_mats = [k for k in rst_data.keys()]
    for r in rst_mats:
        if r not in qual_mats:
            del rst_data[r]

    # Clean all non-existent heal_mats
    heal_mats = [k for k in heal_data.keys()]
    rst_mats = [k for k in rst_data.keys()]
    for r in rst_mats:
        if r not in heal_mats:
            del rst_data[r]
    slc = []
    qual_mats = quality_scores["annot_mat"]
    for ii, r in enumerate(qual_mats):
        if r in heal_mats:
            slc.append(ii)
    for k in quality_scores.keys():
        quality_scores[k] = [quality_scores[k][ii] for ii in slc]

    # Clean all non-existent rst_mats
    heal_mats = [k for k in heal_data.keys()]
    rst_mats = [k for k in rst_data.keys()]
    for r in heal_mats:
        if r not in rst_mats:
            del heal_data[r]
    slc = []
    qual_mats = quality_scores["annot_mat"]
    for ii, r in enumerate(qual_mats):
        if r in rst_mats:
            slc.append(ii)
    for k in quality_scores.keys():
        quality_scores[k] = [quality_scores[k][ii] for ii in slc]

    pears, rhos = ghplots.plot_quali_fric(
        quality_scores, heal_data, rst_data, save_folder
    )
