from rdkit import Chem
import csv
from rdkit.ML.Descriptors import MoleculeDescriptors
from rdkit.Chem import Descriptors
import numpy as np
from sklearn import svm,preprocessing

#Generate the standard list of descriptors for each molecule in the file, then modify to uniquely identify important functional groups for Hunter's rule matching
def Generate_descriptors_from_file(filename):
    desc_names = Get_desc_names_for_descriptor_calculator()
    calc = MoleculeDescriptors.MolecularDescriptorCalculator(desc_names)
    names = []
    smiles = []
    with open(filename, 'r') as infile:
        csvreader = csv.reader(infile)
        for line in csvreader:
            names.append(line[0])
            smiles.append(line[1])

    #Generate molecular descriptors for each molecule in the Coformer list
    mols = [Chem.MolFromSmiles(x) for x in smiles]
    descs = []
    for mol in mols:
        desc = calc.CalcDescriptors(mol)
        desc = Modify_desc(desc, mol)
        descs.append(desc)

    #Create an array of the descriptors
    descs = np.array(descs)
    return descs, names, smiles

def Generate_labels_from_file(label_file):
    labels_from_file = []
    with open(label_file, 'r') as infile:
        csvreader = csv.reader(infile)
        for line in csvreader:
            numbers = [int(x) for x in line]
            labels_from_file.append(numbers)

    labels_from_file = np.array(labels_from_file)
    labels = labels_from_file.transpose()
    return labels

def Get_desc_names_for_descriptor_calculator():
    desc_names = [x[0] for x in Descriptors._descList]
    desc_names.remove('MinPartialCharge')
    desc_names.remove('MaxPartialCharge')
    desc_names.remove('MinAbsPartialCharge')
    desc_names.remove('MaxAbsPartialCharge')
    desc_names.remove('Ipc')
    return desc_names

def Get_desc_names_for_Hunter():
    desc_names = Get_desc_names_for_descriptor_calculator()
    desc_names.append('fr_Ar_F')
    desc_names.append('fr_Ar_Cl')
    desc_names.append('fr_Ar_X')
    desc_names.append('fr_tert_amide')
    return desc_names

#Load Hunter's Table from file
def GetHuntersTable(hunters_filename):
    Hunters_Table = {}
    with open(hunters_filename, 'r') as infile:
        csvreader = csv.reader(infile)
        NIn = 0
        for line in csvreader:
            if NIn !=0:
                key = line[0]
                Hunters_Table[key]= [line[1], line[2]]
            else:
                NIn += 1
    return Hunters_Table

#Uniquely identify functional groups
def numtertamine(mol):
    patt = Chem.MolFromSmarts('[NX3;H0;!$(NC=O);!$(Nc)]([#6])([#6])([#6])')
    matches = mol.GetSubstructMatches(patt)
    return len(matches)

def numsecondaryamine(mol):
    patt = Chem.MolFromSmarts('[NX3;H1;!$(NC=O);!$(Nc)]([#6])([#6])')
    matches = mol.GetSubstructMatches(patt)
    return len(matches)

def numprimaryamine(mol):
    patt = Chem.MolFromSmarts('[NX3;H2;!$(NC=O);!$(Nc)]([#6])')
    matches = mol.GetSubstructMatches(patt)
    return len(matches)

def numamide(mol):
    #primary or secondary amide, not urea
    patt = Chem.MolFromSmarts('O=C([#7;H1,H2])([!#7])')
    matches = mol.GetSubstructMatches(patt)
    return len(matches)

def numtertamide(mol):
    patt = Chem.MolFromSmarts('O=C([#7;H0])([!#7])')
    matches = mol.GetSubstructMatches(patt)
    return len(matches)

def numarylfluoride(mol):
    patt = Chem.MolFromSmarts('[cX3][F]')
    matches = mol.GetSubstructMatches(patt)
    return len(matches)

def numarylchloride(mol):
    patt = Chem.MolFromSmarts('[cX3][Cl]')
    matches = mol.GetSubstructMatches(patt)
    return len(matches)

def numether(mol):
    patt = Chem.MolFromSmarts('[$([OD2]([#6])[#6]);!$([OD2]([#6]=O)[#6])]')
    matches = mol.GetSubstructMatches(patt)
    return len(matches)

def numarylhalide(mol):
    patt = Chem.MolFromSmarts('[cX3][Br,I]')
    matches = mol.GetSubstructMatches(patt)
    return len(matches)


#Modify descriptor list to uniquely identify functional groups
def Modify_desc(desc, mol):
    desc_names = Get_desc_names_for_descriptor_calculator()
    fr_tertamine = numtertamine(mol)
    fr_secondary_amine = numsecondaryamine(mol)
    fr_primary_amine = numprimaryamine(mol)
    fr_amide = numamide(mol)
    fr_Ar_F = numarylfluoride(mol)
    fr_Ar_Cl = numarylchloride(mol)
    fr_Ar_X = numarylhalide(mol)
    fr_tert_amide = numtertamide(mol)
    fr_ether = numether(mol)
    desc = np.array(desc)
    desc[desc_names.index('fr_NH0')] = fr_tertamine
    desc[desc_names.index('fr_NH1')] = fr_secondary_amine
    desc[desc_names.index('fr_NH2')] = fr_primary_amine
    desc[desc_names.index('fr_amide')] = fr_amide
    desc[desc_names.index('fr_ether')] = fr_ether
    desc = np.append(desc, fr_Ar_F)
    desc = np.append(desc, fr_Ar_Cl)
    desc = np.append(desc, fr_Ar_X)
    desc = np.append(desc, fr_tert_amide)
    return desc

# Compare best donor/acceptor relationships
def CalcHunterPairs(acceptor_list, donor_list):

    overall_desc = 0
    acceptor_list_sorted = sorted(acceptor_list, reverse=True)
    donor_list_sorted = sorted(donor_list, reverse=True)
    number_to_check = min(len(acceptor_list_sorted), len(donor_list_sorted))

    for x in xrange(int(number_to_check)):
        desc = acceptor_list_sorted[x] * donor_list_sorted[x]
        overall_desc += desc
    return overall_desc

def Generate_hunters_descs(API_descs, coformer_descs, API_names, coformer_names, hunters_filename):
    Hunters_Table = GetHuntersTable(hunters_filename)
    desc_names = Get_desc_names_for_Hunter()
    cocrystal_descrs = []

    #Cycle through the coformers for each acid and combine the coformer descriptor with the acid descriptor, and calculate the value from Hunter's rules, append this to the end.

    for index1,API_desc in enumerate(API_descs):
        column_descr = []
        API_donor = []
        API_acceptor = []

        for key in Hunters_Table.keys():
            index = desc_names.index(key)
            descriptor = API_desc[index]

            #If a descriptor is present in Hunter's list of functional groups, add the donor and/or acceptor value to the appropriate list

            for x in xrange(int(descriptor)):
                if float(Hunters_Table[key][0]) > 0.001:
                    API_donor.append(float(Hunters_Table[key][0]))
                if float(Hunters_Table[key][1]) > 0.001:
                    API_acceptor.append(float(Hunters_Table[key][1]))

        for index2, coformer_desc in enumerate(coformer_descs):

            #Manually input donor and acceptor values for those where matching is difficult

            #assume caffeine has two tertiary amides and a pyridine-like acceptor (imidazole), no donors
            if coformer_names[index2]=='Caffeine':
                coformer_donor = []
                coformer_acceptor = [8.3,8.3,7.0]

            # assume theophylline has two tertiary amides and a pyridine-like acceptor (imidazole), imidazole donor
            elif coformer_names[index2]=='Theophylline':
                coformer_donor = [3.7]
                coformer_acceptor = [8.3,8.3,7.0]

            # assume theobromine has one tertiary amide, one secondary amide and a pyridine-like acceptor (imidazole), secondary amide donor
            elif coformer_names[index2]=='Theobromine':
                coformer_donor = [2.9]
                coformer_acceptor = [8.3,8.3,7.0]

            #assume riboflavin has 4 alcohols, one primary amide, one tertiary amide, one tertiary amine and a pyridine acceptor
            elif coformer_names[index2]=='Riboflavin':
                coformer_donor = [2.7,2.7,2.7,2.7,2.9]
                coformer_acceptor = [5.8,5.8,5.8,5.8,8.3,8.3,7.0,7.8]

            #assume pyrazine has two pyridine-like acceptors with half the strength
            elif coformer_names[index2] == 'Pyrazine':
                coformer_donor = []
                coformer_acceptor = [3.5,3.5]

            #assume phenazine has two pyridine-like acceptors with half the strength
            elif coformer_names[index2] == 'Phenazine':
                coformer_donor = []
                coformer_acceptor = [3.5,3.5]

            #assume melamine has three anilines and three pyridine-like acceptors with the same strength (balance of other Ns withdrawing and anilines donating
            elif coformer_names[index2] == 'Melamine':
                coformer_donor = [2.1,2.1]
                coformer_acceptor = [7.0, 7.0, 7.0, 5.3, 5.3]

            else:
                coformer_donor = []
                coformer_acceptor = []
                for key in Hunters_Table.keys():
                    index = desc_names.index(key)

                    # If a descriptor is present in Hunter's list of functional groups, add the donor and/or acceptor value to the appropriate list

                    descriptor = coformer_desc[index]
                    if descriptor != 0:
                        for x in xrange(int(descriptor)):
                            if float(Hunters_Table[key][0]) > 0.001:
                                coformer_donor.append(float(Hunters_Table[key][0]))
                            if float(Hunters_Table[key][1]) > 0.001:
                                coformer_acceptor.append(float(Hunters_Table[key][1]))

            # Compare best donor/acceptor relationships, self vs cocrystallisation. Multiply best donor by best acceptor for self and combinations, subtract one from the other
            coformer = CalcHunterPairs(coformer_acceptor, coformer_donor)
            API = CalcHunterPairs(API_acceptor, API_donor)
            intra = coformer + API

            coformer_acceptor_acid_donor = CalcHunterPairs(coformer_acceptor, API_donor)
            coformer_donor_acid_acceptor = CalcHunterPairs(API_acceptor, coformer_donor)
            inter = coformer_acceptor_acid_donor+coformer_donor_acid_acceptor

            new_hunters_desc = inter - intra

            combined_descr = API_desc
            combined_descr = np.append(combined_descr, coformer_desc)

            combined_descr = np.append(combined_descr, new_hunters_desc)
            column_descr.append(combined_descr)
        cocrystal_descrs.append(column_descr)

    cocrystal_descrs = np.array(cocrystal_descrs)
    return cocrystal_descrs

def Run_SVM_on_external_test(cocrystal_descrs, test_descrs, original_cocrystal_labels):
    number_of_columns = len(original_cocrystal_labels)
    column_indices = np.arange(number_of_columns)
    train_descrs = []
    train_labels = []

    #For each column of APIs, add the descriptors to the training set of descriptors if the outcome of the experiment is known (not 2)
    for x in column_indices:
        for y in xrange(len(original_cocrystal_labels[x])):
            if original_cocrystal_labels[x, y] != 2:
                train_descrs.append(cocrystal_descrs[x, y])
                train_labels.append(original_cocrystal_labels[x, y])

    #Scale the descriptors before passing to the algorithm
    scaler = preprocessing.StandardScaler().fit(train_descrs)
    train_scaled = scaler.transform(train_descrs)
    test_scaled = scaler.transform(test_descrs)

    #Create and train the algorithm
    SVM_classifier = svm.SVC(gamma=0.001, C=10, probability=True, class_weight='balanced')
    SVM_classifier = SVM_classifier.fit(train_scaled, train_labels)

    #Predict on the external test set, obtain list of predictions and a list of probabilities
    SVM_predictions = SVM_classifier.predict(test_scaled)
    SVM_probs = SVM_classifier.predict_proba(test_scaled)

    #The second part of the probability is the probability of success
    SVM_probs = [x[1] for x in SVM_probs]

    return SVM_predictions, SVM_probs


hunters_filename = 'Hunters Table improved.csv'

#Files to be used for training defined below
API_smiles_file = 'Acid and Amide Smiles protonated.csv'
coformer_smiles_file = 'Coformer Set 1 and 2 Smiles.csv'
main_training_file = 'Main Data final.csv'

API_descs, API_names, API_smiles = Generate_descriptors_from_file(API_smiles_file)
coformer_descs, coformer_names, coformer_smiles = Generate_descriptors_from_file(coformer_smiles_file)
cocrystal_descrs = Generate_hunters_descs(API_descs, coformer_descs, API_names, coformer_names, hunters_filename)
train_labels = Generate_labels_from_file(main_training_file)


#Files to be used for test defined below
test_API_smiles_file = 'Paracetamol Smiles protonated.csv'
test_coformer_smiles_file = 'Paracetamol Coformer Smiles.csv'

test_API_descs, test_API_names, test_API_smiles = Generate_descriptors_from_file(test_API_smiles_file)
test_coformer_descs, test_coformer_names, test_coformer_smiles = Generate_descriptors_from_file(test_coformer_smiles_file)
test_cocrystal_descrs = Generate_hunters_descs(test_API_descs, test_coformer_descs, test_API_names, test_coformer_names, hunters_filename)

#For each column of test APIs, add the descriptors and labels to a flattened list in the format for prediction using the SVM model
test_descrs_flattened = []
test_cocrystal_names_flattened = []
for x, column in enumerate(test_cocrystal_descrs):
    for y, desc in enumerate(column):
        test_descrs_flattened.append(desc)
        cocrystal_name = '{0}_{1}'.format(test_API_names[x], test_coformer_names[y])
        test_cocrystal_names_flattened.append(cocrystal_name)

test_descrs_flattened=np.array(test_descrs_flattened)

predictions, probs = Run_SVM_on_external_test(cocrystal_descrs, test_descrs_flattened, train_labels)

for x, prob in enumerate(probs):
    print test_cocrystal_names_flattened[x], prob
