# Import libraries
import os
import pandas as pd
import numpy as np
from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit.Chem.Draw import rdMolDraw2D
# Silence RDKit (AMÉM DEUS)
from rdkit import RDLogger
RDLogger.DisableLog('rdApp.*')
from IPython.display import SVG, display
from openbabel import openbabel
from tqdm.auto import tqdm

## reviewed funcs:
def read_sdf(file_path):
    return [mol for mol in Chem.SDMolSupplier(file_path)]

def read_txt(file_path):
    with open(file_path) as handle:
        return [Chem.MolFromSmiles(smiles.strip()) for smiles in handle]

def create_filter(mols_series, smarts_patt):
    patt = Chem.MolFromSmarts(smarts_patt)
    return [mol.HasSubstructMatch(patt) for mol in mols_series]

def merge_filters(filters_df):
    return [row[1].any() for row in filters_df.iterrows()]

def display_molecules(mols_series, legend=None, prt=None, n_cols=5, mol_height=340, mol_width=240):
    assert isinstance(legend, pd.core.series.Series) or legend is None, 'Legend should be a pd.series'
    assert isinstance(prt, str) or prt is None, 'prt should be a path string'
    
    # Page config
    n_mols = len(mols_series.index)
    n_lines = int(np.ceil(n_mols/n_cols))
    page_height = mol_height * n_lines
    page_width = mol_width * n_cols
    
    # Display config
    d2d = rdMolDraw2D.MolDraw2DSVG(page_width, page_height, mol_width, mol_height)
    d2d.drawOptions().legendFraction= 0.25
    d2d.drawOptions().legendFontSize = 16
    d2d.drawOptions().minFontSize = 14
    d2d.drawOptions().drawMolsSameScale = False
    d2d.drawOptions().fixedBondLength = 20
    d2d.drawOptions().prepareMolsBeforeDrawing = True
    
    # Draw
    if legend is not None:
        d2d.DrawMolecules(list(mols_series), legends=list(legend))
    else:
        d2d.DrawMolecules(list(mols_series), legends=[str(i) for i in mols_series.index])
    d2d.FinishDrawing()
    d2d = d2d.GetDrawingText()
    
    # Print
    if prt is not None:
        with open(prt, mode='w') as file:
            file.write(d2d)
            SVG(d2d)
        return
    return display(SVG(d2d))

def convert_3d(mol_series, legend_series, embed=False):
    mol_series = mol_series.apply(Chem.AddHs)
    for i in tqdm(mol_series.index):
        mol_series[i].SetProp('_Name', legend_series[i])
        AllChem.Compute2DCoords(mol_series[i])
        if embed:
            AllChem.EmbedMolecule(mol_series[i])
            AllChem.MMFFOptimizeMolecule(mol_series[i])
    return mol_series

def obabel(input_string, input_format, output_format):
    obConversion = openbabel.OBConversion()
    obConversion.SetInAndOutFormats(input_format, output_format)
    obabel_mol = openbabel.OBMol()
    obConversion.ReadString(obabel_mol,input_string)
    return obConversion.WriteString(obabel_mol)

def convert_MolToXYZ(mol_series, out_path):
    for mol in mol_series:
        xyz_str = obabel(Chem.MolToMolBlock(mol), 'sdf', 'xyz')
        with open(os.path.join(out_path, f'{mol.GetProp("_Name")}.xyz'), 'w') as xyz:
            xyz.write(xyz_str)

def get_atoms_reference(mol_series, smarts, atoms_id):
    smarts = Chem.MolFromSmarts(smarts)
    return pd.DataFrame([mol.GetSubstructMatch(smarts) for mol in mol_series],
                        index=mol_series.index,
                        columns=atoms_id)

def get_prop(pdseries, prop):
    output = list()
    for mol in pdseries:
        try:
            output.append(mol.GetProp(prop))
        except:
            output.append(None)
    return output

def align_structures(mol_series, smiles_str):
    print('\nPlease wait while we align all structures:')
    patt = Chem.MolFromSmarts(smiles_str)
    AllChem.Compute2DCoords(patt)
    for i in tqdm(mol_series.index):
        AllChem.GenerateDepictionMatching2DStructure(mol_series[i], patt)

##################################################################### to implement latter
def display_match(df, smarts):
    print('incomplet function...')
    return
    """
    filter_mol = Chem.MolFromSmarts(smarts)
    for mol in df['RDKit']:
        if mol.GetSubstructMatches(filter_mol):
            print(mol.GetSubstructMatches(filter_mol))
    """

def conformer_complexity(df, classify=True):
    print('incomplet function...')
    return
    """
    df['sigma_bonds'] = [len(mol.GetSubstructMatches(Chem.MolFromSmarts('[a,A!R]-[!H,!F,!Cl,!Br,!I,$(A#A)]'))) for mol in df['RDKit']]
    df['ring_bonds'] = [len(mol.GetSubstructMatches(Chem.MolFromSmarts('[AR!r3]-[AR!r3]'))) for mol in df['RDKit']]
    if classify:
        df.sort_values(by=['sigma_bonds', 'ring_bonds'], inplace=True)
    """
