#!/usr/bin/env python3
# This file is part of granularhealing.

# granularhealing is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.

# granularhealing is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.

# You should have received a copy of the GNU General Public License
# along with granularhealing. If not, see <https://www.gnu.org/licenses/>.

""" File handling functions for granularhealing """

import csv
import json
import os
from tkinter import Tk, filedialog
from typing import List

import nptdms
import numpy as np

import granularhealing.groups as ghgroups


def get_folder_list(
    root_path: str = None, unpack: bool = False, **kwargs: dict
) -> List[str]:
    """
    Returns a list of all folders in root_path

    root_path: The path where to look for the folders, if not given asks the
    user to select one.
    unpack: If True then returns a tuple with a list of full paths and a list
    of directory names. Otherwise returns only a list of full paths.

    Additional keywords are passed to `filedialog.askdirectory()`.

    """
    kwgs = dict()
    if kwargs:
        kwgs.update(kwargs)

    # Get directory if not given
    if not root_path:
        root_path = ask_for_folder(**kwgs)

    # Create both lists
    folder_list = [
        folder
        for folder in os.listdir(root_path)
        if os.path.isdir(os.path.join(root_path, folder))
    ]
    folder_path_list = [
        os.path.join(root_path, folder) for folder in folder_list
    ]

    # Select which to return
    if unpack:
        return (folder_path_list, folder_list)
    else:
        return folder_path_list


def ask_for_folder(**kwargs) -> os.PathLike:
    """
    Asks user for folder path in a safer way than simply calling filedialog
    """

    kwgs = dict()
    if kwargs:
        kwgs.update(kwargs)

    # Safely tries to get the path and deletes root object afterwards. This
    # prevents some issues with Windows explorer that hangs after frequent
    # calls to filedialog.
    try:
        root = Tk()
        root.withdraw()
        folder_path = filedialog.askdirectory(**kwgs)
    finally:
        root.destroy()

    return folder_path


def ask_for_file(**kwargs) -> os.PathLike:
    """
    Asks user for file path in a safer way than simply calling filedialog
    """

    kwgs = dict()
    if kwargs:
        kwgs.update(kwargs)

    try:
        root = Tk()
        root.withdraw()
        file_path = filedialog.askopenfilename(**kwgs)
    finally:
        root.destroy()

    return file_path


def get_existing_picks(pick_file_path: str):
    """
    Checks for existing pick file and loads it. Returns False when something
    goes wrong.
    """
    if os.path.exists(pick_file_path):
        try:
            with open(pick_file_path, "rt") as json_file:
                picks = json.load(json_file)
                for k in picks.keys():
                    if k == "hold_times":
                        picks[k] = decode_timedelta(picks[k])
            return picks
        except json.decoder.JSONDecodeError as _:
            return False
    else:
        return False


def decode_timedelta(inputs: List) -> List:
    """Converts timedelta strings to timedelta in seconds"""
    output = [float(inp_str.split(" ")[0]) / 10**6 for inp_str in inputs]

    return output


def save_picks(pick_file_path: str, picks: dict):
    """
    Saves picks dictionary to a file of the same name as the file_path.
    """

    with open(pick_file_path, "wt") as json_file:
        json.dump(
            picks, json_file, sort_keys=True, indent=4, cls=AltNumpyEncoder
        )


class AltNumpyEncoder(json.JSONEncoder):
    """
    Custom encoder for numpy data types.

    Forked from numpyencoder.NumpyEncoder (https://github.com/hmallen/numpyencoder)
    """

    def default(self, obj):
        if isinstance(
            obj,
            (
                np.int_,
                np.intc,
                np.intp,
                np.int8,
                np.int16,
                np.int32,
                np.int64,
                np.uint8,
                np.uint16,
                np.uint32,
                np.uint64,
            ),
        ):

            return int(obj)

        elif isinstance(obj, (np.float_, np.float16, np.float32, np.float64)):
            return float(obj)

        elif isinstance(obj, (np.complex_, np.complex64, np.complex128)):
            return {"real": obj.real, "imag": obj.imag}

        elif isinstance(obj, (np.ndarray,)):
            return obj.tolist()

        elif isinstance(obj, (np.bool_)):
            return bool(obj)

        elif isinstance(obj, (np.void)):
            return None

        elif isinstance(obj, np.datetime64):
            return obj.__str__()

        elif isinstance(obj, np.timedelta64):
            return obj.__str__()

        return json.JSONEncoder.default(self, obj)


def open_tdms(tdms_file_path: str) -> dict:
    """Opens a ring shear tester tdms file and returns it as a dictionary."""

    # get name
    fname = os.path.basename(tdms_file_path)
    name = os.path.splitext(fname)[0]
    with nptdms.TdmsFile(tdms_file_path) as tdms_file:
        data = dict()
        channel_names = [
            channel.name for channel in tdms_file["Untitled"].channels()
        ]
        # Try to get the data for each known naming scheme
        try:
            # Ring shear tester
            if "Shear Force" in channel_names:
                data["shear"] = tdms_file["Untitled"]["Shear Force"][()]
                wf_start_time = tdms_file["Untitled"][
                    "Shear Force"
                ].properties["wf_start_time"]

                data["normal"] = tdms_file["Untitled"]["Normal Force"][()]
                data["lid_disp"] = tdms_file["Untitled"]["Lid Displacement"][
                    ()
                ]
                data["velocity"] = tdms_file["Untitled"]["Velocity"][()]
                data["time"] = tdms_file["Untitled"][
                    "Shear Force"
                ].time_track()
                data["device"] = "rst"

            # Box
            elif "Force Y" in channel_names:
                data["shear"] = tdms_file["Untitled"]["Force Y"][()]
                data["normal"] = tdms_file["Untitled"]["Force X"][()]
                data["lid_disp"] = np.zeros_like(
                    tdms_file["Untitled"]["Force X"][()]
                )
                data["velocity"] = tdms_file["Untitled"]["Velocity Y"][()]
                data["time"] = tdms_file["Untitled"]["Force Y"].time_track()
                data["device"] = "box"

            # Spring Slider
            elif "Force" in channel_names:
                weight = float(fname.split("_")[1][:-1])
                data["shear"] = tdms_file["Untitled"]["Force"]
                data["velocity"] = tdms_file["Untitled"]["Velocity"]
                data["time"] = tdms_file["Untitled"]["Velocity"].time_track()
                data["lid_disp"] = np.zeros_like(data["velocity"])
                data["normal"] = np.ones_like(data["velocity"]) * (
                    9.81 * (weight / 1000)
                )
                data["device"] = "spring-slider"

        except KeyError as _:
            KeyError("Unknown channel naming scheme.")

    # Calculates all other relevant information
    dt = data["time"][1]
    Fs = 1 / dt
    data["load_disp"] = np.cumsum(data["velocity"] * dt)
    rst = {
        "name": name,
        "data": data,
        "dt": dt,
        "Fs": Fs,
        "wf_start_time": wf_start_time,
    }

    return rst


def get_file_list(folder_path: str, ext: str) -> List[str]:
    """Returns a list of all files with the ending `ext` in `folder_path`"""
    file_list = [
        os.path.join(folder_path, fol)
        for fol in os.listdir(folder_path)
        if fol.endswith(ext)
    ]
    return file_list


def get_all_data(base_folder):
    data_path = os.path.join(
        base_folder, r"Data\Friction_Data\000_all_data.json"
    )
    with open(data_path, "rt") as json_file:
        all_data = json.load(json_file)
    return all_data


def get_heal_data(base_folder):
    data_path = os.path.join(base_folder, r"Data\SHS_Data\all_shs_data.json")
    with open(data_path, "rt") as json_file:
        all_data = json.load(json_file)
    return all_data


def get_quality_score(csv_path):
    """
    Gets quality scores and calculates weighted average.
    """
    with open(csv_path, "rt") as csv_file:
        reader = csv.reader(csv_file, delimiter=";")
        header = True
        for row in reader:
            if header:
                data = dict()
                keys = row
                for k in keys:
                    data[k] = []
                header = False
            else:
                for ii, r in enumerate(row):
                    try:
                        r = float(r)
                    except ValueError:
                        pass
                    data[keys[ii]].append(r)
    weights = {
        "Sphericity (quali-score)": 0.3,
        "Roundness (quali-score)": 0.9,
        "Surface (quali-score)": 1.8,
    }

    weighted_vals = []
    ii = 0
    quality_scores = dict()
    quality_scores["Sample"] = data["Sample"]
    for k, w in weights.items():
        weighted_vals.append([(5 - v) * w for v in data[k]])
        quality_scores[k] = [5 - v for v in data[k]]
    quality_scores["Average Quality Score"] = np.average(weighted_vals, axis=0)

    return quality_scores


def get_more_quality_score(csv_path):
    """
    Gets quality scores and calculates weighted average.
    """
    group_dict = ghgroups.get_group_dict()
    with open(csv_path, "rt") as csv_file:
        reader = csv.reader(csv_file, delimiter=";")
        header = True
        for row in reader:
            if header:
                data = dict()
                keys = row
                for k in keys:
                    data[k] = []
                header = False
            else:
                for ii, r in enumerate(row):
                    try:
                        r = float(r)
                    except ValueError:
                        pass
                    data[keys[ii]].append(r)
    weights = {
        "Sphericity (quali-score)": 0.3,
        "Roundness (quali-score)": 0.9,
        "Surface (quali-score)": 1.8,
    }

    weighted_vals = []
    ii = 0
    quality_scores = dict()
    quality_scores["Sample"] = data["Sample"]
    quality_scores["annot_mat"] = [
        group_dict["annot_mat"][k] for k in data["Sample"]
    ]
    for k, w in weights.items():
        weighted_vals.append([(5 - v) * w for v in data[k]])
        quality_scores[k] = [5 - v for v in data[k]]
    quality_scores["Average Quality Score"] = np.average(weighted_vals, axis=0)

    return quality_scores
