import os
from openpyxl import Workbook
import mdtraj as md
from apbs_est import SimpleElectrostatics
import subprocess
import numpy as np
import pandas as pd


def parse_pdbqt_for_flexibility(pdbqt_file):
    """Extract flexibility index from a PDBQT file."""
    torsions, heavy_atoms = None, 0
    with open(pdbqt_file, 'r') as file:
        for line in file:
            if line.startswith("TORSDOF"):
                torsions = int(line.split()[1])
            elif line.startswith("ATOM"):
                if line[77:79].strip() not in ["H", "HD"]:  # Exclude hydrogens
                    heavy_atoms += 1
    if torsions is None:
        raise ValueError(f"TORSDOF not found in {pdbqt_file}")
    flexibility_index = torsions / heavy_atoms if heavy_atoms > 0 else None
    return flexibility_index


def calculate_hbonds(protein_file, ligand_file):
    """
    Runs Smina to calculate the hydrogen bond contribution.
    Extracts the total hydrogen bond contribution and saves nonzero atom contributions.
    
    Returns:
        - Total hydrogen bond contribution (float) or None if failed.
    """
    hbond_terms_file = "temp_hbonds.txt"  # Temporary file for atom-wise H-bond contributions
    log_file = "temp_score.log"

    # Run Smina to extract hydrogen bond contributions
    command = [
        "smina", "-r", protein_file, "-l", ligand_file,
        "--score_only", "--log", log_file, "--atom_terms", hbond_terms_file
    ]

    try:
        result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
        
        if result.returncode != 0:
            print(f"Error running Smina for {ligand_file}: {result.stderr}")
            return None

        # Read log file to extract total hydrogen bond contribution
        hbond_value = None
        with open(log_file, "r") as log:
            extract_values = False
            for line in log:
                if "Term values, before weighting:" in line:
                    extract_values = True
                    continue
                if extract_values:
                    parts = line.strip().split()
                    if len(parts) >= 6:  # Ensure the correct number of terms
                        try:
                            hbond_value = float(parts[5])  # Extract 6th value (index 5)
                        except ValueError:
                            print(f"Skipping invalid H-bond value in {ligand_file}: {parts[5]}")
                    break  # Stop reading after the first term line

        # Extract per-atom hydrogen bond contributions
        nonzero_atoms = []
        with open(hbond_terms_file, "r") as atom_file:
            for line in atom_file:
                parts = line.strip().split()
                if len(parts) >= 6:
                    try:
                        atom_id = parts[0]  # Atom ID
                        atom_name = parts[1]  # Atom Name (e.g., O6, N3)
                        hbond_contribution = float(parts[-1])  # Last column = H-bond contribution
                        if hbond_contribution != 0:
                            nonzero_atoms.append((atom_id, atom_name, hbond_contribution))
                    except ValueError:
                        continue  # Skip invalid entries
        
        # Determine folder and create the output file
        ligand_folder = os.path.dirname(ligand_file)
        output_file = os.path.join(ligand_folder, "hbonds_atoms.txt")

        # Save nonzero hydrogen bond atoms to file
        with open(output_file, "w") as output:
            output.write("Atom_ID Atom_Name Hbond_Contribution\n")
            for atom in nonzero_atoms:
                output.write(f"{atom[0]} {atom[1]} {atom[2]:.6f}\n")

        return hbond_value  # Return the extracted hydrogen bond value

    except Exception as e:
        print(f"Error processing {ligand_file}: {e}")
        return None


def calculate_buried_sasa(protein_path, ligand_path):
    protein = md.load(protein_path)
    ligand = md.load(ligand_path)

    combined_xyz = np.concatenate((protein.xyz, ligand.xyz), axis=1)
    combined_top = protein.topology.join(ligand.topology)
    combined = md.Trajectory(xyz=combined_xyz, topology=combined_top)

    sasa_protein = md.shrake_rupley(protein)
    sasa_ligand = md.shrake_rupley(ligand)
    sasa_combined = md.shrake_rupley(combined)

    buried = (sasa_protein.sum() + sasa_ligand.sum()) - sasa_combined.sum()
    return buried


def calculate_interaction_metrics(protein_file, ligand_file, distance=0.5):
    protein = md.load(protein_file)
    ligand = md.load(ligand_file)

    combined_xyz = np.concatenate((protein.xyz, ligand.xyz), axis=1)
    combined_top = protein.topology.join(ligand.topology)
    combined = md.Trajectory(xyz=combined_xyz, topology=combined_top)

    hydrophobic_residues = {'ALA', 'VAL', 'LEU', 'ILE', 'MET', 'PHE', 'TRP', 'PRO', 'TYR'}

    hydrophobic_idxs = [a.index for a in combined.topology.atoms
                        if a.residue.name in hydrophobic_residues and a.residue.is_protein]
    ligand_idxs = [a.index for a in combined.topology.atoms
                   if not a.residue.is_protein and a.residue.name != 'HOH']

    pairs = np.array([[i, j] for i in hydrophobic_idxs for j in ligand_idxs])
    distances = md.compute_distances(combined, pairs)[0]
    close_pairs = pairs[distances <= distance]
    close_distances = distances[distances <= distance]

    interacting_residues = set()
    for idx in close_pairs[:, 0]:
        atom = combined.topology.atom(idx)
        interacting_residues.add((atom.residue.name, atom.residue.resSeq, atom.residue.chain.index))

    num_residues = len(interacting_residues)
    avg_distance = float(np.mean(close_distances)) if len(close_distances) > 0 else None
    atom_count = len(close_pairs)

    return num_residues, avg_distance, atom_count

def perform_full_analysis(folder1, folder2, output_excel):
    """Perform full analysis, including hydrogen bond metrics, and compare results."""
    wb = Workbook()
    ws = wb.active
    ws.title = "Full Analysis"
    ws.append([
        "File (Folder1)", "Flexibility Index (Folder1)", "Buried Area (Folder1)", 
        "Hydrophobic Residues (Folder1)", "Avg Interaction Distance (Folder1)", "Atom Count (Folder1)",
        "Electrostatic Potential (Folder1)", "Hydrogen Bonds (Folder1)",
        "File (Folder2)", "Flexibility Index (Folder2)", "Buried Area (Folder2)", 
        "Hydrophobic Residues (Folder2)", "Avg Interaction Distance (Folder2)", "Atom Count (Folder2)",
        "Electrostatic Potential (Folder2)", "Hydrogen Bonds (Folder2)",
        "Hydrogen Bonds Difference", "Electrostatic Potential Difference", "Hydrophobic Residues Difference", 
        "Atom Count Difference", "Avg Distance Difference", "Buried Area Difference", "Flexibility Index Difference"
    ])
  
    
    for subdir, _, filenames in os.walk(folder1):
        if "output.pdbqt" in filenames and "protein.pdb" in filenames:
            ligand_file1 = os.path.join(subdir, "output.pdb")
            protein_file1 = os.path.join(subdir, "protein.pdb")
            pdbqt_file1 = os.path.join(subdir, "output.pdbqt")
            output_file = os.path.join(subdir, "hbond_results.txt")
            relative_path = os.path.relpath(subdir, folder1)
            corresponding_subdir = os.path.join(folder2, relative_path)

            if os.path.exists(corresponding_subdir):
                filenames2 = os.listdir(corresponding_subdir)
                if "output.pdbqt" in filenames2 and "protein.pdb" in filenames2:
                    ligand_file2 = os.path.join(corresponding_subdir, "output.pdb")
                    protein_file2 = os.path.join(corresponding_subdir, "protein.pdb")
                    pdbqt_file2 = os.path.join(corresponding_subdir, "output.pdbqt")

                    try:
                        flex1 = parse_pdbqt_for_flexibility(pdbqt_file1)
                        flex2 = parse_pdbqt_for_flexibility(pdbqt_file2)
                        bsa1 = calculate_buried_sasa(protein_file1, ligand_file1)
                        bsa2 = calculate_buried_sasa(protein_file2, ligand_file2)
                        hbonds1 = calculate_hbonds(protein_file1, ligand_file1)
                        hbonds2 = calculate_hbonds(protein_file2, ligand_file2)
                        residues1, dist1, atoms1 = calculate_interaction_metrics(protein_file1, ligand_file1)
                        residues2, dist2, atoms2 = calculate_interaction_metrics(protein_file2, ligand_file2)
                        apbs1 = SimpleElectrostatics(protein_pdb=protein_file1, ligand_pdbqt=pdbqt_file1).calculate_result()
                        apbs2 = SimpleElectrostatics(protein_pdb=protein_file2, ligand_pdbqt=pdbqt_file2).calculate_result()

                        ws.append([
                            pdbqt_file1, flex1, bsa1, residues1, dist1, atoms1, apbs1, hbonds1,
                            pdbqt_file2, flex2, bsa2, residues2, dist2, atoms2, apbs2, hbonds2,
                            hbonds2 - hbonds1,
                            apbs2 - apbs1,
                            residues2 - residues1,
                            atoms2 - atoms1,
                            dist2 - dist1,
                            bsa2 - bsa1,
                            flex2 - flex1
                        ])
                    except Exception as e:
                        print(f"Error processing {subdir}: {e}")

    wb.save(output_excel)
    print(f"Analysis saved to {output_excel}")

# Example Usage
folder1 = "/home/kbaitsi/Ligands2_Vis_Ligands/"
folder2 = "/home/kbaitsi/Ligands2_Vis_Ligands2/"
output_excel_file = "/home/kbaitsi/full_analysis.xlsx"
perform_full_analysis(folder1, folder2, output_excel_file)


