Bijectors.jl icon indicating copy to clipboard operation
Bijectors.jl copied to clipboard

Support ProductNamedTupleDistribution

Open penelopeysm opened this issue 10 months ago • 0 comments

julia> using Distributions, Bijectors

julia> dist = product_distribution((a = Normal(), b = Normal()))
ProductNamedTupleDistribution{(:a, :b)}(
a: Normal{Float64}(μ=0.0, σ=1.0)
b: Normal{Float64}(μ=0.0, σ=1.0)
)

julia> bijector(dist)
ERROR: MethodError: no method matching bijector(::Distributions.ProductNamedTupleDistribution{(:a, :b), Tuple{Normal{…}, Normal{…}}, Continuous, Float64})
The function `bijector` exists, but no method is defined for this combination of argument types.

I think we need to add a case here:

https://github.com/TuringLang/Bijectors.jl/blob/fbaf783c1e84540903d62179775936fd5169392a/src/transformed_distribution.jl#L75-L90

And I think that StackedBijector already gives us most of the actual code that we need, although we might need additional code to marshal the result from/to a NamedTuple(?).

penelopeysm avatar May 27 '25 11:05 penelopeysm