import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.stats import pearsonr
from scipy import stats

from sklearn.preprocessing import MinMaxScaler, StandardScaler, RobustScaler
from scipy.cluster.hierarchy import linkage, fcluster
from scipy.spatial.distance import pdist

from sklearn.preprocessing import PolynomialFeatures
from sklearn.feature_selection import VarianceThreshold

from numpy import mean, std
import numpy as np
from sklearn.model_selection import LeaveOneOut, KFold
from sklearn.cross_decomposition import PLSRegression
from sklearn.linear_model import LinearRegression, Ridge, Lasso, LogisticRegression, LogisticRegressionCV, RidgeCV, LassoCV
from sklearn.ensemble import RandomForestRegressor
from sklearn.metrics import r2_score, balanced_accuracy_score, confusion_matrix, mean_squared_error, mean_absolute_error

import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import r2_score, mean_absolute_error

from sklearn.preprocessing import PolynomialFeatures
from utils_data import *

import statsmodels.api as sm

from collections import defaultdict



def load_and_pivot_data(remove_inconsistencies = False):
    data = read_experimental_data_DMS(remove_inconsistencies = remove_inconsistencies)

    data_pivot = data.pivot(index=['Catalyst','ID'],
                            columns=['Solvent','Base'],
                            values=['Yield','Conversion','Yield_std','Conversion_std']).reset_index()
    data_pivot.columns = data_pivot.columns.map('{0[0]}_{0[1]}_{0[2]}'.format)
    data_pivot.columns = [i.replace('__','') for i in data_pivot.columns]

    return data_pivot

def mylogreg(): #c
    nrange = list(np.arange(0.01,1,0.01))+list(np.arange(1,5,1))+list(np.arange(5,100,5))
    lr = LogisticRegressionCV(solver='liblinear',
                            penalty='l1', 
                            Cs = nrange, 
                            #multi_class='multinomial',
                            max_iter=5000,
                            tol = 1e-8,
                            class_weight='balanced',
                            cv = LeaveOneOut(),
                            #l1_ratio=0.8,
                            random_state=42)
    return lr

def find_best_c(data_pivot, featnamecols, ext_cat, scaler = None, showplot = False):
    # Assuming `X` is your feature matrix and `y` is the target variable
    #data_pivot_copy = data_pivot.copy()
    data_pivot_copy = data_pivot[~data_pivot['ID'].isin(ext_cat)]
    #scaler = StandardScaler()
    
    # Scaling the features
    if scaler:
        last_10_columns = data_pivot_copy.loc[:, featnamecols]
        data_pivot_copy.loc[:, featnamecols] = scaler.fit_transform(last_10_columns)
    
    X = data_pivot_copy.loc[:, featnamecols].values
    y = data_pivot_copy.fcluster.values
    
    # Logistic regression model
    #balanced_accs = []
    #bestba=0
    #bestc = 1
    #nrange = list(np.arange(0.01,1,0.001))+list(np.arange(1,5,1))+list(np.arange(5,100,5))
    #for c in nrange:
    #    log_reg = mylogreg()
    # 
    #    # Leave-One-Out Cross-Validation
    #    loo = LeaveOneOut()
    #    y_true, y_pred = [], []
    #    for train_index, test_index in loo.split(X):
    #        X_train, X_test = X[train_index], X[test_index]
    #        y_train, y_test = y[train_index], y[test_index]
   # 
   #         # Train the model
   #         log_reg.fit(X_train, y_train)
   # 
   #         # Predict the test sample
   #         y_pred.append(log_reg.predict(X_test)[0])
   #         y_true.append(y_test[0])
   # 
   #     # Calculate balanced accuracy
   #     balanced_acc = balanced_accuracy_score(y_true, y_pred)
   #     #print(f"Balanced Accuracy (LOO CV): {balanced_acc:.4f}")
   #     balanced_accs.append(balanced_acc)
   #     if balanced_acc > bestba:
   #         bestba=balanced_acc
   #         bestc = c
   #         
   # if showplot:
   #     plt.figure(figsize=(35,5))
   #     plt.scatter(nrange, balanced_accs)
   #     #plt.xticks(range(len(nrange)), nrange)
   #     plt.title(f'Best C = {bestc}')
   #     plt.show()
    ml = mylogreg().fit(X,y)
    bestc = ml.C_
    print(f'BEST C = {bestc}')
    return bestc

def plot_coef(data_pivot, featnamecols, c, ext_cat, scaler = None, showplot = False):
    # Assuming `X` is your feature matrix and `y` is the target variable
    #data_pivot_copy = data_pivot.copy()
    data_pivot_copy = data_pivot[~data_pivot['ID'].isin(ext_cat)]
    #scaler = StandardScaler()
    
    # Scaling the features
    if scaler:
        last_10_columns = data_pivot_copy.loc[:, featnamecols]
        data_pivot_copy.loc[:, featnamecols] = scaler.fit_transform(last_10_columns)
    
    X = data_pivot_copy.loc[:, featnamecols].values
    y = data_pivot_copy.fcluster.values
    
    # Logistic regression model
    log_reg = mylogreg()
    
    # Leave-One-Out Cross-Validation
    loo = LeaveOneOut()
    y_true, y_pred = [], []
    for train_index, test_index in loo.split(X):
        X_train, X_test = X[train_index], X[test_index]
        y_train, y_test = y[train_index], y[test_index]
        
        # Train the model
        log_reg.fit(X_train, y_train)
        
        # Predict the test sample
        y_pred.append(log_reg.predict(X_test)[0])
        y_true.append(y_test[0])
    
    # Calculate balanced accuracy
    log_reg.fit(X, y)
    y_pred_train = log_reg.predict(X)
    
    balanced_acc_loocv = balanced_accuracy_score(y_true, y_pred)
    balanced_acc_train = balanced_accuracy_score(y, y_pred_train)
    print(f"Balanced Accuracy (LOO CV): {balanced_acc_loocv:.4f}")
    print(f"Balanced Accuracy (train): {balanced_acc_train:.4f}")
    cm = confusion_matrix(y_true, y_pred)
    print(f"CM (LOO CV):")
    print(cm)
    
    # Plot bar plot of coefficients for each class
    coefficients = log_reg.coef_
    
    feats = []
    
    # Create bar plots for each class
    for i, coef in enumerate(coefficients):
        sorted_indices_all = np.argsort(coef)
        sorted_coef_all = coef[sorted_indices_all]
        sorted_coef = [i for i in sorted_coef_all if abs(i)>0]
        sorted_indices = [featnamecols[i] for i,j in zip(sorted_indices_all, sorted_coef_all) if abs(j)>0]
        
        feats = feats + sorted_indices[:3] + sorted_indices[-3:]
        
        if showplot:
            plt.figure(figsize=(20, 3))
            plt.bar(range(len(sorted_coef)), sorted_coef)
            plt.xticks(range(len(sorted_coef)), sorted_indices, rotation=90)
            plt.title(f"Logistic Regression Coefficients (Class {i})")
            plt.xlabel('Features (sorted)')
            plt.ylabel('Coefficient Value')
            plt.show()
    return log_reg, balanced_acc_train


# Helper function to scale data
def scale_data(df, scaler, ext_cat):
    df_scaled = df.copy()
    df_scaled.loc[[i for i in df_scaled.index if i not in ext_cat], :] = scaler.fit_transform(df_scaled.loc[[i for i in df_scaled.index if i not in ext_cat], :])
    df_scaled.loc[ext_cat, :] = scaler.transform(df_scaled.loc[ext_cat, :])
    return df_scaled

# Helper function to calculate LOOCV
def perform_loocv(X, y, cat, model, all_coef, comb):
    cv = LeaveOneOut()
    actual_values, predicted_values, cats = [], [], []
    for train_index, test_index in cv.split(X):
        X_train, X_test = X[train_index], X[test_index]
        y_train, y_test = y[train_index], y[test_index]
        cat_train, cat_test = cat[train_index], cat[test_index]
        cats.append(cat[test_index])
        model.fit(X_train, y_train)
        y_pred = model.predict(X_test)
        actual_values.append(y_test[0])
        predicted_values.append(y_pred[0])
        
        for n,c in enumerate(list(all_coef[comb].keys())):
            if c == 'intercept':
                all_coef[comb]['intercept'].append(model.intercept_)
            else:
                all_coef[comb][c].append(model.coef_[n])
            
    return np.array(actual_values), np.array(predicted_values), np.array(cats).flatten(), all_coef

# Helper function to calculate LOOCV
#def perform_cv(X, y, cat, model, all_coef, comb):
#    cv = KFold(n_splits=6)
#    
#    actual_values, predicted_values, cats = [], [], []
#    for train_index, test_index in cv.split(X):
#        X_train, X_test = X[train_index], X[test_index]
#        y_train, y_test = y[train_index], y[test_index]
#        cat_train, cat_test = cat[train_index], cat[test_index]
#        cats.append(cat[test_index])
#        model.fit(X_train, y_train)
#        y_pred = model.predict(X_test)
#        actual_values.append(y_test)
#        predicted_values.append(y_pred)
#        
#        for n,c in enumerate(list(all_coef[comb].keys())):
#            if c == 'intercept':
#                all_coef[comb]['intercept'].append(model.intercept_)
#            else:
#                all_coef[comb][c].append(model.coef_[n])
#    model.fit(X, y)    
#    
#    for n,c in enumerate(list(all_coef[comb].keys())):
#        if c == 'intercept':
#            all_coef[comb]['intercept'].append(model.intercept_)
#        else:
#            all_coef[comb][c].append(model.coef_[n])
#    #todo analyze variance for each coefficient and plot catalyst vs coeff to see variation (outliers?)
#    
#    return np.hstack(np.array(actual_values)).flatten(), np.hstack(np.array(predicted_values)).flatten(), #np.hstack(np.array(cats).flatten()), all_coef#
#
#




########################
def cluster_data(D, scal, ext_cat, linkmethod, nclu, yeo =False, dG=False):
    D_scaled = D.copy()#scale_data(D, scal, ext_cat)
    if yeo:
        for c in D_scaled.columns:
            y, best_lambda = stats.yeojohnson(D_scaled[c].values)
            D_scaled.loc[:,c]=y
    elif dG:
        for c in D_scaled.columns:
            y = np.array([normalize_yield(i) for i in D_scaled[c].values])
            D_scaled.loc[:,c]=y
    
    #D_scaled.loc['comp_11',:]=D_scaled.loc['comp_10',:]
    #D_scaled = D_scaled.drop('comp_11', axis=0)
    # Clustering logic
    if D_scaled.values.shape[1]==1:
        C = pd.qcut(D_scaled.values.flatten(), 3, labels=[0, 1, 2])
    else:
        y = pdist(D_scaled.values, metric = 'euclidean')
        Z = linkage(y, linkmethod)
        C = fcluster(Z, nclu, criterion='maxclust')
    return C

def calc_polyfeat(dp, scal, ext_cat, featcols, varthr, ecol, ecol_std):
    scaler = scal
    poly = PolynomialFeatures(degree=2, include_bias=False)

    dp.loc[~dp.ID.isin(ext_cat),featcols] = scaler.fit_transform(dp[~dp.ID.isin(ext_cat)][featcols])
    dp.loc[dp.ID.isin(ext_cat),featcols] = scaler.transform(dp[dp.ID.isin(ext_cat)][featcols])

    # Fit and transform the data
    X_poly = poly.fit_transform(dp[featcols])
    feature_names = poly.get_feature_names_out(input_features=featcols)

    selector = VarianceThreshold(threshold=varthr)
    X_poly = selector.fit_transform(X_poly)
    print(X_poly.shape)
    feature_names = selector.get_feature_names_out(feature_names)

    dp = pd.concat([dp.loc[:,['Catalyst','ID'] + ecol+ecol_std], pd.DataFrame(X_poly, index=dp.index, columns =feature_names)], axis = 1)
    return dp, feature_names

def plotting_annot(axname, X, y, y_predstr, ytrues, ypreds, ytst, y_predsts, y_std,yloo_std,ytst_std, lab = None):  #axes.flatten()[cnt]
    # Plotting and annotation
    # Calculate metrics
    r2 = r2_score(ytrues, ypreds)
    mae = mean_absolute_error(ytrues, ypreds)
    rmse = np.sqrt(mean_squared_error(ytrues, ypreds))

    # Adjusted R² calculation
    n = X.shape[0]  # Number of samples
    p = X.shape[1]  # Number of predictors
    adjusted_r2 = 1 - (1 - r2) * (n - 1) / (n - p - 1)
    
    # Calculate final metrics on train/test
    r2_tr = r2_score(y, y_predstr)
    mae_tr = mean_absolute_error(y, y_predstr)
    
    if len(ytst)>2:
        r2_test = r2_score(ytst, y_predsts)
    else:
        r2_test = np.nan
    mae_test = mean_absolute_error(ytst, y_predsts)
    rmse_test = np.sqrt(mean_squared_error(ytst, y_predsts))
    
    axname.scatter(y, y_predstr, label='Train', c = 'k', alpha = 0.5, s=np.array(list((plt.rcParams['lines.markersize'])*np.array(y_std)+5)))
    axname.scatter(ytrues, ypreds, label='LOO', c = 'b', s=np.array(list((plt.rcParams['lines.markersize'])*np.array(yloo_std)+5)))
    axname.scatter(ytst, y_predsts, label='Test', c = 'r', s=np.array(list((plt.rcParams['lines.markersize'])*np.array(ytst_std)+5)))

    # Define x values
    x = np.linspace(0, 100, 400)

    # Calculate y values for the main line and the shaded area
    y_main = x
    y_upper = y_main + 20
    y_lower = y_main - 20

    axname.axline((0, 0), slope=1, color='red', linestyle=':')
    axname.fill_between(x, y_lower, y_upper, color='red', alpha=0.1)

    # Box color based on test MAE value
    box_color = 'gold' #'green' if (r2 > 0.5) and (mae_test<=20) else 'red'

    # Add metrics as text in the subplot
    metrics_text = f"TrainMAE: {mae_tr:.2f}\n" \
                   f"LOOMAE: {mae:.2f}\n" \
                   f"TestMAE: {mae_test:.2f}"
    
    #f"TrainR²: {r2_tr:.2f}\n" \
    #f"LOOR²: {r2:.2f}\n" \
    #f"TestR²: {r2_test:.2f}\n" \
    axname.text(0.05, 0.95, metrics_text, transform=axname.transAxes,
                             fontsize=9, verticalalignment='top',
                             bbox=dict(facecolor=box_color, alpha=0.1))
    
    axname.set_title(lab)
    axname.legend(['Train','LOOCV','Test'], loc='lower right')
    
    return r2_tr, r2, r2_test, mae_tr, mae, mae_test
    
    
    
def PLSRegressionCV(X, y):
    loocv = LeaveOneOut()

    # Define a range of components to try (e.g., from 1 to the number of features in X)
    n_components_range = range(1, min([X.shape[1], 5]))

    # Store mean squared errors for each number of components
    mse_scores = []

    for n_components in n_components_range:
        mse_fold = []
        pls = PLSRegression(n_components=n_components, scale=False)

        # LOO-CV process
        for train_idx, test_idx in loocv.split(X):
            X_train, X_test = X.values[train_idx], X.values[test_idx]
            y_train, y_test = y[train_idx], y[test_idx]

            # Fit the model
            pls.fit(X_train, y_train)

            # Predict and calculate error
            y_pred = pls.predict(X_test)
            mse_fold.append(mean_squared_error(y_test, y_pred))

        # Average MSE for this number of components
        mse_scores.append(np.mean(mse_fold))

    # Find the optimal number of components (minimizing MSE)
    optimal_n_components = n_components_range[np.argmin(mse_scores)]
    print(f'Optimal number of components: {optimal_n_components}')

    return optimal_n_components

def feat_sele_corr(dp, feature_names, ecol, ext_cat):
    dp_corr = dp[~dp.ID.isin(ext_cat)]
    D = dp_corr.loc[:, ['ID']+ecol+feature_names].set_index('ID')
    featcolsselected = D.corr().abs().iloc[1:,:].loc[:,ecol].sort_values(by=ecol, ascending=False).index[0]
    return [featcolsselected]

from itertools import combinations
from pathlib import Path
from time import time
from multiprocessing import Pool, cpu_count

def process_combination(combination, dp_corr, ecol, dG):
    d1, d2 = combination
    X = dp_corr.loc[:, [d1, d2]]
    y = dp_corr[ecol].values
    if dG:
        y = np.array([normalize_yield(i) for i in y])
    final_model, ypreds, ytrues, yloo_std, best_alpha = ridge_cv(X, y, None)
    
    r2 = r2_score(ytrues, ypreds)
    return (d1, d2, r2)


def feat_sele_set2(dp, feature_names, ecol, ext_cat, preferred_set, dG=False):
    all_triplets = combinations(feature_names,2)
    dp_corr = dp[~dp.ID.isin(ext_cat)]
    r2max = 0
    featcolsselected=[]
    
    my_file = Path(f'/mnt/3_ML/best2/{preferred_set[0]}_{preferred_set[1]}_{ecol[0]}.csv')
    if my_file.is_file():
        dict_r2 = pd.read_csv(f'/mnt/3_ML/best2/{preferred_set[0]}_{preferred_set[1]}_{ecol[0]}.csv')
        r2max = dict_r2.sort_values(by='Value', ascending = False).Value[0]
        featcolsselected = list(dict_r2.sort_values(by='Value', ascending = False).Pair[0])
    else:
        all_triplets = combinations(feature_names, 2)
        dp_corr = dp[~dp.ID.isin(ext_cat)] 
        r2max = 0
        featcolsselected = [] 
        dict_r2 = {}

        start = time()

        with Pool(processes=cpu_count()) as pool:
            # Pass dp_corr, ecol, and dG to the process_combination function
            results = pool.starmap(process_combination, 
                                    [(comb, dp_corr, ecol, dG) for comb in all_triplets])

        for d1, d2, r2 in results:
            if r2 > r2max:
                featcolsselected = [d1, d2]
                r2max = r2  # Update r2max here
            dict_r2[(d1, d2)] = r2
            print(f'Processed combination ({d1}, {d2}) with R² = {r2}')

        print(f'{len(results)} combinations processed in {time()-start} seconds')

        dict_r2_df = pd.DataFrame(list(dict_r2.items()), columns=['Pair', 'Value'])
        dict_r2_df.to_csv(f'/mnt/3_ML/best2/{preferred_set[0]}_{preferred_set[1]}_{ecol[0]}.csv', index=False)
        
    print(f'R2MAX = {r2max}')
    return featcolsselected

def feat_sele_logi_regr(dp, feature_names, ecol, scal, ext_cat, linkmethod='average', nclu=3, yeo=False, dG=False, plotfig = False):
    D = dp.loc[:, ['ID'] + ecol].set_index('ID')
    D[D.columns] = D[D.columns].apply(pd.to_numeric, errors='coerce')
    print(f'{D.isna().sum().sum()} missing values imputed')
    D = D.T
    D = D.fillna(D.median())
    D = D.T
    C = cluster_data(D, scal, ext_cat, linkmethod, nclu, yeo=False, dG=dG)
    #if sum(C==1)>1 and sum(C==2)>1:

    dp['fcluster'] = C
    if min(dp['fcluster'].value_counts().values)<3:
        return []
    else:
        row_colors = dp['fcluster'].map(dict(zip(dp['fcluster'].unique(), "rbgkm")))
        
        if plotfig:
            plt.figure()
            a = sns.clustermap(D, method=linkmethod, row_colors=row_colors.values)
            plt.show()

        # Feature selection and model fitting
        bestc = find_best_c(dp, feature_names, ext_cat, scaler=None)
        mylr, balanced_acc = plot_coef(dp, feature_names, bestc, ext_cat, scaler=None)
        #featcolsselected = list(set([x for xs in np.abs(mylr.coef_) > 0 for x in xs]))

        #np.percentile(np.abs(mylr.coef_[i,:][mylr.coef_[i,:]!=0]), 75)
        featcolsselected = [list(np.array(feature_names)[np.abs(mylr.coef_[i,:])>0]) for i in range(mylr.coef_.shape[0])]
        featcolsselected = [x for xs in featcolsselected for x in xs]
        featcolsselected = list(set(featcolsselected))
        return featcolsselected
    
def normalize_yield(y):
    h = 6.62607004E-34
    RefC = 1
    exponent = 0
    Kb = 1.380649E-23
    SusT = 333.15
    R = 0.00831446261815324
    nH=2

    result = np.log((np.log(1 / (1 - (y / 100))) / (nH * 60 * 60)) * h / 
                      ((1 / (RefC ** exponent)) * (Kb * SusT))) * R * SusT
    
    #maxdg = -100 #np.log((np.log(1 / (1 - (99.9999 / 100))) / (nH * 60 * 60)) * h / 
                      #((1 / (RefC ** exponent)) * (Kb * SusT))) * R * SusT
    #mindg = -120 #np.log((np.log(1 / (1 - (0.0001 / 100))) / (nH * 60 * 60)) * h / 
                      #((1 / (RefC ** exponent)) * (Kb * SusT))) * R * SusT
        
    
    return result #(result-mindg)/(maxdg-mindg)

def denormalize_yield(normalized_result, maxdg=-100, mindg=-120):
    h = 6.62607004E-34
    RefC = 1
    exponent = 0
    Kb = 1.380649E-23
    SusT = 333.15
    R = 0.00831446261815324
    nH = 2

    # Constants for maxdg and mindg as given in your original function
    #maxdg = -100
    #mindg = -120

    # Reverse normalization
    result = (normalized_result/100) * (maxdg - mindg) + mindg

    # Rearranging to find y
    intermediate = np.exp(result / (R * SusT))
    y = 100 * (1 - np.exp(-nH * 60 * 60 * intermediate * (Kb * SusT) / h / (1 / (RefC ** exponent))))
    
    return y

def feat_sele_lasso(dp, feature_names, ecol, scal, ext_cat, remove_incons, plotcoef = False, yeo=False, dG=False):
    dp_lasso = dp[~dp.ID.isin(ext_cat)]  #dp.copy()
    if remove_incons:
        incons_df = pd.read_csv(f'/mnt/data/{remove_incons}.csv')
        incons_df['ID']=[id_to_id[i] for i in incons_df['Catalyst']]
        tgt = ecol[0].split('_')[0]
        for cnt, row in incons_df.iterrows():
            dp_lasso.loc[dp_lasso.ID==row.ID,f'{tgt}_{row.Solvent}_{row.Base}'] = np.nan
    
    X = dp_lasso.loc[:,feature_names].values
    
    #if yeo:
    #    for c in ecol:
    #        y, best_lambda = stats.yeojohnson(dp_lasso[c].values)
    #        dp_lasso.loc[:,c]=y
    y = dp_lasso.loc[:,ecol].median(axis=1).values
    
    if yeo:
        y, best_lambda = stats.yeojohnson(y)
        
    elif dG:
        y = np.array([normalize_yield(i) for i in y])
        
    #scaler = scal
    #X_scaled = scaler.fit_transform(X)
    mylr = LassoCV(cv=LeaveOneOut(), 
                   alphas=np.arange(0.1,10,0.1), 
                   max_iter=5000, random_state=42, tol=1e-5).fit(X, y)
    
    
    # Leave-One-Out Cross-Validation

    loo = LeaveOneOut()
    y_true, y_pred = [], []
    for train_index, test_index in loo.split(X):
        mymod = Lasso(alpha = mylr.alpha_, max_iter=5000, random_state=42, tol=1e-5)
        X_train, X_test = X[train_index], X[test_index]
        y_train, y_test = y[train_index], y[test_index]

        # Train the model
        mymod.fit(X_train, y_train)

        # Predict the test sample
        y_pred.append(mymod.predict(X_test))
        y_true.append(y_test)

    # Calculate balanced accuracy
    r2loo = r2_score(y_true, y_pred)
    print(f"R2 (LOO CV): {r2loo:.4f}")

    # Plot bar plot of coefficients for each class
    coef = mylr.coef_
    # Create bar plots for each class
    sorted_indices_all = np.argsort(coef)
    sorted_coef_all = coef[sorted_indices_all]
    sorted_coef = [i for i in sorted_coef_all if abs(i)>0]
    sorted_indices = [feature_names[i] for i,j in zip(sorted_indices_all, sorted_coef_all) if abs(j)>0]
    
    if plotcoef:

        plt.figure(figsize=(20, 3))
        plt.bar(range(len(sorted_coef)), sorted_coef)
        plt.xticks(range(len(sorted_coef)), sorted_indices, rotation=90)
        plt.title(f"Coefficients")
        plt.xlabel('Features (sorted)')
        plt.ylabel('Coefficient Value')
        plt.show()
        
        #############################
    
    print(mylr.alpha_)

    featcolsselected = list(np.array(feature_names)[np.abs(mylr.coef_)>0])
    return featcolsselected

def logi_cv(X,y, y_index):
    loo=LeaveOneOut()
    alphas = 10
    final_model = LogisticRegressionCV(Cs=alphas, cv=loo, penalty='l2',class_weight='balanced').fit(X, y)
    best_alpha = final_model.C_[0]
    
    ypreds = []
    ytrues = []
    y_index_loo = []
    
    for train_index, test_index in loo.split(X):
        X_train, X_test = X.values[train_index], X.values[test_index, :]
        y_train, y_test = y[train_index], y[test_index]
            
        if y_index is not None:
            y_index_train, y_index_test = y_index[train_index], y_index[test_index]

        # **RidgeCV** for automatic alpha tuning
        #ridge_cv = RidgeCV(alphas=alphas, cv=loo).fit(X_train, y_train)
        mod = LogisticRegression(C=best_alpha, random_state=42,class_weight='balanced').fit(X_train, y_train)
        y_pred = mod.predict(X_test)

        ypreds.append(y_pred)
        ytrues.append(y_test)
        if y_index is not None:    
            y_index_loo.append(y_index_test)

    return final_model, np.array(ypreds).flatten(), np.array(ytrues).flatten(), y_index_loo, best_alpha

def ridge_cv(X,y, y_std, y_index):
    loo=LeaveOneOut()
    alphas = list(np.logspace(-4, 1, 6))+list(np.arange(1,5,1))
    final_model = RidgeCV(alphas=alphas, cv=loo).fit(X, y)
    best_alpha = final_model.alpha_
    
    ypreds = []
    ytrues = []
    yloo_std = []
    y_index_loo = []
    
    for train_index, test_index in loo.split(X):
        X_train, X_test = X.values[train_index], X.values[test_index, :]
        y_train, y_test = y[train_index], y[test_index]
        if y_std is not None:
            yloo_train, yloo_test = y_std.values[train_index], y_std.values[test_index]
            
        if y_index is not None:
            y_index_train, y_index_test = y_index[train_index], y_index[test_index]

        # **RidgeCV** for automatic alpha tuning
        #ridge_cv = RidgeCV(alphas=alphas, cv=loo).fit(X_train, y_train)
        mod = Ridge(alpha=best_alpha, random_state=42).fit(X_train, y_train)
        y_pred = mod.predict(X_test)

        ypreds.append(y_pred)
        ytrues.append(y_test)
        if y_std is not None:
            yloo_std.append(yloo_test)
        if y_index is not None:    
            y_index_loo.append(y_index_test)
    #final_model = RidgeCV(alphas=alphas, cv=loo).fit(X, y)
    return final_model, np.array(ypreds).flatten(), np.array(ytrues).flatten(), yloo_std, y_index_loo, best_alpha

def lasso_cv(X,y, y_std):
    loo=LeaveOneOut()
    alphas = list(np.logspace(-4, 1, 6))+list(np.arange(1,5,1))
    final_model = LassoCV(alphas=alphas, cv=loo, random_state=42).fit(X, y)
    best_alpha = final_model.alpha_
    
    ypreds = []
    ytrues = []
    yloo_std = []
    
    for train_index, test_index in loo.split(X):
        X_train, X_test = X.values[train_index], X.values[test_index, :]
        y_train, y_test = y[train_index], y[test_index]
        if y_std is not None:
            yloo_train, yloo_test = y_std.values[train_index], y_std.values[test_index]

        # **RidgeCV** for automatic alpha tuning
        #ridge_cv = RidgeCV(alphas=alphas, cv=loo).fit(X_train, y_train)
        mod = Lasso(alpha=best_alpha, random_state=42).fit(X_train, y_train)
        y_pred = mod.predict(X_test)

        ypreds.append(y_pred)
        ytrues.append(y_test)
        if y_std is not None:
            yloo_std.append(yloo_test)
    #final_model = LassoCV(alphas=alphas, cv=loo, random_state=42).fit(X, y)
    return final_model, np.array(ypreds).flatten(), np.array(ytrues).flatten(), yloo_std, best_alpha

def pls_cv(X,y, y_std, y_index):
    loo=LeaveOneOut()
    best_n_component = PLSRegressionCV(X, y)
    
    ypreds = []
    ytrues = []
    yloo_std = []
    y_index_loo = []
    
    for train_index, test_index in loo.split(X):
        X_train, X_test = X.values[train_index], X.values[test_index, :]
        y_train, y_test = y[train_index], y[test_index]
        yloo_train, yloo_test = y_std.values[train_index], y_std.values[test_index]
        if y_index is not None:
            y_index_train, y_index_test = y_index[train_index], y_index[test_index]

        mod = PLSRegression(n_components=best_n_component, scale = False).fit(X_train, y_train)
        y_pred = mod.predict(X_test)[0]

        ypreds.append(y_pred)
        ytrues.append(y_test)
        yloo_std.append(yloo_test)
        if y_index is not None:    
            y_index_loo.append(y_index_test)
    final_model = PLSRegression(n_components=best_n_component, scale = False).fit(X, y)
    return final_model, np.array(ypreds).flatten(), np.array(ytrues).flatten(), yloo_std, y_index_loo, best_n_component

# Scrambled target Ridge regression
#y_scrambled = np.random.permutation(y_train)  # Shuffle the target values



def initialize_dict(scaling = True):
    res_dict = {}
    #res_dict['preferred_set'] = []
    res_dict['approach'] = []
    res_dict['target'] = []
    if scaling:
        res_dict['scaling'] = []
    #res_dict['featsele'] = []
    res_dict['algo'] = []
    res_dict['dG'] = []
    res_dict['condition'] = []
    res_dict['R2train'] = []
    res_dict['R2LOO'] = []
    res_dict['R2test'] = []
    res_dict['MAEtrain'] = []
    res_dict['MAELOO'] = []
    res_dict['MAEtest'] = []
    res_dict['feats'] = []
    res_dict['num_feats'] = []
    return res_dict

def read_exp_data(approach):
    if approach == 'restrictive':
        #data, _ = read_experimental_data_DMS_v2_old(analysis = 'GC', 
        #                               plating='Method2', 
        #                               aggregation='median', 
        #                               inconsistency_imputation=None,
        #                               thr = 10)
        #print(f'{data.isna().sum().sum()} missing values')
        #data = data.drop('Catalyst', axis=1)
        #data = data.set_index('ID')
        
        data,_ = read_experimental_data_DMS_new(analysis = 'GC', 
                           plating='Method2',std_thr=9)
        print(f'{data.isna().sum().sum()} missing values')
        
    else:
        #data,_ = read_experimental_data_DMS(analysis = 'GC', 
        #                   plating='Method2', 
        #                   aggregation='median')
        
        data,_ = read_experimental_data_DMS_new(analysis = 'GC', 
                           plating='Method2')
        print(f'{data.isna().sum().sum()} missing values')
    return data

def read_descr(path='/mnt/2_descriptors/all_df_no_cis_pdcl2.csv'):
    df = pd.read_csv(path, index_col = 0)
    df = df.dropna(axis = 1)
    df2 = pd.DataFrame(df.values**2, index = df.index, columns = [f'{c}**2' for c in df.columns])
    df = pd.concat([df, df2], axis = 1)

    featcols = df.columns.to_list()
    print(len(featcols))
    df = df.reset_index()
    df.columns = ['ID']+featcols
    return df, featcols

def filter_target(data, expcols, featcols):
    expcols_yield = [i for i in expcols if ('Yield' in i) and ('std' not in i)]
    expcols_yield = [i for i in expcols_yield if i not in ['Yield_Toluene_K2CO3',
                                                           'Yield_Toluene_K3PO4', 
                                                           'Yield_Me-THF_K2CO3', 
                                                           'Yield_iPrOAc_K2CO3',
                                                           'Yield_EtOH_Cs2CO3']]

    expcols_yield_std = [i.replace('Yield_','Yield_std_') for i in expcols_yield]


    data_yield = data.loc[:,['ID']+expcols_yield+expcols_yield_std+featcols]
    return data_yield, expcols_yield, expcols_yield_std

def calc_lst1_lst2(pair_of_feats, dict_correlated_intra, dict_correlated_inter):
    lst1 = [pair_of_feats[0]] + list(dict_correlated_intra['electronic_descriptors_TPSS_TZVP_COSMO'][pair_of_feats[0]])+list(dict_correlated_inter[pair_of_feats[0]])
    lst1 = lst1+[f'{i}**2' for i in lst1]
    lst1 = [i for i in lst1 if 'TPSS_TZVP_COSMO' in i]
    if pair_of_feats[1] in list(dict_correlated_intra['steric_%V'].keys()):
        lst2 = [pair_of_feats[1]]+list(dict_correlated_intra['steric_%V'][pair_of_feats[1]])+list(dict_correlated_inter[pair_of_feats[1]])
    else:
        lst2 = [pair_of_feats[1]]+list(dict_correlated_inter[pair_of_feats[1]])
    lst2 = lst2+[f'{i}**2' for i in lst2]
    return lst1, lst2

def calc_lst1_lst2_v2(pair_of_feats, dict_correlated_intra, dict_correlated_inter):
    lst1 = [pair_of_feats[0]] + list(dict_correlated_intra['electronic_descriptors_TPSS_TZVP_COSMO'][pair_of_feats[0]])+list(dict_correlated_inter[pair_of_feats[0]])
    lst1 = lst1+[f'{i}**2' for i in lst1]
    lst1 = [i for i in lst1 if 'TPSS_TZVP_COSMO' in i]
    if pair_of_feats[1] in list(dict_correlated_intra['electronic_descriptors_TPSS_TZVP_COSMO'].keys()):
        lst2 = [pair_of_feats[1]]+list(dict_correlated_intra['electronic_descriptors_TPSS_TZVP_COSMO'][pair_of_feats[1]])+list(dict_correlated_inter[pair_of_feats[1]])
    else:
        lst2 = [pair_of_feats[1]]+list(dict_correlated_inter[pair_of_feats[1]])
    lst2 = lst2+[f'{i}**2' for i in lst2]
    return lst1, lst2

def norm_y(y, ytst):
    y = np.array([normalize_yield(i) for i in y])
    maxdg = max(y)
    mindg = min(y)
    y = ((y-mindg)/(maxdg-mindg))*100

    ytst = np.array([normalize_yield(i) for i in ytst])
    ytst = ((ytst-mindg)/(maxdg-mindg))*100
    return y, ytst, maxdg, mindg

def denorm_y(ypreds, ytrues, y_predsts, y_predstr, y, ytst, maxdg, mindg):
    ypreds = np.array([denormalize_yield(i, maxdg=maxdg, mindg=mindg) for i in ypreds])
    ytrues = np.array([denormalize_yield(i, maxdg=maxdg, mindg=mindg) for i in ytrues])
    y_predsts = np.array([denormalize_yield(i, maxdg=maxdg, mindg=mindg) for i in y_predsts])
    y_predstr = np.array([denormalize_yield(i, maxdg=maxdg, mindg=mindg) for i in y_predstr])
    y = np.array([denormalize_yield(i, maxdg=maxdg, mindg=mindg) for i in y])
    ytst = np.array([denormalize_yield(i, maxdg=maxdg, mindg=mindg) for i in ytst])
    return ypreds, ytrues, y_predsts, y_predstr, y, ytst

def main_comp(dpxx, ext_cat, 
              featcolsselected, c, c_std, coeffs, 
              dG, res_dict, axes, cnt, 
              eq=True, conf_df = None, 
              algof = 'Ridge', labdp= 'Yield'):
    data_train, data_test = dpxx[~dpxx.ID.isin(ext_cat)], dpxx[dpxx.ID.isin(ext_cat)]
    #-------------------
    data_test = data_test[~data_test.ID.isin(['comp_5','comp_6'])]
    #------------------
    X = data_train.loc[:, featcolsselected]
    y = data_train[c].values
    y_std = data_train[c_std]
    y_index = data_train.ID.values
    Xtst = data_test.loc[:, featcolsselected]
    ytst = data_test[c].values
    ytst_std = data_test[c_std]
    
    if conf_df:
        X_aniso = conf_df['dpx_aniso'].loc[:, featcolsselected]
        X_vbur = conf_df['dpx_vbur'].loc[:, featcolsselected]
        X_anisovbur = conf_df['dpx_anisovbur'].loc[:, featcolsselected]
        y_aniso = conf_df['dpx_aniso'][c]
        y_vbur = conf_df['dpx_vbur'][c]
        y_anisovbur = conf_df['dpx_anisovbur'][c]


    if dG:
        y, ytst, maxdg, mindg=norm_y(y, ytst)

    # Initialize LOOCV
    loo = LeaveOneOut()

    if algof == 'Ridge':
        final_model, ypreds, ytrues, yloo_std, y_index_loo, best_alpha = ridge_cv(X,y, y_std, y_index)
    else:
        final_model, ypreds, ytrues, yloo_std, y_index_loo, best_component = pls_cv(X,y, y_std, y_index)

    coeffs[c] = list(final_model.coef_.flatten()) + [final_model.intercept_]

    y_predsts = final_model.predict(Xtst)
    y_predstr = final_model.predict(X)
    
    if conf_df:
        y_preds_aniso = final_model.predict(X_aniso)
        y_preds_vbur = final_model.predict(X_vbur)
        y_preds_anisovbur = final_model.predict(X_anisovbur)

    if dG:
        ypreds, ytrues, y_predsts, y_predstr, y, ytst=denorm_y(ypreds, 
                                                               ytrues, 
                                                               y_predsts, 
                                                               y_predstr, 
                                                               y, 
                                                               ytst, 
                                                               maxdg, mindg)
        if conf_df:
            y_preds_aniso = np.array([denormalize_yield(i, maxdg=maxdg, mindg=mindg) for i in y_preds_aniso])
            y_preds_vbur = np.array([denormalize_yield(i, maxdg=maxdg, mindg=mindg) for i in y_preds_vbur])
            y_preds_anisovbur = np.array([denormalize_yield(i, maxdg=maxdg, mindg=mindg) for i in y_preds_anisovbur])


    ypreds = np.clip(ypreds, 0, 100)
    y_predsts = np.clip(y_predsts, 0, 100)
    y_predstr = np.clip(y_predstr, 0, 100)
    if conf_df:
        y_preds_aniso = np.clip(y_preds_aniso, 0, 100)
        y_preds_vbur = np.clip(y_preds_vbur, 0, 100)
        y_preds_anisovbur = np.clip(y_preds_anisovbur, 0, 100)

        conf_df['dpx_aniso']['pred']=y_preds_aniso
        conf_df['dpx_vbur']['pred']=y_preds_vbur
        conf_df['dpx_anisovbur']['pred']=y_preds_anisovbur
        
    if conf_df:
        pass
    else:
        ## Calculate residuals
        residuals = ytrues - ypreds
        residuals = np.array(list(residuals))
        ## Generate Q-Q plot
        stats.probplot(residuals.flatten(), dist="norm", plot=axes.flatten()[cnt+8])
        axes.flatten()[cnt+8].set_title("Q-Q Plot of Residuals for Ridge Regression")
    
    seledf=pd.DataFrame({'ID':np.array(y_index_loo).flatten(), 
                         'True':ytrues,
                         'Pred':ypreds})
    seledf['diff']=np.abs(seledf['True'].values-seledf['Pred'].values)
    seledf = seledf[seledf['diff']>20]
    
    if conf_df:
        #############################
        sns.scatterplot(data=conf_df['dpx_vbur'], x=c, y='pred', hue='ID', 
                        ax = axes[cnt,1], legend = False)
        axes[cnt,1].axline((0,0), slope = 1, c='k', linestyle=':')
        axes[cnt,1].set_title('conformer-VburC2')
        axes[cnt,1].set_xlabel('Exp.Yield')
        axes[cnt,1].set_ylabel('Pred.Yield')
        for ll, group in conf_df['dpx_vbur'].groupby('ID'):
            group = group.sort_values(by='pred', ascending = False)
            axes[cnt,1].text(group[c].values[0],group['pred'].values[0]+1,ll.replace('comp_','Pd-'), fontsize=9)
        #sns.move_legend(axes[cnt,1], "upper left", bbox_to_anchor=(1, 1))

        sns.scatterplot(data=conf_df['dpx_aniso'], x=c, y='pred', hue='ID', 
                        ax = axes[cnt,2], legend = False)
        axes[cnt,2].axline((0,0), slope = 1, c='k', linestyle=':')
        axes[cnt,2].set_title('conformer-anisotropy')
        axes[cnt,2].set_xlabel('Exp.Yield')
        axes[cnt,2].set_ylabel('Pred.Yield')
        for ll, group in conf_df['dpx_aniso'].groupby('ID'):
            group = group.sort_values(by='pred', ascending = False)
            axes[cnt,2].text(group[c].values[0],group['pred'].values[0]+1,ll.replace('comp_','Pd-'), fontsize=9)
        #sns.move_legend(axes[cnt,2], "upper left", bbox_to_anchor=(1, 1))

        sns.scatterplot(data=conf_df['dpx_anisovbur'], x=c, y='pred', hue='ID', 
                        ax = axes[cnt,3],cmap='CMRmap', legend = False)
        
        for ll, group in conf_df['dpx_anisovbur'].groupby('ID'):
            group = group.sort_values(by='pred', ascending = False)
            axes[cnt,3].text(group[c].values[0],group['pred'].values[0]+1,ll.replace('comp_','Pd-'), fontsize=9)
        axes[cnt,3].axline((0,0), slope = 1, c='k', linestyle=':')
        axes[cnt,3].set_title('conformer analysis')
        axes[cnt,3].set_xlabel('Exp.Yield')
        axes[cnt,3].set_ylabel('Pred.Yield')
        #sns.move_legend(axes[cnt,1], "upper left", bbox_to_anchor=(1, 1))


        #############################

        r2_tr, r2, r2_test, mae_tr, mae, mae_test = plotting_annot(axes[cnt,0],
                                                                   X, 
                                                                   y, 
                                                                   y_predstr, 
                                                                   ytrues, 
                                                                   ypreds, 
                                                                   ytst, 
                                                                   y_predsts,
                                                                  y_std,
                                                                  np.array(yloo_std).flatten(),
                                                                  ytst_std,
                                                                  lab = f"{c.replace('Yield_','').replace('_',' ')}") # ({len(dpxx)})({len(featcolsselected)})
        for idxrow, row in seledf.iterrows():
            axes[cnt,0].text(row['True'], row.Pred, row.ID.replace('comp_','Pd-'), fontsize=9)

        axes[cnt,0].set_xlabel('Exp.Yield')
        axes[cnt,0].set_ylabel('Pred.Yield')

        if algof=='PLS':
            add_text = f'y = {final_model.intercept_[0]:.2f} + \n'+''.join([f"{final_model.coef_[i][0]:.2f} * {featcolsselected[i].replace('para','trans').replace('ortho','cis')} + \n" for i in range(len(featcolsselected))])
        else:
            add_text = f'y = {final_model.intercept_:.2f} + \n'+''.join([f"{final_model.coef_[i]:.2f} * {featcolsselected[i].replace('para','trans').replace('ortho','cis')} + \n" for i in range(len(featcolsselected))])

        if eq:
            axes[cnt,0].text(0.05, 0.16, add_text, 
                                     transform=axes.flatten()[cnt].transAxes,
                                     fontsize=9, verticalalignment='top',
                                     bbox=dict(facecolor='yellow', alpha=0.02))
    else:
        r2_tr, r2, r2_test, mae_tr, mae, mae_test = plotting_annot(axes.flatten()[cnt], 
                                                               X, 
                                                               y, 
                                                               y_predstr, 
                                                               ytrues, 
                                                               ypreds, 
                                                               ytst, 
                                                               y_predsts,
                                                              y_std,
                                                              np.array(yloo_std).flatten(),
                                                              ytst_std,
                                                              lab = f'{c} ({len(dpxx)})({len(featcolsselected)})')

    

        for idxrow, row in seledf.iterrows():
            axes.flatten()[cnt].text(row['True'], row.Pred, row.ID.replace('comp_','Pd-'), fontsize=9)

        axes.flatten()[cnt].set_xlabel('Exp.Yield')
        axes.flatten()[cnt].set_ylabel('Pred.Yield')

        if algof=='PLS':
            add_text = f'y = {final_model.intercept_[0]:.2f} + \n'+''.join([f"{final_model.coef_[i][0]:.2f} * {featcolsselected[i].replace('para','trans').replace('ortho','cis')} + \n" for i in range(len(featcolsselected))])
        else:
            add_text = f'y = {final_model.intercept_:.2f} + \n'+''.join([f"{final_model.coef_[i]:.2f} * {featcolsselected[i].replace('para','trans').replace('ortho','cis')} + \n" for i in range(len(featcolsselected))])

        if eq:
            axes.flatten()[cnt].text(0.05, 0.16, add_text, 
                                     transform=axes.flatten()[cnt].transAxes,
                                     fontsize=9, verticalalignment='top',
                                     bbox=dict(facecolor='yellow', alpha=0.02))

    res_dict['target'].append(labdp)
    
    
    #res_dict['featsele'].append(sele)
    res_dict['algo'].append(algof)
    res_dict['dG'].append(dG)
    res_dict['condition'].append(c)
    res_dict['R2train'].append(r2_tr)
    res_dict['R2LOO'].append(r2)
    res_dict['R2test'].append(r2_test)
    res_dict['MAEtrain'].append(mae_tr)
    res_dict['MAELOO'].append(mae)
    res_dict['MAEtest'].append(mae_test)
    res_dict['feats'].append(featcolsselected)
    res_dict['num_feats'].append(len(featcolsselected))
    #res_dict['preferred_set'].append(preferred_set[0]+'+'+preferred_set[1])
    
    if conf_df:
        return res_dict, coeffs, conf_df
    else:
        return res_dict, coeffs
    
    
def main_comp_logo(dpxx, featcolsselected, c, c_std, coeffs, dG, cnt, eq=True, conf_df = None, algof = 'Ridge', labdp= 'Yield', coluni='R'):
    fig, axes = plt.subplots(2,len(dpxx[coluni].unique()), figsize = (10*len(dpxx[coluni].unique()),20))
    #rangernum = np.arange(0,len(dpxx.R.unique())*16,16)
    res_dict = {}
    
    for rnum, rgroup in enumerate(dpxx[coluni].unique()):
        res_dict[rgroup]={}
        res_dict[rgroup]['target'] = []
        res_dict[rgroup]['algo'] = []
        res_dict[rgroup]['dG'] = []
        res_dict[rgroup]['condition'] = []
        res_dict[rgroup]['R2train'] = []
        res_dict[rgroup]['R2LOO'] = []
        res_dict[rgroup]['R2test'] = []
        res_dict[rgroup]['MAEtrain'] = []
        res_dict[rgroup]['MAELOO'] = []
        res_dict[rgroup]['MAEtest'] = []
        res_dict[rgroup]['feats'] = []
        res_dict[rgroup]['num_feats'] = []
        
        
        cnt = rnum
        data_train, data_test = dpxx[dpxx[coluni]!=rgroup], dpxx[dpxx[coluni]==rgroup]
        X = data_train.loc[:, featcolsselected]
        y = data_train[c].values
        y_std = data_train[c_std]
        y_index = data_train.ID.values
        Xtst = data_test.loc[:, featcolsselected]
        ytst = data_test[c].values
        ytst_std = data_test[c_std]

        if dG:
            y, ytst, maxdg, mindg=norm_y(y, ytst)

        # Initialize LOOCV
        loo = LeaveOneOut()

        if algof == 'Ridge':
            final_model, ypreds, ytrues, yloo_std, y_index_loo, best_alpha = ridge_cv(X,y, y_std, y_index)
        else:
            final_model, ypreds, ytrues, yloo_std, y_index_loo, best_component = pls_cv(X,y, y_std, y_index)

        coeffs[c] = list(final_model.coef_.flatten()) + [final_model.intercept_]

        y_predsts = final_model.predict(Xtst)
        y_predstr = final_model.predict(X)

        if dG:
            ypreds, ytrues, y_predsts, y_predstr, y, ytst=denorm_y(ypreds, 
                                                                   ytrues, 
                                                                   y_predsts, 
                                                                   y_predstr, 
                                                                   y, 
                                                                   ytst, 
                                                                   maxdg, mindg)

        ypreds = np.clip(ypreds, 0, 100)
        y_predsts = np.clip(y_predsts, 0, 100)
        y_predstr = np.clip(y_predstr, 0, 100)

        ## Calculate residuals
        residuals = ytrues - ypreds
        residuals = np.array(list(residuals))
        ## Generate Q-Q plot
        stats.probplot(residuals.flatten(), dist="norm", plot=axes.flatten()[cnt+len(dpxx[coluni].unique())])
        axes.flatten()[cnt+len(dpxx[coluni].unique())].set_title("Q-Q Plot of Residuals for Ridge Regression")

        seledf=pd.DataFrame({'ID':np.array(y_index_loo).flatten(), 
                             'True':ytrues,
                             'Pred':ypreds})
        seledf['diff']=np.abs(seledf['True'].values-seledf['Pred'].values)
        seledf = seledf[seledf['diff']>20]

        
        r2_tr, r2, r2_test, mae_tr, mae, mae_test = plotting_annot(axes.flatten()[cnt], 
                                                               X, 
                                                               y, 
                                                               y_predstr, 
                                                               ytrues, 
                                                               ypreds, 
                                                               ytst, 
                                                               y_predsts,
                                                              y_std,
                                                              np.array(yloo_std).flatten(),
                                                              ytst_std,
                                                              lab = f'{c} ({len(dpxx)})({len(featcolsselected)})')



        for idxrow, row in seledf.iterrows():
            axes.flatten()[cnt].text(row['True'], row.Pred, row.ID, fontsize=9)

        axes.flatten()[cnt].set_xlabel('Exp.Yield')
        axes.flatten()[cnt].set_ylabel('Pred.Yield')

        if algof=='PLS':
            add_text = f'y = {final_model.intercept_[0]:.2f} + \n'+''.join([f"{final_model.coef_[i][0]:.2f} * {featcolsselected[i].replace('para','trans').replace('ortho','cis')} + \n" for i in range(len(featcolsselected))])
        else:
            add_text = f'y = {final_model.intercept_:.2f} + \n'+''.join([f"{final_model.coef_[i]:.2f} * {featcolsselected[i].replace('para','trans').replace('ortho','cis')} + \n" for i in range(len(featcolsselected))])

        if eq:
            axes.flatten()[cnt].text(0.05, 0.16, add_text, 
                                     transform=axes.flatten()[cnt].transAxes,
                                     fontsize=9, verticalalignment='top',
                                     bbox=dict(facecolor='yellow', alpha=0.02))

        res_dict[rgroup]['target']=labdp
        res_dict[rgroup]['algo']=algof
        res_dict[rgroup]['dG']=dG
        res_dict[rgroup]['condition']=c
        res_dict[rgroup]['R2train']=r2_tr
        res_dict[rgroup]['R2LOO']=r2
        res_dict[rgroup]['R2test']=r2_test
        res_dict[rgroup]['MAEtrain']=mae_tr
        res_dict[rgroup]['MAELOO']=mae
        res_dict[rgroup]['MAEtest']=mae_test
        res_dict[rgroup]['feats']=featcolsselected
        res_dict[rgroup]['num_feats']=len(featcolsselected)
        #res_dict['preferred_set'].append(preferred_set[0]+'+'+preferred_set[1])
    
    plt.tight_layout()
    plt.show()
    return res_dict, coeffs
            
        
def main_comp_class(dpxx, ext_cat, featcolsselected, c, coeffs, res_dict, axes, cnt, eq=True, conf_df = None, algof = 'Logi', labdp= 'Yield'):
    data_train, data_test = dpxx[~dpxx.ID.isin(ext_cat)], dpxx[dpxx.ID.isin(ext_cat)]
    #-------------------
    data_test = data_test[~data_test.ID.isin(['comp_5','comp_6'])]
    #------------------
    X = data_train.loc[:, featcolsselected]
    y = data_train[c].values
    y_index = data_train.ID.values
    Xtst = data_test.loc[:, featcolsselected]
    ytst = data_test[c].values
    
    if conf_df:
        X_aniso = conf_df['dpx_aniso'].loc[:, featcolsselected]
        X_vbur = conf_df['dpx_vbur'].loc[:, featcolsselected]
        X_anisovbur = conf_df['dpx_anisovbur'].loc[:, featcolsselected]
        y_aniso = conf_df['dpx_aniso'][c]
        y_vbur = conf_df['dpx_vbur'][c]
        y_anisovbur = conf_df['dpx_anisovbur'][c]


    # Initialize LOOCV
    loo = LeaveOneOut()

    if algof == 'Logi':
        final_model, ypreds, ytrues, y_index_loo, best_alpha = logi_cv(X,y, y_index)
    else:
        final_model, ypreds, ytrues, y_index_loo, best_component = pls_cv(X,y, y_std, y_index)

    coeffs[c] = list(final_model.coef_.flatten()) + [final_model.intercept_[0]]

    y_predsts = final_model.predict(Xtst)
    y_predstr = final_model.predict(X)
    
    if conf_df:
        y_preds_aniso = final_model.predict(X_aniso)
        y_preds_vbur = final_model.predict(X_vbur)
        y_preds_anisovbur = final_model.predict(X_anisovbur)

    if conf_df:
        y_preds_aniso = np.clip(y_preds_aniso, 0, 100)
        y_preds_vbur = np.clip(y_preds_vbur, 0, 100)
        y_preds_anisovbur = np.clip(y_preds_anisovbur, 0, 100)

        conf_df['dpx_aniso']['pred']=y_preds_aniso
        conf_df['dpx_vbur']['pred']=y_preds_vbur
        conf_df['dpx_anisovbur']['pred']=y_preds_anisovbur

    
    seledf=pd.DataFrame({'ID':np.array(y_index_loo).flatten(), 
                         'True':ytrues,
                         'Pred':ypreds})

    from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
    if conf_df:
        plt.figure(figsize = (3,3))
        #conf_df['dpx_anisovbur'], x=c, y='pred'
        cm = confusion_matrix(conf_df['dpx_anisovbur'][c].values, 
                              conf_df['dpx_anisovbur']['pred'].values, labels=final_model.classes_)
        
        df_cm = pd.DataFrame(cm, index = [i for i in final_model.classes_],
                          columns = [i for i in final_model.classes_])
        sns.heatmap(df_cm, annot=True)
        
        plt.show()
    else:
        print('TRAIN')
        plt.figure(figsize = (3,3))
        #conf_df['dpx_anisovbur'], x=c, y='pred'
        cm = confusion_matrix(y, 
                              y_predstr, labels=final_model.classes_)
        df_cm = pd.DataFrame(cm, index = [i for i in final_model.classes_],
                          columns = [i for i in final_model.classes_])
        sns.heatmap(df_cm, annot=True)
        plt.show()
        
        print('LOOCV')
        plt.figure(figsize = (3,3))
        #conf_df['dpx_anisovbur'], x=c, y='pred'
        cm = confusion_matrix(ytrues, 
                              ypreds, labels=final_model.classes_)
        df_cm = pd.DataFrame(cm, index = [i for i in final_model.classes_],
                          columns = [i for i in final_model.classes_])
        sns.heatmap(df_cm, annot=True)
        plt.show()
        
        print('TEST')
        plt.figure(figsize = (3,3))
        #conf_df['dpx_anisovbur'], x=c, y='pred'
        cm = confusion_matrix(ytst, 
                              y_predsts, labels=final_model.classes_)
        df_cm = pd.DataFrame(cm, index = [i for i in final_model.classes_],
                          columns = [i for i in final_model.classes_])
        sns.heatmap(df_cm, annot=True)
        plt.show()
        

    res_dict['target'].append(labdp)
    
    ba_tr = balanced_accuracy_score(y, y_predstr)
    ba_loo = balanced_accuracy_score(ytrues, ypreds)
    ba_test = balanced_accuracy_score(ytst, y_predsts)
    
    #res_dict['featsele'].append(sele)
    res_dict['algo'].append(algof)
    res_dict['condition'].append(c)
    res_dict['BAtrain'].append(ba_tr)
    res_dict['BALOO'].append(ba_loo)
    res_dict['BAtest'].append(ba_test)
    res_dict['feats'].append(featcolsselected)
    res_dict['num_feats'].append(len(featcolsselected))
    #res_dict['preferred_set'].append(preferred_set[0]+'+'+preferred_set[1])
    
    if conf_df:
        return res_dict, coeffs, conf_df
    else:
        return res_dict, coeffs
    
    
def main_comp_pred(dpxx, ext_cat, featcolsselected, c, c_std, coeffs, dG, res_dict, axes, cnt, eq=True, conf_df = None, algof = 'Ridge', labdp= 'Yield'):
    data_train, data_test = dpxx[~dpxx.ID.isin(ext_cat)], dpxx[dpxx.ID.isin(ext_cat)]
    #-------------------
    data_test = data_test[~data_test.ID.isin(['comp_5','comp_6'])]
    #------------------
    X = data_train.loc[:, featcolsselected]
    y = data_train[c].values
    y_std = data_train[c_std]
    y_index = data_train.ID.values
    Xtst = data_test.loc[:, featcolsselected]
    ytst = data_test[c].values
    ytst_std = data_test[c_std]
    
    if conf_df:
        X_aniso = conf_df['dpx_aniso'].loc[:, featcolsselected]
        X_vbur = conf_df['dpx_vbur'].loc[:, featcolsselected]
        X_anisovbur = conf_df['dpx_anisovbur'].loc[:, featcolsselected]
        y_aniso = conf_df['dpx_aniso'][c]
        y_vbur = conf_df['dpx_vbur'][c]
        y_anisovbur = conf_df['dpx_anisovbur'][c]


    if dG:
        y, ytst, maxdg, mindg=norm_y(y, ytst)

    # Initialize LOOCV
    loo = LeaveOneOut()

    if algof == 'Ridge':
        final_model, ypreds, ytrues, yloo_std, y_index_loo, best_alpha = ridge_cv(X,y, y_std, y_index)
    else:
        final_model, ypreds, ytrues, yloo_std, y_index_loo, best_component = pls_cv(X,y, y_std, y_index)

    coeffs[c] = list(final_model.coef_.flatten()) + [final_model.intercept_]

    y_predsts = final_model.predict(Xtst)
    y_predstr = final_model.predict(X)
    
    if conf_df:
        y_preds_aniso = final_model.predict(X_aniso)
        y_preds_vbur = final_model.predict(X_vbur)
        y_preds_anisovbur = final_model.predict(X_anisovbur)

    if dG:
        ypreds, ytrues, y_predsts, y_predstr, y, ytst=denorm_y(ypreds, 
                                                               ytrues, 
                                                               y_predsts, 
                                                               y_predstr, 
                                                               y, 
                                                               ytst, 
                                                               maxdg, mindg)
        if conf_df:
            y_preds_aniso = np.array([denormalize_yield(i, maxdg=maxdg, mindg=mindg) for i in y_preds_aniso])
            y_preds_vbur = np.array([denormalize_yield(i, maxdg=maxdg, mindg=mindg) for i in y_preds_vbur])
            y_preds_anisovbur = np.array([denormalize_yield(i, maxdg=maxdg, mindg=mindg) for i in y_preds_anisovbur])


    ypreds = np.clip(ypreds, 0, 100)
    y_predsts = np.clip(y_predsts, 0, 100)
    y_predstr = np.clip(y_predstr, 0, 100)
    if conf_df:
        y_preds_aniso = np.clip(y_preds_aniso, 0, 100)
        y_preds_vbur = np.clip(y_preds_vbur, 0, 100)
        y_preds_anisovbur = np.clip(y_preds_anisovbur, 0, 100)

        conf_df['dpx_aniso']['pred']=y_preds_aniso
        conf_df['dpx_vbur']['pred']=y_preds_vbur
        conf_df['dpx_anisovbur']['pred']=y_preds_anisovbur
        
    if conf_df:
        pass
    else:
        ## Calculate residuals
        residuals = ytrues - ypreds
        residuals = np.array(list(residuals))
        ## Generate Q-Q plot
        stats.probplot(residuals.flatten(), dist="norm", plot=axes.flatten()[cnt+8])
        axes.flatten()[cnt+8].set_title("Q-Q Plot of Residuals for Ridge Regression")
    
    cvdf=pd.DataFrame({'ID':np.array(y_index_loo).flatten(), 
                         'True':ytrues,
                         'Pred':ypreds})
    testdf=pd.DataFrame({'True':ytst,
                         'Pred':y_predsts})
    
    traindf=pd.DataFrame({'True':y,
                         'Pred':y_predstr})
    #testdf=pd.concat([data_test.reset_index(), testdf.reset_index()], axis=1)

    return cvdf, testdf, traindf
    