import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import scipy.linalg as la
from mpmath import *
import scipy.ndimage
import os
import pprint
import math
from matplotlib.ticker import FormatStrFormatter


# The spectral density is calculated as 
# $J(\omega)_{m,n} = \sum_\xi g_m(\xi) g_n(\xi) \delta(\omega-\omega_\xi)$

def lorentzian_broadening(x,sigma,gm_list,gn_list,freq_list):
    """
    calculate the lorentzian broadening of the stick-spectrum
    x - frequency in cm^-1
    sigma - broadening sqrt(cm^-2)
    
    output unit = 1/cm^-1
    
    """  

    if x < 0:
        x = -x
    
    l_sum = 0
    for i in range(len(freq_list)):
        a = gn_list[i]*gm_list[i]/np.pi
        #gamma = 2*np.sqrt(2*np.log(2))*sigma #same FWHM as gaussian
        gamma = sigma
        lz = a*(0.5*gamma/((x-freq_list[i])**2+(0.5*gamma)**2))
        l_sum += lz
        
    return l_sum


# The correlation funtion is $C(\omega) = 2\pi \omega^2 [1 + n(\omega)](J(\omega) - J(-\omega))$

def Cfunc(x,sigma,gm_list,gn_list,freq_list,T):
    """ Calculate the correlation funtion from the spectral density - the gb needs the frequency in cm^-1, but the correlation needs it in Hz
    
    output unit = eV^2*s
    """ 
    hbar = 6.58211899*10**(-16) #eV*s
    kb = 8.6173303*10**(-5) #eV/K
    c = 299792458*10**2 #cm/s
    s = np.sign(x)
    x *= 2*np.pi*c #1/s #the angular frequency

    gb = lorentzian_broadening(x/(2*np.pi*c),sigma,gm_list,gn_list,freq_list)
    gb = gb*(1/(2*np.pi*c)) #change the spectral density from cm^-1 to Hz
    #using L'Hopitals rule to find the limit we get
    if x == 0:
        cor = 0
    else:
        cor = 2*np.pi*hbar**2*x**2*(1+(1/(np.exp(hbar*x/(kb*T))-1)))*(gb)
    return s*float(cor)


def D_matrix(Cfunc,W,T,sigma,g_matrix,j,k,freq_list):
    """
    The Fourier tranformed of the correlation function
    Cfunc - the correlation function
    W - the hamiltonian eigenfrequencies
    T - the temperature in K
    g_matrix - collection with row vectors of all calculated couplings
    freq_list - frequencies of the normal_modes    
    
    output unit = eV^2*s
    """
    D_matrix = np.zeros((len(W),len(W)))
    for n in range(len(W)):
        for m in range(len(W)):
            D_matrix[n,m] = float(Cfunc(W[m,m]-W[n,n],sigma,g_matrix[j],g_matrix[k],freq_list,T))
    return 0.5*D_matrix

def DT_matrix(Cfunc,W,T,sigma,g_matrix,j,k,freq_list):
    D_matrix = np.zeros((len(W),len(W)))
    for n in range(len(W)):
        for m in range(len(W)):
            D_matrix[n,m] = float(Cfunc(W[n,n]-W[m,m],sigma,g_matrix[j],g_matrix[k],freq_list,T))
    return 0.5*D_matrix


def q_matrix(W,V,basis,s,C,T,sigma,g_matrix,freq_list):
    """
    The q-matrix of the redfield tensor - the matrix with the correlation function and system coupling

    output unit = eV^2*s
    """
    num_vector = len(basis) #number of basis vectors

    q_matrix = np.zeros((len(s),len(s)),dtype=object)
      
    for j in range(len(s)):
        for k in range(len(s)):
            sub_matrix = []
            if k != j: #the use of delta function in correlation function
                sub_matrix = np.zeros((num_vector,num_vector))
            else:    
                #sub_matrix = (V.conj().T@s[k]@V)*D
                D = D_matrix(C,W,T,sigma,g_matrix,j,k,freq_list)
                sub_matrix = (V.conj().T@s[k]@V)*D                    
            q_matrix[j][k] = sub_matrix         
    return q_matrix

def q_matrix_hat(W,V,basis,s,C,T,sigma,g_matrix,freq_list):
    """
    The q-hat-matrix of the redfield tensor - the matrix with the correlation function and system coupling

    output unit = eV^2*s
    """
    num_vector = len(basis) #number of basis vectors

    q_matrix = np.zeros((len(s),len(s)),dtype=object)
    
    
    for j in range(len(s)):
        for k in range(len(s)):
            sub_matrix = []
            if k != j: #the use of delta function in correlation function
                sub_matrix = np.zeros((num_vector,num_vector))
            else:    
                #sub_matrix = (V.conj().T@s[k]@V)*DT
                DT = DT_matrix(Cfunc,W,T,sigma,g_matrix,j,k,freq_list)
                sub_matrix = (V.conj().T@s[k]@V)*DT
            q_matrix[j][k] = sub_matrix           
    
    return q_matrix

def summation(eye,s,q,q_hat,V):
    """
    The summation of the coupling terms, the second term of the redfield tensor
    """
    summation = 0
    for j in range(len(s)):
        for k in range(len(s)):
            elem = (-np.kron(eye,s[j]@V@q[j][k]@V.conj().T)+np.kron(s[j].T, V@q[j][k]@V.conj().T) - np.kron(s[j].T@V.conj()@q_hat[j][k].T@V.T,eye)+np.kron(V.conj()@q_hat[j][k].T@V.T,s[j]))
            summation += elem
    return summation

def rho_derivative(H,s,Cfunc,T,basis,sigma,g_matrix,freq_list):
    """
    Takes the density in vector; rho
    System Hamiltonian; H
    V^dag H V = diag(e1,e2,-); V diagonalizes H
    C; spectral function
    s; system operator vector
    T: temperatur in K
    """
    hbar = 6.58211899*10**(-16) #eV*s
    c = 299792458*10**2 #cm/s
    kb = 8.6173303*10**(-5) #eV/K

    eps,V = np.linalg.eigh(H)
    W = np.diag(eps)*8065.6 #eigen frequencies in cm^-1
    
    qmatrix = q_matrix(W,V,basis,s,Cfunc,T,sigma,g_matrix,freq_list)
    qmatrix_hat = q_matrix_hat(W,V,basis,s,Cfunc,T,sigma,g_matrix,freq_list)

    row,colum = H.shape
    eye = np.identity(row)
    
    summat = summation(eye,s,qmatrix,qmatrix_hat,V)

    r_tensor = (1.0j)/hbar*(np.kron(H.conj().T,eye) - np.kron(eye,H))+ 1/(hbar**2)*summat

    return r_tensor

def basis_state(n,m):
    """
    create the sitebasis
    n; dimension of system Hamiltonian
    m; state of basis
    """
    state = np.zeros(n)
    state[m] = 1
    return state


def s_state_dia(m,n):
    """
    system part of system-bath coupling
    m; basis state 1
    n; basis state 2
    """
    s = np.outer(basis[m],basis[n])
    return np.array(s)

"""
In the following an example of a GMH and FMR Hamiltonian is given
The energies and couplings are calculated with other scripts, and manually added to build
the Hamiltonian. 
"""
def gmh_ham(g_matrix):
    """
    Create the GMH Hamiltonian.
    The energies and couplings are added manually, calculated with other scripts. 
    Furthermore, for the GMH formalism the correlation funcitons are transformed to the 
    diabatic basis using the same transformation matrices as those found with GMH
    eps; energies and couplings are given in eV
    """
    eps0 = 0.0
    eps1 = 2.97 #GMH
    eps2 = 3.42 #GMH
    eps3 = 3.25 #GMH
    eps4 = 3.70 #GMH    
    
    g01 = 0.00528*10**-3 #GMH
    g03 = 7.92*10**-3 #GMH
    g04 = 14.78*10**-3 #GMH
    
    g24 = 1.509*10**-3 #GMH
    g23 = 3.74*10**-3 #GMH
    g13 = 0.732*10**-3 #GMH
    g14 = 15.8*10**-3 #GMH
    
    H = np.zeros((5,5))
    n = H.shape[0] #dimension of the Hamiltonian
    
    #diabatic energies
    H[0,0] = eps0 #ground state
    H[1,1] = eps1 #d+-an-ac-
    H[2,2] = eps2 #d-an*-ac
    H[3,3] = eps3 #d+-an--ac
    H[4,4] = eps4 #d-an+-ac-
    
    #electronic couplings
    H[0,1] = H[1,0] = g01
    H[0,3] = H[3,0] = g03
    H[0,4] = H[4,0] = g04
    H[2,4] = H[4,2] = g24
    H[2,3] = H[3,2] = g23
    H[1,3] = H[3,1] = g13
    H[1,4] = H[4,1] = g14

    #set 1
    U_23 = np.array([[-0.9986776, 0.05141058], [0.05141058, 0.9986776]])
    U_19 = np.array([[-0.99988399,0.01523174], [ 0.01523174, 0.99988399]])

    m = len(g_matrix[0])
    gn = np.zeros(m)
    gm = np.zeros(m)
    
    for i in range(m):
        #state 9 is the 4th element in the g_matrix
        #create state d+-an-ac-(1) and d-an+-ac-(4) (g1, g4)
        diag = np.diag((g_matrix[1][i],g_matrix[4][i]))
        transformed = U_19.T@diag@U_19
        gn[i] = transformed[0,0]
        gm[i] = transformed[1,1]
        
    g_matrix[1] = gn
    g_matrix[4] = gm
    
    gn = np.zeros(m)
    gm = np.zeros(m)
    
    for i in range(m):
        #state 8 is the 4th element in the g_matrix
        #create state d-an*-ac (2) and d+-an--ac (3) (g2, g3)
        diag = np.diag((g_matrix[2][i],g_matrix[3][i]))
        transformed = U_23.T@diag@U_23
        gn[i] = transformed[0,0]
        gm[i] = transformed[1,1]
        
    g_matrix[2] = gn
    g_matrix[3] = gm   
    return np.array(H), np.array(g_matrix)
    

def fmr_ham(g_matrix):
    """
    FMR Hamiltonian.
    The g_matrix is not transformed with FMR
    eps; energies and couplings are given in eV
    """

    eps0 = 0.0
    eps1 = 2.97
    eps2 = 3.42
    eps3 = 3.25
    eps4 = 3.70
    
    g01 = 0.00363*10**-3 #FMR
    g03 = 0.396*10**-3 #FMR
    g04 = 0.926*10**-3 #FMR
    
    
    g24 = 0.793*10**-3 #FMR
    g23 = 34.9*10**-3 #FMR
    g13 = g24
    g14 = g23
    
    H = np.zeros((5,5))
    n = H.shape[0] #dimension of the Hamiltonian
    
    #diabatic energies
    H[0,0] = eps0 #ground state
    H[1,1] = eps1 #d+-an-ac-
    H[2,2] = eps2 #d-an*-ac
    H[3,3] = eps3 #d+-an--ac
    H[4,4] = eps4 #d-an+-ac-
    
    #electronic couplings
    H[0,1] = H[1,0] = g01
    H[0,3] = H[3,0] = g03
    H[0,4] = H[4,0] = g04
    H[2,4] = H[4,2] = g24
    H[2,3] = H[3,2] = g23
    H[1,3] = H[3,1] = g13
    H[1,4] = H[4,1] = g14
    return np.array(H), np.array(g_matrix)

def plot_pop(N,populations,plot_name1):
    """
    Plot the populations from the Redfield tensor.
    Since the density is transformed to Liouville space/vectorized the populaiton 
    elements needs to be extracted.
    N; Dimension of H/#of states
    populations; the full vector of the density matrix
    plot_name1; name of the .pdf file created
    """
    pop_list = []
    for i in range(N):
        pop_list.append(i*N+i)

    labels = [r"D-An-Ac", r"D$^+$-An-Ac$^-$",r"D-An$^*$-Ac",r"D$^+$-An$^-$-Ac",r"D-An$^+$-Ac$^-$"]
    ### Plot1 All Populations ####
    fig = plt.figure()
    ax = plt.subplot(111)

    k = 0
    for i in pop_list:
        ax.plot(t_list,np.abs(populations[:,i]),'-',label=labels[k])
        k += 1

    plt.ticklabel_format(style='sci', axis='x', scilimits=(0,0))
    # Shrink current axis's height by 10% on the bottom
    box = ax.get_position()
    ax.set_position([box.x0, box.y0 + box.height * 0.2,
                 box.width, box.height * 0.8])

    # Put a legend below current axis
    ax.legend(loc='upper center', bbox_to_anchor=(0.5, -0.15),
          fancybox=True, shadow=False, ncol=3)

    plt.title("Populations")
    plt.xlabel("[s]")
    plt.ylabel("Population")
    plt.savefig(plot_name1)

n=5 #dimension of Hs

### make basis and Vs
basis = [basis_state(n,0),basis_state(n,1),basis_state(n,2),basis_state(n,3),basis_state(n,4)]
basis = np.array(basis)
basis = basis.T
s = np.array([0*s_state_dia(0,0),s_state_dia(1,1),s_state_dia(2,2),s_state_dia(3,3),s_state_dia(4,4)])

### Read the coupling constant of the vibrations ###
files = [filename for filename in os.listdir(".") if filename.startswith("coupling")]
files.sort()

g_matrix = np.zeros(len(basis),dtype=object)
for i,filename in enumerate(files):
    if i == 1:
        with open(filename,"r") as f:
            freq = f.readline().split(",")
            freq = np.array(freq,dtype=float)

    with open(filename,"r") as f:
        f.readline()
        g = f.readline().split(",")
        g = np.array(g,dtype=float)
        g_matrix[i+1] = g

g_matrix[0] = np.zeros(len(freq))
### Create the Hamiltonian ###
H,g_matrix = gmh_ham(g_matrix)
### Start in the excited state ###
row = H.shape[0]
rho_start = np.zeros(row**2)
rho_start[12]=1.0

### Propagate the populations ####
t_list = np.arange(1.5*10**-8,2.0*10**-8,10**(-10)) #ms

#np.save("tlist",t_list)
####
plot_gmh = "GMH_"+str(sys.argv[1])+".pdf"
plot_fmr = "FMR_"+str(sys.argv[1])+".pdf"



n_steps = len(t_list)
temp = 298.15 # K
sigma = 10

R = rho_derivative(H,s,Cfunc,temp,basis,sigma,g_matrix,freq)

populations = []

for m in range(n_steps):
    exp_R = la.expm(R*t_list[m])
    rho = exp_R@rho_start
    populations.append(rho)

populations = np.array(populations)
#np.save("populations_gmh",populations)

plot_pop(row,populations,plot_gmh)

### Create the Hamiltonian ###
plt.clf()
H,g_matrix = fmr_ham(g_matrix)
R = rho_derivative(H,s,Cfunc,temp,basis,sigma,g_matrix,freq)

populations = []

for m in range(n_steps):
    exp_R = la.expm(R*t_list[m])
    rho = exp_R@rho_start
    populations.append(rho)

populations = np.array(populations)
#np.save("populations_fmr",populations)

plot_pop(row,populations,plot_fmr)
