#FTST CODE# 
include("parameters.jl")
include("FTST_functions.jl")
include("statistical_mechanics.jl")
include("fragment_and_potential_input.jl")
using .parameters
using .FTST_functions
using .statistical_mechanics
using .fragment_and_potential_input

using StaticArrays
using Printf
using Interpolations
using ScatteredInterpolation
using DelimitedFiles


##################################################Input Information######################################################### 
const global Tmax = 2500.0      #maximum temperature in K
const global Tpoints = 1000         #Number of temperatures at which the canonical rate coefficient is calculated.
const global ΔT = log(Tmax) / (Tpoints - 1)
const global βGrid = @SVector [1.0 / (kB * exp((i - 1) * ΔT)) for i = 1:Tpoints]

const global Emax = 25 * kB * Tmax          #Upper energy limit for integration
const global RCmin = 1.0Å                    #minimum reaction-coordinate distance 
const global RCmax = 25.0Å                   #maximum reaction-coordinate distance
############################################################################################################################ 

println("Prepare potential:")

#Read rigid and relaxed MEP scans for 1D correction
rigid_potential_NH2OH = readdlm("rigid_potential_NH2OH.txt")
relaxed_potential_NH2OH = readdlm("relaxed_potential_NH2OH.txt")
rigid_potential_NH3O = readdlm("rigid_potential_NH3O.txt")
relaxed_potential_NH3O = readdlm("relaxed_potential_NH3O.txt")
rigid_potential_NH2OH = linear_interpolation(rigid_potential_NH2OH[:, 1], rigid_potential_NH2OH[:, 2], extrapolation_bc=0.0)
relaxation_potential_NH2OH = linear_interpolation(relaxed_potential_NH2OH[:, 1], relaxed_potential_NH2OH[:, 2], extrapolation_bc=0.0)
rigid_potential_NH3O = linear_interpolation(rigid_potential_NH3O[:, 1], rigid_potential_NH3O[:, 2], extrapolation_bc=0.0)
relaxation_potential_NH3O = linear_interpolation(relaxed_potential_NH3O[:, 1], relaxed_potential_NH3O[:, 2], extrapolation_bc=0.0)

#Read angle-dependent rigid-body potential
potential = readdlm("potential2.txt")
potential = ScatteredInterpolation.interpolate(NearestNeighbor(), potential[:, 1:3]', potential[:, 4])
@time begin
    ϕ_grid = [i * π / 180.0 for i = 0:360]
    θ_grid = [j * π / 180.0 for j = 0:180]
    r_grid = [r = RCmin + k * (RCmax - RCmin) / 500 for k = 0:500]
    potential_NH2OH = zeros(361, 181, 501)
    potential_NH3O = zeros(361, 181, 501)
    for i = 1:361
        for j = 1:181
            Threads.@threads for k = 1:501
                potential_NH2OH[i, j, k] = evaluate(potential, [ϕ_grid[i]; θ_grid[j]; r_grid[k]])[1] #
                potential_NH3O[i, j, k] = potential_NH2OH[i, j, k]
                potential_NH2OH[i, j, k] = potential_NH2OH[i, j, k] + relaxation_potential_NH2OH(r_grid[k]) - rigid_potential_NH2OH(r_grid[k])
                potential_NH3O[i, j, k] = potential_NH3O[i, j, k] + relaxation_potential_NH3O(r_grid[k]) - rigid_potential_NH3O(r_grid[k])
            end
        end
    end
end
potential_NH2OH = linear_interpolation((ϕ_grid, θ_grid, r_grid), potential_NH2OH, extrapolation_bc=Flat())
potential_NH3O = linear_interpolation((ϕ_grid, θ_grid, r_grid), potential_NH3O, extrapolation_bc=Flat())

grid_points = 100
println("State-sum scan with a grid size of 300:")

RC_grid = @SVector [RCmin + (k - 1) * (RCmax - RCmin) / (grid_points - 1) for k = 1:grid_points]
E_grid = [@SVector [0.0]; @SVector [kB * exp((i - 2) * (log(Emax) - log(kB)) / (grid_points - 2)) for i = 2:grid_points]]


println("Find maximum J:s to be considered:")
@time begin
    global Jmax = find_max_J(Emax, (2 * Emax * μ * RCmax^2)^0.5 / grid_points, RC_grid, potential_NH2OH, potential_NH3O)
end
println("Jmax", ' ', Jmax)


println("Determine state sum as a function of E, J, and RC")
state_sum = zeros(4, grid_points, grid_points, grid_points)
TS_state_sum_EJ_NH2OH = zeros(4, grid_points, grid_points)
TS_state_sum_EJ_NH3O = zeros(4, grid_points, grid_points)

@time begin
    for i = grid_points:-1:2
        ΔJ = Jmax / (grid_points - 1)
        J_grid = [@SVector [1.0e-6]; @SVector [j * ΔJ for j = 1:(grid_points-1)]]
        for j = grid_points:-1:1
            TS_state_sum_EJ_NH2OH[1, j, i], TS_state_sum_EJ_NH2OH[2, j, i] = E_grid[i], J_grid[j]
            TS_state_sum_EJ_NH3O[1, j, i], TS_state_sum_EJ_NH3O[2, j, i] = E_grid[i], J_grid[j]
            N_TS = 1.0e300
            Threads.@threads for k = 1:grid_points
                N_NH2OH, N_NH3O = monte_carlo_integral(E_grid[i], J_grid[j], RC_grid[k], potential_NH2OH, potential_NH3O)
                if (N_NH2OH + N_NH3O) < N_TS
                    N_TS = N_NH2OH + N_NH3O
                    TS_state_sum_EJ_NH2OH[3, j, i], TS_state_sum_EJ_NH2OH[4, j, i] = RC_grid[k], N_NH2OH
                    TS_state_sum_EJ_NH3O[3, j, i], TS_state_sum_EJ_NH3O[4, j, i] = RC_grid[k], N_NH3O
                end
            end
            if j < (grid_points - 1)
                if N_TS < ϵ_0 && (TS_state_sum_EJ_NH2OH[4, j+1, i] + TS_state_sum_EJ_NH3O[4, j+1, i]) < ϵ_0 && (TS_state_sum_EJ_NH2OH[4, j+2, i] + TS_state_sum_EJ_NH3O[4, j+2, i]) < ϵ_0 && J_grid[j] < Jmax
                    global Jmax = TS_state_sum_EJ_NH3O[2, j+2, i]
                end
            end
        end
    end
end

println("Save TS N and RC as a function of E and J")
@time begin
    file = open("EJRCN.dat", "w")
    for i = 2:grid_points
        for j = 1:grid_points
            if TS_state_sum_EJ_NH2OH[1, j, i] < ϵ_0 || TS_state_sum_EJ_NH2OH[2, j, i] < ϵ_0
                continue
            end
            write(file, @sprintf "%16.8e" TS_state_sum_EJ_NH2OH[1, j, i])
            write(file, @sprintf "%16.8e" TS_state_sum_EJ_NH2OH[2, j, i])
            write(file, @sprintf "%16.8e" TS_state_sum_EJ_NH2OH[3, j, i])
            write(file, @sprintf "%16.8e" TS_state_sum_EJ_NH2OH[4, j, i] + TS_state_sum_EJ_NH3O[4, j, i])
            write(file, '\n')
        end
        write(file, '\n')
    end
    close(file)
end

println("Integrate over J")
@time begin
    TS_state_sum_E_NH2OH = zeros(MVector{grid_points})
    TS_state_sum_E_NH3O = zeros(MVector{grid_points})
    for i = 1:grid_points
        N_EJ_NH2OH = @SVector [TS_state_sum_EJ_NH2OH[4, j, i] for j = 1:grid_points]
        N_EJ_NH3O = @SVector [TS_state_sum_EJ_NH3O[4, j, i] for j = 1:grid_points]
        J_grid = @SVector [TS_state_sum_EJ_NH2OH[2, j, i] for j = 1:grid_points]
        TS_state_sum_E_NH2OH[i] = average_over_J(J_grid, N_EJ_NH2OH)
        TS_state_sum_E_NH3O[i] = average_over_J(J_grid, N_EJ_NH3O)
    end
end

println("Compute canonical rate coefficients and save them to 'CanonicalRateCoefficients.dat'")
@time begin
    Tkβ = @MMatrix zeros(3, Tpoints)
    Threads.@threads for i = 1:Tpoints
        β = βGrid[i]
        Tkβ[1, i] = 1 / (kB * β)
        Q_transitional_NH2OH = Q_laplace(E_grid, SVector(TS_state_sum_E_NH2OH), β)
        Q_transitional_NH3O = Q_laplace(E_grid, SVector(TS_state_sum_E_NH3O), β)
        Tkβ[2, i] = (ge / σ) / (h * Q_translational(μ, β) * Q_rotational(Ia, Ib, Ic, β) * Q_rotational(0.0, 0.0, 0.0, β)) * Q_transitional_NH2OH * (100 * a0)^3 / timeUnit  #The last part converts from a.u. to cm3 / s 
        Tkβ[3, i] = (ge / σ) / (h * Q_translational(μ, β) * Q_rotational(Ia, Ib, Ic, β) * Q_rotational(0.0, 0.0, 0.0, β)) * Q_transitional_NH3O * (100 * a0)^3 / timeUnit  #The last part converts from a.u. to cm3 / s 
    end
    file = open("CanonicalRateCoefficients.dat", "w")
    for i = 1:Tpoints
        write(file, @sprintf "%16.8e" Tkβ[1, i])
        write(file, @sprintf "%16.8e" Tkβ[2, i])
        write(file, @sprintf "%16.8e" Tkβ[3, i])
        write(file, '\n')
    end
    close(file)
end

