#!/usr/bin/env python3
# This file is part of granularhealing.

# granularhealing is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.

# granularhealing is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.

# You should have received a copy of the GNU General Public License
# along with granularhealing. If not, see <https://www.gnu.org/licenses/>.

"""
Plotting functionality for various plots e.g.:

 - Individual axis plots.
 - Stylesheets
 - Plot prototypes

"""

import configparser
import os
import string

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import uncertainties as unc
from matplotlib.figure import Figure
from matplotlib.gridspec import GridSpec
from matplotlib.lines import Line2D
from matplotlib.offsetbox import AnchoredText
from mpl_toolkits.axes_grid1 import make_axes_locatable
from pathvalidate import sanitize_filename
from scipy.optimize import curve_fit
from scipy.stats import pearsonr, spearmanr
from uncertainties import unumpy as unp

import granularhealing.angles_unc as ghangl_unc
import granularhealing.cfit as ghcfit
import granularhealing.groups as ghgroups
import granularhealing.processing as ghprocess
from granularhealing.main import SAVEINTERMEDS

# from scipy.stats import pearsonr


def plot_reloading_modulus(
    folder_path: os.PathLike, cfg: configparser.ConfigParser
):
    """
    Calculates the elastic modulus during reloading
    """
    mpl.use("Agg")
    main_path, proj_name = os.path.split(folder_path)
    save_folder = os.path.join(main_path, "elasticity_plots")
    os.makedirs(save_folder, exist_ok=True)
    save_path = os.path.join(save_folder, proj_name)
    picks = ghprocess.get_pick_data(folder_path, cfg)
    emods = [p[0][0] for p in picks["reload_region"]]
    emods_e = [p[1][0] for p in picks["reload_region"]]
    # Create Plot
    fig, ax = plt.subplots()
    ax.errorbar(
        picks["hold_times"],
        emods,
        yerr=emods_e,
        linewidth=0,
        elinewidth=1,
        marker="s",
        capsize=2,
    )
    # Fit data and plot fit
    (xq, yq, ucb, lcb, popt, perr) = get_fit(
        np.log10(picks["hold_times"]),
        emods,
    )
    ax.semilogx(10**xq, yq, color="C1")
    ax.semilogx(10**xq, lcb, color="C1", linestyle=":")
    ax.semilogx(10**xq, ucb, color="C1", linestyle=":")
    ax.add_artist(
        AnchoredText(
            "b = %s" % ghcfit.sign_str(popt[0], perr[0]), loc="lower right"
        )
    )
    ax.set_xscale("log")
    ax.set_ylim(5000, 25000)
    ax.set_xlabel(r"hold time $t_h$ (s)")
    ax.set_ylabel(r"shear modulus $E$ (Pa)")
    # Final adjustments and save
    fig.tight_layout()
    if SAVEINTERMEDS:
        fig.savefig(save_path)
        fig.savefig(save_path + ".pdf")
    plt.close(fig)
    return popt[0], perr[0]


def correlate_plot(
    base_folder, exp_names, x, xe, y, ye, xlabel, ylabel, fname="all"
):
    """
    Creates a simple x vs. y plot to see if things are correlated
    """
    # Samples that are outside the scope but should still be plot in the
    # overview
    outcasts = ghgroups.get_outcasts()
    # outcast_list = [k for k in outcasts.keys()]
    group_dict = ghgroups.get_group_dict()
    fig, ax = plt.subplots()

    xf = []
    xfe = []
    yf = []
    yfe = []
    for ii in range(len(x)):
        ax.errorbar(
            x[ii],
            y[ii],
            xerr=xe[ii],
            yerr=ye[ii],
            marker="s",
            capsize=2,
            color="C%i" % group_dict["mat_group"][exp_names[ii]],
        )
        if exp_names[ii] not in outcasts.keys():
            xf.append(x[ii])
            xfe.append(xe[ii])
            yf.append(y[ii])
            yfe.append(ye[ii])

    ax.set_xlabel(xlabel)
    ax.set_ylabel(ylabel)
    leg_handles = [
        Line2D([], [], color="C%i" % jj, label=lbl, marker="s")
        for jj, lbl in enumerate(group_dict["mats"])
    ]
    ax.legend(handles=leg_handles, loc="best", fontsize="small")
    fig.tight_layout()
    fig.savefig(os.path.join(base_folder, fname))
    fig.savefig(os.path.join(base_folder, fname + ".pdf"))
    plt.close(fig)


def plot_shs_simple(folder_path: str, cfg: configparser.ConfigParser):
    """
    Creates a very simple plot of hold times versus reactivation strength for
    the experimental series in the folder_path.
    """
    # Set Paths and get picks
    mpl.use("Agg")
    main_path, proj_name = os.path.split(folder_path)
    save_folder = os.path.join(main_path, "shs_plots")
    os.makedirs(save_folder, exist_ok=True)
    save_path = os.path.join(save_folder, proj_name)
    picks = ghprocess.get_pick_data(folder_path, cfg)

    # Create Plot
    fig, ax = plt.subplots()
    ax.semilogx(
        picks["hold_times"],
        np.array(picks["picked_peaks"]) - picks["stable_friction"],
        "s",
        base=10,
    )

    # Fit data and plot fit
    (xq, yq, ucb, lcb, popt, perr) = get_fit(
        np.log10(picks["hold_times"]),
        np.array(picks["picked_peaks"]) - picks["stable_friction"],
    )
    ax.semilogx(10**xq, yq, color="C1")
    ax.semilogx(10**xq, lcb, color="C1", linestyle=":")
    ax.semilogx(10**xq, ucb, color="C1", linestyle=":")
    ax.add_artist(
        AnchoredText(
            "b = %s" % ghcfit.sign_str(popt[0], perr[0]), loc="lower right"
        )
    )

    # Make axis labeling consistent
    ax.set_xlabel(r"hold time $t_h$ (s)")
    ax.set_ylabel(r"apparent friction change $\Delta\mu_p$")
    group_dict = ghgroups.get_group_dict()
    ax.set_title(group_dict["annot"][proj_name])
    ax.set_xlim(1, 10**5)
    ax.set_ylim(-0.02, 0.2)

    # Final adjustments and save
    fig.tight_layout()
    if SAVEINTERMEDS:
        fig.savefig(save_path)
        fig.savefig(save_path + ".pdf")
    plt.close(fig)
    return (
        popt[0],
        perr[0],
        picks["stable_friction"],
        picks["stable_friction_err"],
    )


def plot_hold_compact(folder_path: str, cfg: configparser.ConfigParser):
    """
    Creates a very simple plot of hold times versus reactivation strength for
    the experimental series in the folder_path.
    """
    # Set paths and load dataset
    mpl.use("Agg")
    main_path, proj_name = os.path.split(folder_path)
    save_folder = os.path.join(main_path, "hold_compaction")
    os.makedirs(save_folder, exist_ok=True)
    save_path = os.path.join(save_folder, proj_name + "_hold_compaction")
    picks = ghprocess.get_pick_data(folder_path, cfg)

    # Create Plot
    fig, axes = plt.subplots(nrows=2, sharex=True)
    str_rate, str_rate_e = add_hold_compact(
        axes[0],
        picks["hold_times"][:-1],
        np.array(picks["init_hold_fri"][:-1])
        - np.array(picks["pre_reload_fri"][:-1]),
        "stressing rate",
    )

    comp_rate, comp_rate_e = add_hold_compact(
        axes[1],
        picks["hold_times"][:-1],
        np.array(picks["init_hold_lid"][:-1])
        - np.array(picks["pre_reload_lid"][:-1]),
        "compaction rate",
    )

    axes[1].set_xlabel("Hold Time (s)")
    axes[0].set_ylabel("Friction change ()")
    axes[1].set_ylabel("Compaction (mm)")
    # axes[0].set_ylim(-0.1, 0.25)
    axes[1].set_ylim(-0.3, 0.1)
    fig.tight_layout()
    if SAVEINTERMEDS:
        fig.savefig(save_path)
    plt.close(fig)
    return (str_rate, str_rate_e, comp_rate, comp_rate_e)


def add_hold_compact(ax, x, y, lbl):
    ax.semilogx(x, y, "s")

    # Fit data and plot fit
    (xq, yq, ucb, lcb, popt, perr) = get_fit(np.log10(x), y)
    ax.semilogx(10**xq, yq, color="C1")
    ax.semilogx(10**xq, lcb, color="C1", linestyle=":")
    ax.semilogx(10**xq, ucb, color="C1", linestyle=":")
    ax.add_artist(
        AnchoredText(
            "%s = %s" % (lbl, ghcfit.sign_str(popt[0], perr[0])),
            loc="lower right",
        )
    )
    return (popt[0], perr[0])


def get_fit(x, y, func=ghcfit.poly1, xq=None, popt=None, perr=None, pcov=None):
    """Gets fit for plotting"""

    if not isinstance(xq, np.ndarray):
        xq = np.linspace(0, 5)
    if not isinstance(popt, np.ndarray):
        popt, pcov = curve_fit(func, x, y)
        perr = 2 * np.sqrt(np.diag(pcov))
    ucb, lcb = ghcfit.confband(xq, x, y, popt, pcov, func)
    yq = func(xq, *popt)

    return (xq, yq, ucb, lcb, popt, perr)


def summarize_data(exp_names, x_data, x_err):
    """
    Uses the annotation field to summarize all results for individual
    materials and returns them ready to plot with the overview.
    """
    group_dict = ghgroups.get_group_dict()

    # create a dictionary with all unique sample names
    sample_names = [group_dict["annot"][exp] for exp in exp_names]
    unique_names = np.unique(sample_names)
    smrzr = dict()
    for un in unique_names:
        smrzr[un] = dict()

    # Fill summarizer with data
    for exp, x, xe in zip(exp_names, x_data, x_err):
        name = group_dict["annot"][exp]
        is_bench = exp in ghgroups.get_benchmarks()
        if not smrzr[name]:
            smrzr[name] = {
                "x": [x],
                "xe": [xe],
                "exp_name": [exp],
                "has_bench": [is_bench],
            }
        else:
            smrzr[name]["x"].append(x)
            smrzr[name]["xe"].append(xe)
            smrzr[name]["exp_name"].append(exp)
            smrzr[name]["has_bench"].append(is_bench)

    # Now check each entry in the summarizer and summarize
    exp_names_smr = []
    x_data_smr = []
    x_err_smr = []
    for k in smrzr:
        if len(smrzr[k]["x"]) == 1:
            exp_names_smr.append(smrzr[k]["exp_name"][0])
            x_data_smr.append(smrzr[k]["x"][0])
            x_err_smr.append(smrzr[k]["xe"][0])
        else:
            # If one of the materials is in benchmark make clear that this is # respected.
            if any(smrzr[k]["has_bench"]):
                exp_name = smrzr[k]["exp_name"][
                    smrzr[k]["has_bench"].index(True)
                ]
            else:
                exp_name = smrzr[k]["exp_name"][0]
            exp_names_smr.append(exp_name)
            data = unp.uarray(smrzr[k]["x"], smrzr[k]["xe"])
            data_mean = np.mean(data)
            x_data_smr.append(data_mean.nominal_value)
            x_err_smr.append(data_mean.std_dev)

    return (exp_names_smr, x_data_smr, x_err_smr)


def create_overview_plot(
    base_folder,
    exp_names,
    x,
    x_err,
    xlabel,
    fname="all",
    summarize=True,
    ax=None,
    leg_axis=None,
    title=None,
):
    """Creates an overview of data in x"""
    if summarize:
        exp_names, x, x_err = summarize_data(exp_names, x, x_err)

    num_pts = len(x)
    height = 1 / 8  # Height allocated per entry
    sort_indices = np.argsort(x)

    exp_names = np.array(exp_names)[sort_indices]
    x = np.array(x)[sort_indices]
    x_err = np.array(x_err)[sort_indices]
    group_dict = ghgroups.get_group_dict()
    if ax:
        add_overview_to_axis(
            exp_names,
            ax,
            x,
            x_err,
            group_dict,
            xlabel,
            title,
            leg_axis=leg_axis,
        )
        ax.set_frame_on(False)
        ax.grid(True, which="major", axis="x")
    else:
        fig, ax = plt.subplots(figsize=(8, 2 + height * num_pts))
        add_overview_to_axis(
            exp_names, ax, x, x_err, group_dict, xlabel, title
        )
        ax.set_frame_on(False)
        ax.grid(True, which="major", axis="x")
        fig.tight_layout()
        fig.savefig(os.path.join(base_folder, fname))
        fig.savefig(os.path.join(base_folder, fname + ".pdf"))
        plt.close(fig)


def add_overview_to_axis(
    exp_names, ax, x, x_err, group_dict, xlabel, title, leg_axis=False
):
    """
    Adds overview data to the axis
    """
    data_range = np.max(x + x_err) - np.min(x - x_err)

    for ii in range(len(x)):
        # num = exp_names[ii].split("-")[0]
        if exp_names[ii] in ghgroups.get_benchmarks():
            mark = "*"
        else:
            mark = "s"
        ax.errorbar(
            x[ii],
            ii,
            xerr=x_err[ii],
            linestyle=None,
            marker=mark,
            capsize=2,
            color="C%i" % group_dict["mat_group"][exp_names[ii]],
        )
        ax.annotate(
            group_dict["annot"][exp_names[ii]],  # + " (%s)" % num,
            (x[ii] + x_err[ii] + 0.05 * data_range, ii),
            color="C%i" % group_dict["mat_group"][exp_names[ii]],
            verticalalignment="center",
            annotation_clip=False,
        )

    if title:
        ax.set_title(title)
    ax.set(yticklabels=[])  # remove the tick labels
    ax.tick_params(left=False)  # remove the ticks
    ax.set_xlabel(xlabel)

    if not leg_axis:
        leg_handles = get_leg_handles(group_dict)
        ax.legend(handles=leg_handles, loc="best", fontsize="small")


def get_leg_handles(group_dict=None):
    """Creates legend handles to match with data"""
    if not group_dict:
        group_dict = ghgroups.get_group_dict()
    leg_handles = [
        Line2D(
            [],
            [],
            color="C%i" % jj,
            label=lbl,
            marker="s",
            linewidth=0,
        )
        for jj, lbl in enumerate(group_dict["mats"])
    ]
    leg_handles.append(
        Line2D(
            [],
            [],
            color="k",
            marker="*",
            linewidth=0,
            label="in Klinkmüller\net al., 2016",
        )
    )
    leg_handles.append(
        Line2D(
            [],
            [],
            color="k",
            marker="s",
            label="new materials",
            linewidth=0,
        )
    )
    return leg_handles


def plot_react_with_healing_err(
    rho,
    mu_p,
    coh_p,
    mu_s,
    coh_s,
    b,
    reference_pressure=1000,
    base_folder="",
    fname="",
    plot_differences=True,
):
    """
    Plots the difference between new and reactivated fault pressure depending on time
    """
    time = np.logspace(0, 4, 100)
    h = reference_pressure / (rho * 9.81)

    angle_ext = ghangl_unc.active_fault_angle(mu_p)
    angle_comp = ghangl_unc.passive_fault_angle(mu_s)
    pressures_new = []
    pressures_pre = []
    pressures_new_tri = []
    pressures_pre_tri = []
    for t in time:
        mu_s_new = mu_s * t**b
        pressure_compr_pre = ghangl_unc.passive_pressure_coulomb(
            h, mu_s_new, angle_ext, coh_s, rho
        )
        pressure_compr_new = ghangl_unc.passive_pressure_coulomb(
            h, mu_p, angle_comp, coh_p, rho
        )
        pressures_new.append(pressure_compr_new)
        pressures_pre.append(pressure_compr_pre)

        pressure_compr_pre_tri = ghangl_unc.force_triangle(
            h, mu_s_new, angle_ext, rho
        )
        pressure_compr_new_tri = ghangl_unc.force_triangle(
            h, mu_p, angle_comp, rho
        )
        pressures_new_tri.append(pressure_compr_new_tri)
        pressures_pre_tri.append(pressure_compr_pre_tri)

    fig, ax = plt.subplots()
    if plot_differences:
        difference = np.array(pressures_new) - np.array(pressures_pre)
        difference_tri = np.array(pressures_new_tri) - np.array(pressures_pre)
        rng = np.max(np.abs(unp.nominal_values(difference)))

        ax.errorbar(
            time,
            unp.nominal_values(difference),
            unp.std_devs(difference),
            label="Coulomb Analysis (with cohesion)",
        )
        ax.errorbar(
            time,
            unp.nominal_values(difference_tri),
            unp.std_devs(difference_tri),
            label="Force Triangle (no cohesion)",
        )
        ax.set_xscale("log")
        ax.annotate(
            "Reactivation is preferred",
            (0.95 * np.max(time), 0.15 * rng),
            horizontalalignment="right",
            verticalalignment="center",
        )
        ax.annotate(
            "New fault is preferred",
            (0.95 * np.max(time), -0.15 * rng),
            horizontalalignment="right",
            verticalalignment="center",
        )
        ax.axhline(0, color="k")
        ax.set_ylim(-rng * 1.1, rng * 1.1)
        ax.set_ylabel("Contact Force Difference Along Edge (N/m)")
    else:
        ax.errorbar(
            time,
            unp.nominal_values(pressures_pre),
            unp.std_devs(pressures_pre),
            label="Pressure for pre-existing fault at angle of %s°"
            % (ghcfit.sign_str(angle_ext)),
        )
        ax.errorbar(
            time,
            unp.nominal_values(pressures_new),
            unp.std_devs(pressures_new),
            label="Pressure for a new fault at angle of %s°"
            % (ghcfit.sign_str(angle_comp)),
        )
        ax.errorbar(
            time,
            unp.nominal_values(pressures_pre_tri),
            unp.std_devs(pressures_pre_tri),
            label="pre-existing fault with triangle",
        )
        ax.errorbar(
            time,
            unp.nominal_values(pressures_new_tri),
            unp.std_devs(pressures_new_tri),
            label="a new fault with triangle",
        )
        ax.set_ylabel("Contact Force Along Edge (N/m)")

    # Finalize Plot
    ax.legend()
    ax.set_xlabel("Time (s)")
    ax.set_title(fname)
    fig.tight_layout()
    plt.show()


def numerate_axes(fig: Figure, n=0, step=1):
    """Adds numbering to all axes in a figure"""
    axes = fig.get_axes()

    for ii in range(0, len(axes), step):
        axes[ii].annotate(
            "(" + string.ascii_lowercase[int((ii + n) / step)] + ")",
            (-0.05, 1.05),
            xycoords="axes fraction",
            fontweight="bold",
            fontsize="xx-large",
            verticalalignment="center",
            horizontalalignment="center",
            # bbox=dict(fc="w", boxstyle="Circle"),
        )


def plot_example_shs(time, fric, slc, slc_fric):
    """
    Plots example time series for publication
    """
    fric_slc = fric[slc]
    memu = np.mean(slc_fric)
    memu_std = 2 * np.std(slc_fric)
    fig = plt.figure(figsize=(12, 5))
    gs = GridSpec(nrows=1, ncols=3)
    axes = []
    axes.append(fig.add_subplot(gs[0, :2]))
    axes.append(fig.add_subplot(gs[0, -1]))
    axes[0].plot(time, fric, label="Data")
    axes[1].plot(time[slc], fric_slc)
    axes[0].axvline(time[slc.start], color="k", label="Detail")
    axes[0].axvline(time[slc.stop], color="k")
    for ax in axes:
        ax.axhline(memu, color="C1", label="mean $\\mu_d$")
        ax.axhline(
            memu + memu_std, color="C1", linestyle=":", label="2$\\sigma$"
        )
        ax.axhline(memu - memu_std, color="C1", linestyle=":")
    axes[1].axhline(np.max(fric_slc), color="C2")
    axes[1].annotate(
        "",
        (time[slc][12500], np.max(fric_slc)),
        (time[slc][12500], memu),
        arrowprops={"arrowstyle": "<->"},
    )
    axes[1].annotate(
        "$\\Delta\\mu_p$",
        (time[slc][12000], np.max(fric_slc) - 0.5 * (np.max(fric_slc) - memu)),
        horizontalalignment="right",
        verticalalignment="center",
    )
    axes[1].annotate(
        "",
        (time[slc][6900], memu - 0.5 * (memu - np.min(fric_slc))),
        (time[slc][17050], memu - 0.5 * (memu - np.min(fric_slc))),
        arrowprops={"arrowstyle": "<->"},
    )
    axes[1].annotate(
        "$t_h$",
        (
            time[slc][int(6900 + 0.5 * (17050 - 6900))],
            memu - 0.55 * (memu - np.min(fric_slc)),
        ),
        horizontalalignment="center",
        verticalalignment="top",
    )
    axes[1].sharey(axes[0])
    axes[0].legend()
    axes[0].set_ylabel("Apparent friction $\\mu=\\frac{\\tau}{\\sigma_N}$")
    for ax in axes:
        ax.set_xlabel("Time(s)")

    numerate_axes(fig)
    return fig


def add_erb(ax, x, y, col="", mark="-", lbl="", alpha=1):
    upper = unp.nominal_values(y) + unp.std_devs(y)
    lower = unp.nominal_values(y) - unp.std_devs(y)
    slc = []
    for ii, (u, l) in enumerate(zip(upper, lower)):
        if u > 0 and l < 0:
            slc.append(ii)
        elif u < 0 and l > 0:
            slc.append(ii)
    y[slc] = np.nan
    upper[slc] = np.nan
    lower[slc] = np.nan
    ax.plot(
        x,
        unp.nominal_values(y),
        label=lbl,
        color=col,
        linestyle=mark,
        alpha=alpha,
    )

    ax.fill_between(
        x,
        upper,
        lower,
        color=col,
        alpha=alpha * 0.25,
    )


def plot_force_mulugeta(heal_medians, rst_medians, save_folder):
    fig, axes = plt.subplots(
        nrows=2, ncols=4, sharex=True, sharey=True, figsize=(14, 8)
    )

    ii = 0
    matcol = {
        "quartz sand": "C0",
        "corundum sand": "C1",
        "feldspar sand": "C2",
        "garnet sand": "C3",
        "glass beads": "C4",
        "zircon sand": "C5",
        "foam glass": "C6",
    }
    mean_rhos = {
        "quartz sand": unc.ufloat(1424, 145),
        "corundum sand": unc.ufloat(1797, 218),
        "feldspar sand": unc.ufloat(1160, 72),
        "garnet sand": unc.ufloat(2087, 70),
        "glass beads": unc.ufloat(1372, 91),
        "zircon sand": unc.ufloat(2569, 70),
        "foam glass": unc.ufloat(377, 163),
    }
    matname = {
        "quartz sand": "Quartz Sand",
        "corundum sand": "Corundum Sand",
        "feldspar sand": "Feldspar Sand",
        "garnet sand": "Garnet Sand",
        "glass beads": "Glass Beads",
        "zircon sand": "Zircon Sand",
        "foam glass": "Foam Glass",
    }

    time = np.logspace(0, 6, 100)
    for mat in heal_medians.keys():
        if heal_medians[mat] and not unp.isnan(rst_medians[mat]["mup"]):
            rst_medians[mat]["mup"] = unc.ufloat(
                rst_medians[mat]["mup"].nominal_value, 0.01
            )
            rst_medians[mat]["cp"] = unc.ufloat(
                rst_medians[mat]["cp"].nominal_value, 10
            )
            rst_medians[mat]["mur"] = unc.ufloat(
                rst_medians[mat]["mur"].nominal_value, 0.01
            )
            rst_medians[mat]["cr"] = unc.ufloat(
                rst_medians[mat]["cr"].nominal_value, 10
            )
            # friction on reactivated fault (time dependent)
            react_mu = rst_medians[mat]["mur"] * time ** heal_medians[mat]
            # friction to create new fault
            static_mu = rst_medians[mat]["mup"] * np.ones_like(react_mu)

            # angle of fault in extension (w.r.t. horizontal)
            angle_pre = np.ones_like(react_mu) * (
                90 - ghangl_unc.opt_angle(rst_medians[mat]["mup"])
            )
            # angle of fault in compression (w.r.t horizontal)
            angle_new = np.ones_like(react_mu) * (
                ghangl_unc.opt_angle(rst_medians[mat]["mup"])
            )

            force_pre = ghangl_unc.force_mulugeta(
                react_mu, angle_pre, 0.05, mean_rhos[mat]
            )
            force_new = ghangl_unc.force_mulugeta(
                static_mu, angle_new, 0.05, mean_rhos[mat]
            )

            add_erb(
                axes.flat[ii],
                time,
                force_new,
                col=matcol[mat],
                lbl="New fault",
            )

            add_erb(
                axes.flat[ii],
                time,
                force_pre,
                col=matcol[mat],
                mark="--",
                lbl="Reactivated fault",
            )
        axes.flat[ii].set_title(matname[mat])
        ii += 1
    for ax in axes.flat:
        ax.set_xscale("log")
        ax.set_xlim(10**0, 10**6)
        ax.axhline(0, color="k")
        ax.fill_between(
            time,
            0 * time,
            -400 * np.ones_like(time),
            color="k",
            alpha=0.1,
            zorder=-1,
        )
        ax.annotate(
            "Lockup Region",
            (0.5, 0.05),
            xycoords="axes fraction",
            horizontalalignment="center",
        )
    for ax in axes[1]:
        ax.set_xlabel("Time (s)")

    axes[0][0].legend()
    axes[0][0].set_ylim(-350, 350)
    axes[0][0].set_xlim(10**0, 10**6)
    axes[0][0].set_ylabel("Shear force per unit area ($\\frac{N}{m}$)")
    axes[1][0].set_ylabel("Shear force per unit area ($\\frac{N}{m}$)")
    numerate_axes(fig, n=7)
    fig.tight_layout()
    fig.savefig(
        os.path.join(save_folder, "Figure6_ShearForce_with_Healing_Mulugeta")
    )
    fig.savefig(
        os.path.join(
            save_folder, "Figure6_ShearForce_with_Healing_Mulugeta.pdf"
        )
    )


def plot_opt_angles(heal_medians, rst_medians, save_folder):
    """Plots optimal angles according to Bonini et al."""
    fig, axes = plt.subplots(
        nrows=2, ncols=4, sharex=True, sharey=True, figsize=(14, 8)
    )

    ii = 0
    matcol = {
        "quartz sand": "C0",
        "corundum sand": "C1",
        "feldspar sand": "C2",
        "garnet sand": "C3",
        "glass beads": "C4",
        "zircon sand": "C5",
        "foam glass": "C6",
    }
    matname = {
        "quartz sand": "Quartz Sand",
        "corundum sand": "Corundum Sand",
        "feldspar sand": "Feldspar Sand",
        "garnet sand": "Garnet Sand",
        "glass beads": "Glass Beads",
        "zircon sand": "Zircon Sand",
        "foam glass": "Foam Glass",
    }

    time = np.logspace(0, 6)
    for mat in heal_medians.keys():
        if heal_medians[mat] and not unp.isnan(rst_medians[mat]["mup"]):
            rst_medians[mat]["mup"] = unc.ufloat(
                rst_medians[mat]["mup"].nominal_value, 0.01
            )
            rst_medians[mat]["cp"] = unc.ufloat(
                rst_medians[mat]["cp"].nominal_value, 10
            )
            rst_medians[mat]["mur"] = unc.ufloat(
                rst_medians[mat]["mur"].nominal_value, 0.01
            )
            rst_medians[mat]["cr"] = unc.ufloat(
                rst_medians[mat]["cr"].nominal_value, 10
            )
            react_mu = rst_medians[mat]["mur"] * time ** heal_medians[mat]
            static_mu = rst_medians[mat]["mup"] * np.ones_like(react_mu)
            opt_angle_react_mu = ghangl_unc.opt_angle(react_mu)
            lock_angle_react_mu = ghangl_unc.lock_angle(react_mu)
            opt_angle_static_mu = ghangl_unc.opt_angle(static_mu)
            # lock_angle_static_mu = ghangl_unc.lock_angle(static_mu)

            angle_pre = np.ones_like(react_mu) * (
                90 - ghangl_unc.opt_angle(rst_medians[mat]["mup"])
            )

            # print("###")
            # print(matname[mat])
            # print("mu", np.mean(static_mu))
            # print("preexisting:", np.mean(angle_pre))
            # print("new:", np.mean(opt_angle_static_mu))

            add_erb(
                axes.flat[ii],
                time,
                angle_pre,
                lbl="Angle of preexisting fault",
                col=matcol[mat],
                mark="--",
            )
            add_erb(
                axes.flat[ii],
                time,
                opt_angle_static_mu,
                lbl="Angle of new fault",
                col=matcol[mat],
            )

            add_erb(
                axes.flat[ii],
                time,
                opt_angle_react_mu,
                lbl="Optimal angle for reactivation",
                col="k",
                mark="--",
            )

            axes.flat[ii].fill_between(
                time,
                unp.nominal_values(lock_angle_react_mu),
                90 * np.ones_like(time),
                color="k",
                alpha=0.25,
                label="Lockup region for reactivated faults",
                zorder=-1,
            )

            axes.flat[ii].set_title(matname[mat])
        ii += 1
    axes[0][0].set_xscale("log")
    for ax in axes[1]:
        ax.set_xlabel("Time (s)")
        ax.set_ylim(0, 90)
        ax.set_xlim(10**0, 10**6)
    axes[0][0].legend()
    axes[0][0].set_ylabel("Angle (°)")
    axes[1][0].set_ylabel("Angle (°)")
    numerate_axes(fig)
    fig.tight_layout()
    fig.savefig(os.path.join(save_folder, "Figure5_OptimalAngles"))
    fig.savefig(os.path.join(save_folder, "Figure5_OptimalAngles.pdf"))


def plot_quality_score(quality_scores, base_folder):
    """Plots an overview of all quality scores"""
    fig, ax = plt.subplots(
        ncols=4, sharex=True, sharey=True, figsize=(16, 6.5)
    )
    titles = {
        "Sphericity (quali-score)": "Sphericity",
        "Roundness (quali-score)": "Roundness",
        "Surface (quali-score)": "Surface Roughness",
    }

    ii = 0
    for k in titles.keys():
        x = quality_scores[k]
        xerr = np.zeros_like(x)
        create_overview_plot(
            base_folder,
            quality_scores["Sample"],
            x,
            xerr,
            xlabel="Quality Score",
            fname="All_QualityScores",
            ax=ax[ii],
            title=titles[k],
            leg_axis=True,
        )
        ii += 1
    x = quality_scores["Average Quality Score"]
    xerr = np.zeros_like(x)
    create_overview_plot(
        base_folder,
        quality_scores["Sample"],
        x,
        xerr,
        xlabel="Quality Score",
        fname="All_QualityScores",
        ax=ax[ii],
        title="Average Quality Score",
        leg_axis=True,
    )
    fig.legend(
        handles=get_leg_handles(),
        loc="upper left",
        bbox_to_anchor=(0, 0.9),
        fontsize="small",
    )
    numerate_axes(fig)
    fig.tight_layout()
    # plt.show()
    fig.savefig(os.path.join(base_folder, "Figure3_All_QualityScores"))
    fig.savefig(os.path.join(base_folder, "Figure3_All_QualityScores.pdf"))
    plt.close(fig)


def plot_quali_fric(quality_scores, heal_data, rst_data, base_folder):
    """Plots quality measures against friction"""

    titles = {
        "Sphericity (quali-score)": "Sphericity",
        "Roundness (quali-score)": "Roundness",
        "Surface (quali-score)": "Surface Roughness",
        "Average Quality Score": "Average Quality Score",
    }
    group_dict = ghgroups.get_group_dict()
    colors = [
        "C%i" % group_dict["mat_group"][k] for k in quality_scores["Sample"]
    ]
    ys = []
    yerrs = []
    ylabels = []

    # Healing Rates
    shs_lbls = {
        "healing_rates": "Healing rate $b$ ()",
        "comp_rates": "Compaction rate $c$ ()",
        # "str_rates": "Stressing rate $s$ ()",
        # "rel_mods": "Reloading Modulus $R$ (Pa)",
    }
    for ll in shs_lbls.keys():
        ys.append(
            [
                heal_data[k][ll].nominal_value
                for k in quality_scores["annot_mat"]
            ]
        )
        yerrs.append(
            [heal_data[k][ll].std_dev for k in quality_scores["annot_mat"]]
        )
        ylabels.append(shs_lbls[ll])
    # Friction parameters:
    fri_lbls = {
        "mup": "Peak friction $\\mu_p$ ()",
        "mur": "React. friction $\\mu_r$ ()",
        "mud": "Stable friction $\\mu_s$ ()",
        "cp": "Peak cohesion $C_p$ (Pa)",
        "cr": "React. cohesion $C_r$ (Pa)",
        "cd": "Stable cohesion $C_s$ (Pa)",
    }
    for k in ["mup", "mur", "mud", "cp", "cr", "cd"]:
        ys.append(
            [rst_data[q][k].nominal_value for q in quality_scores["annot_mat"]]
        )
        yerrs.append(
            [rst_data[q][k].std_dev for q in quality_scores["annot_mat"]]
        )
        ylabels.append(fri_lbls[k])
    fig, axes = plt.subplots(
        ncols=4,
        nrows=len(ys),
        sharex=True,
        sharey="row",
        figsize=(1.3 * 210 * 0.03937008, 2 * 297 * 0.03937008),
        dpi=72,
    )
    pears = np.zeros((4, len(ys)))
    rhos = np.zeros((4, len(ys)))
    for jj, (y, yerr, ylbl) in enumerate(zip(ys, yerrs, ylabels)):
        ii = 0
        for k in titles.keys():
            x = quality_scores[k]
            for xx, yy, yee, c in zip(x, y, yerr, colors):
                axes[jj][ii].errorbar(
                    xx,
                    yy,
                    yerr=yee,
                    color=c,
                    capsize=2,
                    marker="o",
                    markersize=2,
                )
            pear, pval = pearsonr(x, y)
            pears[ii][jj] = pear
            rho, pval2 = spearmanr(x, y)
            rhos[ii][jj] = rho
            if "Healing" in ylbl:
                ann_coords = (0.05, 0.95)
                vertal = "top"
            else:
                ann_coords = (0.05, 0.05)
                vertal = "bottom"
            axes[jj][ii].annotate(
                "r=%.2f (p=%.2f)\n$\\rho=$%.2f (p=%.2f)"
                % (
                    pear,
                    np.ceil(pval * 100) / 100,
                    rho,
                    np.ceil(pval2 * 100) / 100,
                ),
                ann_coords,
                verticalalignment=vertal,
                xycoords="axes fraction",
                fontsize="small",
            )
            (xq, yq, ucb, lcb, popt, perr) = get_fit(x, y)
            axes[jj][ii].plot(xq, yq, color="k", alpha=0.3)

            # axes[jj][ii].plot(xq, ucb, color="k", alpha=0.3, linestyle=":")
            # axes[jj][ii].plot(xq, lcb, color="k", alpha=0.3, linestyle=":")
            # plot_correlation_with_histograms(
            #     x,
            #     y,
            #     yerr,
            #     xlabel="Quality Score()",
            #     ylabel=ylbl,
            #     title=k,
            #     base_folder=base_folder,
            # )
            ii += 1
        axes[jj][0].set_ylabel(ylbl)
        if "friction" in ylbl:
            axes[jj][0].set_ylim(-0.05, 0.85)
        elif "cohesion" in ylbl:
            axes[jj][0].set_ylim(-110, 210)
        elif "Healing" in ylbl:
            axes[jj][0].set_ylim(-0.0, 0.055)
        elif "Compaction" in ylbl:
            axes[jj][0].set_ylim(-0.075, 0.0)
    for ii, k in enumerate(titles.keys()):
        axes[0][ii].set_title(titles[k])
        axes[-1][ii].set_xlabel("Quality Score ()")
    axes[0][0].set_xlim(0.5, 4.5)
    numerate_axes(fig, step=4)
    fig.tight_layout()
    fig.savefig(
        os.path.join(base_folder, "Figure7_All_Friction_QualityScores")
    )
    fig.savefig(
        os.path.join(base_folder, "Figure7_All_Friction_QualityScores.pdf")
    )
    plt.close(fig)

    return pears, rhos


def plot_correlation_with_histograms(
    x, y, yerr, ylabel, xlabel, title, base_folder
):
    """Plots a single correlation plot with histograms on x and y axis"""

    fig, ax = plt.subplots(figsize=(8, 8), dpi=150)
    ax.errorbar(x, y, yerr=yerr, fmt="s", capsize=3)

    # create new axes on the right and on the top of the current axes
    divider = make_axes_locatable(ax)
    # below height and pad are in inches
    ax_histx = divider.append_axes("top", 1.2, pad=0.1, sharex=ax)
    ax_histy = divider.append_axes("right", 1.2, pad=0.1, sharey=ax)

    # make some labels invisible
    ax_histx.xaxis.set_tick_params(labelbottom=False)
    ax_histy.yaxis.set_tick_params(labelleft=False)
    ax_histx.hist(x, bins="auto")
    ax_histy.hist(y, bins="auto", orientation="horizontal")
    fname = sanitize_filename(f"{title}_{ylabel}")
    fname = "".join([f for f in fname if f not in " $()"])

    if "friction" in ylabel:
        ax.set_ylim(0, 0.8)
    elif "cohesion" in ylabel:
        ax.set_ylim(-20, 200)
    elif "Healing" in ylabel:
        ax.set_ylim(0, 0.04)
    elif "Compaction" in ylabel:
        ax.set_ylim(-0.06, 0)
    ax.set_xlabel(xlabel)
    ax.set_ylabel(ylabel)
    fig.suptitle(title)
    fig.tight_layout()
    fig.savefig(os.path.join(base_folder, fname + ".png"))
    # fig.savefig(os.path.join(base_folder, fname + ".pdf"))
    plt.close(fig)
