import pandas as pd
import numpy as np
#import ppscore as pps
import glob
import os
from functools import reduce
import pickle
from itertools import combinations
from sklearn.metrics import r2_score
from scipy.stats import pearsonr
import seaborn as sns
import matplotlib.pyplot as plt

id_to_id = {'IPr-Pd-DMS':'comp_1',
            'SIPr-Pd-DMS':'comp_2',
            'IMes-Pd-DMS':'comp_3',
            'SIMes-Pd-DMS':'comp_4',
            'IDD-Pd-DMS':'comp_5',
            'SIDD-Pd-DMS':'comp_6',
            'IPr*-Pd-DMS':'comp_7',
            'IPr*OMe-Pd-DMS':'comp_8',
            'IPr#-Pd-DMS':'comp_9',
            'BIAN-IMes-Pd-DMS':'comp_10',
            'BIAN-IPr-Pd-DMS':'comp_11',
            'BIAN-IPr#-Pd-DMS':'comp_12',
            'IPrCl-Pd-DMS':'comp_13',
            'IPentCl-Pd-DMS':'comp_14',
            'INonCl-Pd-DMS':'comp_15',
            'IPriPr-Pd-DMS':'comp_16',
            'IPaul-Pd-DMS':'comp_17',
            'IHeptCl-Pd-DMS':'comp_18',
            'IPr*Tol-Pd-DMS':'comp_19',
            'IPent-Pd-DMS':'comp_20',
            'IHept-Pd-DMS':'comp_21'}

def scatterplot_with_corr(df, x, y, h = None, ax = None):
    data_df = df.copy()
    data_df[x] = pd.to_numeric(data_df[x], errors='coerce')
    data_df[y] = pd.to_numeric(data_df[y], errors='coerce')
    data_df = data_df.dropna(subset=[x, y])
    
    if ax:
        sns.scatterplot(data=data_df, x = x, y = y, hue = h, ax = ax, legend = False)

        minval = min([min(df[x].values), min(df[y].values)])
        ax.axline((minval, minval), slope=1, color='black')

        r2 = r2_score(data_df[x], data_df[y])
        pea = pearsonr(data_df[x], data_df[y]).statistic
        textstrtmp = f'R2={r2:.2f}\nPearson={pea:.2f}'

        #ax = plt.gca()  # Get current axis
        ax.text(0.75, 0.08, textstrtmp, transform=ax.transAxes, 
                 fontsize=10, 
                 verticalalignment='top', 
                 bbox=dict(boxstyle='round,pad=0.1', alpha=0.7, color='lightgreen'))
        
        ax.set_xlim([min([min(data_df[x].values),min(data_df[y].values)]),
                     max([max(data_df[x].values),max(data_df[y].values)])])
        ax.set_ylim([min([min(data_df[x].values),min(data_df[y].values)]),
                     max([max(data_df[x].values),max(data_df[y].values)])])

    else:
    
        plt.figure(figsize=(6,6))
        sns.scatterplot(data=data_df, x = x, y = y, hue = h)

        minval = min([min(df[x].values), min(df[y].values)])
        plt.axline((minval, minval), slope=1, color='black')

        r2 = r2_score(data_df[x], data_df[y])
        pea = pearsonr(data_df[x], data_df[y]).statistic
        textstrtmp = f'R2={r2:.2f}\nPearson={pea:.2f}'

        ax = plt.gca()  # Get current axis
        plt.text(0.75, 0.08, textstrtmp, transform=ax.transAxes, 
                 fontsize=10, 
                 verticalalignment='top', 
                 bbox=dict(boxstyle='round,pad=0.1', alpha=0.7, color='lightgreen'))


        plt.show()

def read_descriptor_files(path, verbose = 1, merge_symmetrical=True):
    subset_path = ['cosmodescriptors','electronic_descriptors','geometric','solvents']
    csv_files = [f'{path}{i}.csv' for i in subset_path]
    dfs = []
    dict_cols = {}
    for csv_file in csv_files:
        lab = csv_file.replace(path, '').replace('.csv','')
        tmp_df = pd.read_csv(csv_file, index_col=False)
        tmp_df[' Complex']=[i.replace('ortho','cis').replace('para','trans') for i in tmp_df[' Complex']]
        if len(tmp_df)!=766:
            tmp_df = tmp_df.pivot(index=['Catalyst',' Complex',' Conformer'], 
                                  columns=list(tmp_df.columns)[3], 
                                  values=list(tmp_df.columns)[4:]).reset_index()
            tmp_df.columns = ['Catalyst',' Complex',' Conformer']+[f'{i[0]}{i[1]}' for i in list(tmp_df.columns)[3:]]
        
        tmp_df[tmp_df.columns[3:]] = tmp_df.iloc[:, 3:].apply(pd.to_numeric, errors='coerce')    
        
        if merge_symmetrical:
            if lab == 'electronic_descriptors':
                pairs_symmetric = [(1,3),(4,5),(7,8),(12,13)]
                suff_ele = [' NBO(x) BP86_TZVP_COSMO',
                             ' NBO(x) BP86_TZVP_GAS',
                             ' NBO(x) TPSS_SVP_GAS',
                             ' NBO(x) TPSS_TZVP_COSMO',
                             ' NBO(x) TPSS_TZVP_GAS',
                             ' IBO(x) BP86_TZVP_COSMO',
                             ' IBO(x) BP86_TZVP_GAS',
                             ' IBO(x) TPSS_SVP_GAS',
                             ' IBO(x) TPSS_TZVP_COSMO',
                             ' IBO(x) TPSS_TZVP_GAS',
                            ' shield(x) TPSS_TZVP_COSMO']

                for i,j in pairs_symmetric:
                    for suff in suff_ele:

                        suffi = suff.replace('x',str(i))
                        suffj = suff.replace('x',str(j))
                        print(suffi)
                        print(suffj)
                        tmp_df[suff.replace('x',f'{i}&{j}')] = tmp_df.loc[:,[suffi, suffj]].mean(axis=1).values
                        tmp_df[suff.replace('x',f'MIN{i}&{j}')] = tmp_df.loc[:,[suffi, suffj]].min(axis=1).values
                        tmp_df[suff.replace('x',f'MAX{i}&{j}')] = tmp_df.loc[:,[suffi, suffj]].max(axis=1).values
                        #tmp_df[suff.replace('x',f'{i}-{j}')] = tmp_df.loc[:,[suffi, suffj]].diff(axis=1).abs().iloc[:,-1].values
                        tmp_df = tmp_df.drop([suffi, suffj], axis = 1)
                        tmp_df = tmp_df.copy()


            if lab == 'geometric':
                torscs = [i for i in tmp_df.columns if 'tors' in i]
                for tors_c in torscs:
                    tmp_df[tors_c] = tmp_df[tors_c].abs()
                #tmp_df[' tors2-1-9-22'] = abs(tmp_df[' tors2-1-9-22'])-180
                #tmp_df[' tors2-1-9-23'] = abs(tmp_df[' tors2-1-9-23'])-180
                #tmp_df[' tors2-3-10-30'] = abs(tmp_df[' tors2-3-10-30'])-180
                #tmp_df[' tors2-3-10-31'] = abs(tmp_df[' tors2-3-10-31'])-180

                pairs_geom = [(' dist1-2',' dist2-3'),
                              (' dist3-4',' dist5-1'),
                              (' dist6-7',' dist6-8'),
                              (' dist3-10',' dist1-9'),#
                              (' tors2-1-9-22',' tors2-1-9-23'),
                              (' tors2-3-10-30',' tors2-3-10-31'),
                              ( ' bend2-6-7',' bend2-6-8'),
                              (' bend7-6-11',' bend8-6-11')]

                for i,j in pairs_geom:
                    tmp_df[f'{i}&{j}'] = tmp_df.loc[:,[i, j]].mean(axis=1).values
                    tmp_df[f'MIN{i}&{j}'] = tmp_df.loc[:,[i, j]].min(axis=1).values
                    tmp_df[f'MAX{i}&{j}'] = tmp_df.loc[:,[i, j]].max(axis=1).values
                    #tmp_df[f'{i}-{j}'] = tmp_df.loc[:,[i, j]].diff(axis=1).abs().iloc[:,-1].values
                    tmp_df = tmp_df.drop([i, j], axis = 1)
                    tmp_df = tmp_df.copy()
            
        dict_cols[csv_file.replace(path,'').replace('.csv','')] = list(tmp_df.columns)[3:]
        dfs.append(tmp_df)
    
    if verbose>0:
        for df_i, name_i in zip(dfs,subset_path):
            print(f'{name_i} dimension: {df_i.iloc[:,3:].values.shape}')
    
    merged_df_ele = reduce(lambda left, right: pd.merge(left, 
                                                        right, 
                                                        on=['Catalyst',' Complex',' Conformer'], 
                                                        how='outer'), dfs)
    set_col = set(list(merged_df_ele.columns))
    merged_df_ele = merged_df_ele.loc[:, (merged_df_ele != merged_df_ele.iloc[0]).any()]
    constant_col = set_col-set(list(merged_df_ele.columns))
    
    if verbose>0:
        print(f'\nFollowing {len(constant_col)} columns constant and thus pruned')
        if verbose>1:
              print(f'{constant_col}')

    merged_df_ele = merged_df_ele.replace('  ', '', inplace=False)
    merged_df_ele = merged_df_ele.replace(' ', '', regex = True, inplace=False)
    merged_df_ele.replace(r'', np.nan, regex=False, inplace=True)
    merged_df_ele.iloc[:, 3:] = merged_df_ele.iloc[:, 3:].apply(pd.to_numeric, errors='coerce')

    set_col = set(list(merged_df_ele.columns))
    merged_df_ele = merged_df_ele.dropna(axis=1, how='all')
    allnan_col = set_col-set(list(merged_df_ele.columns))
    
    if verbose>0:
        print(f'\nFollowing {len(allnan_col)} columns all nan and thus pruned')
        if verbose>1:
              print(f'{allnan_col}')
        print(f'\nFinal dataset of length: {len(merged_df_ele)}')
        
    for k,v in dict_cols.items():
        dict_cols[k]=[i for i in dict_cols[k] if i in list(merged_df_ele.columns)]

    return merged_df_ele, dict_cols

def read_steric_descriptors(cols_dict):
    #TO DO: CREATE FUNCTIONS (check paths_dfs = should be 9 keys!)
    paths_dfs = {}
    paths_dfs_tmp = pd.read_pickle("../2_descriptors/descriptors/DBSTEP/DBSTEP_results/dbstep_descr_ligand_v0.pkl")
    for k,v in paths_dfs_tmp.items():
        print(k)
        paths_dfs[k]=v
        
    paths_dfs_tmp = pd.read_pickle("../2_descriptors/descriptors/DBSTEP/DBSTEP_results/dbstep_pdcl2_ortho_v0.pkl")
    for k,v in paths_dfs_tmp.items():
        print(k)
        paths_dfs[k]=v
        
    paths_dfs_tmp = pd.read_pickle("../2_descriptors/descriptors/DBSTEP/DBSTEP_results/dbstep_pdcl2_para_v0.pkl")
    for k,v in paths_dfs_tmp.items():
        print(k)
        paths_dfs[k]=v

    paths_dfs_tmp = pd.read_pickle("../2_descriptors/descriptors/DBSTEP/DBSTEP_results/dbstep_sme2_ortho_v0.pkl")
    for k,v in paths_dfs_tmp.items():
        print(k)
        paths_dfs[k]=v
        
    paths_dfs_tmp = pd.read_pickle("../2_descriptors/descriptors/DBSTEP/DBSTEP_results/dbstep_sme2_para_v0.pkl")
    for k,v in paths_dfs_tmp.items():
        print(k)
        paths_dfs[k]=v
    
    paths_dfs_tmp = pd.read_pickle("../2_descriptors/descriptors/DBSTEP/DBSTEP_results/dbstep_descr_ligand_v0_20_21.pkl")
    for k,v in paths_dfs_tmp.items():
        print(k)
        paths_dfs[k]=pd.concat([paths_dfs[k],v])
        
    ##########    
    paths_dfs_tmp = pd.read_pickle("../2_descriptors/descriptors/DBSTEP/DBSTEP_results/dbstep_ligand_from_trans_v0_20_21.pkl")
    for k,v in paths_dfs_tmp.items():
        print(k)
        v.index = [i.replace('trans_pd',k) for i in v.index]
        paths_dfs[k]=v
    ###########    
        

    paths_dfs_tmp = pd.read_pickle("../2_descriptors/descriptors/DBSTEP/DBSTEP_results/dbstep_pdcl2_ortho_v0_20_21.pkl")
    for k,v in paths_dfs_tmp.items():
        print(k)
        paths_dfs[k]=pd.concat([paths_dfs[k],v])

    paths_dfs_tmp = pd.read_pickle("../2_descriptors/descriptors/DBSTEP/DBSTEP_results/dbstep_pdcl2_para_v0_20_21.pkl")
    for k,v in paths_dfs_tmp.items():
        print(k)
        paths_dfs[k]=pd.concat([paths_dfs[k],v])

    paths_dfs_tmp = pd.read_pickle("../2_descriptors/descriptors/DBSTEP/DBSTEP_results/dbstep_sme2_ortho_v0_20_21.pkl")
    for k,v in paths_dfs_tmp.items():
        print(k)
        paths_dfs[k]=pd.concat([paths_dfs[k],v])

    paths_dfs_tmp = pd.read_pickle("../2_descriptors/descriptors/DBSTEP/DBSTEP_results/dbstep_sme2_para_v0_20_21.pkl")
    for k,v in paths_dfs_tmp.items():
        print(k)
        paths_dfs[k]=pd.concat([paths_dfs[k],v])
    

    for k in paths_dfs.keys():
        paths_dfs[k]=paths_dfs[k].groupby(paths_dfs[k].columns, axis=1).mean()
    newdfs = []
    for k,v in paths_dfs.items():
        fold = k
        dftmp = v.copy()
        newdfs.append(dftmp)
    final_df = pd.concat(newdfs, axis=0)

    final_df = final_df.reset_index()
    final_df.insert(0,' Conformer',[i.split('/')[2] for i in final_df['index'].values])
    final_df.insert(0,' Complex',[i.split('/')[1] for i in final_df['index'].values])
    final_df[' Complex']=[i.replace('ortho','cis').replace('para','trans') for i in final_df[' Complex']]###################
    final_df.insert(0,'Catalyst',[i.split('/')[0] for i in final_df['index'].values])
    final_df = final_df.drop('index',axis=1)
    final_df = final_df.loc[:, (final_df != final_df.iloc[0]).any()] 
    cols_dict['steric'] = list(final_df.columns)[3:]
    return final_df, cols_dict

def calc_variance_score(df, descr_col):
    df_stdconf_over_stdall =df.groupby(['Catalyst', ' Complex'], as_index=False).agg('std').loc[:,descr_col]/df.loc[:,descr_col].std()
    df_stdconf_over_stdall =pd.concat([df.groupby(['Catalyst', ' Complex'], as_index=False).agg('std').iloc[:,:2],
                                       df_stdconf_over_stdall], axis=1)
    df_stdconf_over_stdall = df_stdconf_over_stdall.pivot(index=['Catalyst'], 
                                 columns=' Complex').reset_index()
    df_stdconf_over_stdall.columns = ['Catalyst']+[f'{i[0]}_{i[1]}' for i in list(df_stdconf_over_stdall.columns)[1:]]
    df_stdconf_over_stdall = df_stdconf_over_stdall.dropna(axis=1, how='all')
    pup_df = df_stdconf_over_stdall.iloc[:,1:].max().reset_index()
    pup_df.columns = ['descriptors', 'max_stdconf_on_stdall']
    return pup_df

def add_G_values(merged_df_ele):
    conv_kcal2hartree=0.0015936010974213599
    conv_hartree2kcal=627.509608030593
    
    for solvent_type in ['2methf','ethanol','isopropylacetate','toluene']:
        col_energy = f' Energy (hartree) TPSS_TZVP_GAS'
        col_chempot = f' Chempot (freerotor hartree) TPSS_SVP_GAS'
        col_gsolv = f' Gsolv (kcal) {solvent_type}'
        #RelativeG = Energy(TPSS_TZVP_GAS) + Chem. pot. (freerotor)(TPSS_SVP_GAS) + Gsolv(solvent)* 0.0015936010974213599
        relG = merged_df_ele[col_energy]+merged_df_ele[col_chempot]+merged_df_ele[col_gsolv]*conv_kcal2hartree
        merged_df_ele[f'G (kcal)_{solvent_type}']=relG*conv_hartree2kcal
    return merged_df_ele


# Function to subtract minimum within each group
def subtract_min(group):
    return (group - group.min())

def add_diffG_values(merged_df_ele):
    # Columns to adjust
    columns_to_adjust = [f'G (kcal)_{solvent_type}' for solvent_type in ['2methf','ethanol','isopropylacetate','toluene']]
    # Subtract minimum within each group
    merged_df_ele[[i.replace('G','diff_G') for i in columns_to_adjust]] = merged_df_ele.groupby(['Catalyst', ' Complex'])[columns_to_adjust].transform(subtract_min)
    return merged_df_ele

def prune_conformers_thrG(merged_df_ele, thrG=4):
    # Check if any value in the last 4 columns exceeds the threshold
    mask = (merged_df_ele.iloc[:, -4:] > thrG).any(axis=1)
    # Prune rows based on the mask
    pruned_df = merged_df_ele[~mask]
    
        # trans_pd, cis_pd, trans_pdcl2_s, cis_pdcl2_s should have same conformers as  trans_pdcl2, cis_pdcl2, trans_pdcl2_sme2, cis_pdcl2_sme2
    row_to_drop=[]
    for pairs in [('trans_pd','trans_pdcl2'),
                  ('cis_pd','cis_pdcl2'),
                  ('trans_pdcl2_s','trans_pdcl2_sme2'),
                  ('cis_pdcl2_s','cis_pdcl2_sme2'),
                  ('trans_pd_sme2', 'trans_pdcl2_sme2'),
                  ('cis_pd_sme2', 'cis_pdcl2_sme2')]:
        for comp in pruned_df['Catalyst'].unique():
            df_tmp = pruned_df[pruned_df['Catalyst']==comp]
            conf_in_pair2 = df_tmp[df_tmp[' Complex']==pairs[1]][' Conformer'].to_list()
            df_tmp = pruned_df[(pruned_df['Catalyst']==comp)&(df_tmp[' Complex']==pairs[0])]
            df_tmp = df_tmp[~df_tmp[' Conformer'].isin(conf_in_pair2)]
            row_to_drop += df_tmp.index.to_list()
            #df_tmp[(df_tmp[' Complex']==pairs[0])|(df_tmp[' Complex']==pairs[1])]
    pruned_df = pruned_df.drop(row_to_drop, axis = 0)
    
    return pruned_df.reset_index(drop=True)

def groupby_conf(pruned_df, colname_except_G, agg_type='mean'):
    # Perform the groupby operation and calculate the average TODO= Boltzmann average without 3kcal window pruning!
    agg_funcs = {col: agg_type for col in colname_except_G}  # For columns except first three and the last 8
    agg_funcs.update({col: 'min' for col in pruned_df.columns[-8:-4]})  # For the last 8 columns
    grouped_df = pruned_df.groupby(['Catalyst',' Complex']).agg(agg_funcs).reset_index()
    return grouped_df

def pivot_complexes(grouped_df):
    grouped_df_pivoted = grouped_df.pivot(index=['Catalyst'], 
                          columns=' Complex', 
                          values=list(grouped_df.columns)[2:]).reset_index()
    grouped_df_pivoted.columns = ['Catalyst']+[f'{i[0]}_{i[1]}' for i in list(grouped_df_pivoted.columns)[1:]]
    return grouped_df_pivoted


from itertools import combinations
def at_least_3_agree(values, thr = 20):
    # Generate all combinations of three out of four values
    if len(values)>=3:
        combos = combinations(values, 3)
        # Check if in any combination, the maximum minus the minimum is less than 10
        return any(max(combo) - min(combo) <= thr for combo in combos)
    else:
        return True
    
def who_agree(values,plates,thr = 20):
    # Generate all combinations of three out of four values
    if len(values)>=3:
        combos = combinations(values, 3)
        if all(max(combo) - min(combo) <= thr for combo in combos):
            return np.array(plates)
        
        elif any(max(combo) - min(combo) <= thr for combo in combos):
            combos = combinations(zip(values,plates), 3)
            for combo in combos:
                val = [i[0] for i in combo] 
                pl = [i[1] for i in combo]
                if max(val) - min(val) <= thr:
                    return np.array(pl)
        else:
            return np.array([])
    else:
        return plates
    
def no_more_than_one_discrep(values, thr = 20):
    # Generate all combinations of three out of four values
    if len(values)>=3:
        combos = combinations(values, 2)
        # Check if in any combination, the maximum minus the minimum is less than 10
        return [max(combo) - min(combo) >= thr for combo in combos]
    else:
        return []
    
def reject_outliers(data: np.ndarray, 
                    m: float=1.9, # bigger values -> bigger threshold and vice versa
                    max_threshold: float=5.0,
                    ) -> np.ndarray:

    if 0.0 in data and data.max() > max_threshold:
        data = np.delete(data, np.where(data == 0.0))
        
    d = np.abs(data - np.median(data))
    mdev = np.median(d)
    
    ### Test either one
    s = d/mdev if mdev else 0.
    # s = d / (mdev if mdev else 1.)
    
    return data[s<m]
    
def read_experimental_data_DMS_new(analysis = 'GC', plating='Method2', std_thr = None):
    df = pd.read_excel('../data/HTE_data.xlsx','final')
    df = df[(df['Plating']==plating)&(df['Analysis']==analysis)]
    df = df[~((df['Base']=='K2CO3')&(df['Plate']==883))]
    df = df[~((df['Base']=='K2CO3')&(df['Plate']==884))]
    df = df[~df.Catalyst.str.contains('PEPPSI')]
    df = df.dropna()
    df['Plate']=[str(i) for i in df['Plate']]
    
    dfs = []
    ct = 0
    for cnt, group in df.groupby(['Catalyst',
                       'Solvent',
                       'Base']):
        
        data = reject_outliers(group.Yield.values.flatten())

        dfs.append(pd.DataFrame({ct:{'Catalyst':cnt[0],
                  'Solvent':cnt[1],
                  'Base':cnt[2],
                  'Yield_lst':list(data.flatten()), 
                  'Yield':np.median(data), 
                  'Yield_std':data.std()}}).T)
        ct+=1
    data = pd.concat(dfs)
    data['ID']=[id_to_id[i] for i in data['Catalyst']]
    
    if std_thr:
        data = data[data.Yield_std<=std_thr]
    
    data_pivot = data.pivot(index=['ID'], columns=['Solvent','Base'], values = ['Yield', 'Yield_std'])
    data_pivot.columns = [f'{i}_{x}_{y}' for i,x,y in data_pivot.columns]
    data_pivot = data_pivot.drop(['Yield_Me-THF_K2CO3','Yield_std_Me-THF_K2CO3',
                                  'Yield_Toluene_K2CO3','Yield_std_Toluene_K2CO3',
                                  'Yield_EtOH_Cs2CO3','Yield_std_EtOH_Cs2CO3',
                                  'Yield_Toluene_K3PO4','Yield_std_Toluene_K3PO4',
                                  'Yield_iPrOAc_K2CO3','Yield_std_iPrOAc_K2CO3'], axis=1)
    return data_pivot, data
    
def add_col_aggreg(df, pair1, pair2, aggtype='max', list_col=None):
    tmp_df = df.loc[:,[i for i in df.columns if ((pair1 in i) or (pair2 in i)) and ('MAX' not in i)]]
    col_to_drop = tmp_df.columns
    tmp_df.columns = [i.replace(pair1,'').replace(pair2,'') for i in tmp_df.columns]
    
    if aggtype=='max':
        tmp_df = tmp_df.groupby(tmp_df.columns, axis=1).max()
    elif aggtype=='min':
        tmp_df = tmp_df.groupby(tmp_df.columns, axis=1).min()
    tmp_df.columns = [f'{aggtype.upper()}_{pair1.split("_")[0]}_{"_".join(pair1.split("_")[1:])}_{"_".join(pair2.split("_")[1:])}{i}' for i in tmp_df.columns]
    
    for c in tmp_df.columns:
        df[c] = tmp_df[c].values
        df = df.copy()
        if list_col:
            list_col = [i for i in list_col if i not in col_to_drop]
            list_col = list_col+[c]
    df = df.drop(col_to_drop, axis=1)
    return df, list_col

def find_label_rank(column, labels):
    """Find the highest-ranked label in column name, or return None if no label is found."""
    for rank, label in enumerate(labels):
        if label in column:
            return rank
    return None

def prune_corr(df, ordered_labels=[], dict_conformer_score = {}, threshold=0.9, corr_matrix = None):
    # Calculate correlation matrix and variance
    if corr_matrix is None:
        tmp_df = df.copy()
        corr_matrix = tmp_df.corr().abs().dropna(how='all', axis=0).dropna(how='all', axis=1)
        tmp_df = df.loc[:,list(corr_matrix.columns)]
        corr_matrix = corr_matrix.to_numpy()
    else:
        tmp_df = df.copy()
    columns_variance = tmp_df.var()
    pairs_to_drop = set()
    
    dict_correlated={}
    for i in range(len(tmp_df.columns)):
        if i % 100==0:
            print(i)
        dict_correlated[list(tmp_df.columns)[i]] = []
        for j in range(i+1, len(tmp_df.columns)):
            if corr_matrix[i, j] >= threshold:
                dict_correlated[list(tmp_df.columns)[i]].append(list(tmp_df.columns)[j])
                col_i_rank = find_label_rank(tmp_df.columns[i], ordered_labels)
                col_j_rank = find_label_rank(tmp_df.columns[j], ordered_labels)

                # Decide which column to drop based on label rank and variance
                if col_i_rank is not None and (col_j_rank is None or col_i_rank < col_j_rank):
                    col_to_drop = tmp_df.columns[j]
                elif col_j_rank is not None and (col_i_rank is None or col_j_rank < col_i_rank):
                    col_to_drop = tmp_df.columns[i]
                    #keep column with lowest conformer score
                elif (tmp_df.columns[i] in dict_conformer_score.keys()) and (tmp_df.columns[j] in dict_conformer_score.keys()):
                    col_to_drop = tmp_df.columns[i] if dict_conformer_score[tmp_df.columns[i]] > dict_conformer_score[tmp_df.columns[j]] else tmp_df.columns[j]
                else:  # Neither column has a preferred label, or they have the same rank
                    # Keep the column with the highest variance
                    col_to_drop = tmp_df.columns[i] if columns_variance[tmp_df.columns[i]] < columns_variance[tmp_df.columns[j]] else tmp_df.columns[j]

                pairs_to_drop.add(col_to_drop)

    # Drop identified columns
    df_pruned = tmp_df.drop(columns=list(pairs_to_drop))
    df_pruned.columns = [f'{i} ({len(dict_correlated[i])})' for i in df_pruned.columns]
    #df_pruned.insert(0,'index',merged_df['index'].values)
    return df_pruned, dict_correlated


def add_col_delta(df, name_suffix_col, pair_complexes_names, list_col = None):
    if (name_suffix_col+pair_complexes_names[0] in df.columns) and (name_suffix_col+pair_complexes_names[1] in df.columns):
        tmp_col = df[name_suffix_col+pair_complexes_names[0]]-df[name_suffix_col+pair_complexes_names[1]]
        df[f'{name_suffix_col}delta_{pair_complexes_names[0]}-{pair_complexes_names[1]}'] = tmp_col
        list_col.append(f'{name_suffix_col}delta_{pair_complexes_names[0]}-{pair_complexes_names[1]}')
        
        #add also minimum between trans and cis
        if pair_complexes_names[0]=='ligand':
            tmp_col = df.loc[:,[name_suffix_col+pair_complexes_names[1].replace('cis','trans'),
                                name_suffix_col+pair_complexes_names[1].replace('trans','cis')]].min(axis=1)-df[name_suffix_col+pair_complexes_names[0]]
            
        elif pair_complexes_names[1]=='ligand':
            tmp_col = df.loc[:,[name_suffix_col+pair_complexes_names[0].replace('cis','trans'),
                                name_suffix_col+pair_complexes_names[0].replace('trans','cis')]].min(axis=1)-df[name_suffix_col+pair_complexes_names[1]]
        else:
            tmp_col = df.loc[:,[name_suffix_col+pair_complexes_names[0].replace('cis','trans'),
                                name_suffix_col+pair_complexes_names[0].replace('trans','cis')]].min(axis=1) -df.loc[:,[name_suffix_col+pair_complexes_names[1].replace('cis','trans'),
                                name_suffix_col+pair_complexes_names[1].replace('trans','cis')]].min(axis=1)
        
        df[f"{name_suffix_col}delta_{pair_complexes_names[0].replace('cis','min').replace('trans','min')}-{pair_complexes_names[1].replace('cis','min').replace('trans','min')}"] = tmp_col
        list_col.append(f"{name_suffix_col}delta_{pair_complexes_names[0].replace('cis','min').replace('trans','min')}-{pair_complexes_names[1].replace('cis','min').replace('trans','min')}")
        df = df.copy()
            
        return df, list_col
    else:
        print(f'{name_suffix_col+pair_complexes_names[0]} or {name_suffix_col+pair_complexes_names[1]} not in columns')
        #for atom 6 and ligand no column!
        return df, list_col

def no_more_than_one_discrep(values):
    # Generate all combinations of three out of four values
    combos = combinations(values, 3)
    # Check if in any combination, the maximum minus the minimum is less than 10
    return not any(max(combo) - min(combo) < 15 for combo in combos)

def read_dehalo_data(thr=5):
    dehalo_df = pd.read_excel('../data/dehalogenation.xlsx')
    dehalo_df['ID']=[id_to_id[i] for i in dehalo_df.Catalyst]
    dehalo_df['Yield']=(dehalo_df['Yield']>thr)*1
    dehalo_df = dehalo_df.groupby(['ID','Solvent','Base'])['Yield'].agg(set).reset_index()
    dehalo_df = dehalo_df[dehalo_df.Yield.apply(len)==1]
    dehalo_df['Yield']=[list(i)[0] for i in dehalo_df['Yield']]
    dehalo_df = dehalo_df.pivot(index='ID', columns=['Solvent','Base'], values='Yield')
    dehalo_df.columns = [f'Yield_{x}_{y}' for x,y in dehalo_df.columns]
    dehalo_df = dehalo_df.fillna(0)
    return dehalo_df