#!/usr/bin/env python3
# ##########################################################
# This is a graphical python program - using PySympleGUI
# Rolling up of a sheet placed at the x-z plane
# the rolling is made around x, y or z axis following an
# archimedean spiral 
# Version 1.0 - Oct/26/2023
# Ricardo Paupitz - ricardo.paupitz@unesp.br
# ---------------------------------------------------------- 
# ##########################################################


import PySimpleGUI as sg
import sys
import readline
import numpy as np
import pprint
import os
from ase.build import graphene_nanoribbon
from ase.visualize import view
from ase.io import read,write


#########################################################
#    +++ Numerical functions +++
#########################################################
# calculates the length of the spiral line
# for the given angle 'angle'
# Archimedean spiral: r=b*angle
def spirallength(angle,b):
    r = np.sqrt((angle**2.0)+1.0)
    lnn = np.log(angle + r)
    length = b*0.5*(angle*r + lnn)
    return(length)


###############################################
# when given the length 'length' for the line 
# on the spiral, this function returns the
# corresponding angle
# Archimedean spiral: r=b*angle
def findangle(length,b): 
   x=0
   n=0
   l=0
   dx=0.01
   while (l <= length):
       x = x+dx
       l = spirallength(x,b)
   ######################################
   tol = 0.001
   while (np.abs(l-length) > tol):
       if (l>length):
           dx = 0.5*dx
           x = x-dx
           l = spirallength(x,b)
           dl = np.abs(l-length)
       elif (l<length):
           dx = 0.5*dx
           x = x+dx
           l = spirallength(x,b)
           dl = np.abs(l-length)
   return(x)
   ######################################       
###############################################    

#######################################
# counts the number of lines of a file
def linecount(filename):
    liness = 0
    for lines in open(filename):
        liness += 1
    return liness
    filename.close()


###########################################################
# verify if there are negative values on x,y,z coordinates
# or if the sheet is too close to the origin
# and dislocates it as necessary
def verify(pl,f):
    disl = 5.0 # minimal distance from x=0,y=0 and z=0
    linn = open(f,'r')
    nl = linecount(f)
    listaux = linn.readlines()
    auxname = "%s_aux.xyz" % f 
    out = open(auxname,'w')
    out.write(listaux[0][:])
    out.write(listaux[1][:])
    linn.close()
    out.close()

    elem = []
    posx = []
    posy = []
    posz = []
    numx = 0
    numy = 0
    numz = 0
    coordinates = open(f,'r')
    for line in range(2,nl):
        d = listaux[line][:].split()
        elem.append(d[0])
        posx.append(float(d[1]))
        posy.append(float(d[2]))
        posz.append(float(d[3]))
    minx = min(posx)
    miny = min(posy)
    minz = min(posz)
    maxx = max(posx)
    maxy = max(posy)
    maxz = max(posz)
#    global xwidth
#    global ywidth
#    global zwidth
    xwidth = maxx - minx
    ywidth = maxy - miny
    zwidth = maxz - minz
    nx = len(posx)
    ny = len(posy)
    nz = len(posz)
#    print ("minx: %f ::: miny: %f :::  minz: %f " % (minx,miny,minz))
#    print ("nx: %f ::: ny: %f :::  nz: %f " % (nx,ny,nz))
#    if ((maxy > 0) and (pl == 'xz')):
#        corry = float(0.5*ywidth)
#        while (numy < len(posy)):
#            posy[numy] += corry
#            numy += 1

    if ((minx < 0) and (pl == 'xy')):
        corrx = np.abs(minx)+disl
        while (numx < len(posx)):
            posx[numx] += corrx
            numx += 1
    else:
        pass
    
    if ((miny < 0) and (pl == 'xy')):
        corry = np.abs(miny)+disl
        while (numy < len(posy)):
            posy[numy] += corry
            numy += 1
    else:
        pass

    
    if ((minx < 0) and (pl == 'xz')):
        corrx = np.abs(minx)+disl
        while (numx < len(posx)):
            posx[numx] += corrx
            numx += 1
    else:
        pass
    
    if ((minz < 0) and (pl == 'xz')):
        corrz = np.abs(minz)+disl
        while (numz < len(posz)):
            posz[numz] += corrz
            numz += 1
    else:
        pass


    if ((miny < 0) and (pl == 'yz')):
        corry = np.abs(miny)+disl
        while (numy < len(posy)):
            posy[numy] += corry
            numy += 1
    else:
        pass
    
    if ((minz < 0) and (pl == 'yz')):
        corrz = np.abs(minz)+disl
        while (numz < len(posz)):
            posz[numz] += corrz
            numz += 1
    else:
        pass
    
# copy xyz file to an auxiliary temp file
# this file will be used by the 'map()' function
    if ((minx < 0) or (miny < 0) or (minz < 0)):
        auxname = "%s_aux.xyz" % f 
        aux = open(auxname,'a')
        ll=0
        while (ll < len(posx)):
            c = "%s  %s  %s  %s  \n" % (elem[ll],posx[ll],posy[ll],posz[ll])
            aux.write(c)
            ll += 1
    else:
        nlines = linecount(f)
        with open(f) as f1:
            auxname = "%s_aux.xyz" % f 
            with open(auxname, 'w') as f2:
                Lines = f1.readlines()
                for iline in range(nlines):
                    f2.write(Lines[iline])
            f2.close()
        f1.close()
    print ("dx: %f ::: dy: %f :::  dz: %f " % (xwidth,ywidth,zwidth))
    return xwidth,ywidth,zwidth
###########################################################
# map nanoribbon positions to a spiral configuration
def map(f,plane,axis,b,positions):
    a = []
    aa = []
    clean = "rm -f temp.txt"
    os.system(clean)

    linn = open(f,'r')
    lnn = linecount(f)
    listaux = linn.readlines()
    a.append(listaux[0][:])
    a.append(listaux[1][:])
    linn.close()
    i = 0
    lnumber = linecount(f)-2
    lin = open(f,'r')
    if (plane == 'xy'):
        if (axis == 'x'):
            for line in range(2,lnn):
                sg.one_line_progress_meter('Mapping progress', i+1, lnumber, key = 'progressbar', orientation = 'h',bar_color=('green','gray'))
                d = listaux[line][:].split()
                i += 1
                theta = findangle(float(d[2]),b)
                ###################################################
                # calculating the cartesian components of the unit
                # vector (versor) tangent to the curve pointing
                # to the increasing theta direction
                mod = float(b*np.sqrt(1.0 + (theta*theta)))
                yt = float((b*np.cos(theta) - b*theta*np.sin(theta))/mod)
                zt = float((b*np.sin(theta) + b*theta*np.cos(theta))/mod)
                # now, the perpendicular versor at the same point
                yp =  1.0*zt*float(d[3])  
                zp = -1.0*yt*float(d[3])  
                # now the new position of the atom, corrected  
                y = b*theta*np.cos(theta) + yp
                z = b*theta*np.sin(theta) + zp
                ###############################################################
                c = "%s  %s  %s  %s  \n" % (d[0],d[1],y,z)
                a.append(c)
        elif (axis == 'y'):
            for line in range(2,lnn):
                sg.one_line_progress_meter('Mapping progress', i+1, lnumber, key = 'progressbar', orientation = 'h',bar_color=('green','gray'))
                d = listaux[line][:].split()
                i += 1
                theta = findangle(float(d[1]),b)
                ###################################################
                # calculating the cartesian components of the unit
                # vector (versor) tangent to the curve pointing
                # to the increasing theta direction
                mod = float(b*np.sqrt(1.0 + (theta*theta)))
                xt = float((b*np.cos(theta) - b*theta*np.sin(theta))/mod)
                zt = float((b*np.sin(theta) + b*theta*np.cos(theta))/mod)
                # now, the perpendicular versor at the same point
                xp =  1.0*zt*float(d[3])  
                zp = -1.0*xt*float(d[3])  
                # now the new position of the atom, corrected  
                x = b*theta*np.cos(theta) + xp
                z = b*theta*np.sin(theta) + zp
                ###############################################################
                c = "%s  %s  %s  %s  \n" % (d[0],x,d[2],z)
                a.append(c)
    elif (plane == 'xz'):
        if (axis == 'x'):
            for line in range(2,lnn):
                sg.one_line_progress_meter('Mapping progress', i+1, lnumber, key = 'progressbar', orientation = 'h',bar_color=('green','gray'))
                d = listaux[line][:].split()
                i += 1
                theta = findangle(float(d[3]),b)
                ###################################################
                # calculating the cartesian components of the unit
                # vector (versor) tangent to the curve pointing
                # to the increasing theta direction
                mod = float(b*np.sqrt(1.0 + (theta*theta)))
                yt = float((b*np.cos(theta) - b*theta*np.sin(theta))/mod)
                zt = float((b*np.sin(theta) + b*theta*np.cos(theta))/mod)
                # now, the perpendicular versor at the same point
                yp =  1.0*zt*float(d[2])  
                zp = -1.0*yt*float(d[2])  
                # now the new position of the atom, corrected  
                y = b*theta*np.cos(theta) + yp
                z = b*theta*np.sin(theta) + zp
                c = "%s  %s  %s  %s  \n" % (d[0],d[1],y,z)
                a.append(c)
        elif (axis == 'z'):
            for line in range(2,lnn):
                sg.one_line_progress_meter('Mapping progress', i+1, lnumber, key = 'progressbar', orientation = 'h',bar_color=('green','gray'))
                d = listaux[line][:].split()
                i += 1       
                theta = findangle(float(d[1]),b)
                ###################################################
                # calculating the cartesian components of the unit
                # vector (versor) tangent to the curve pointing
                # to the increasing theta direction
                mod = float(b*np.sqrt(1.0 + (theta*theta)))
                xt = float((b*np.cos(theta) - b*theta*np.sin(theta))/mod)
                yt = float((b*np.sin(theta) + b*theta*np.cos(theta))/mod)
                # now, the perpendicular versor at the same point
                xp =  1.0*yt*float(d[2])
                yp = -1.0*xt*float(d[2])
                #print('mod_perpend: ',float(np.sqrt(xp*xp+yp*yp)))
                # now the new position of the atom, corrected  
                x = b*theta*np.cos(theta) + xp
                y = b*theta*np.sin(theta) + yp
                c = "%s  %s  %s  %s  \n" % (d[0],x,y,d[3])
                a.append(c)
    elif (plane == 'yz'):
        if (axis == 'y'):
            for line in range(2,lnn):
                sg.one_line_progress_meter('Mapping progress', i+1, lnumber, key = 'progressbar', orientation = 'h',bar_color=('green','gray'))
                d = listaux[line][:].split()
                i += 1
                theta = findangle(float(d[3]),b)
                ###################################################
                # calculating the cartesian components of the unit
                # vector (versor) tangent to the curve pointing
                # to the increasing theta direction
                mod = float(b*np.sqrt(1.0 + (theta*theta)))
                xt = float((b*np.cos(theta) - b*theta*np.sin(theta))/mod)
                zt = float((b*np.sin(theta) + b*theta*np.cos(theta))/mod)
                # now, the perpendicular versor at the same point
                xp =  1.0*zt*float(d[1])
                zp = -1.0*xt*float(d[1])
                #print('mod_perpend: ',float(np.sqrt(xp*xp+yp*yp)))
                # now the new position of the atom, corrected  
                x = b*theta*np.cos(theta) + xp
                z = b*theta*np.sin(theta) + zp                
                c = "%s  %s  %s  %s  \n" % (d[0],x,d[2],z)
                a.append(c)
        elif (axis == 'z'):
            for line in range(2,lnn):
                sg.one_line_progress_meter('Mapping progress', i+1, lnumber, key = 'progressbar', orientation = 'h',bar_color=('green','gray'))
                d = listaux[line][:].split()
                i += 1
                theta = findangle(float(d[2]),b)
                ###################################################
                # calculating the cartesian components of the unit
                # vector (versor) tangent to the curve pointing
                # to the increasing theta direction
                mod = float(b*np.sqrt(1.0 + (theta*theta)))
                xt = float((b*np.cos(theta) - b*theta*np.sin(theta))/mod)
                yt = float((b*np.sin(theta) + b*theta*np.cos(theta))/mod)
                # now, the perpendicular versor at the same point
                xp =  1.0*yt*float(d[1])
                yp = -1.0*xt*float(d[1])
                #print('mod_perpend: ',float(np.sqrt(xp*xp+yp*yp)))
                # now the new position of the atom, corrected  
                x = b*theta*np.cos(theta) + xp
                y = b*theta*np.sin(theta) + yp                
                c = "%s  %s  %s  %s  \n" % (d[0],x,y,d[3])
                a.append(c)
    else:
        print("unknown option")
    lin.close()
    ll = open(positions,'w')
    for text in a:
        ll.write(text)
    ll.close()


# +++ End of the numerical functions' definitions  +++
#########################################################

#zwidth = 0.0

#--------------------------------------------------------
# +++ GUI definitions and behavior  +++
sg.theme('DarkBlue 3')  # please make your windows colorful

#menu_def = ['TESTE', ['&Open', '---', '&Save', ['1', '2', ['a', 'b']], '&Properties', 'E&xit']]
menu_def = [['Edit', ['Paste', ['Special', 'Normal',], 'Undo'],],
['Help', 'About...'],]

layout = [[sg.Menu(menu_def)],
          [sg.Text('Orientation of the sheet', size=(25, 1))],
#          [sg.OptionMenu(values=('Plane X-Y',
#                                 'Plane X-Z',
#                                 'Plane Y-Z'), key='plane')],
          [sg.Radio('X-Y plane', group_id= 'P', default=True, size=(10,1), key = 'xy', enable_events=True),
           sg.Radio('X-Z plane', group_id= 'P', size=(10, 1), key = 'xz', enable_events=True),
           sg.Radio('Y-Z plane', group_id= 'P', size=(10, 1), key = 'yz', enable_events=True), 
          ],
          [sg.Text('Choose the rolling axis', size=(25, 1))],
          [sg.Radio('X axis ', group_id= 'R', default=True, size=(10,1), key = 'x', enable_events=True),
           sg.Radio('Y axis ', group_id= 'R', size=(10, 1), key = 'y', enable_events=True),
           sg.Radio('Z axis ', group_id= 'R', size=(10, 1), key = 'z', enable_events=True),
          ],
          [sg.Text('Define separation between turns')],
          [sg.Slider(range=(0,15), orientation='h', size=(20,20), default_value=3.5, resolution=0.1,
                     enable_events=True, key='slide')],
#          [sg.Text('Initial structure:', size=(25, 1)),
#           sg.FileBrowse(enable_events=True)],
#          [sg.Text('Lets Roll!'),sg.Button('Roll')],
          [sg.Text('Filename')], [sg.Input(key = 'wff',enable_events=False), sg.FileBrowse(key = 'ff', enable_events=True)], [sg.Button('Roll'), sg.Button('Exit')] 
]

window = sg.FlexForm("Change  Values", default_element_size=(12,1), text_justification='r', auto_size_text=False, auto_size_buttons=False,
                   default_button_element_size=(12,1))
window.Layout(layout)
while True:  # Event Loop
    event, values = window.Read()
#    print(event, values)
#    print('--- file: ',values['ff'])
#    print('+++ wfile: ',values['wff'])
    
    if event == sg.WIN_CLOSED or event == 'Exit':
        break
#    if event == 'Show':
        # change the "output" element to be the value of "input" element
#        window['-OUTPUT-'].update(values['-IN-'])
    if values['xy'] == True:
        plane = 'xy'
        window.FindElement('x').Update(disabled=False)
        window.FindElement('y').Update(disabled=False)
        window.FindElement('z').Update(disabled=True)
        if (values['x'] == True) and (values['y'] == False):
#            print("Sheet at XY plane will roll over X axis")
            axis = 'x'
        elif (values['x'] == False) and (values['y'] == True):
#            print ("Sheet at XY plane will roll over Y axis")
            axis = 'y'
    if values['xz'] == True:
        plane = 'xz'
        window.FindElement('x').Update(disabled=False)
        window.FindElement('y').Update(disabled=True)
        window.FindElement('z').Update(disabled=False)
        if (values['x'] == True) and (values['z'] == False):
#           print("Sheet at XZ plane will roll over X axis")
            axis = 'x'
        elif (values['x'] == False) and (values['z'] == True):
#            print ("Sheet at XZ plane will roll over Z axis")
            axis = 'z'
    if values['yz'] == True:
        plane = 'yz'
        window.FindElement('x').Update(disabled=True)
        window.FindElement('y').Update(disabled=False)
        window.FindElement('z').Update(disabled=False)
        if (values['y'] == True) and (values['z'] == False):
#            print("Sheet at YZ plane will roll over Y axis")
            axis = 'y'
        elif (values['y'] == False) and (values['z'] == True):
#            print ("Sheet at YZ plane will roll over Z axis")
            axis = 'z'
    if event == 'Roll':
#       l = spirallength(6.2,3.0)
#       print('l: ',l)
        input_file = values['wff']       
#        name_aux = verify(plane,input_file)
        dx,dy,dz = verify(plane,input_file)
#       plane_sheet = input("Plane of the sheet: ")
#       roll_direction = input("Rolling axis (x, y or z): ")
        separation = values['slide']
        dist = float(separation)/(2*np.pi)
        output_file = 'output.xyz'
        #--- defining the name of new input file
        input_aux = "%s_aux.xyz" % input_file
        # name_aux = verify(input_file)
        if plane == 'xy':
            if separation <= dz:
                if dz <= 3.5:
                    separation = 3.5
                else:
                    separation = dz + 0.5
            elif separation > dz:
                if separation <= 3.5:
                    separation = 3.5
                else:
                    separation = separation + 0.5
                
        elif plane == 'xz':
            if separation <= dy:
                if dy < 3.5:
                    separation = 3.5
                else:
                    separation = dy + 0.5
            elif separation > dy:
                if separation <= 3.5:
                    separation = 3.5
                else:
                    separation = separation + 0.5
                
        elif plane == 'yz':
            if separation <= dx:
                if dx < 3.5:
                    separation = 3.5
                else:
                    separation = dx + 0.5
            elif separation > dx:
                if separation <= 3.5:
                    separation = 3.5
                else:
                    separation = separation + 0.5
        dist = float(separation)/(2*np.pi)
        map(input_aux,plane,axis,dist,output_file)
        clean = "rm -f %s" % input_aux
        os.system(clean)
window.close()

 
