beanmachine
beanmachine copied to clipboard
Simulate calls queried functions
Issue Description
If one uses the function simulate the graphical network is evalatued N times per sample, were N is the number of @bm.random_variable in the network
Steps to Reproduce
@bm.random_variable def A(): return dist.Normal(1,1)
@bm.random_variable def B(): return dist.Normal(1,1)
@bm.random_variable def C(): print('C') return dist.Normal(A()+B(),1)
obs_queries=[C()] predictives = bm.simulate(obs_queries, num_samples=1) -> 3 calls
(One also can do this with a profiler)
Expected Behavior
1 Call per sample
System Info
Please provide information about your setup
- PyTorch Version 1.12.1
- Python version 3.9
Additional Context
This happens because simulate uses "inference = SingleSiteAncestralMetropolisHastings()" for a sample step which is exactly this N times evaluation.
A solution could be to substitute the "next" in predictivy.py with an function which just uses a random proposer to generate a now world and return this would. (a adapted send method from "sampler.py")
Is this soluation viable or do I break something else with this along the line?