
# coding: utf-8


import numpy as np
import sys
from scipy.linalg import eigh



def getnumberofexcitedstates(file):

    """
    Extracts the number of excited states included in the calculation
    """

    N_excited_states = 0
    with open(file, 'r') as file:
        line = file.readline()
        while N_excited_states == 0:
            if '** CALCULATION OF TRANSITION MOMENTS BETWEEN EXCITED STATES **' in line:
                line = file.readline()
                N_excited_states = int(file.readline().split()[0])
            line = file.readline()

    return N_excited_states



def getexcitationenergies(file, N_excited_states, state_i, state_j):

    """
    Extract the excitations energies in the adiabatic basis for the two states of interest
    """

    hartree_to_eV = 27.212 #eV/hartree

    excitation_energies = []
    with open(file, 'r') as file:
        line = file.readline()
        while len(excitation_energies) < N_excited_states:
            if '@ Singlet excitation energy' in line:
                excitation_energy = np.float(line.split()[4].replace('D', 'E'))
                excitation_energy *= hartree_to_eV
                excitation_energies.append(excitation_energy)
            line = file.readline()
            
    #make 2 level hamiltonian
    hamiltonian = np.zeros((2,2))
    hamiltonian[0,0] = excitation_energies[state_i-1]
    hamiltonian[1,1] = excitation_energies[state_j-1]
    print(excitation_energies)

    return hamiltonian


def gettransitionmoments(file, state_i, state_j):
    """
    Extracts the transition dipole moments to create the dipole matrix
    in the basis of the two excited states of interest. The matrix is then projected 
    onto the central charge transfer direction defined by the difference vector of the dipole moment
    of the two excited states. 
    """ 

    if state_i < state_j:
        pass
    else:
        print('state_1 should be less than state_2')

    transition_moments = []


    with open(file, 'r') as file:
        line = file.readline()
        while line:
            if '@ Transition moment <B | A - <A> | C>' in line:
                line = file.readline()
                line = file.readline()
                line2 = file.readline()
                if int(line.split()[7]) == state_i and int(line2.split()[7]) == state_i:
                    line = file.readline()
                    line = file.readline()
                    transition_moment = np.float(line.split()[9].replace('D', 'E'))
                    transition_moments.append(transition_moment)

                elif int(line.split()[7]) == state_i and int(line2.split()[7]) == state_j:
                    line = file.readline()
                    line = file.readline()
                    transition_moment = np.float(line.split()[9].replace('D', 'E'))
                    transition_moments.append(transition_moment)

                elif int(line.split()[7]) == state_j and int(line2.split()[7]) == state_j:
                    line = file.readline()
                    line = file.readline()
                    transition_moment = np.float(line.split()[9].replace('D', 'E'))
                    transition_moments.append(transition_moment)
                
                else:
                    pass
            line = file.readline()

    dipole_vector_ii = np.array([transition_moments[i] for i in range(3)])
    dipole_vector_ij = np.array([transition_moments[i+3] for i in range(3)])
    dipole_vector_jj = np.array([transition_moments[i+6]for i in range(3)])
    
    #make the dipole moment matrix
    dipole_mat = np.zeros((2,2))
    
    dipole_difference_vector = dipole_vector_ii - dipole_vector_jj
    
    
    norm_dipole_difference_vector = np.sqrt(np.dot(dipole_difference_vector, dipole_difference_vector))
    
    norm_dipole_vector_ij = 1/norm_dipole_difference_vector*abs((np.dot(dipole_difference_vector, dipole_vector_ij)))
    #norm_dipole_vector_ij = np.sqrt(np.dot(dipole_vector_ij, dipole_vector_ij))
    
    dipole_mat[0,0] = 0.
    dipole_mat[1,1] = norm_dipole_difference_vector
    dipole_mat[0,1] = norm_dipole_vector_ij
    dipole_mat[1,0] = norm_dipole_vector_ij
    

    return dipole_mat


def rotate_to_diabatic_matrix(dipole_mat, hamiltonian):

    """
    This function diagonalizes the dipole matrix in the adiabatic basis "dipole_mat" 
    giving the rotation matrix from adiabatic to diabatic basis which is used to rotate
    rotate the Hamiltonian into the diabatic basis.
    """
    
    diabatic_dipole, rot_mat = eigh(dipole_mat)
    print('Difference in diabatic dipole moment =')
    print(diabatic_dipole[1]-diabatic_dipole[0])
    diabatic_hamiltonian = np.dot(rot_mat.T, np.dot(hamiltonian, rot_mat))
    
    return diabatic_dipole, diabatic_hamiltonian, rot_mat
    


file = sys.argv[1]
state_1 = int(sys.argv[2])
state_2 = int(sys.argv[3])

N_excited_states = getnumberofexcitedstates(file)

hamiltonian = getexcitationenergies(file, N_excited_states, state_1, state_2)
print(hamiltonian)


dipole_mat = gettransitionmoments(file, state_1, state_2)

print(dipole_mat)


diabatic_dipole, diabatic_hamiltonian, rot_mat = rotate_to_diabatic_matrix(dipole_mat, hamiltonian)
print('Diabatic Hamiltonian =')
print(diabatic_hamiltonian)
print('Diabatic energy difference = ') 
print(diabatic_hamiltonian[1,1]-diabatic_hamiltonian[0,0])
print('Rotationmatrix:')
print(rot_mat)

