
# coding: utf-8

# In[1]:


import numpy as np
import sys
from scipy.linalg import eigh




def get_ground_exc_dipole_dif(file, state, transition_moment):

    """
    This function exctracts the difference in dipole moment between ground and excited state.
    It then projects the transition_moment betwen the ground and excited state onto this vector 
    and outputs both dipole difference length and the projected transition moment. 
    """
 
    state = int(state)
    dipole_diff_xyz = np.zeros(3)
    with open(file, 'r') as file:
        line = file.readline()
        while line:
            if '@ Transition moment <B | A - <A> | C>' in line:
                line1 = file.readline()
                line2 = file.readline()
                line3 = file.readline()
                if int(line2.split()[7]) == state and int(line3.split()[7]) == state:
                    line = file.readline()
                    line = file.readline()
                    dipole_diff = np.float(line.split()[9].replace('D', 'E'))
                    if line1.split()[6] == 'XDIPLEN':
                        dipole_diff_xyz[0] = dipole_diff
                    elif line1.split()[6] == 'YDIPLEN':
                        dipole_diff_xyz[1] = dipole_diff
                    elif line1.split()[6] == 'ZDIPLEN':
                        dipole_diff_xyz[2] = dipole_diff

                else:
                    pass
            line = file.readline()


    print(dipole_diff_xyz)
    dipole_diff_len = np.sqrt(np.dot(dipole_diff_xyz, dipole_diff_xyz))


    transition_moment_len = 1/dipole_diff_len*abs(np.dot(dipole_diff_xyz, transition_moment))

    return dipole_diff_len, transition_moment_len


def get_nonzero_transmoment_lr(file, state):

    """
    This function exctracts the ground to excited state transition dipole moment vector for the specified state. 
    """

    excitation_energies = []
    transition_moments = []
    symmetries=[]
    with open(file, 'r') as file:
        line = file.readline()
        while line:
            if str(state)+' *TRANSITION MOMENT:' in line:
                transition_moment = float(line.split()[6])
                transition_moments.append(transition_moment)
                excitation_energy = float(line.split()[8])
                excitation_energies.append(excitation_energy)

            line = file.readline()

    sorted_excitation_energies = sorted(excitation_energies)
    index_2nd_lowest_transition = excitation_energies.index(sorted_excitation_energies[1])
    excitation_energies = np.array(excitation_energies)
    index_min_exc_energy = np.argmin(excitation_energies)
    print(excitation_energies[index_min_exc_energy], excitation_energies[index_2nd_lowest_transition])
    excitation_energy = excitation_energies[index_min_exc_energy]

    transition_moment = np.array(transition_moments)
    print(transition_moment)


    return excitation_energy, transition_moment



def rotate_to_diabatic_matrix(dipole_mat, hamiltonian):

    """
    This function diagonalizes the dipole matrix in the adiabatic basis "dipole_mat" 
    giving the rotation matrix from adiabatic to diabatic basis which is used to rotate
    rotate the Hamiltonian into the diabatic basis.
    """

    diabatic_dipole, rot_mat = eigh(dipole_mat)
    print('Difference in diabatic dipole moment =')
    print(diabatic_dipole[1]-diabatic_dipole[0])
    diabatic_hamiltonian = np.dot(rot_mat.T, np.dot(hamiltonian, rot_mat))
    
    return diabatic_dipole, diabatic_hamiltonian
    




file = sys.argv[1]
file2 = sys.argv[2]
state = int(sys.argv[3])

excitation_energy, transition_moment = get_nonzero_transmoment_lr(file, state)

dipole_dif_len, transition_moment_len = get_ground_exc_dipole_dif(file2, state, transition_moment)

dipole_mat = np.zeros((2,2))
hamiltonian = np.zeros((2,2))

hamiltonian[1,1] = excitation_energy

dipole_mat[0,0] = 0.
dipole_mat[1,1] = dipole_dif_len
dipole_mat[0,1] = transition_moment_len
dipole_mat[1,0] = transition_moment_len

print(hamiltonian)
print(dipole_mat)


diabatic_dipole, diabatic_hamiltonian = rotate_to_diabatic_matrix(dipole_mat, hamiltonian)
print('Diabatic Hamiltonian =')
print(diabatic_hamiltonian)
print('Diabatic energy difference = ') 
print(diabatic_hamiltonian[1,1]-diabatic_hamiltonian[0,0])

