from sklearn.base import TransformerMixin, BaseEstimator
from sklearn.pipeline import Pipeline
from sklearn.cross_decomposition import PLSRegression
from sklearn.decomposition import PCA
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import RandomizedSearchCV, GridSearchCV, train_test_split, LeaveOneOut, GroupKFold, GroupShuffleSplit, StratifiedGroupKFold
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.metrics import classification_report, accuracy_score, ConfusionMatrixDisplay, confusion_matrix, precision_score, recall_score, f1_score
import itertools

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

import argparse

import csv

import seaborn as sns 
def do_pls_da(file_path, grouping, dim_red, validation_type, grid_type):
    
    chunk_size = 30000
    n_components = 10
    pls_model = PLSRegression(n_components=n_components)
    
    pca_model = PCA(n_components=n_components)
    
    X_transformed = []
    y_all = []
    ids = []
    
    # Read the data in chunks
    for chunk in pd.read_csv(file_path, chunksize=chunk_size, encoding='unicode_escape'):
    
    
        #chunk=chunk.drop(['line', 'shot', 'Unnamed: 9468'], axis=1) #use this line only for the raw dataset
        chunk=chunk.dropna()
        
        cat_cols_values=chunk.select_dtypes(include=['object'])
        print(cat_cols_values)
        numerical=chunk.select_dtypes(include=['number']).fillna(0)
               
        
        X_chunk=numerical
        #X_chunk = np.log1p(X_chunk)
        y_chunk=cat_cols_values[f'{grouping}']
        id_chunk = cat_cols_values['ID']
        
        # Label encode target variable
        le = LabelEncoder()
        y_chunk_scaled = le.fit_transform(y_chunk)
        
        
        print(f"Processing chunk with shape: {chunk.shape}")
        print(f"Shape of numerical data (X_chunk): {X_chunk.shape}")
        
        # Skip empty chunks
        if X_chunk.shape[0] == 0:
            print("Empty chunk, skipping...")
            continue
            
        print(id_chunk)
        print(X_chunk)
        #print(y_chunk.unique())
        
        # Scale features
        sc = StandardScaler()
        X_chunk_scaled= sc.fit_transform(X_chunk)
        
        print("Scaling successful for current chunk.")
        
         
        if dim_red=="pls_da":
            print("Fit and transform the chunk with PLS")
            print(y_chunk)
            
            X_chunk_pls = pls_model.fit_transform(X_chunk_scaled, y_chunk_scaled)[0]
            
            # Append transformed data and labels
            X_transformed.append(X_chunk_pls)
                        
            # Project X and Y onto the PLS components
            X_scores = pls_model.x_scores_
            Y_scores = pls_model.y_scores_

        elif dim_red=="pca":
            print("Fit and transform the chunk with PCA")
            X_chunk_pca = pca_model.fit_transform(X_chunk_scaled)
            # Append transformed data and labels
            X_transformed.append(X_chunk_pca)

        else:
            print("data is not transformed")
            
            X_transformed.append(X_chunk)
                
        y_all.append(y_chunk)
        ids.append(id_chunk)
       
    # Concatenate the transformed features and labels into single arrays
    if dim_red=="pls_da" or dim_red=="pca":
        X_transformed =np.vstack(X_transformed)
    else:    
        X_transformed = pd.concat(X_transformed) #alternatively: np.vstack(X_transformed)
    
    y_all = np.concatenate(y_all)
    le = LabelEncoder()
    y_all_le = le.fit_transform(y_all)
    ids = np.concatenate(ids)  # All IDs
    print(y_all) 
    # Check the final shapes of the transformed data
    print("Transformed data shape by", dim_red, X_transformed.shape)
    print("Labels shape:", y_all.shape)
    print("IDs shape:", ids.shape)
    
    # Unique IDs for LOOCV grouping
    unique_ids = np.unique(ids)
    print(unique_ids)
    
    y_true = []
    y_pred = []
    
    
    if validation_type=="GroupKFold":
        
        
        # Extract unique combinations of (group_label, ID)
        
        unique_combinations = pd.DataFrame({grouping: y_all, "ID": ids}).drop_duplicates()
        print(unique_combinations)
        y_true = []
        y_pred = []
        
        # Group by `grouping` to process one group at a time
        group_ids = unique_combinations.groupby(grouping)["ID"].unique()

        # Generate all unique combinations of one ID per grouping
        all_combinations = list(itertools.product(*group_ids))

        for test_ids_per_group in all_combinations:
            test_mask = np.zeros(len(ids), dtype=bool)

            # For each grouping, include only the current `test_id` in the test set
            for group, test_id in zip(group_ids.keys(), test_ids_per_group):
                test_mask |= (ids == test_id)

            train_mask = ~test_mask

            # Split the data
            X_train, X_test = X_transformed[train_mask], X_transformed[test_mask]
            y_train, y_test = y_all[train_mask], y_all[test_mask]
            print(test_ids_per_group)
            # Train RandomForest
            rf_model = RandomForestClassifier(random_state=42, 
                                              min_samples_split=10,
                                              min_samples_leaf=10,
                                              max_depth=10,
                                              max_features='sqrt',
                                              n_estimators=100)
            
            rf_model.fit(X_train, y_train)
            y_pred_sample = rf_model.predict(X_test)
            print(accuracy_score(y_test, y_pred_sample))
            # Store true and predicted labels
            y_true.extend(y_test)
            y_pred.extend(y_pred_sample)
            

        # Evaluate the performance
        print("GroupKFold Classification Report:")
        print(classification_report(y_true, y_pred))
        print(f"GroupKFold Accuracy: {accuracy_score(y_true, y_pred)}")
        
        # Compute normalized confusion matrix
        conf_mat = confusion_matrix(y_true, y_pred, labels=np.unique(y_all), normalize='true')* 100

        # Abbreviate labels (first 3 characters)
        short_labels = [label[:3] for label in np.unique(y_all)]

        # Set figure size slightly larger for better readability
        fig, ax = plt.subplots(figsize=(4, 4), constrained_layout=True)  # Adjust for better fit

        # Create heatmap with rotated annotations
        sns.heatmap(conf_mat, annot=True, fmt='.0f', cmap='Greys',  # Colorblind-friendly
                    annot_kws={"size": 7, "rotation": 45},  # Rotate numbers inside cells
                    cbar_kws={"shrink": 0.8},  # Reduce colorbar size
                    linewidths=0.7, linecolor='gray',  # Increase spacing between cells
                    xticklabels=short_labels,  
                    yticklabels=short_labels)  

        # Rotate x-axis labels to prevent overlap
        plt.xticks(rotation=45, ha="right", fontsize=8)  
        plt.yticks(rotation=0, fontsize=8)  

        # Maintain equal aspect ratio
        ax.set_aspect("equal")

        # Labels and title
        plt.xlabel("Predicted Label", fontsize=10, labelpad=10)
        plt.ylabel("True Label", fontsize=10, labelpad=10)
        plt.title("Confusion Matrix (Percentages)", fontsize=10)

        # Save as high-resolution TIFF for JAAS
        plt.savefig("confusion_matrix.tiff", dpi=300, bbox_inches='tight', format='tiff')
        plt.savefig("confusion_matrix.pdf", dpi=300, bbox_inches='tight', format='pdf')
        # Show plot
        plt.show()
        
    
    elif validation_type=="GroupShuffleSplit":

        gss = GroupShuffleSplit(n_splits=5, test_size=0.2, random_state=42)

        for train_index, test_index in gss.split(X_transformed, y_all_le, groups=ids):
            X_train, X_test = X_transformed[train_index], X_transformed[test_index]
            y_train, y_test = y_all_le[train_index], y_all_le[test_index]
            
            # Train RandomForest
            rf_model = RandomForestClassifier(random_state=42, 
                                              min_samples_split=10,
                                              min_samples_leaf=10,
                                              max_depth=10,
                                              max_features='sqrt',
                                              n_estimators=100)
            rf_model.fit(X_train, y_train)
            y_pred_sample = rf_model.predict(X_test)
            
            # Evaluate the performance
            print(classification_report(y_test, y_pred_sample))
            print(f"Accuracy: {accuracy_score(y_test, y_pred_sample)}")
    else:
        training, testing, training_labels, testing_labels = train_test_split(X_transformed, y_all, test_size = 0.2, random_state=42, shuffle=True) 
        n_estimators =[10, 50, 100, 200]# Number of trees in random forest 
        max_features = ['log2','sqrt']# Number of features to consider at every split
        max_depth = [5, 10, 20, 50,100]# , 50, 75, 100, 150, 200 # Maximum number of levels in tree
        min_samples_split = [ 10, 15, 20] # Minimum number of samples required to split a node
        min_samples_leaf = [ 10, 15, 20]# Minimum number of samples required at each leaf node
        bootstrap = [True, False]# Method of selecting samples for training each tree
        criterion=['gini', 'entropy', 'log_loss']# Criterion
        if grid_type=="random":
        
            random_grid = {'n_estimators': n_estimators,
                           'max_features': max_features,
                           'max_depth': max_depth,
                           'min_samples_split': min_samples_split,
                           'min_samples_leaf': min_samples_leaf,
                           'bootstrap': bootstrap,
                           'criterion': criterion}   
            rf_base = RandomForestClassifier(class_weight="balanced")
            
            rf_random = RandomizedSearchCV(estimator = rf_base,
                                       param_distributions = random_grid,
                                       n_iter = 4, cv = 5,
                                       verbose=2,
                                       random_state=42, n_jobs = 4)
            
           
            rf_random.fit(training,training_labels)
            
            rf_random.best_params_
            preds = rf_random.predict(testing)
            print (rf_random.score(training, training_labels))
            print(rf_random.score(testing, testing_labels))
            print(rf_random.best_params_)
            # print classification report 
            print(classification_report(testing_labels, preds))
            conf_mat=confusion_matrix(testing_labels, preds, labels=testing_labels, normalize=None)
            print(conf_mat)
        elif grid_type=="complete":
            criterion=['gini', 'entropy', 'log_loss']
            param_grid = {
            'n_estimators': np.linspace(180, 240, 20, dtype = int),
            'max_depth': np.linspace(80, 120, 10, dtype = int),
            'min_samples_split': [ 18, 20,22],
            'min_samples_leaf': [ 8, 10, 12],
            'criterion': criterion
            }
            
            # Base model
            rf_grid = RandomForestClassifier(criterion, bootstrap = True)
            # Instantiate the grid search model
            rf_model = GridSearchCV(estimator = rf_grid, param_grid = param_grid, 
                                      cv = 5, n_jobs = 4, verbose = 2)
            
            rf_model.fit(training, training_labels)
            
            best_rf_model = rf_model.best_estimator_
            preds = rf_model.predict(testing)
            print(rf_model.best_params_)
            conf_mat=confusion_matrix(testing_labels, preds, labels=testing_labels, normalize=None)
            print(conf_mat)
        else:
            rf_model=RandomForestClassifier(random_state=42, 
                                                  min_samples_split=20,
                                                  min_samples_leaf=10,
                                                  max_depth=100,
                                                  max_features='sqrt',
                                                  n_estimators=100,
                                                  criterion='log_loss',
                                                  class_weight="balanced")
                                                  
            rf_model.fit(training, training_labels)
            y_pred_sample = rf_model.predict(testing)
            print (rf_model.score(training, training_labels))
            print(rf_model.score(testing, testing_labels))
            

            # Compute normalized confusion matrix
            conf_mat = confusion_matrix(testing_labels, y_pred_sample, labels=rf_model.classes_, normalize='true')* 100
            
            labels = ["Brazil", "China_HL", "China_IM", "China_He", "China_Sh", "Germany", "Korea", "Madagascar",
              "Mozambique", "Namibia", "Norway", "Russia", "Ukraine"]
            

            # Set figure size slightly larger for better readability
            fig, ax = plt.subplots(figsize=(4, 4), constrained_layout=True)  # Adjust for better fit
            print("Confusion Matrix Data:\n", conf_mat)
            pd.DataFrame(conf_mat, index=rf_model.classes_, columns=rf_model.classes_).to_csv("confusion_matrix.csv", index=True)
            # Create heatmap with rotated annotations
            sns.heatmap(conf_mat, annot=True, fmt='.1f', cmap='Greys',  # Colorblind-friendly
                        annot_kws={"size": 6.5, "clip_on": False},  
                        cbar_kws={"shrink": 0.8},  # Reduce colorbar size
                        linewidths=1.2, linecolor='gray',  # Increase spacing between cells
                        xticklabels=labels,  
                        yticklabels=labels)  

            # Rotate x-axis labels to prevent overlap
            plt.xticks(rotation=45, ha="right", fontsize=8)  
            plt.yticks(rotation=0, fontsize=8)  

            # Maintain equal aspect ratio
            ax.set_aspect("equal")

            # Labels and title
            plt.xlabel("Predicted Label", fontsize=8, labelpad=5)
            plt.ylabel("True Label", fontsize=8, labelpad=5)
            #plt.title("Confusion Matrix (Percentages)", fontsize=10)

            # Save as high-resolution TIFF for JAAS
            plt.savefig("confusion_matrix.tiff", dpi=300, bbox_inches='tight', format='tiff')
            plt.savefig("confusion_matrix.pdf", dpi=300, bbox_inches='tight', format='pdf')
            # Show plot
            plt.show()


       
            disp = ConfusionMatrixDisplay(confusion_matrix=conf_mat, display_labels=rf_model.classes_)
            disp.plot(xticks_rotation=45)
            plt.show()
            
            # Metrics Calculation
            metrics = {
                "Accuracy": accuracy_score(testing_labels, y_pred_sample),
                "Precision": precision_score(testing_labels, y_pred_sample, average="macro"),
                "Recall": recall_score(testing_labels, y_pred_sample, average="macro"),
                "F1-Score": f1_score(testing_labels, y_pred_sample, average="macro")
            }

            # Feature Importances
            feature_importances = rf_model.feature_importances_
            X_transformed=pd.DataFrame(X_transformed)
            feature_importances_df = pd.DataFrame({
                "Feature": X_transformed.columns,
                "Importance": feature_importances
            }).sort_values(by="Importance", ascending=False)

            # Saving Metrics and Feature Importances to CSV
            output_file = "random_forest_metrics.csv"
            with open(output_file, mode="w", newline="") as file:
                writer = csv.writer(file)
                writer.writerow(["Metric", "Value"])
                for key, value in metrics.items():
                    writer.writerow([key, value])

                writer.writerow([])  # Blank line for separation
                writer.writerow(["Feature", "Importance"])
                for _, row in feature_importances_df.iterrows():
                    writer.writerow([row["Feature"], row["Importance"]])

            print(f"Metrics and feature importances have been saved to {output_file}.")
if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Carry out dimension reduction and classification on a LIBS dataset")
    parser.add_argument("file_path", type=str, help="Path to the LIBS datafile.")
    parser.add_argument("grouping", type=str) #group by one of the categorical variables
    parser.add_argument("dim_red", type=str) #which dimension reduction technique to use
    parser.add_argument("validation_type", type=str)#which validation type to use
    parser.add_argument("--grid_type", type=str)
    
    args = parser.parse_args()

    do_pls_da(args.file_path, args.grouping, args.dim_red, args.validation_type, args.grid_type)
