Bijectors.jl
Bijectors.jl copied to clipboard
Support ProductNamedTupleDistribution
julia> using Distributions, Bijectors
julia> dist = product_distribution((a = Normal(), b = Normal()))
ProductNamedTupleDistribution{(:a, :b)}(
a: Normal{Float64}(μ=0.0, σ=1.0)
b: Normal{Float64}(μ=0.0, σ=1.0)
)
julia> bijector(dist)
ERROR: MethodError: no method matching bijector(::Distributions.ProductNamedTupleDistribution{(:a, :b), Tuple{Normal{…}, Normal{…}}, Continuous, Float64})
The function `bijector` exists, but no method is defined for this combination of argument types.
I think we need to add a case here:
https://github.com/TuringLang/Bijectors.jl/blob/fbaf783c1e84540903d62179775936fd5169392a/src/transformed_distribution.jl#L75-L90
And I think that StackedBijector already gives us most of the actual code that we need, although we might need additional code to marshal the result from/to a NamedTuple(?).