DistributionsAD.jl
DistributionsAD.jl copied to clipboard
ForwardDiff.jl + derivative of parameters of a truncated distribution = NaNs everywhere
The following is currently the case:
julia> h(θ) = logpdf(truncated(Normal(θ, 1), 0, Inf), 1.0)
h (generic function with 1 method)
julia> ForwardDiff.derivative(h, rand())
NaN
julia> g(θ) = logpdf(truncated(Normal(θ, 1), 1e-6, 1000), 1.0)
g (generic function with 1 method)
julia> ForwardDiff.derivative(g, rand())
-0.2321285832954859
IIRC, this has come up before? It comes down the usage of the cdf in the computation of the truncated log-pdf, which causes issues.
@sethaxen did we talk about this over Slack at some point? Feel like there was a thread about this issue.
This has come up quite a few times but fortunately the solution is easy: Use NaN-safe mode in ForwardDiff (by default, it does return incorrect result for infinite values with zero partials) or use the keyword argument syntax of truncated (truncated(Normal(...); lower=0)). The latter has the additional advantage that it avoids undesired promotions and, in the future, that you can dispatch on left- and right-truncated distributions (just opened a PR to Distributions a few days ago).