#!/usr/bin/env python3
# This file is part of granularhealing.

# granularhealing is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.

# granularhealing is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.

# You should have received a copy of the GNU General Public License
# along with granularhealing. If not, see <https://www.gnu.org/licenses/>.
""" Provides single and batch processing utilities for granularhealing """
import configparser
import os
from typing import List

import numpy as np

import granularhealing.cfit as ghfit
import granularhealing.files as ghfiles
from granularhealing.errors import GranularHealingError


def pick_folder(folder_path: str, cfg: configparser.ConfigParser) -> bool:
    """
    Processes a full sequence of measurement files and returns True if
    everything was successfully processed.
    """
    meas_file_list = ghfiles.get_file_list(folder_path, ".tdms")
    picks = get_hold_phases(meas_file_list, cfg)
    pick_file_path = os.path.join(folder_path, "picks.json")
    ghfiles.save_picks(pick_file_path, picks)
    return True


def get_hold_phases(
    meas_file_list: List[str], cfg: configparser.ConfigParser
) -> dict:
    """
    Gets the hold phases across multiple files.

    Checks if the length is in accordance with the configuration.
    """

    # Preallocation
    abs_hold_starts = []
    abs_hold_ends = []
    picked_vals = []
    pre_reload_fri = []
    pre_reload_lid = []
    init_hold_fri = []
    init_hold_lid = []
    stable_vals = []
    stable_vals_err = []
    reload_region = []

    for meas_file_path in meas_file_list:

        exp = ghfiles.open_tdms(meas_file_path)

        hold_starts, hold_ends = get_hold_intervals(exp["data"]["velocity"])

        for start_hold in hold_starts:
            # Convert all starts into absolute timestamps
            abs_hold_starts.append(
                exp["wf_start_time"]
                + np.timedelta64(int(start_hold * exp["dt"] * 1000), "ms")
            )
            init_hold_fri.append(
                pick_value(
                    exp,
                    start_hold,
                    region=cfg.getfloat("pick", "post_stop"),
                    shift=-0.5,
                    signal="shear",
                    pick_function="max",
                )
            )
            init_hold_lid.append(
                pick_value(
                    exp,
                    start_hold,
                    region=cfg.getfloat("pick", "post_stop"),
                    shift=-0.5,
                    pick_function="max",
                    signal="lid_disp",
                )
            )

        for end_hold in hold_ends:
            # Convert all ends into absolute timestamps
            abs_hold_ends.append(
                exp["wf_start_time"]
                + np.timedelta64(int(end_hold * exp["dt"] * 1000), "ms")
            )
            # Pick the reloading peak with the parameters from the config
            picked_vals.append(
                pick_value(
                    exp,
                    end_hold,
                    shift=cfg.getfloat("pick", "post_shift"),
                    region=cfg.getfloat("pick", "post_region"),
                )
            )
            # Pick the friction and lid position before reloading
            pre_reload_fri.append(
                pick_value(
                    exp,
                    end_hold,
                    region=cfg.getfloat("pick", "pre_loading"),
                    forward=False,
                    pick_function="mean",
                    signal="shear",
                )
            )
            pre_reload_lid.append(
                pick_value(
                    exp,
                    end_hold,
                    region=cfg.getfloat("pick", "pre_loading"),
                    forward=False,
                    pick_function="mean",
                    signal="lid_disp",
                )
            )

            # Get the reloading modulus
            reload_region.append(get_reload_area(exp, end_hold, cfg))

        # Get the stable part of the time series
        stbl, stbl_e = get_stable_friction(exp)
        stable_vals.append(stbl)
        stable_vals_err.append(stbl_e)

    # Sort the data
    sort_index = np.argsort(abs_hold_ends)
    abs_hold_ends = [abs_hold_ends[ii] for ii in sort_index]
    picked_vals = [picked_vals[ii] for ii in sort_index]
    abs_hold_starts.sort()

    # Clean hold starts
    clean_hold_starts = [
        abs_hold_starts[ii]
        for ii, hold_diff in enumerate(np.diff(abs_hold_starts))
        if hold_diff > np.timedelta64(70, "s")
    ]

    clean_hold_ends = [
        abs_hold_ends[ii]
        for ii, hold_diff in enumerate(np.diff(abs_hold_ends))
        if hold_diff > np.timedelta64(70, "s")
    ]
    clean_hold_ends.append(abs_hold_ends[-1])

    clean_pick_indices = []
    diffs = np.diff(abs_hold_ends)
    ii = 0
    while ii < len(diffs):
        if diffs[ii] > np.timedelta64(70, "s"):
            clean_pick_indices.append(ii)
            ii += 1
        else:
            clean_pick_indices.append(ii)
            ii += 2
    clean_pick_indices.append(-1)

    # Calculate hold times
    hold_times = [a - e for a, e in zip(clean_hold_ends, clean_hold_starts)]

    # Create return dictionary
    picks = {
        "hold_starts_time": clean_hold_starts,
        "hold_ends_time": clean_hold_ends,
        "hold_times": hold_times,
        "stable_friction": np.mean(stable_vals),
        "stable_friction_err": np.mean(stable_vals_err),
        "picked_peaks": clean_picks(picked_vals, clean_pick_indices),
        "pre_reload_fri": clean_picks(pre_reload_fri, clean_pick_indices),
        "pre_reload_lid": clean_picks(pre_reload_lid, clean_pick_indices),
        "init_hold_fri": clean_picks(init_hold_fri, clean_pick_indices),
        "init_hold_lid": clean_picks(init_hold_lid, clean_pick_indices),
        "reload_region": clean_picks(reload_region, clean_pick_indices),
    }
    return picks


def get_reload_area(
    exp: dict, start_index: int, cfg: configparser.ConfigParser
):
    end_index = pick_value(
        exp,
        start_index,
        shift=cfg.getfloat("pick", "post_shift"),
        region=cfg.getfloat("pick", "post_region"),
        return_index=True,
    )
    slc = slice(start_index, start_index + end_index)
    load_disp = np.cumsum(exp["data"]["velocity"][slc] * exp["dt"])
    shear_strain = load_disp / 40
    shear_stress = exp["data"]["shear"][slc]
    if shear_stress.any():
        spread = np.max(shear_stress) - np.min(shear_stress)
        start_fit = np.squeeze(
            np.argwhere(shear_stress >= np.min(shear_stress) + spread * 0.1)[0]
        )
        end_fit = np.squeeze(
            np.argwhere(shear_stress >= np.min(shear_stress) + spread * 0.9)[0]
        )

        popt, perr = ghfit.normal_fit(
            shear_strain[start_fit:end_fit], shear_stress[start_fit:end_fit]
        )
        return popt, perr
    else:
        return None, None


def clean_picks(pick_input, clean_indices):
    """Returns only picks at clean locations"""
    clean = [pick_input[ii] for ii in clean_indices]
    return clean


def get_stable_friction(exp: dict, vlim: float = 0.001):
    """Gets the stable friction as the average value where vel > vlim"""

    nstress = np.mean(exp["data"]["normal"])
    stable = (
        np.mean(
            exp["data"]["shear"][np.argwhere(exp["data"]["velocity"] > vlim)]
        )
        / nstress
    )
    stable_err = 2 * np.std(stable)

    return (stable, stable_err)


def get_hold_intervals(vel: np.ndarray, vlim: float = 0.001):
    """
    Gets starts and ends of hold intervals in the given velocity array.
    """

    n = np.argwhere(vel[()] < vlim).flat
    switches = np.argwhere(np.diff(n) > 10)

    start_hold = []
    end_hold = []
    start_hold.append(int(n[0]))
    for val in switches:
        end_hold.append(int(n[val]))
        start_hold.append(int(n[val + 1]))
    start_hold.append(len(vel))
    end_hold.append(int(n[-1]))
    return (np.array(start_hold), np.array(end_hold).flatten())


def pick_value(
    exp: dict,
    start_index: int,
    signal: str = "friction",
    shift: float = 0,
    region: float = 0,
    forward: bool = True,
    pick_function: str = "max",
    return_index: bool = False,
) -> float:
    """Picks a value from a signal according to the specified parameters.

    For minima and maxima we always use shear stress as input, the output is
    chosen from the signal parameter.

    Args:
        exp (dict): Experiment data
        start_index (int): Position where to start with pick region
        signal (str, optional): Signal from where to pick. Defaults to 'shear'.
        shift (float, optional): Shift of pick region. Defaults to 0.
        region (float, optional): Length of pick region. Defaults to 0.
        forward (bool, optional): Pick forward in time, if False, picks
        backward. Defaults to True.
        pick_function (str, optional): Specifier for the picking function.
        Defaults to max (possible options: max, min, mean, median)

    Returns:
        float: The picked value
    """

    # True means that we first have to pick the argument and then get the
    # value, otherwise we are fine by just giving back the value without
    # prepicking
    picking_fnc = {
        "max": (True, np.argmax),
        "min": (True, np.argmin),
        "mean": (False, np.mean),
        "median": (False, np.median),
    }

    slc = create_pick_region(exp["Fs"], start_index, shift, region, forward)

    if signal == "friction":
        channel = exp["data"]["shear"][slc]
    else:
        channel = exp["data"][signal][slc]
    pick_data = exp["data"]["shear"][slc]
    prepick, pickfnc = picking_fnc[pick_function]
    n_load = np.mean(exp["data"]["normal"])
    if prepick:
        if pick_data.size > 0:
            pick_point = pickfnc(pick_data)
            pick_val = channel[pick_point]
        else:
            pick_val = np.nan
    else:
        pick_val = pickfnc(channel)
        pick_point = None
    if signal == "friction":
        pick_val /= n_load
    if return_index:
        return pick_point
    else:
        return pick_val


def create_pick_region(Fs, start_index, shift, region, forward):
    """
    Creates a slice object for picking.
    """
    region_shift = int(Fs * shift)
    if forward:
        region_start = start_index + region_shift
        region_end = region_start + int(Fs * region)
    else:
        region_end = start_index - region_shift
        region_start = region_end - int(Fs * region)
    if region_end <= 0:
        raise ValueError("Region shift too large for backwards picking")
    elif region_start < 0:
        region_start = 0

    slc = slice(region_start, region_end)
    return slc


def get_pick_data(folder_path: str, cfg: configparser.ConfigParser) -> dict:
    """
    Tries to get picked data from given `pick_file_path` if none are there
    or overwriting is enabled, then new picks are generated.
    """
    pick_file_path = os.path.join(folder_path, "picks.json")
    picks = ghfiles.get_existing_picks(pick_file_path)

    if (
        not picks
        or cfg.getboolean("main", "overwrite")
        or "hold_times" not in picks.keys()
    ):
        pick_success = pick_folder(folder_path, cfg)

        if pick_success:
            picks = ghfiles.get_existing_picks(pick_file_path)
        else:
            raise GranularHealingError(
                "Creating picks failed.", "Picking Error"
            )
    return picks
