import pandas as pd
import glob
import os
import argparse
import numpy as np
import pickle
from time import time
import pip

def calc_sterimol(file, lab, list_atom = [1,2,3,4,5], exclude = False):
    pairs = [(i, j) for i in list_atom for j in list_atom if i != j]
    
    atom1,atom2,Ls,Bmins,Bmaxs=[],[],[],[],[]
    
    #neglect_h = []
    rotate_mol = []
    radii = []

    for i,j in pairs:
        for rad in [3.5,4,4.5,5]: #'radius',3.5
            for nH in [True]:
                for nr in [False,True]:
                    mol = db.dbstep(file,
                                    atom1=i,
                                    atom2=j,
                                    commandline=True,    #set to False for pymol output
                                    verbose=False,
                                    sterimol=True,
                                    measure = 'classic',
                                    surface = 'vdw',
                                    noH = nH,
                                    addmetals=True,
                                    norot=nr,
                                    r=rad,
                                    exclude=exclude,
                                   )
                                    #scan="0.0:6.0:0.5") #removed because constant
                    L = mol.L
                    Bmin = mol.Bmin
                    Bmax = mol.Bmax
                    atom1.append(i)
                    atom2.append(j)
                    Ls.append(L)
                    Bmins.append(Bmin)
                    Bmaxs.append(Bmax)

                    #neglect_h.append(nH)
                    radii.append(rad)
                    rotate_mol.append(nr)
        
    #create dataframe to collect values    
    sterimol_df = pd.DataFrame({'Atom1':atom1,'Atom2':atom2,
                                'L':Ls,'Bmin':Bmins,'Bmax':Bmaxs,
                                'radius':radii, 'rotate':rotate_mol})
    
    sterimol_df = pd.melt(sterimol_df, id_vars=['Atom1','Atom2','radius','rotate'], 
                          value_vars=['L','Bmin','Bmax'])
    
    sterimol_df['var_name']=[f'{v}_{a1}_{a2}_{rad}_{r}' for v,a1,a2,rad,r in zip(sterimol_df['variable'],
                                                                               sterimol_df['Atom1'],
                                                                               sterimol_df['Atom2'],
                                                                               sterimol_df['radius'],
                                                                               sterimol_df['rotate'],
                                                                              )]
    sterimol_df = sterimol_df.loc[:,['var_name','value']]
    sterimol_df.columns = ['var_name',f'{lab}']
    sterimol_df = sterimol_df.set_index('var_name')

    sterimol_df

    return sterimol_df
    
def calc_vol(file, lab, list_atom = [1,2,3,4,5], exclude = False):
    atom1=[]
    vbur = []
    vshell = []
    
    #neglect_h = []
    rotate_mol = []
    radii = []

    for i in list_atom:
        for rad in [3.5,4,4.5,5]:
            for nH in [True]:
                for nr in [False,True]:

                    mol = db.dbstep(file,
                                    atom1=i,
                                    commandline=True,
                                    verbose=False,
                                    volume=True,
                                    measure = 'classic',
                                    surface = 'vdw',
                                    noH = nH,
                                    addmetals=True,
                                    norot=nr,
                                    r = rad,
                                    exclude=exclude,
                                   )
                                    #measure='classic')
                                    #scan="0.0:6.0:0.5")

                    atom1.append(i)
                    vbur.append(mol.bur_vol)
                    vshell.append(mol.bur_shell)

                    #neglect_h.append(nH)
                    radii.append(rad)
                    rotate_mol.append(nr)
    
    #create dataframe to collect values
    vol_df = pd.DataFrame({'Atom1':atom1,
                           '%V_Bur':vbur,'%S_Bur':vshell,
                           'radius':radii, 'rotate':rotate_mol})
    
    vol_df = pd.melt(vol_df, id_vars=['Atom1','radius','rotate'],value_vars=['%V_Bur','%S_Bur'])
    
    vol_df['var_name']=[f'{v}_{a1}_{rad}_{r}' for v,a1,rad,r in zip(vol_df['variable'],
                                                                  vol_df['Atom1'],
                                                                  vol_df['radius'],
                                                                  vol_df['rotate'],
                                                                 )]
    vol_df = vol_df.loc[:,['var_name','value']]
    vol_df.columns = ['var_name',f'{lab}']
    vol_df = vol_df.set_index('var_name')

    return vol_df

def install(package):
    if hasattr(pip, 'main'):
        pip.main(['install', package])
    else:
        pip._internal.main(['install', package])


if __name__ == '__main__':
    install('dbstep')
    import dbstep.Dbstep as db
    parser = argparse.ArgumentParser()
    
    #ligand
    #cis_pdcl2
    #cis_pdcl2_sme2
    #trans_pdcl2
    #trans_pdcl2_sme2
 
    parser.add_argument('--path_ligand_only', default=[f'/mnt/1_xyz_files/comp_{i}/ligand/' for i in range(1,20)])#, type=int)
    parser.add_argument('--path_PdCl2_cis_complexes', default=[f'/mnt/1_xyz_files/comp_{i}/cis_pdcl2/' for i in range(1,20)])
    parser.add_argument('--path_PdCl2_trans_complexes', default=[f'/mnt/1_xyz_files/comp_{i}/trans_pdcl2/' for i in range(1,20)])
    parser.add_argument('--path_PdCl2_SMe2_cis_complexes', default=[f'/mnt/1_xyz_files/comp_{i}/cis_pdcl2_sme2/' for i in range(1,20)])
    parser.add_argument('--path_PdCl2_SMe2_trans_complexes', default=[f'/mnt/1_xyz_files/comp_{i}/trans_pdcl2_sme2/' for i in range(1,20)])
    
    parser.add_argument('--path_pickle', default=r'/mnt/2_descriptors/dbstep_descr_raw.pkl')

    args = parser.parse_args()

    start =time()

    paths_dfs={}
    for path_lst, atom_idx, keyname in zip([args.path_PdCl2_SMe2_trans_complexes,
                                            args.path_PdCl2_SMe2_cis_complexes,
                                            args.path_PdCl2_trans_complexes,
                                            args.path_PdCl2_cis_complexes,
                                            args.path_ligand_only,
                                           ],
                                           [[1,2,3,4,5,6,7,8,11,12,13],
                                            [1,2,3,4,5,6,7,8,11,12,13],
                                            [1,2,3,4,5,6,7,8],
                                            [1,2,3,4,5,6,7,8],
                                            [1,2,3,4,5],
                                           ], ['pdcl2_sme2_trans',
                                               'pdcl2_sme2_cis',
                                               'pdcl2_trans',
                                               'pdcl2_cis',
                                               'ligand',
                                               ]):#noH
        
        xyz_files = []
        for path in path_lst:
            search_pattern = os.path.join(path, '*.xyz')
            xyz_files = xyz_files + glob.glob(search_pattern)
        
        xyz_files = sorted(xyz_files)
        
        dfs = []
        for file in xyz_files:
            with open(file, 'r') as f:
                file_content = f.read()

            lab = file.replace('/mnt/1_xyz_files/','').replace('.xyz','')
            sterimol_df = calc_sterimol(file, lab,list_atom=atom_idx)
            vol_df = calc_vol(file, lab,list_atom=atom_idx)

            dbstep_descr_df = pd.concat([sterimol_df,vol_df], axis=0).T
            dfs.append(dbstep_descr_df)

        path_df = pd.concat(dfs, axis=0)
        path_df = path_df.loc[:, (path_df != path_df.iloc[0]).any()] 
        paths_dfs[keyname]=path_df

        dfs = []
        if ('pdcl2_o' in keyname) or ('pdcl2_p' in keyname): #calculate same except 2Cl [7,8]
            for file in xyz_files:
                with open(file, 'r') as f:
                    file_content = f.read()

                lab = file.replace('/mnt/1_xyz_files/','').replace('.xyz','').replace('pdcl2','pd')
                sterimol_df = calc_sterimol(file, lab,list_atom=atom_idx, exclude="7,8")
                vol_df = calc_vol(file, lab,list_atom=atom_idx, exclude="7,8")

                dbstep_descr_df = pd.concat([sterimol_df,vol_df], axis=0).T
                dfs.append(dbstep_descr_df)
            path_df = pd.concat(dfs, axis=0)
            path_df = path_df.loc[:, (path_df != path_df.iloc[0]).any()] 
            paths_dfs[keyname.replace('pdcl2_','pd_')]=path_df

        dfs = []
        if 'sme2' in keyname: #calculate same except CH3 [12,13]
            for file in xyz_files:
                with open(file, 'r') as f:
                    file_content = f.read()

                lab = file.replace('/mnt/1_xyz_files/','').replace('.xyz','').replace('sme2','s')
                sterimol_df = calc_sterimol(file, lab,list_atom=atom_idx, exclude="12,13")
                vol_df = calc_vol(file, lab,list_atom=atom_idx, exclude="12,13")

                dbstep_descr_df = pd.concat([sterimol_df,vol_df], axis=0).T
                dfs.append(dbstep_descr_df)
            path_df = pd.concat(dfs, axis=0)
            paths_dfs[keyname.replace('sme2','s')]=path_df


        with open(args.path_pickle, "wb") as output_file:
            pickle.dump(paths_dfs, output_file)
        print(time()-start)