# -*- coding: utf-8 -*-
"""
Spyder Editor
Authors: Michael Warsitzka, Matthias Rosenau, Malte Ritter
When using the data please use the citation as given in the description of data of this data publication.
"""
#%%=====================IMPORT================================================
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import os,re
from scipy import stats 
#%%================PARAMTERS FOR RST/ SHEAR CELL================================================
A =     0.022619    # area of shear zone (= surface area of lid) (m^2)
li =    0.0776     # inner lever (center cell to center of shear zone)
lo =    0.1250     # outer lever (center cell to hinge point)
v=      3          #shear velocity (mm/min)
#%%================NAMES====================================================
path_in =        './../Data files/'     #relative path of input data
#%%===================LOAD DATA=========================================================
for files in os.listdir(path_in):
    files = [i for i in os.listdir(path_in) if i.endswith(".txt")]
    rstnames = [i for i in os.listdir(path_in) if '_peak' in i] #names of measurements
    for f in range(0,len(rstnames)):
        rstname = re.sub('_peak.txt', '', rstnames[f])
        data_peak =     np.loadtxt(path_in+[i for i in os.listdir(path_in) if i.endswith(".txt") and \
                                            rstname in i and 'peak' in i][0], skiprows=1)
        data_dyn =      np.loadtxt(path_in+[i for i in os.listdir(path_in) if i.endswith(".txt") and \
                                            rstname in i and 'dynamic' in i][0], skiprows=1)
        data_react =    np.loadtxt(path_in+[i for i in os.listdir(path_in) if i.endswith(".txt") and \
                                            rstname in i and 'reactivation' in i][0], skiprows=1)
        data_ts =       np.loadtxt(path_in+[i for i in os.listdir(path_in) if i.endswith(".txt") and \
                                            rstname in i and '_ts' in i][0], skiprows=1)
        data =          [data_peak, data_dyn, data_react]
        m,n =           data_peak.shape
        data_rst = np.asanyarray((data_peak, data_dyn, data_react))
#%%=======================RST STANDARD CORRELATION=================================================
            #%% set figure parameters:
        plt.rcParams['savefig.format'] =        'pdf'
        plt.rcParams['font.size'] =             10
        plt.rcParams['font.family'] =           'Arial Unicode MS'
        plt.rcParams['savefig.dpi'] =           1000
        plt.rcParams['figure.figsize'] =        (8, 8)
        fig1=plt.figure()
        linecolor=['royalblue', 'r', 'g']
        marker=['o', '^', 's']
        label1=['Peak friction', 'Dynamic friction', 'Reactivation friction']
        label2=['(Peak)', '(Dynamic)', '(Reactivation)']
        for i in range(0,len(data_rst)):
            x,y = data_rst[i][:,0], data_rst[i][:,1] 
            # Correlation of normal and shear stress:  y = mu * x + C
            P1 = m*np.sum(x*y)-np.sum(x)*np.sum(y)
            P2 = m*np.sum(x**2)-np.sum(x)**2
            mu = P1/P2                          #coefficient of friction
            C = (np.sum(x**2)*np.sum(y)-np.sum(x)*np.sum(x*y))/P2    #cohesion
            diff = y-C-mu*x
            s = np.sum(diff**2)/(m-2)
            std_mu = np.sqrt(m*s/P2)                    #standard error Coef. friction
            std_C = np.sqrt(np.sum(x**2) * s/ P2)       #standard error cohesion
            fit =      np.polyval([mu, C],x)
   
         #%============PLOT LINEAR REGRESSION========================================= 
            plt.plot(x,y,
                     marker=marker[i], 
                     mfc=linecolor[i], 
                     ms= 6, lw=0,
                     markeredgewidth=0,
                     label=label1[i])
            plt.plot(x, fit, 
                     '-',color=linecolor[i], 
                     linewidth=1, 
                     label=r'Lin. Regr. '+label2[i]+r': $\tau$=({:.3f}$\pm${:.3f})$\sigma_N$+({:.2f}$\pm${:.2f})'.format(mu,std_mu,C,std_C))        
        plt.xlabel('Normal stress $\sigma_N$ [Pa]')
        plt.ylabel(r'Shear stress $\tau$ [Pa]')
        plt.xlim(0,2400)
        plt.ylim(0,2000)
        plt.xticks(np.arange(0,2400, 200))
        plt.yticks(np.arange(0,2000, 200))
        plt.legend(fontsize=8, 
                 facecolor='w',edgecolor='k', 
                 framealpha=1,loc= 'upper left',)
        plt.grid(color='gray', linestyle='-', linewidth=0.5)
        fig1.suptitle(rstname + ' ('+str(m) + ' measurements)', y=0.92)
        plt.savefig(path_in+rstname + '_linregr', 
                bbox_inches='tight', 
                edgecolor='w')
        plt.close()     
        

#%%=============POINTWISE CORRELATION OF FRICTION PARAMETERS=============================
        #calulation of friction coefficients (slope, M) and y-axis intercept (cohesion, C) by mutual two point linear regressions:
        M = np.zeros((len(data),m,m))
        C = np.zeros((len(data),m,m))
        for i in range(0,len(data)):
            for k in range(0, m-1):
                for l in range(0, m-k):
                    M[i][k,l]=(data[i][k+l,1]-data[i][k,1])/(data[i][k+l,0]-data[i][k,0]) #calculate slop/friction coefficient
                    C[i][k,l]=data[i][k,1]-M[i][k,l]*data[i][k,0] #calculate y-axis intercept/cohesion
                    l = l+1
                k= k+1     
            M[i][M[i] == np.inf] = np.nan       #set inf to Nan 
            M[i][M[i] == -np.inf] = np.nan      #set -inf to Nan 
            M[i][M[i] == 0] = np.nan         #set 0 to Nan 
            #calculation of cohesions (y axis intercept):
            C[i][C[i] == np.inf] = np.nan       #set inf to Nan 
            C[i][C[i] == -np.inf] = np.nan      #set -inf to Nan 
            C[i][C[i] == 0] = np.nan         #set 0 to Nan 
            i=i+1
        
#%%======================PLOT HISTOGRAMS===========================================     
        title_mu = [r'Peak friction coefficient $\mu_P$', 
                    r'Dynamic friction coefficient $\mu_D$', 
                    r'Reactivation friction coefficient $\mu_R$']
        title_C = [r'Peak cohesion $C_P$', 
                   r'Dynamic cohesion $C_D$', 
                   r'Reactivation cohesion $C_R$']
        plt.rcParams['figure.figsize'] =        (10, 14) 
        # function for plotting histograms:
        def histplot(axrow, coef, coh, tit_mu, tit_C):
            global histfit_coh, lnspc, statscoef, statscoh
            #==============FRICTION COEFFICIENT========================
            axrow[0].hist(coef[~np.isnan(coef)], 
                 bins=nbins, 
                 normed=True, 
                 color = 'royalblue', 
                 edgecolor = 'black')
            lnspc = np.linspace(np.nanmin(coef), np.nanmax(coef), len(coef))
            histfit_coef = stats.norm.pdf(lnspc, 
                stats.norm.fit(coef[~np.isnan(coef)])[0], 
                stats.norm.fit(coef[~np.isnan(coef)])[1])  
            axrow[0].plot(lnspc, histfit_coef, 'r--', linewidth=2, label='normal distribution') 
            axrow[0].set_title(tit_mu)
            axrow[0].set_xlabel('Friction coefficient $\mu$')
            axrow[0].set_ylabel('Counts')
            text='Mean: ' +str(round(stats.norm.fit(coef[~np.isnan(coef)])[0],3)) + '\n' + \
                'Std.: ' + str(round(stats.norm.fit(coef[~np.isnan(coef)])[1],3)) +  '\n' + \
                ' (' +str(coef[~np.isnan(coef)].size) + ' data)'
            axrow[0].text(0.98, 0.84, text,
                 horizontalalignment='right',
                 verticalalignment='bottom',
                 transform=axrow[0].transAxes)
            #==============COEHSION================================
            axrow[1].hist(coh[~np.isnan(coh)], 
                 bins=nbins, 
                 normed=True, 
                 color = 'royalblue',
                 edgecolor = 'black')
            statscoh=stats.norm.fit(coh[~np.isnan(coh)])
            lnspc = np.linspace(np.nanmin(coh), np.nanmax(coh), len(coh))
            histfit_coh = stats.norm.pdf(lnspc, 
                stats.norm.fit(coh[~np.isnan(coh)])[0], stats.norm.fit(coh[~np.isnan(coh)])[1]) 
            axrow[1].plot(lnspc, histfit_coh, 'r--', linewidth=2) 
            axrow[1].set_title(tit_C)
            axrow[1].set_xlabel('Cohesion $C$ [Pa]')
            axrow[1].set_ylabel('Counts')
            text ='Mean: ' + str(round(stats.norm.fit(coh[~np.isnan(coh)])[0],2)) + '\n' + \
                'Std.: ' + str(round(stats.norm.fit(coh[~np.isnan(coh)])[1],2)) + '\n' + \
                ' (' + str(coh[~np.isnan(coh)].size) + ' data)'
            axrow[1].text(0.98, 0.84, text,
                 horizontalalignment='right',
                 verticalalignment='bottom',
                 transform=axrow[1].transAxes)
        
        fig2, axes = plt.subplots(3, 2)
        plt.subplots_adjust(hspace = .25, wspace=.25)
        nbins = np.int(np.round(np.sqrt(np.sum(~np.isnan(M[0])))))
        for i in range(0, len(M)):
            for row in axes[i:i+1]:
                histplot(row, M[i], C[i], title_mu[i], title_C[i])  
        fig2.suptitle(rstname, y=0.92)
        plt.savefig(path_in+rstname + '_hist', 
                bbox_inches='tight', 
                edgecolor='w')
        plt.close()  
        
        #%%=======================CONVERT  TSIME SERIES=================================================          
        #convert time [s] to shear displacement [mm] and put into dataframe:
        dfts = pd.DataFrame(np.zeros(data_ts.shape))
        dfts.iloc[:,0] = (data_ts[:,0] * v/60) #convert time to shear displacement [mm]
        dfts.iloc[:,1:] = (data_ts[:,1:]*lo)/(li*A) #convert load (kg) to stress (Pa)
        dfts = dfts.iloc[1:, dfts.max().sort_values().index] # sort column according to max shear stress    
        #%%=======================PLOT TIME SERIES===========================================
        plt.rcParams['figure.figsize'] =        (10, 8) 
        linecolor= ['','red', 'orange','gold', 'green', 'royalblue']
        t = int(m/3)
        fig3=plt.figure()
        for i in range(0,t):
            normalstress_avg = int((data_peak[:,0][i]+data_peak[:,0][i+t]++data_peak[:,0][i+2*t])/3)
            plt.plot(dfts.iloc[:,0], np.zeros(len(dfts.iloc[:,0])), linewidth= 0.5, color=linecolor[i+1], label=str(int(normalstress_avg))+' Pa')
            plt.plot(dfts.iloc[:,0], dfts.iloc[:,i*3+1:(i+1)*3+1], linewidth= 0.5, color=linecolor[i+1])
        plt.legend(fontsize=8, 
                 facecolor='w', 
                 edgecolor='k', 
                 framealpha=1,
                 loc= 'upper right',
                 title=r"Normal stress $\sigma_N$")
        
        plt.xlabel('Shear displacement $d$ [mm]')
        plt.ylabel(r'Shear stress $\tau$ [Pa]')
        plt.xlim(0, max(dfts.iloc[:,0]))
        plt.ylim(0,round(max(dfts.max()),-3))
        plt.yticks(np.arange(0,round(max(dfts.max()),-3), 200))
        fig3.suptitle(rstname, y=0.92)
        plt.savefig(path_in+rstname +'_ts', 
                bbox_inches='tight', 
                edgecolor='w')
        plt.close()
    break