DistributionsAD.jl
DistributionsAD.jl copied to clipboard
Automatic differentiation of Distributions using Tracker, Zygote, ForwardDiff and ReverseDiff
Currently the test coverage is pretty low. Some of the functions defined here are used in packages downstream so we get indirect testing, but this is not good enough for...
Sometimes it is useful to have a broadcasted version of a distribution that is treated as a multivariate distribution. This is critical for performance when doing reverse AD using Tracker...
Currently, time-series models are very slow in Turing compared to Stan mainly because of the dynamic dispatch of Tracker.jl. Loops are hard to avoid in such models and loops currently...
Sometimes it is useful to be able to define a multivariate distribution on iid variables by generating distributions on the fly which use different distribution parameters each variable according to...
# Overview Since a distributions has to be re-implemented here and the focus is on AD, I was wondering if it would be of any interested to add reparameterization to...
I had tried to run Enzyme tests, and it seems currently quite a few tests fail/error.
Implementation of Distributions.rand is not compatible with Zygote due to mutable operation: function Distributions.rand(rng::Random.AbstractRNG, d::TuringMvLogNormal, n::Int) x = rand(rng, d.normal, n) **map!(exp,x, x)** return x end This works with with...
MWE: ``` using Distributions using DistributionsAD using Random test_mm = MixtureModel([TuringDenseMvNormal(randn(10), collect(I(10))) for comp in 1:10]) rand(test_mm) ``` This type of code comes up as I construct a mixture distribution...
The following is currently the case: ```julia julia> h(θ) = logpdf(truncated(Normal(θ, 1), 0, Inf), 1.0) h (generic function with 1 method) julia> ForwardDiff.derivative(h, rand()) NaN julia> g(θ) = logpdf(truncated(Normal(θ, 1),...
This PR is basically an accumulation of the discussion in https://github.com/TuringLang/Turing.jl/issues/1934 and https://github.com/TuringLang/DistributionsAD.jl/pull/230. It's a hack to make reverse-mode AD packages that uses ForwardDiff for broadcasting _much_ faster when used...