'''
Created on January 29, 2024 

This Python script was created by Prof. Joline Uichanco (jolineu@umich.edu) and is meant for research and not commercial use.
If using this code for research, please cite this publication "M. Jamalzadeh, et al., vol. X, Anayst, 2024"


This Python script implements Principal Component Regression (PCR) to calibrate a quantitative model
to predict the concentrations of Dopamine and Serotonin from the Fast Scan Cyclic Voltammetry (FSCV) data.

Input files:
1. Training_CVs.csv (see provided template for the format)
2. Testing_CVs.csv (see provided template for the format)
3. cycle_voltage.csv (only used in plotting)

Output files:
1. Concentration_predictions_PCR.csv (where PCR predictions are stored)
2. Diagnostic K matrix.pdf (plot of the K matrix; used to diagnose the PCR model)
3. PC_data.pdf (plot of the first 8 PCs)
4. Training_CV_plot.pdf (plots of all CVs used for training the PCR model)

Python packages needed to be installed (with conda or pip install):
1. scikit-learn (a package needed for all the statistical functions such as PCA, linear regression, and cross-fold validation)
2. numpy (a package for handling numerical arrays)
3. pandas (a package for fast handling of data)
4. matplotlib (a package for plotting)

@author: Joline Uichanco (jolineu@umich.edu)
'''

import numpy as np
import csv

#data handling
import pandas as pd

#plotting
import matplotlib.pyplot as plt
import seaborn as sns
sns.set()
sns.set_style("white")

from sklearn.decomposition import PCA
from sklearn.linear_model import LinearRegression
from sklearn.model_selection import KFold
from sklearn.model_selection import cross_val_score



'''
User options that can be changed
'''
# method options
to_predict_flag  = True
best_pc_num = 2 # number of PCs used in prediction (user-defined)

# if the training sensor and the testing sensor have different sensitivities, 
# then the two parameters will need to be set to adjust the predictions to account for the difference
DA_scale = 1 
SER_scale = 1 

# visualization options
plot_CV_flag = True # set to True if you want to visualize the CV training data
plot_PCA_flag = True # set to True if you want to visualize the Principal Components estimated from the CV training data

# option to conduct Kfold cross validation (Note that you still need to manually change best_pc_num after the test
Kfold_validation_test_flag = False # set to True if you want to select the parameter best_pc_num based on K-fold cross-validation test


'''
Training and testing data preparation
'''
# read the voltage sweep data
df_vol = pd.read_csv("cycle_voltage.csv") #, index_col=0)
voltage_list = np.transpose(df_vol.to_numpy())[0]

# read the training CVs into a pandas dataset 
df_train = pd.read_csv("Training_CVs.csv") 
df_train.sort_index(inplace=True)

# read testing CVs into a pandas dataset
df_test = pd.read_csv("Testing_CVs.csv") #, index_col=0)
df_test.sort_index(inplace=True)

# divide the dataset into X and y for training and testing
targets = ["C_DA","C_5HT"]
X_train = df_train.drop(targets,axis=1)
y_train = df_train[targets]
num_samples = X_train.shape[0]

X_test = df_test.drop(targets,axis=1)
y_test = df_test[targets]



'''
Data visualizations (plotting CVs)
'''
if plot_CV_flag:
    # determine index of data corresponding to pure DA, pure 5HT, and mix
    cols_DA = []
    cols_5HT = []
    cols_mix = []
    for j in range(num_samples):
        if (y_train.loc[j,"C_DA"] == 0) and (y_train.loc[j,"C_5HT"] > 0):
            cols_5HT.append(j)
        elif (y_train.loc[j,"C_5HT"] == 0) and (y_train.loc[j,"C_DA"] > 0):
            cols_DA.append(j)
        else:
            cols_mix.append(j)
    
    # plot all CVs including those of mixtures
    if len(cols_mix) > 0:
        fig, axs = plt.subplots(1, 3,figsize=(8,2),dpi=200,sharey=False)
        # adjust spacing between subplots
        plt.subplots_adjust(left=0.05,bottom=0.2,right=.98,top=.9,wspace=.2)
    # plot only CVs of pure analytes
    else:
        fig, axs = plt.subplots(1, 2,figsize=(5,2),dpi=200,sharey=False)
        # adjust spacing between subplots
        plt.subplots_adjust(left=0.08,bottom=0.2,right=.98,top=.9,wspace=.2)
        
    # set the properties common among subplots
    for ax in fig.axes:
        for spine in ax.spines.values():
            spine.set_linewidth(0.1)
        ax.tick_params(axis='both', which='major', labelsize=6, pad=-3)
        ax.minorticks_on()
        ax.grid(visible=False,which='both',linewidth=0.75,alpha=0.4)
    
    X_DA = X_train.iloc[cols_DA,:]
    X_5HT = X_train.iloc[cols_5HT,:]
    X_mix = X_train.iloc[cols_mix,:]
    with plt.style.context(('ggplot')):
        axs[0].plot(voltage_list, X_DA.to_numpy().T, linewidth=0.5)
        axs[0].set_xlabel('Potential (V)', fontsize=8)
        axs[0].set_ylabel(r'Current (nA)', fontsize=8)
        axs[0].set_title('Training, Pure DA (n=' + str(len(X_DA.index)) + ')', fontsize=8)
    
        axs[1].plot(voltage_list, X_5HT.to_numpy().T, linewidth=0.5)
        axs[1].set_xlabel('Potential (V)', fontsize=8)
        axs[1].set_ylabel(r'Current (nA)', fontsize=8)
        axs[1].set_title('Training, Pure 5-HT (n=' + str(len(X_5HT.index)) + ')', fontsize=8)
        
        if len(cols_mix) > 0:
            axs[2].plot(voltage_list, X_mix.to_numpy().T, linewidth=0.5)
            axs[2].set_xlabel('Potential (V)', fontsize=8)
            axs[2].set_ylabel(r'Current (nA)', fontsize=8)
            axs[2].set_title('Training, Mix (n=' + str(len(X_mix.index)) + ')', fontsize=8)
            
    plt.savefig('Training_CV_plots.pdf',dpi=200)
    

'''
    Conduct K-fold cross-validation to determine the best number of pcs to choose
    The results of this test can be used to re-run PCR with a different choice of best_pc_num
'''
if Kfold_validation_test_flag:
    # setup the figure
    fig = plt.figure(figsize=(3,2),dpi=200)
    plt.subplots_adjust(left=0.2, bottom=0.2, right=.98, top=.9,
                    wspace=.2, hspace=0.1)
    ax = plt.gca()
    for spine in ax.spines.values():
        spine.set_linewidth(0.1)
    ax.tick_params(axis='both', which='major', labelsize=6, pad=-3)
    ax.minorticks_on()
    ax.grid(visible=True,which='both',linewidth=0.75,alpha=0.4)
    
    # start the test (performance score is based on mean absolute error)
    mae = [] # MAE = mean absolute error
    num_pc_all_test = np.arange(start=1,stop=min(min(X_train.shape),15),step=1)
    for num_pc_test in num_pc_all_test:
        # Generate first M principal components (PCs) where M = num_pc_test
        pca = PCA(num_pc_test) 
        pca.fit(X_train)
        Vc = pca.components_.T
        
        # Compute the PC scores of the training set
        Xproj = np.dot(Vc.T,X_train.T) 
    
        # compute the parameters of PCR (must have no intercept)
        reg = LinearRegression(fit_intercept=False)
        # define cross-validation method to use
        cv = KFold(n_splits=10, random_state=5, shuffle=True)
        #use k-fold CV to evaluate the PCR model with the first M PCs where M = num_pc_test
        performance_scores = cross_val_score(reg, Xproj.T, y_train, scoring='neg_mean_absolute_error',cv=cv, n_jobs=-1)
        mae.append(np.mean(performance_scores))
    plt.plot(num_pc_all_test,mae)
    plt.xlabel('Num of PCs', fontsize=8)
    plt.ylabel('Performance', fontsize=8) #performance score is based on mean absolute error
    plt.title('K-fold cross validation test (K=10)', fontsize=8)
            


'''
    Estimate the Principal Components (PCs)
'''
# Generate ALL the principal components 
pca = PCA() 
pca.fit(X_train)
Vc = pca.components_.T
eig = pca.explained_variance_
num_pcs = len(eig)

# Save the principal components to a csv file
df_vc = pd.DataFrame(Vc, index=voltage_list, columns=["PC" + str(i+1) for i in range(num_pcs)])
df_vc.to_csv("PC_data.csv", index_label = 'voltage')

# visualizations for PCA
if plot_PCA_flag:
    # create the figure
    f, axs = plt.subplots(4,2,figsize=(3.5,5),dpi=200,sharex=True, sharey=True);
    # set the properties common among subplots
    for ax in f.axes:
        for spine in ax.spines.values():
            spine.set_linewidth(0.1)
        ax.tick_params(axis='both', which='major', labelsize=6, pad=-3)
        ax.minorticks_on()
        ax.grid(visible=False,which='both',linewidth=0.75,alpha=0.4)
    
    # plot the the first 8 PCs
    for i in range(min(8,num_pcs)):
        row_ind = i // 2
        col_ind = i % 2
        axs[row_ind,col_ind].plot(voltage_list,Vc[:,i])
        axs[row_ind,col_ind].set_title('PC ' + str(i+1),fontsize=8)
        if col_ind == 0:
            axs[row_ind,col_ind].set_ylabel('Current (nA)', fontsize=8)
        if row_ind == 3:
            axs[row_ind,col_ind].set_xlabel('Potential (V)', fontsize=8)
    plt.tight_layout()
    plt.gcf().subplots_adjust(left = 0.15, bottom = 0.08, right = 0.99, top = .95, wspace=0, hspace=.3)
    # save the plot as a pdf
    plt.savefig('PC_data.pdf',dpi=200)
    
    # create a plot showing how well the first 15 PCs explain the variance in the data
    fig = plt.figure(figsize=(3,2),dpi=200)
    plt.subplots_adjust(left=0.2, bottom=0.2, right=.98, top=.9,
                    wspace=.2, hspace=0.1)
    ax = plt.gca()
    for spine in ax.spines.values():
        spine.set_linewidth(0.1)
    ax.tick_params(axis='both', which='major', labelsize=6, pad=-3)
    ax.minorticks_on()
    ax.grid(visible=True,which='both',linewidth=0.75,alpha=0.4)

    plt.plot(range(1,min(15,num_pcs)+1), np.cumsum(pca.explained_variance_ratio_[:min(15,num_pcs)]), 'r-', linewidth=1)
    plt.title('Variance of data explained by PCs', fontsize=8)
    plt.ylabel('Cumulative variance (%)', fontsize=8)
    plt.xlabel('Number of PCs', fontsize=8)


'''
    Calibrate the PCR model using the training data set
'''
# Generate those PC components used in PCR
pca = PCA(best_pc_num) 
pca.fit(X_train)
Vc = pca.components_.T

# compute the PC scores of the training set
Xproj = np.dot(Vc.T,X_train.T)

# compute the parameters of PCR (must have no intercept)
reg = LinearRegression(fit_intercept=False).fit(Xproj.T,y_train)



'''
    Diagnostic test of the PCR model using the K-matrix.
    The k-vector is the PCR model's estimation of the current response to a pure unit analyte concentration.
    So if the k-vector is different from the CV shape of an analyte, you will probably need to choose a different best_num_pcs
'''
# setup the figure
fig = plt.figure(figsize=(3,2),dpi=200)
plt.subplots_adjust(left=0.15, bottom=0.15, right=.95, top=.9,
                wspace=.2, hspace=0.1)
ax = plt.gca()
for spine in ax.spines.values():
    spine.set_linewidth(0.1)
ax.tick_params(axis='both', which='major', labelsize=6, pad=-3)
ax.minorticks_on()

# recover the K matrix and plot it
Xprojinv = np.linalg.pinv(Xproj)
F = np.linalg.multi_dot([y_train.T,Xprojinv]) # same outcome as reg.coef_
Kinv = np.dot(F,Vc.T)
K = np.linalg.pinv(Kinv)
plt.title('Diagnostic test of PCR model (' + str(best_pc_num) + ' PCs) using K-matrix', fontsize = 8)
plt.plot(voltage_list, K)
plt.xlabel('Potential (V)', fontsize = 8)
plt.ylabel('Current (nA)', fontsize = 8)
plt.legend(['k vector (DA)','k vector (5-HT)'], fontsize = 8, frameon=False)

plt.savefig('Diagnostic K matrix.pdf',dpi=200)



'''
    Make out-of-sample predictions using the regression model
'''
if to_predict_flag:
    # compute the PC scores of the test set
    X_test = np.dot(Vc.T,X_test.T).T
    
    # Use the PCR model to predict the concentrations
    y_test_pred = pd.DataFrame(reg.predict(X_test), index=y_test.index, columns=y_test.columns)
    # Adjust the prediction based on difference in sensitivity of the training and testing sensors (if needed)
    y_test_pred['C_DA'] = DA_scale*y_test_pred['C_DA']
    y_test_pred['C_5HT'] = SER_scale*y_test_pred['C_5HT']
    
    # save predictions on a csv
    with open('Concentration_predictions_PCR.csv', 'w', newline='') as csvfile:
        writer = csv.writer(csvfile, delimiter=',')
        writer.writerow(["C_DA (Actual)", "C_DA (Predicted)","C_5HT (Actual)","C_5HT (Predicted)"])
        for k in range(len(y_test.index)):
            true_vals = y_test.iloc[k,:].tolist()
            pred_vals = y_test_pred.iloc[k,:].tolist()
            row = [true_vals[0],pred_vals[0],
                   true_vals[1],pred_vals[1]]
            writer.writerow(row)
            
    print("Predictions have been saved to a csv file")

            
    # show all of the generated figures
    plt.show()


