viewlast=True
import networkx as nx
import matplotlib.pyplot as plt
import pickle
from rdkit import Chem
from rdkit.Chem.Draw import IPythonConsole
IPythonConsole.ipython_useSVG = False
from PIL import Image, ImageDraw, ImageFont
import os
import json
import shutil
showPLOTS=True

def count_files(file_extension="path.pk"):
    files = [f for f in os.listdir('.') if f.endswith(file_extension) and f[0].isdigit() ]
    return len(files)

def add_edges_from_dict(graph, edges_dict,color_edges):
    for source, neighbors in edges_dict.items():
        for target, weight in neighbors.items():
            if not graph.has_node(source) :
                graph.add_node(source)
            if not graph.has_node(target):
                graph.add_node(target)
            if not graph.has_edge(source, target):
                graph.add_edge(source, target, weight=weight,color=color_edges)
            else :
                print(source, target)

def makeGplot(graph,path,goal,node_prop):
    plt.figure()
    G = nx.DiGraph()
    edge_labels = {}
    node_label={}
    node_colors={}

    for node in graph:
        G.add_node(node)
        node_colors[node] ='red'
        try :
            if node_prop[node] <100000000 :
                node_label[node]=str(float("%.3f" %  node_prop[node] ))
            else :
                node_label[node]=""    
        except :
            node_label[node]=""

    node_colors[goal] = 'green'
    node_colors[path[0]] = 'gray'

    add_edges_from_dict(G,graph,color_edges="blue")                 

    for ist in range(1,len(path)):
        try :
            w = graph[path[ist-1]][path[ist]]
            G.add_edge(path[ist-1], path[ist], weight=w,color="red")
            edge_labels[(path[ist-1], path[ist])] = float("%.3f" %graph[path[ist-1]][path[ist]])
        except :
            if path[ist-1] in graph[path[ist]].keys() :
                w = graph[path[ist]][path[ist-1]]
                G.add_edge(path[ist-1], path[ist], weight=w,color="red")
                edge_labels[(path[ist-1], path[ist])] = float("%.3f" %graph[path[ist]][path[ist-1]])                
            pass
        # node_colors[neighbor] = 'orange'
    edge_colors = [G[u][v]['color'] for u, v in G.edges()]
    # .set_alpha
    try :
        pos = nx.planar_layout(G)
        # pos = nx.spring_layout(G)
        # pos = nx.kamada_kawai_layout(G)
        # pos = nx.spectral_layout(self.G)
        # pos = nx.fruchterman_reingold_layout(G)
        nx.draw(G, pos, with_labels=False, arrows=True)
        nx.draw_networkx_nodes(G, pos, node_color=node_colors.values())
        arcs=nx.draw_networkx_edges(G, pos, edge_color=edge_colors)
        for i, arc in enumerate(arcs):
            arc.set_alpha(1.)
        nx.draw_networkx_edge_labels(G, pos, edge_labels=edge_labels)
        nx.draw_networkx_labels(G, pos,labels=node_label)
        try :
           if showPLOTS:  plt.show()
        except :
           pass 
#          plt.savefig("graph.png")
    except :
        pass

def save_obj(filename,obj):
    with open(filename, 'wb') as f:
        pickle.dump(obj, f)

def load_obj(filename):
    with open(filename, 'rb') as f:
        obj = pickle.load(f)
    return obj

def drawpath(path,fechiupac=False):
    figmols=[]
    legim=[]
    for ismiles in path :    
        mol =  Chem.MolFromSmiles(ismiles)
        figmols.append(mol)
        if fechiupac :
            try :
                upacna = iupacname(ismiles)
            except :
                upacna = ismiles
        else : upacna = ismiles
        legim.append(upacna)
    molsPerRow=4
    img = Chem.Draw.MolsToGridImage((figmols), molsPerRow=molsPerRow, subImgSize=(200, 200),returnPNG=False)
    img.save("outputmol.png")
# Create a drawing context
    try :
       draw = ImageDraw.Draw(img)
    except :
       pass

# Define the font for the legend text
    font = ImageFont.load_default()  # You can also specify a custom font if needed

# Add legends to the image
    for i, legend_text in enumerate(legim):
        x = (i % molsPerRow) * 200  # Adjust the position based on the grid layout
        y = (i // molsPerRow) * 200  # Adjust the position based on the grid layout
        draw.text((x, y), legend_text, fill="black", font=font)

    # Display the image
    try :
       img.show()

    
       plt.figure()
       plt.imshow(img)
       plt.axis('off')  # Turn off axis
       if showPLOTS: plt.show()
    except :
       pass

def plot_path_LIST(Lpath=[],node_prop={}):
    plt.figure()
    for ipath in Lpath :
        props=[]
        for ist in ipath :
            try :
                pro1 = node_prop[ist]
                props.append(pro1)
            except :
                print( node_prop.keys())
    #        pro =myChemS.calculate_properties(ist)
    #        print("Path Free Energy",pro," ",pro1)
        plt.plot(props)
    plt.savefig("pathoutall.png")
    if showPLOTS : plt.show()



def plot_path_val(path=[],node_prop={}):
    plt.figure()
    props=[]
    for ist in path :
        try :
            pro1 = node_prop[ist]
            props.append(pro1)
        except :
            print( node_prop.keys())
#        pro =myChemS.calculate_properties(ist)
#        print("Path Free Energy",pro," ",pro1)
    plt.plot(props)
    plt.savefig("pathout.png")
    if showPLOTS : plt.show()

def Viewlast(irun="",showg=True):      
    global showPLOTS
    showPLOTS=True 
    oldgraph1=load_obj(str(irun)+"outp_graph.pk")
    minv=10e10
    oldpath1 = load_obj(str(irun)+"outp_path.pk")
    oldnode_prop1=load_obj(str(irun)+"outp_node.pk")
    for k in oldnode_prop1.keys() :
        print(k,oldnode_prop1[k])
        # minv=min(minv,oldnode_prop1[k])
    # print(minv)
    if showg :     
        drawpath(oldpath1)
        plot_path_val(path=oldpath1,node_prop=oldnode_prop1)
    try :
        oldG1=load_obj("outp_G.pk")
        # path2 = nx.shortest_path(oldG1, source=oldpath1[0], target=oldpath1[-2])
        # print(path2)
    except :
        pass
    shutil.copy(str(irun)+"outp_G.pk" , "outp_G.pk")
    shutil.copy(str(irun)+"outp_path.pk" , "outp_path.pk")
    shutil.copy(str(irun)+"outp_node.pk" , "outp_node.pk")
    if showg :     
        makeGplot(graph=oldgraph1,path=oldpath1,goal=oldpath1[-1],node_prop=oldnode_prop1)

    return oldgraph1 , oldpath1, oldnode_prop1  

def Viewall(nstates,Drawmore=True):      
    global showPLOTS
    showPLOTS=True 
    S_paths=[]
    LS_paths=[]
    S_G=nx.DiGraph()
    S_graph={}
    global S_node_prop
    S_node_prop={}
    for ibase in range(1,nstates+1):
        path=load_obj(str(ibase)+"outp_path.pk")
        graph=load_obj(str(ibase)+"outp_graph.pk")
        try :
            G=load_obj(str(ibase)+"outp_G.pk")
            node_prop=load_obj(str(ibase)+"outp_node.pk")
        except :
            pass
        S_G=nx.compose(S_G,G)
        S_paths.extend(path)
        LS_paths.append(path)
        S_graph.update(graph)
        S_node_prop.update(node_prop)
    save_obj("outp_path.pk",S_paths)
    save_obj("outp_graph.pk",S_graph)
    save_obj("outp_G.pk",S_G)
    save_obj("outp_node.pk",S_node_prop)

    
    oldgraph1=S_graph
    minv=10e10
    oldpath1 = S_paths
    oldnode_prop1=S_node_prop
    for k in oldnode_prop1.keys() :
        print(k,oldnode_prop1[k])
        # if oldnode_prop1[k] < minv : kmin=k
        if type(oldnode_prop1[k])==type([]) :
            if abs(oldnode_prop1[k][0]-290. )< minv : kmin=k
            minv=min(minv,abs(oldnode_prop1[k][0]-290.))
        else :
            if abs(oldnode_prop1[k]-290. )< minv : kmin=k            
            minv=min(minv,abs(oldnode_prop1[k]-290.))
        # minv=min(minv,oldnode_prop1[k])

    print(minv)
    if Drawmore : 
        drawpath(oldpath1)
        plot_path_val(path=oldpath1,node_prop=oldnode_prop1)
        
        plot_path_LIST(Lpath=LS_paths,node_prop=oldnode_prop1)
        
        try :
            oldG1=load_obj("outp_G.pk")
        except :
            pass
        makeGplot(graph=oldgraph1,path=oldpath1,goal=kmin,node_prop=oldnode_prop1)
    return oldgraph1 , oldpath1, oldnode_prop1  

def viewparam(fnam="variables.json"):
    with open(fnam, "r") as file:
        variables = json.load(file)
    print("Variables loaded successfully:")
    print(variables)    

if __name__ =="__main__":
    
    #Viewlast(1)
    Viewall(count_files())
    try :
        with open("variables.json", "r") as file:
            variables = json.load(file)
        print("Variables loaded successfully:")
        print(variables)
    except:
        pass
