pyro
pyro copied to clipboard
🚧 RSA vectorized prior achieved; L0 doesn't enumerate
structured_prior
will correctly enumerate (and specify individual sample sites - vs making a single sample site of a tensor over all enumerations).
listener0
doesn't enumerate over the support of structured_prior
. It strictly visits the highest probability sample from the structured_prior
.
@eb8680 Just a bump on this. I can update with some more work I've done with this – but right now I've hit a snag where the following doesn't appear to work:
@config_enumerate
def listener0(utterance: Tensor, threshold: Tensor, States: RSAMarginal) -> Tensor:
state = pyro.sample("state", States)
# ...
return state
RSAMarginal
is much like HashingMarginal
, but it extracts the support
from a TracePosterior
(as TraceEnum_ELBO
appears to do in TraceEnum_ELBO._traces
).
I can push the changes so you can inspect them if it's helpful – but at the moment, I'm not sure how to "signal" that RSAMarginal
can be enumerated as though dist.Categorical
-like classes appear to do.
I've searched the forum pretty extensively and haven't seen any questions about this.
@jmuchovej sorry, I'll try to look at it this week.
Also, a side note, I've actually replicated everything (that is, the models we're looking to build) in WebPPL. The only way to get a reasonable run-time was to use of their cache(..., N)
function. Just running the model (no data fitting) takes <5min as a result, compared to hours without cache(..., N)
. (I didn't let the model run without cache
run to completion.)
There's a @memoize(...)
decorator in the existing HashingMarginal
, which I would assume should be analogous to WebPPL's cache
function, but it doesn't seem to be... (If memory serves – I've tried with maxsize=100000
and no dice.)
I'm not exactly sure how WebPPL's cache
works under the hood (in terms of relating it to Pyro nomenclature), but I would imagine it works similarly to mem
. mem
uses the arguments to compute the hash for the lookup table. (This seems very similar to what _dist_and_values
does, yet there seems to be little to no performance uplift from memoizing _dist_and_values
.)
WebPPL's docs on cache
and on mem
.
(Later this week, or early next week, I can try replicating the results I recall if that would be helpful.)