RxInfer.jl
RxInfer.jl copied to clipboard
Vector of inputs in `streaming_inference`
If we add a vector/tensor of inputs in streaming inference and create a (data)variable for every entry, we get the following error message:
MethodError: no method matching is_data(::Vector{RxInfer.GraphVariableRef})
Closest candidates are:
is_data(!Matched::RxInfer.GraphVariableRef)
@ RxInfer ~/.julia/packages/RxInfer/SROpQ/src/model/plugins/reactivemp_inference.jl:229
is_data(!Matched::GraphPPL.VariableNodeProperties)
@ GraphPPL ~/.julia/packages/GraphPPL/ke7hR/src/graph_engine.jl:696
MWE:
@model function test_model(x, y, mx, vx)
for i in 1:3
x[i] ~ NormalMeanVariance(mx, vx)
end
my ~ NormalMeanVariance(0, 1)
y ~ NormalMeanVariance(my, 1.0)
end
d = [(x = rand(3),y = rand()) for i in 1:10]
datastream = from(d) |> map(NamedTuple{(:x, :y), Tuple{Vector{Float64}, Float64}}, (d) -> d)
foo(x) = 1.0
autoupdates = @autoupdates begin
mx = foo(q(my))
vx = foo(q(my))
end
The following code runs and gives a result:
infer(model = test_model(mx = 1.0, vx = 1.0), data=(x = rand(3), y = 0.0), iterations=10, showprogress=true)
When we run streaming inference the error message is being thrown:
infer(model = test_model(), datastream=datastream, autoupdates = autoupdates, initialization = @initialization begin q(my) = NormalMeanVariance(1.0, 1.0) end)
The following fixes this, but might not be the most rigorous fix:
RxInfer.is_data(vector::Vector{RxInfer.GraphVariableRef}) = all(RxInfer.is_data.(vector))