import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.stats import pearsonr, spearmanr

from sklearn.preprocessing import MinMaxScaler, StandardScaler, RobustScaler
from scipy.cluster.hierarchy import linkage, fcluster
from scipy.spatial.distance import pdist

from numpy import mean, std
import numpy as np
from sklearn.model_selection import LeaveOneOut
from sklearn.cross_decomposition import PLSRegression
from sklearn.linear_model import LinearRegression, Ridge, Lasso, LassoCV
from sklearn.ensemble import RandomForestRegressor
from sklearn.metrics import r2_score

import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import r2_score, mean_absolute_error

from itertools import chain, combinations

def powerset(iterable):
    "powerset([1,2,3]) --> () (1,) (2,) (3,) (1,2) (1,3) (2,3) (1,2,3)"
    s = list(iterable)  # allows duplicate elements
    return chain.from_iterable(combinations(s, r+1) for r in range(2))

import sys
sys.path.append("/mnt/") # go to parent dir

from utils_data import *
from utils_sele import *

###############################################################

from multiprocessing import Pool, cpu_count

def process_feature(args):
    featcolsselected, c, c_std, dp, ext_cat, dG, labdp, labscal = args
    dpxx = dp.loc[:, ['ID'] + [c] + [c_std] + featcolsselected]
    dpxx = dpxx.dropna(subset=c)
    data_train = dpxx[~dpxx.ID.isin(ext_cat)]
    data_test = dpxx[dpxx.ID.isin(ext_cat)]

    X = data_train[featcolsselected].values
    y = data_train[c].values
    y_std = data_train[c_std].values
    Xtst = data_test[featcolsselected].values
    ytst = data_test[c].values
    ytst_std = data_test[c_std].values

    if dG:
        y = np.array([normalize_yield(i) for i in y])
        maxdg = max(y)
        mindg = min(y)
        y = ((y-mindg)/(maxdg-mindg))*100

        ytst = np.array([normalize_yield(i) for i in ytst])
        ytst = ((ytst-mindg)/(maxdg-mindg))*100

    model = LinearRegression().fit(X, y)
    y_predsts = model.predict(Xtst)
    y_predstr = model.predict(X)
    
    loo=LeaveOneOut()
    ypreds = []
    ytrues = []
    
    for train_index, test_index in loo.split(X):
        X_train, X_test = X[train_index], X[test_index, :]
        y_train, y_test = y[train_index], y[test_index]

        mod = LinearRegression().fit(X_train, y_train)
        y_pred = mod.predict(X_test)

        ypreds.append(y_pred)
        ytrues.append(y_test)

    if dG:
        ypreds = np.array([denormalize_yield(i, maxdg=maxdg, mindg=mindg) for i in ypreds])
        ytrues = np.array([denormalize_yield(i, maxdg=maxdg, mindg=mindg) for i in ytrues])
        y_predsts = np.array([denormalize_yield(i, maxdg=maxdg, mindg=mindg) for i in y_predsts])
        y_predstr = np.array([denormalize_yield(i, maxdg=maxdg, mindg=mindg) for i in y_predstr])
        y = np.array([denormalize_yield(i, maxdg=maxdg, mindg=mindg) for i in y])
        ytst = np.array([denormalize_yield(i, maxdg=maxdg, mindg=mindg) for i in ytst])

    y_predsts = np.clip(y_predsts, 0, 100)
    y_predstr = np.clip(y_predstr, 0, 100)
    ypreds = np.clip(ypreds, 0, 100)

    r2_train = r2_score(y, y_predstr)
    mae_train = mean_absolute_error(y, y_predstr)
    
    r2_loo = r2_score(ytrues, ypreds)
    mae_loo = mean_absolute_error(ytrues, ypreds)
    
    r2_test = r2_score(ytst, y_predsts) if len(ytst) > 2 else np.nan
    mae_test = mean_absolute_error(ytst, y_predsts)

    return {
        'target': labdp,
        'scaling': labscal,
        'dG': dG,
        'condition': c,
        'feats': featcolsselected,
        'num_feats': len(featcolsselected),
        'R2train': r2_train,
        'R2LOO': r2_loo,
        'R2test': r2_test,
        'MAEtrain': mae_train,
        'MAEtest': mae_test,
        'MAELOO': mae_loo
    }

def main(dG = True, polyfeat = False):
    ext_cat = [id_to_id[i] for i in ['IPr-Pd-DMS','BIAN-IMes-Pd-DMS','IPentCl-Pd-DMS']]
    ext_cat = ext_cat+['comp_5','comp_6'] #, 'comp_9','comp12','comp_3']
    rs = 42
    #dG = True
    #polyfeat = False

    data, _ = read_experimental_data_DMS_new()

    df = pd.read_csv('/mnt/2_descriptors/df_pruned2_no_cis_pdcl3.csv', index_col = 0)
    df = df.dropna(axis=1)
    
    for c in df.columns:
        df[f'{c}**2']=[i**2 for i in df[c]]

    expcols = data.columns.tolist()
    featcols = df.columns.to_list()

    df = df.reset_index()
    df.columns = ['ID']+featcols
    data = data.merge(df, on='ID')


    featcols = [i for i in featcols if 'BP86' not in i]

    a = list(powerset(featcols))
    af1 = [i for i in a if ('NBO' in str(i))&('IBO' in str(i))]
    af2 = [i for i in a if ('VP_COSMO' in str(i))&('VP_GAS' in str(i))]
    af3 = [i for i in a if ('_SVP_' in str(i))&('_TZVP_' in str(i))]
    ared = list(((set(a)-set(af1))-set(af2))-set(af3))

    print(len(featcols))
    print(len(ared))



    expcols_yield = [i for i in expcols if ('Yield' in i) and ('std' not in i)]

    expcols_yield = [i for i in expcols_yield if i not in ['Yield_Toluene_K2CO3',
                                                           'Yield_Toluene_K3PO4', 
                                                           'Yield_Me-THF_K2CO3', 
                                                           'Yield_iPrOAc_K2CO3',
                                                           'Yield_EtOH_Cs2CO3']]

    expcols_yield_std = [i for i in expcols if ('Yield' in i) and ('std' in i)]

    expcols_yield_std = [i for i in expcols_yield_std if i not in ['Yield_std_Toluene_K2CO3',
                                                                   'Yield_std_Toluene_K3PO4', 
                                                                   'Yield_std_Me-THF_K2CO3', 
                                                                   'Yield_std_iPrOAc_K2CO3',
                                                                   'Yield_std_EtOH_Cs2CO3']]

    data_yield = data.loc[:,['ID']+expcols_yield+expcols_yield_std+featcols]

    dpp=data_yield
    ecol=expcols_yield
    ecol_std= expcols_yield_std
    labdp ='Yield'

    scal=StandardScaler()
    labscal='standard'
    varthr=8

    dp = dpp.copy()
    ##########################polyfeats
    if polyfeat:
        dp, feature_names = calc_polyfeat(dp, scal, ext_cat, featcols, varthr, [c], [c_std])
    else:
        feature_names = featcols
    ##################################
    #SCALE FEATURES
    scaler = scal
    dp.loc[~dp.ID.isin(ext_cat),feature_names] = scaler.fit_transform(dp[~dp.ID.isin(ext_cat)][feature_names])
    dp.loc[dp.ID.isin(ext_cat),feature_names] = scaler.transform(dp[dp.ID.isin(ext_cat)][feature_names])
    
    
    
    res_dict_list = []
    with Pool(processes=cpu_count()) as pool:
        results = pool.map(process_feature, [
            (list(featcolsselected), c, c_std, dp, ext_cat, dG, labdp, labscal)
            for c, c_std in zip(ecol, ecol_std)
            for featcolsselected in ared[:5000000]
        ])
        res_dict_list.extend(results)

    res_df = pd.DataFrame(res_dict_list)
    res_df.to_csv(f'/mnt/results/brute_force_up2.csv')

if __name__ == "__main__":
    main(dG = False)