# Basic Libs
import os, shutil, glob, time #, re
import pandas as pd
import numpy as np
from functools import wraps

################################################################################ Text editing
# function to convert to superscript
def super_script(string):
    normal = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+-=()"
    super_s = "ᴬᴮᶜᴰᴱᶠᴳᴴᴵᴶᴷᴸᴹᴺᴼᴾQᴿˢᵀᵁⱽᵂˣʸᶻᵃᵇᶜᵈᵉᶠᵍʰᶦʲᵏˡᵐⁿᵒᵖ۹ʳˢᵗᵘᵛʷˣʸᶻ⁰¹²³⁴⁵⁶⁷⁸⁹⁺⁻⁼⁽⁾"
    res = string.maketrans(''.join(normal), ''.join(super_s))
    return string.translate(res)

# function to convert to subscript
def sub_script(string):
    normal = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+-=()"
    sub_s = "ₐ₈CDₑբGₕᵢⱼₖₗₘₙₒₚQᵣₛₜᵤᵥwₓᵧZₐ♭꜀ᑯₑբ₉ₕᵢⱼₖₗₘₙₒₚ૧ᵣₛₜᵤᵥwₓᵧ₂₀₁₂₃₄₅₆₇₈₉₊₋₌₍₎"
    res = string.maketrans(''.join(normal), ''.join(sub_s))
    return string.translate(res)

# function to format the xyz line upon file conversion
def format_xyz(atom, coord, decimal=5):
    for i in range(len(coord)):
        if coord[i] < 0:
            coord[i] = f'{coord[i]:.{decimal}f}'
        else:
            coord[i] = f' {coord[i]:.{decimal}f}'
    if len(atom) == 1:
        atom = f' {atom} '
    else:
        atom = f' {atom}'
    return f'{atom}    {coord[0]}    {coord[1]}    {coord[2]}\n'



################################################################################ File & System
# function to retrieve files and/or directories paths from a given root
def get_main(path, select):
    for root, dirs, files in os.walk(path):
        if root == path:
            if select == 'dirs':
                return [os.path.join(root, d) for d in dirs]
            if select == 'files':
                return [os.path.join(root, f) for f in files]
            if select == 'all':
                return [os.path.join(root, i) for i in os.listdir(root)]

# function to retrieve files and/or directories that match a given pattern string
def glob_files(path, pattern, select='all'):
    if select == 'all': # loop through all inner dirs and files of a root
        output = list()
        for root, dirs, files in os.walk(path):
            path_pattern = os.path.join(root, pattern)
            output.extend(glob.glob(path_pattern))
        return output
    elif select == 'root': # only the root path
        path_pattern = os.path.join(root, pattern)
        return glob.glob(path_pattern)

# makes sure the path exists and creates/clears it when don't
def secure_path(path, clear=False):
    if not os.path.exists(path):
        os.makedirs(path)
    elif os.path.exists(path) and clear:
        shutil.rmtree(path)
        os.makedirs(path)
    else:
        while True:
            ans = input(f'<{path}> already exists, do you want to overwirte? [y/n]')
            if ans == 'y':
                shutil.rmtree(path)
                os.makedirs(path)
                break
            elif ans == 'n':
                break
            else:
                print('Wrong input')

# extracts only file name from path string (removes extension)
def get_name(path):
    file = os.path.split(path)[1]
    name = os.path.splitext(file)[0]
    return name



################################################################################ DataFrame & Distributions
# Displays the full DataFrame without cuting off columns and rows
def show_all(pandas_df, col=None, row=None):
    with pd.option_context('display.max_columns', col, 'display.max_rows', row):
        return display(pandas_df)

# creates and updates a dataframe backup as a .csv file
def df_backup(df, backup_path, backread=True):
    if not df.empty:
        df.to_csv(backup_path)
        time.sleep(5)
    # Creates a new DataFrame from the backup file (solving the fragmentation issue)
    if backread:
        return pd.read_csv(backup_path, index_col=0)

# support function, returns a pandas dataframe containing the splited molecule name and conformer id
def get_conformers(pd_index):
    return pd.DataFrame([idx.rsplit('_', 1) for idx in pd_index],
                        columns=['molecule_name', 'conformer_id'],
                        index=pd_index)

# support function, returns a pandas series containing the boltzmann wheights
def boltz(pd_serie, R=0.008314463, T=298.15):
    conformers = get_conformers(pd_serie.index)
    output = list()
    for group, df in conformers.groupby('molecule_name'):
        probability_func = np.exp(-(pd_serie[df.index] - pd_serie[df.index].min())/(R*T))
        output.append(probability_func/probability_func.sum())
    return pd.concat(output, axis=0)

# returns a dataframe containing the boltzman distribution for the features based on their conformer space
def dist_boltz(pd_dataframe, reference_column, name=None):
    conformers = get_conformers(pd_dataframe.index)
    boltz_weights = boltz(pd_dataframe[reference_column])
    dist = pd_dataframe.drop(columns=reference_column).multiply(boltz_weights, axis='index').groupby(conformers['molecule_name']).agg('sum')
    dist.index = dist.index.rename(None)
    if name:
        dist.columns = [f'{name}_{col}' for col in dist.columns]
    return dist

# returns a dataframe containing a distribution for the features based on the "MODE" values of a reference column
# can be used to get LowestEnergyConformer or HighestEnergyConformer // mode=min, max (didn't test other functions)
def dist_select_conf(pd_dataframe, reference_column, mode, name=None):
    conformers = get_conformers(pd_dataframe.index)
    dist = pd_dataframe.drop(columns=reference_column).loc[pd_dataframe.groupby(conformers['molecule_name'])[reference_column].transform(mode) == pd_dataframe[reference_column]]
    dist.index = [idx.rsplit('_', 1)[0] for idx in dist.index]
    if name:
        dist.columns = [f'{name}_{col}' for col in dist.columns]
    return dist

# returns a dataframe containing a distribution for the features based on the "MODE" values for each column
# can be used to get the minimum or maximum value of each feature // mode=min, max, average, sum... other statstics
def dist_stat(pd_dataframe, mode, name=None):
    conformers = get_conformers(pd_dataframe.index)
    dist = pd_dataframe.groupby(conformers['molecule_name']).agg(mode)
    dist.index = dist.index.rename(None)
    if name:
        dist.columns = [f'{name}_{col}' for col in dist.columns]
    return dist



################################################################################ Model Training
# Funtion to iterate through hyperparameters
def hyper_tuning(hyper_dict):
    if not hyper_dict:
        yield dict()
    else:
        key_to_iterate = list(hyper_dict.keys())[0]
        next_round = {param: hyper_dict[param] for param in hyper_dict.keys() if param != key_to_iterate}
        assert isinstance(hyper_dict[key_to_iterate], list), f'{key_to_iterate} is not a list.'
        for val in hyper_dict[key_to_iterate]:
            for param in hyper_tuning(next_round):
                temp_result = param
                temp_result[key_to_iterate] = val
                yield temp_result



################################################################################ Decorators
# handles the aplication of a function to a single file or a list of files
def arg_file_handler(func):
    @wraps(func)
    def wrapper(arg, *args, **kwargs):
        if isinstance(arg, str):  # Single file path
            return func(arg, *args, **kwargs)
        elif isinstance(arg, list):  # List of file paths
            results = []
            for file in arg:
                result = func(file, *args, **kwargs)
                results.append(result)
            return results
        else:
            raise TypeError("Invalid argument type. Expected string or list.")
    return wrapper

# dont recall why i made that for lol
def np_array_decorator(func):
    def wrapper(*args, **kwargs):
        new_args = [np.array(arg) for arg in args]
        return func(*new_args, **kwargs)
    return wrapper



################################################################################ to implement latter
def pack():
    print('incomplet function...')
    return

def unpack():
    print('incomplet function...')
    return
    """
    args = (path, select='dirs')
    pack_folds = get_main(path)
    for fold in pack_folds:
        transfer = get_main(fold, select=select)
        for old in transfer:
            new = os.path.join(path, os.path.split(old)[1])
            shutil.move(old, new)
        os.rmdir(fold)
    """
