import pandas as pd
import numpy as np
from itertools import product
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.metrics import r2_score, accuracy_score, balanced_accuracy_score, confusion_matrix

from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier
from sklearn.svm import SVR, SVC
from sklearn.linear_model import LinearRegression, LogisticRegression
from sklearn.neural_network import MLPRegressor, MLPClassifier
import numpy as np

import matplotlib.pyplot as plt
import seaborn as sns
#import shap
from sklearn.metrics import confusion_matrix, r2_score, balanced_accuracy_score

seed = 42
np.random.seed(seed)

def smiles_to_ecfp(smiles, radius=2, nBits=512):
    from rdkit import Chem
    from rdkit.Chem import AllChem

    mol = Chem.MolFromSmiles(smiles)
    if mol:
        fp = AllChem.GetMorganFingerprintAsBitVect(mol, radius, nBits=nBits)
        return list(map(int, fp.ToBitString()))
    return [0] * nBits

def get_features(data, feature_type='OHE', catalysts_path='/mnt/data/catalysts.xlsx', sheet_name='Sheet1', train_col = None):
    cat = pd.read_excel(catalysts_path, sheet_name=sheet_name)
    cat.columns = ['Catalyst'] + list(cat.columns[1:])
    cat = cat.loc[:, ['Catalyst', 'Core', 'R','SMILES', 'SMILES Core','SMILES R']]
    data = cat.merge(data, on='Catalyst')
    
    if feature_type == 'OHE':
        to_dummies=['Solvent', 'Base','Core', 'R']
        data_dum = pd.get_dummies(data.loc[:,to_dummies], columns=to_dummies)
        data = pd.concat([data, data_dum], axis=1) 
    else:
        to_dummies=['Solvent', 'Base']
        data_dum = pd.get_dummies(data.loc[:,to_dummies], columns=to_dummies)
        data = pd.concat([data, data_dum], axis=1)
        col_data = list(data.columns)

        if feature_type == 'ECFPs':
            ecfps = data['SMILES'].apply(lambda x: smiles_to_ecfp(x, radius=2, nBits=512))
            ecfp_df = pd.DataFrame(ecfps.tolist(), columns=[f'Bit_{i}' for i in range(512)])
            data = pd.concat([data[col_data], ecfp_df], axis=1)
            if train_col:
                data = data.loc[:,train_col]
            else:
                data = data.loc[:, (data != data.iloc[0]).any()]

        elif feature_type == 'ECFPs_Core_R':
            ecfps_core = data['SMILES Core'].apply(lambda x: smiles_to_ecfp(x, radius=2, nBits=512))
            ecfps_r = data['SMILES R'].apply(lambda x: smiles_to_ecfp(x, radius=2, nBits=512))
            ecfp_core_df = pd.DataFrame(ecfps_core.tolist(), columns=[f'core_bit_{i}' for i in range(512)])
            ecfp_r_df = pd.DataFrame(ecfps_r.tolist(), columns=[f'r_bit_{i}' for i in range(512)])
            data = pd.concat([data[col_data], ecfp_core_df, ecfp_r_df], axis=1)
            data = data.loc[:, (data != data.iloc[0]).any()]
    if train_col:
        return data
    else:
        data = data.drop(['SMILES', 'SMILES Core','SMILES R'],axis=1)
        return data


##########################

def train_test_preparation(train_df, test_df, feature_names, target):
    X_train = train_df[feature_names].values
    X_test = test_df[feature_names].values
    y_train = train_df[target].values
    y_test = test_df[target].values
    return X_train, y_train, X_test, y_test


def do_grid_search_cv(modelname, config, X_train, y_train, scoring):
    model = config['model']
    params = config['params']

    grid_search = GridSearchCV(model, params, cv=int(len(X_train)/5), scoring=scoring, n_jobs=-1)
    grid_search.fit(X_train, y_train)

    # Get the best model
    best_model = grid_search.best_estimator_ #already refit
    print(f'{modelname} - Best Parameters: {grid_search.best_params_}')
    return best_model


def do_single_model(modality, model_name, featnames, config, X_train, y_train, X_test, y_test, target, test_df,
                    train_df):
    modelnamedict = {}
    if modality == 'regression':
        scoring = 'r2'
    else:
        scoring = 'balanced_accuracy'
    best_model = do_grid_search_cv(model_name, config, X_train, y_train, scoring=scoring)
    # Make predictions on the test set
    y_pred_test = best_model.predict(X_test)
    y_pred_train = best_model.predict(X_train)
    score_test, score_train = evaluate_model(modality, y_test, y_pred_test, y_train, y_pred_train)
    #if modality == 'regression':
    #    plot_predictions(y_train, y_pred_train, y_test, y_pred_test, model_name+'_'+target)
    #else:
    #    plot_confusion_matrix(y_train, y_pred_train, y_test, y_pred_test, model_name+'_'+target)
    # explainer = shap_explainer(best_model, X_train, X_test, featnames, modality)

    modelnamedict['best_model'] = best_model
    modelnamedict['test_df'] = test_df.copy()
    modelnamedict['train_df'] = train_df.copy()
    if len(X_test)>1:
        modelnamedict['score_test'] = score_test
    else:
        print(f'pred = {y_pred_test}, true = {y_test}')
        modelnamedict['score_test'] = y_test-y_pred_test
    modelnamedict['score_train'] = score_train
    # modelnamedict['shap_explainer']=explainer

    modelnamedict['test_df']['pred'] = y_pred_test
    modelnamedict['train_df']['pred'] = y_pred_train

    return modelnamedict


def random_loop(models, datacopy, featnames, target, modality):
    models_dict = {}
    for seed in [42, 66, 10, 5, 304]:
        models_dict[seed] = {}
        train_df, test_df = train_test_split(datacopy, test_size=0.2, random_state=seed)
        X_train, y_train, X_test, y_test = train_test_preparation(train_df, test_df, featnames, target)
        for model_name, config in models.items():
            models_dict[seed][model_name] = do_single_model(modality, model_name, featnames, config, X_train, y_train,
                                                            X_test, y_test, target, test_df, train_df)

    return models_dict


def stratifed_random_loop(models, datacopy, featnames, target, modality):
    models_dict = {}
    for seed in [42, 66, 10, 5, 304]:
        models_dict[seed] = {}
        train_df, test_df = train_test_split(datacopy, test_size=0.2, random_state=seed,
                                             stratify=datacopy['Catalyst'].values)
        X_train, y_train, X_test, y_test = train_test_preparation(train_df, test_df, featnames, target)
        for model_name, config in models.items():
            models_dict[seed][model_name] = do_single_model(modality, model_name, featnames, config, X_train, y_train,
                                                            X_test, y_test, target, test_df, train_df)

    return models_dict


def loco_loop(models, datacopy, featnames, target, modality):
    models_dict = {}
    for cat in set(datacopy['Catalyst'].values):
        print(cat)
        models_dict[cat] = {}
        train_df = datacopy[datacopy['Catalyst'] != cat]
        test_df = datacopy[datacopy['Catalyst'] == cat]

        X_train, y_train, X_test, y_test = train_test_preparation(train_df, test_df, featnames, target)
        for model_name, config in models.items():
            models_dict[cat][model_name] = do_single_model(modality, model_name, featnames, config, X_train, y_train,
                                                           X_test, y_test, target, test_df, train_df)

    return models_dict


def do_models(data, target, featnames, split_strategy, modality):
    # target: Conversion or Yield
    # split_strategy: 'random','random_cat_stratified','leave_one_cat_out'
    # modality = 'regression' or 'classification'
    datacopy = data.copy()
    if target == 'Conversion':
        datacopy = datacopy.dropna()

    if modality == 'classification':
        datacopy[target] = (datacopy[target].values > np.median(datacopy[target].values)) * 1

    print(f'target = {target}; # samples = {len(datacopy)}')

    # Define models and their hyperparameter grids
    if modality == 'regression':
        models = {
            'RandomForest': {
                'model': RandomForestRegressor(random_state=42),
                'params': {
                    'n_estimators': [10, 50],
                    'max_depth': [None, 2, 5],
                    'min_samples_split': [2, 5]
                }
            },
            'SVR': {
                'model': SVR(),
                'params': {
                    'C': [1, 10, 20, 30],
                    'kernel': ['linear', 'rbf']
                }
            },
            'LinearRegression': {
                'model': LinearRegression(),
                'params': {}
            },
            #'MLP': {
            #    'model': MLPRegressor(random_state=42),
            #    'params': {
            #        'alpha': [0.001],
            #        'learning_rate_init': [0.001, 0.01, 0.1],
            #        'activation': ['logistic', 'tanh', 'relu'],
            #        'hidden_layer_sizes': [(10,), (5,)],
            #        'solver': ['sgd', 'adam'],
            #        'early_stopping': [True]
            #    }
            #}
        }
    else:
        models = {
            'RandomForest': {
                'model': RandomForestClassifier(random_state=42),
                'params': {
                    'n_estimators': [10, 50],
                    'max_depth': [None, 2, 5],
                    'min_samples_split': [2, 5]
                }
            },
            'SVM': {
                'model': SVC(),
                'params': {
                    'C': [0.1, 1, 10, 20],
                    'kernel': ['linear', 'rbf'],
                    'probability': [True]
                }
            },
            'LogisticRegression': {
                'model': LogisticRegression(),
                'params': {
                    'C': [0.1, 1, 10],
                    'penalty': ['l1', 'l2']
                }
            },
            'MLP': {
                'model': MLPClassifier(random_state=42),
                'params': {
                    'alpha': [0.001],
                    'learning_rate_init': [0.001, 0.01, 0.1],
                    'activation': ['logistic', 'tanh', 'relu'],
                    'hidden_layer_sizes': [(10,), (5,)],
                    'solver': ['sgd', 'adam'],
                    'early_stopping': [True]
                }
            }
        }

    if split_strategy == 'random':
        # random split loop
        models_dict = random_loop(models, datacopy, featnames, target, modality)

    elif split_strategy == 'random_cat_stratified':
        # random stratified split loop
        models_dict = stratifed_random_loop(models, datacopy, featnames, target, modality)

    elif split_strategy == 'leave_one_cat_out':
        # leave one cat out loop
        models_dict = loco_loop(models, datacopy, featnames, target, modality)

    return models_dict

##########################################
def evaluate_model(modality, y_test, y_pred_test, y_train, y_pred_train):
    if modality=='regression':
        # Evaluate the model
        score_test = r2_score(y_test, y_pred_test)
        score_train = r2_score(y_train, y_pred_train)
    else:
        score_test = balanced_accuracy_score(y_test, y_pred_test)
        score_train = balanced_accuracy_score(y_train, y_pred_train)
    print(f'{modality.upper()} Score test: {str(score_test)[:5]}, Score train: {str(score_train)[:5]}')
    return score_test,score_train


def plot_predictions(y_train, y_pred_train, y_test, y_pred_test, title):
    plt.figure(figsize=(3, 3))
    plt.scatter(y_train, y_pred_train, label='Train')
    plt.scatter(y_test, y_pred_test, label='Test')
    plt.xlabel('Actual')
    plt.ylabel('Predicted')
    plt.title(title)
    plt.legend()
    plt.show()


def plot_confusion_matrix(y_train, y_pred_train,y_test, y_pred_test, title):
    cm_test=confusion_matrix(y_test,y_pred_test)
    cm_train=confusion_matrix(y_train,y_pred_train)

    gig, axes = plt.subplots(1,2,figsize =(7,3))
    sns.heatmap(cm_train, annot=True, annot_kws={"size": 10},ax = axes[0]) # font size
    sns.heatmap(cm_test, annot=True, annot_kws={"size": 10},ax = axes[1]) # font size
    plt.title(title)
    plt.show()


def shap_explainer(best_model, X_train, X_test, featnames, modality):
    explainer = shap.KernelExplainer(best_model.predict, X_train, feature_names=featnames)
    shap_values = explainer(X_test)
    if modality == 'regression':
        shap.plots.beeswarm(shap_values)

    else:
        shap.plots.beeswarm(shap_values[:, :, 0])

    return explainer
