DistributionsAD.jl
DistributionsAD.jl copied to clipboard
Unwanted type promotion from Float32 to Float64 while calculating logpdf of TuringDiagMvNormal
While computing logpdf of Multivariate normal distributions, the logpdf is of type Float64, even when the mean, covariance matrix and observation all have eltype Float32. MVE for TuringDiagMvNormal (TuringScalMvNormal and TuringDenseMvNormal suffer the same problem):
julia> using Distributions, DistributionsAD
julia> d = TuringDiagMvNormal(zeros(Float32, 2), ones(Float32, 2))
TuringDiagMvNormal{Vector{Float32}, Vector{Float32}}(m=Float32[0.0, 0.0], σ=Float32[1.0, 1.0])
julia> l = logpdf(d, [1f0, 2f0])
-4.337877066409345
julia> typeof(l)
Float64 #expected Float32
After a quick glance at the source, the problem seems to be that the constant 2π is Float64, which then promotes the rest of the expression.
I'm experiencing the same problem with DistributionsAD v0.6.28
Edit: Thinking about this a bit longer, I'm not sure if it is reasonable to expect that the output type of logpdf should mirror the type of either the parameters of the distribution or the value whose log probability should be determined.