import numpy as np
from openbabel import openbabel as ob
import subprocess
from gridData import Grid

class SimpleElectrostatics:
    def __init__(self, protein_pdb, ligand_pdbqt):
        self.original_protein_pdb = protein_pdb
        self.cleaned_protein_pdb = "protein_noH.pdb"
        self.ligand_pdbqt = ligand_pdbqt

        self.remove_hydrogens()
        self.run_pdb2pqr()
        self.create_apbs_input()
        self.run_apbs()
        self.read_map()
        self.read_coords_from_pdbqt()
        self.read_charges()

    def remove_hydrogens(self):
        subprocess.run([
            "obabel",
            self.original_protein_pdb,
            "-O", self.cleaned_protein_pdb,
            "-d"  # delete hydrogens
        ], check=True)

    def run_pdb2pqr(self):
        subprocess.run([
            "pdb2pqr",
            "--ff=AMBER",
            "--with-ph=7.0",
            "--titration-state-method=propka",
            "--drop-water",
            self.cleaned_protein_pdb,
            "protein.pqr"
        ], check=True)

    def create_apbs_input(self):
        with open("apbs.in", "w") as f:
            f.write("""\
read
    mol pqr protein.pqr
end

elec
    mg-auto
    dime 65 65 65
    fglen 80.0 80.0 80.0
    cglen 100.0 100.0 100.0
    fgcent mol 1
    cgcent mol 1
    mol 1
    lpbe
    bcfl sdh
    pdie 2.0
    sdie 78.0
    chgm spl2
    srfm smol
    srad 1.4
    swin 0.3
    sdens 10.0
    temp 310
    calcenergy no
    calcforce no
    write pot dx map
end
""")

    def run_apbs(self):
        subprocess.run(["apbs", "apbs.in"], check=True)

    def read_map(self):
        self.grid = Grid("map.dx")

    def read_coords_from_pdbqt(self):
        self.ligand_x_coords = []
        self.ligand_y_coords = []
        self.ligand_z_coords = []
        conv = ob.OBConversion()
        mol = ob.OBMol()
        conv.SetInFormat("pdbqt")
        conv.ReadFile(mol, self.ligand_pdbqt)
        for atom in ob.OBMolAtomIter(mol):
            self.ligand_x_coords.append(atom.GetX())
            self.ligand_y_coords.append(atom.GetY())
            self.ligand_z_coords.append(atom.GetZ())

    def read_charges(self):
        self.charges = []
        conv = ob.OBConversion()
        mol = ob.OBMol()
        conv.SetInFormat("pdbqt")
        conv.ReadFile(mol, self.ligand_pdbqt)
        for atom in ob.OBMolAtomIter(mol):
            self.charges.append(atom.GetPartialCharge())

    def calculate_result(self):
        g = self.grid.interpolated(
            self.ligand_x_coords,
            self.ligand_y_coords,
            self.ligand_z_coords
        )
        epc = [g[i] * self.charges[i] for i in range(len(g))]
        total_epc = sum(epc) * 96.4853321233  # Convert to kJ/mol
        return total_epc
