#!/usr/bin/env python3
# This file is part of granularhealing.

# granularhealing is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.

# granularhealing is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.

# You should have received a copy of the GNU General Public License
# along with granularhealing. If not, see <https://www.gnu.org/licenses/>.

"""
cfit.py: Contains fitting functions for scipy.optimize.curve_fit

"""
import numpy as np
import uncertainties as unc
import uncertainties.unumpy as unp
from scipy import odr, optimize, stats


def poly1(x, a, b):
    """a * x + b"""
    return a * x + b


def poly2(x, a, b, c):
    """a * x**2 + b * x + c"""
    return a * x**2 + b * x + c


def poly3(x, a, b, c, d):
    """a * x**3 + b * x**2 + c * x + d"""
    return a * x**3 + b * x**2 + c * x + d


def power1(x, a, b, c):
    """a * x**b + c"""
    return a * x**b + c


def exp1(x, a, b, c):
    """a * b**x + c"""
    return a * b**x + c


def natexp1(x, a, b):
    """a * np.exp(x) + b"""
    return a * np.exp(x) + b


def natexp2(x, a, b, c):
    """a * np.exp(b * x) + c"""
    return a * np.exp(b * x) + c


def expdecay(x, a, b):
    """a * np.exp(-b * x)"""
    return a * np.exp(-b * x)


def expdecay2(x, a, b, c):
    """a * np.exp(-b * x) + c"""
    return a * np.exp(-b * x) + c


def natlog1(x, a, b):
    """a + b * np.log(x)"""
    return a + b * np.log(x)


def poly1_odr(B, x):
    """poly1 reformated for ODR: B[0] * x + B[1]"""
    return B[0] * x + B[1]


def fit_bootstrap(p0, datax, datay, fnc=poly1, yerr_systematic=0.0, numit=100):
    """
    This function fits the data and returns the estimated error of fit usind
    resampled residuals.
    (https://en.wikipedia.org/wiki/Bootstrapping_(statistics)#Resampling_residuals)

    First the data is fit using linear least squares and the function given in
    `function`. The results are used to get the standard deviation of the
    residuals. A Gaussian distribution with a location of 0 and a scale of the
    residual standard deviation delivers the randomized residuals. These are
    then used for adding a random deviation to the original data. From 100
    randomized sets using this Gaussian deviation, the errors are estimated.

    Returns:
     pfit and perr
    """

    def errfunc(p, x, y):
        return fnc(x, *p) - y

    # Fit first time
    pfit, _ = optimize.leastsq(
        errfunc, p0, args=(datax, datay), full_output=0, maxfev=10000
    )

    # Get the stdev of the residuals
    residuals = errfunc(pfit, datax, datay)
    sigma_res = np.std(residuals)
    sigma_err_total = np.sqrt(sigma_res**2 + yerr_systematic**2)

    # Randomize residuals and add them to data,
    # then refit (repeat `numit` times)
    ps = []
    for _ in range(numit):
        randomDelta = np.random.normal(0.0, sigma_err_total, len(datay))
        randomdataY = datay + randomDelta

        randomfit, _ = optimize.leastsq(
            errfunc, p0, args=(datax, randomdataY), full_output=0, maxfev=10000
        )
        ps.append(randomfit)
    ps = np.array(ps)
    mean_pfit = np.mean(ps, 0)
    # You can choose the confidence interval that you want for your
    # parameter estimates:
    Nsigma = 2.0
    # 1 sigma gets approximately the same as methods above
    # 1 sigma corresponds to 68.3% confidence interval
    # 2 sigma corresponds to 95.44% confidence interval
    err_pfit = Nsigma * np.std(ps, 0)

    pfit_bootstrap = mean_pfit
    perr_bootstrap = err_pfit
    return pfit_bootstrap, perr_bootstrap


def normal_fit(x, y, fnc=poly1, rel=None, sig=None, p0=None, getcov=False):
    """
    Returns the parameters and errors given by optimize.curve_fit.
    The precision of the measurements is propagated into the estimated fit
    errors.

    Keyword arguments:
        fnc: function to fit to.
        rel: relative precision of data
        sig: absolute precision of data
    """
    popt, pcov = optimize.curve_fit(fnc, x, y, p0=p0, maxfev=10000)

    sigma = np.sqrt(np.diag(pcov))
    if rel:
        sigma = np.sqrt(sigma**2 + (rel * sigma) ** 2)
    elif sig:
        sigma = np.sqrt(sigma**2 + (sig) ** 2)
    if getcov:
        return popt, 2 * sigma, pcov
    else:
        return popt, 2 * sigma


def predband(x, xd, yd, popt, func, conf=0.95):
    """
    Calculates the prediction band of the regression model at the
    desired confidence level.

    Clarification of the difference between confidence and prediction bands:

    "The prediction bands are further from the best-fit line than the
    confidence bands, a lot further if you have many data points. The 95%
    prediction band is the area in which you expect 95% of all data points
    to fall. In contrast, the 95% confidence band is the area that has a
    95% chance of containing the true regression line."
    (from
    http://www.graphpad.com/guides/prism/6/curve-fitting/index.htm?reg_graphing_tips_linear_regressio.htm
    )

    Arguments:
    - x: array with x values to calculate the confidence band.
    - xd, yd: data arrays.
    - popt: parameters of fit.
    - conf: desired confidence level, by default 0.95 (2 sigma).

    References:
    1. http://www.JerryDallal.com/LHSP/slr.htm, Introduction to Simple Linear
    Regression, Gerard E. Dallal, Ph.D.

    Code adapted from:
    https://codereview.stackexchange.com/questions/84414/obtaining-prediction-bands-for-regression-model

    and
    http://astropython.blogspot.com.ar/2011/12/calculating-prediction-band-of-linear.html

    """
    if type(xd) != np.ndarray:
        xd = np.array(xd)
    if type(yd) != np.ndarray:
        yd = np.array(yd)

    alpha = 1.0 - conf  # significance
    N = xd.size  # data sample size
    var_n = len(popt)  # number of parameters
    # Quantile of Student's t distribution for p=(1-alpha/2)
    q = stats.t.ppf(1.0 - alpha / 2.0, N - var_n)
    # Stdev of an individual measurement
    se = np.sqrt(1.0 / (N - var_n) * np.sum((yd - func(xd, *popt)) ** 2))
    # Auxiliary definitions
    sx = (x - xd.mean()) ** 2
    sxd = np.sum((xd - xd.mean()) ** 2)
    # Predicted values (best-fit model)
    yp = func(x, *popt)
    # Prediction band
    dy = q * se * np.sqrt(1.0 + (1.0 / N) + (sx / sxd))
    # Upper & lower prediction bands.
    lcb, upb = yp - dy, yp + dy
    return lcb, upb


def confband(xq, xd, yd, popt, pcov, func, conf=0.95):
    """
    Calculates the confidence band of the regression model at the
    desired confidence level.

    Clarification of the difference between confidence and prediction bands:

    "The prediction bands are further from the best-fit line than the
    confidence bands, a lot further if you have many data points. The 95%
    prediction band is the area in which you expect 95% of all data points
    to fall. In contrast, the 95% confidence band is the area that has a
    95% chance of containing the true regression line."
    (from
    http://www.graphpad.com/guides/prism/6/curve-fitting/index.htm?reg_graphing_tips_linear_regressio.htm
    )

    Arguments:
    - x: array with x values to calculate the confidence band.
    - xd, yd: data arrays.
    - popt: parameters of fit.
    - pcov: covariance matrix of fit.
    - conf: desired confidence level, by default 0.95 (2 sigma)
    """

    if type(xd) != np.ndarray:
        xd = np.array(xd)
    if type(yd) != np.ndarray:
        yd = np.array(yd)

    corr_val = unc.correlated_values(popt, pcov)
    y = func(xq, *corr_val)
    nom = unp.nominal_values(y)
    std = unp.std_devs(y)
    lcb = nom - 2 * std
    ucb = nom + 2 * std

    return lcb, ucb


def sign_str(x, xe=None, polarity=None, stdout=False, sigdig=None):
    """
    Returns a string with value 'x \\pm xe'. The error handling is done with
    the uncertainties package so that the returned significant digits are in
    accordance with error reporting. Optionally the amount of significant
    digits can be done with the `sigdig` option.

        Keyword arguments:
            polarity: Adds polarity to the resulting string (+ or - in front).
            stdout: Changes from LaTeX to Unicode output.
            sigdig: Predetermined number of significant digits.
    """
    if stdout:
        sym = "\u00b1"
    else:
        sym = "\\pm"

    # if the input is already a unc ufloat
    if isinstance(x, unc.UFloat):
        out = "{:P}".format(x).replace("+/-", sym)

    # Use uncertainties to generate string
    elif not sigdig:
        value = unc.ufloat(x, std_dev=xe)
        # replace unc default +/- by sym
        out = "{:P}".format(value).replace("+/-", sym)

    # Fall back to legacy version for predefined significant digits
    else:
        s = sigdig
        rnd_xe = np.ceil(xe * 10**s) / (10**s)
        # In some cases rounding upwards leads to a change in significant
        # digits, e.g. when an error is .09 -> 0.1
        if sign(rnd_xe) != s and not sigdig:
            s = sign(rnd_xe)
        if s > 0:
            out = ("%." + str(s) + r"f " + sym + " %." + str(s) + "f") % (
                x,
                xe,
            )
        else:
            out = ("%" + str(s) + r"i " + sym + " %" + str(s) + "i") % (x, xe)
        if polarity and x > 0:
            out = "+ " + out
    return out


def sign(x):
    """
    Gives back number of significant digits
    """
    return int(-np.floor(np.log10(x)))


def fit_list(popt, perr):
    """
    Returns a list of formatted parameters with errors
    """
    fit_list = [sign_str(o, e, polarity=True) for o, e in zip(popt, perr)]
    return fit_list


def orthogonal_distance_regression(
    x, xe, y, ye, func=poly1_odr, max_iter=200, it_steps=10
):
    """
    Runs an orthogonal distance regression to fit y = f(x) with errors in x and y.
    """
    fit_model = odr.Model(func)
    data = odr.RealData(x, y, sx=xe, sy=ye)
    odr_reg = odr.ODR(data, fit_model, beta0=[0.0, 1.0])
    out = odr_reg.run()
    t_iter = 50
    while out.stopreason == "Iteration limit reached" and t_iter <= max_iter:
        out = odr_reg.restart(it_steps)
        t_iter += it_steps
    # out.pprint()
    popt = out.beta
    perr = out.sd_beta
    pcov = out.cov_beta
    return (popt, perr, pcov)
