from rdkit import Chem
from molpher.core import MolpherMol, MolpherAtom
from molpher.core.morphing.operators import MorphingOperator
import random
 
class AddFragment(MorphingOperator):
    def __init__(self, fragment, open_atoms_frag, oper_name):
        super(AddFragment, self).__init__()
        self._name = oper_name
        self._fragment = fragment
        self._open_atoms_frag = open_atoms_frag
        self._orig_rdkit = None
        self._open_atoms = []

    def setOriginal(self, mol):
        super(AddFragment, self).setOriginal(mol)
        if self.original:
            self._orig_rdkit = self.original.asRDMol()
            self._open_atoms = []
            for atm_rdkit, atm_molpher in zip(self._orig_rdkit.GetAtoms(), self.original.atoms):
                free_bonds = atm_rdkit.GetImplicitValence()
                if free_bonds >= 1 and not (MolpherAtom.NO_ADDITION & atm_molpher.locking_mask):
                    self._open_atoms.append(atm_rdkit.GetIdx())

    def morph(self):
        combo_mol = Chem.EditableMol(Chem.CombineMols(
        self._orig_rdkit
        , self._fragment
        ))
        atom_orig = self._open_atoms[random.choice(range(len(self._open_atoms)))]
        isel=random.choice(range(len(self._open_atoms_frag)))
        atom_frag = len(self.original.atoms) + self._open_atoms_frag[isel]
        combo_mol.AddBond(atom_orig, atom_frag, order=Chem.rdchem.BondType.SINGLE)
        combo_mol = combo_mol.GetMol()
        Chem.SanitizeMol(combo_mol)

        ret = MolpherMol(other=combo_mol)
        for atm_ret, atm_orig in zip(ret.atoms, self.original.atoms):
            atm_ret.locking_mask = atm_orig.locking_mask
        return ret

    def getName(self):
        return self._name
    
class DelFragment(MorphingOperator):
    def __init__(self, onbits, oper_name):
        super(DelFragment, self).__init__()
        self._name = oper_name
        self._onbits = onbits
        self._orig_rdkit = None

    def setOriginal(self, mol):
        super(DelFragment, self).setOriginal(mol)
        if self.original:
            self._orig_rdkit = self.original.asRDMol()

    def _remove_atoms_by_nr(self,emol: Chem.rdchem.EditableMol, list_atomnrs):
                        # -> Mol:
            """
            Input is a list of integers representing atom numbers to be removed from the molecule.
            Output is the remaining Mol.
            """
            list_atomnrs_sorted = sorted(list_atomnrs, reverse=True)
            list_frags = []
            for atomnr in list_atomnrs_sorted:
                emol.RemoveAtom(atomnr)
            return(emol.GetMol())

    # Modify the function to accept the pattern molecule directly
    def remove_substr_matches_return_frag(self,mol, pattern_mol):
        # -> List[Mol]:
        """
        Iteratively removes substructure matches, one-by-one, and returns fragment each time.
        Substructure matches defined by SMILES.
        Returns list of remaining fragments from each substructure removal operation.
        """
        matches = mol.GetSubstructMatches(pattern_mol)
        matches = [list(m) for m in matches]  # convert to list of lists
        list_frags = []
        for m in matches:
            emol = Chem.EditableMol(mol)
            self._remove_atoms_by_nr(emol, m)
            list_frags.append(emol.GetMol())
        print(f"{len(list_frags)} substructure match(es) found.")
        if len(list_frags) == 0:
            print("No matches found, returning input molecule")
            return [mol]
        return list_frags


    def morph(self):
        mol=self._orig_rdkit
        onbits =self._onbits
        # Randomly select a fragment and get the central atom and radius
        non_stop = True
        while(non_stop):
            try:
                random_fragment_key = random.choice(list(onbits.keys()))
                central_atom, radius = random.choice(onbits[random_fragment_key])

                # Find the atom environment using the randomly selected central atom and radius
                env = Chem.FindAtomEnvironmentOfRadiusN(mol, radius, central_atom)

                # Create atom map for the selected fragment
                amap = {}
                submol = Chem.PathToSubmol(mol, env, atomMap=amap)

                if central_atom not in amap:
                    print(f"Error: Central atom {central_atom} not found in the atom map.")

                # Get the SMILES form of the selected fragment using the central atom from the atom map
                smiles_fragment = Chem.MolToSmiles(submol, rootedAtAtom=amap[central_atom], canonical=False)

                # Print the results for the random fragment
                print(f"Randomly selected fragment key: {random_fragment_key}")
                print(f"Atom map: {amap}")
                print(f"Number of atoms in the randomly selected fragment: {submol.GetNumAtoms()}")
                print(f"SMILES form of the selected fragment: {smiles_fragment}")
                print(f"Central Atom Index: {central_atom}")
                print(f"Radius: {radius}")

                # Use the randomly selected fragment for pattern_smiles
                pattern_mol = Chem.MolFromSmiles(smiles_fragment)
                if pattern_mol==None : pattern_mol = Chem.MolFromSmiles(smiles_fragment.upper())

                # Visualize substructure matches
                matches = mol.GetSubstructMatches(pattern_mol)
                non_stop = False
                print("!!!!!!!!!!!!!!!")
            except:
                print("?????????????")
                non_stop = True
        print(f"There are {len(matches)} substructure matche(s) with the following atom numbers: {matches}.")

        # Add atom numbers to visualize better
        for i, atom in enumerate(mol.GetAtoms()):
            atom.SetProp("molAtomMapNumber", str(atom.GetIdx()))
  
        # Use the randomly selected fragment for myfrags
        myfrags = self.remove_substr_matches_return_frag(mol, pattern_mol)
        results= []
        for ifr in myfrags :
            results.append(MolpherMol(Chem.MolToSmiles(ifr)))
        return results



    def getName(self):
        return self._name    