module FTST_functions

include("parameters.jl")
include("statistical_mechanics.jl")
include("fragment_and_potential_input.jl")
using .parameters
using .statistical_mechanics
using .fragment_and_potential_input

using Interpolations
using LinearAlgebra
using StaticArrays
using ForwardDiff

export monte_carlo_integral, Trapezoid, find_max_J

function find_max_J(E::Real, ΔJ::Real, RC_grid::SVector{}, potential_NH2OH, potential_NH3O)
    J = 0.0
    N = 1.0
    while true
        J = J + ΔJ
        Threads.@threads for k = 1:length(RC_grid)
            N_NH2OH, N_NH3O = monte_carlo_integral(E, J, RC_grid[k], potential_NH2OH, potential_NH3O)
            if N_NH2OH < ϵ_0 && N_NH3O < ϵ_0
                N = 0.0
            end
        end
        if N < ϵ_0
            return J
        end
    end
end

function monte_carlo_integral(E::Real, J::Real, RC::Real, potential_NH2OH, potential_NH3O)        #Σ_previous is guess for the integral value from an interpolated value from a less dense grid.     

    if E < ϵ_0 || J < ϵ_0           #State sum is zero if E or J is zero. 
        return 0.0
    end

    #!!!Monte Carlo settings: adjust as needed!!!
    relative_change = 1.0e-2
    loop_size = 1000            #~3% accuracy 
    ########################

    Σ_function_NH2OH = 0.0     #Monte Carlo sum
    Σ_previous_NH2OH = 0.0
    Σ_function_NH3O = 0.0     #Monte Carlo sum
    Σ_previous_NH3O = 0.0
    N = 0                       #Monte Carlo points

    while true
        for i = 1:loop_size
            #Sample coordinates. 
            ϕ, θ, ν, η, = 2π * rand(), acos(2 * rand() - 1.0), 2π * rand(), 2π * rand(), acos(2 * rand() - 1.0)
            if θ <= 1.0e-5 || ν <= 1.0e-5   #To avoid division by zero issues. 
                continue
            end
            N = N + 1

            #Get external moment of inertia
            (RCcm, species) = F(ϕ, θ, RC)
            if species == 1
                Vpot = potential_NH2OH(ϕ, θ, RCcm)
            else
                Vpot = potential_NH3O(ϕ, θ, RCcm)
            end
            Iex = μ * RCcm^2

            detB, Erot = ComputeBMatrixDeterminant(J, ν, η, ϕ, θ, Iex)


            ΔE = E - Erot - Vpot
            if ΔE < 0.0  #The Heaviside step-function condition.                        
                continue
            else #Get determinant of the inverse G matrix (A)
                detA = ComputeAMatrixDeterminant(ϕ, θ, RC, Iex)
                if species == 1
                    Σ_function_NH2OH = Σ_function_NH2OH + (detA / detB)^0.5 * ΔE^((n - 3.0) / 2.0)
                else
                    Σ_function_NH3O = Σ_function_NH3O + (detA / detB)^0.5 * ΔE^((n - 3.0) / 2.0)
                end
            end
        end

        if Σ_previous_NH2OH < ϵ_0 && Σ_previous_NH3O < ϵ_0
            break
        elseif abs((Σ_function_NH2OH / N - Σ_previous_NH2OH) / (Σ_function_NH2OH / N)) < relative_change && abs((Σ_function_NH3O / N - Σ_previous_NH3O) / (Σ_function_NH3O / N))
            break
        else
            Σ_previous_NH2OH = Σ_function_NH2OH / N
            Σ_previous_NH3O = Σ_function_NH3O / N
        end
    end

    #This gives N(E,J), not the flux (f)!, so there is the additional h in front.)
    return h * 2J^2 * (2π / h^2)^((n + 1) / 2) / (Γ((n - 1.0) / 2.0)) * Σ_function_NH2OH * 16π^2 / N, h * 2J^2 * (2π / h^2)^((n + 1) / 2) / (Γ((n - 1.0) / 2.0)) * Σ_function_NH3O * 16π^2 / N

end

function ComputeBMatrixDeterminant(J::Real, ν::Real, η::Real, ϕ::Real, θ::Real, Iex::Real)
    B = ComputeBMatrix(ϕ, θ, Iex)
    IaEx, IbEx, IcEx = eigvals(B)
    detB = IaEx * IbEx * IcEx
    Erot = (J^2 / 2) * ((sin(ν)^2 / IaEx + cos(ν)^2 / IbEx) * sin(η)^2 + cos(η)^2 / IcEx)

    return detB, Erot
end

function ComputeBMatrix(ϕ, θ, Iex)
    Cϕ, Cθ = cos(ϕ), cos(θ)
    Sϕ, Sθ, = sin(ϕ), sin(θ)

    Ixx = (Ic * Cϕ^2 + Ib * Sϕ^2) * Cθ^2 + Ia * Sθ^2 + Iex
    Ixy = (Ic - Ib) * Cϕ * Cθ * Sϕ
    Ixz = (Ic * Cϕ^2 + Ib * Sϕ^2 - Ia) * Cθ * Sθ
    Iyy = Ic * Sϕ^2 + Ib * Cϕ^2 + Iex
    Iyz = (Ic - Ib) * Cϕ * Sθ * Sϕ
    Izz = (Ic * Cϕ^2 + Ib * Sϕ^2) * Sθ^2 + Ia * Cθ^2

    B = @SMatrix [Ixx Ixy Ixz;
        Ixy Iyy Iyz;
        Ixz Iyz Izz]

    return B
end

function ComputeAMatrixDeterminant(ϕ::Real, θ::Real, RC::Real, Iex::Real)
    Cϕ, Cθ = cos(ϕ), cos(θ)
    Sϕ, Sθ, = sin(ϕ), sin(θ)

    detA0 = Iex^2 * Ia * Ib * Ic

    G44 = 1.0 / (Iex * Sθ^2) + 1.0 / Ia + (Cθ^2 / Sθ^2) * (Sϕ^2 / Ib + Cϕ^2 / Ic)
    G45 = (Ic - Ib) * Cϕ * Cθ * Sϕ / (Ib * Ic * Sθ)
    G55 = 1.0 / Iex + Cϕ^2 / Ib + Sϕ^2 / Ic

    ∂F_∂ϕ, ∂F_∂θ, ∂F_∂RC = ForwardDiff.gradient(qInt::SVector{3} -> F(qInt[1], qInt[2], qInt[3])[1], SVector(ϕ, θ, RC))

    return detA0 * (1.0 + μ * (∂F_∂ϕ^2 * G44 + ∂F_∂θ^2 * G55 + 2 * (∂F_∂ϕ * ∂F_∂θ * G45)))
end

function Trapezoid(X::SVector{}, Y::SVector{}, β::Real)       #Trapezoid integration with Boltzmann weighting
    integral = 0.0
    for i = 1:(length(X)-1)
        integral = integral + (Y[i] * exp(-X[i] * β) + Y[i+1] * exp(-X[i+1] * β)) * (X[i+1] - X[i])
    end
    integral = integral / 2
    return integral
end





end