RxInfer.jl
RxInfer.jl copied to clipboard
`PointMassFormConstraint` on a `MatrixDirichlet` marginal
Whenever we try to put a PointMassFormConstraint
on a variable where the marginal is a MatrixDirichlet
we run into an error:
MWE:
@model function foo(y)
p ~ MatrixDirichlet(ones(5, 5))
old_s ~ Categorical([0.2, 0.2, 0.2, 0.2, 0.2])
new_s ~ Transition(old_s, p)
y ~ Transition(new_s, diageye(5))
end
constraints = @constraints begin
q(p, old_s, new_s) = q(p)q(old_s, new_s)
q(p) :: PointMassFormConstraint()
end
init = @initialization begin
q(p) = MatrixDirichlet(ones(5, 5))
end
infer(model = foo(), data = (y = [0, 1, 0, 0, 0],), constraints = constraints, initialization = init)
This tries to call mode
on the resulting marginal, which is not implemented for a MatrixDirichlet
.