pyro icon indicating copy to clipboard operation
pyro copied to clipboard

🚧 RSA vectorized prior achieved; L0 doesn't enumerate

Open jmuchovej opened this issue 3 years ago • 3 comments

structured_prior will correctly enumerate (and specify individual sample sites - vs making a single sample site of a tensor over all enumerations).

listener0 doesn't enumerate over the support of structured_prior. It strictly visits the highest probability sample from the structured_prior.

jmuchovej avatar Aug 01 '21 17:08 jmuchovej

@eb8680 Just a bump on this. I can update with some more work I've done with this – but right now I've hit a snag where the following doesn't appear to work:

@config_enumerate
def listener0(utterance: Tensor, threshold: Tensor, States: RSAMarginal) -> Tensor:
    state = pyro.sample("state", States)
    # ...
    return state

RSAMarginal is much like HashingMarginal, but it extracts the support from a TracePosterior (as TraceEnum_ELBO appears to do in TraceEnum_ELBO._traces).

I can push the changes so you can inspect them if it's helpful – but at the moment, I'm not sure how to "signal" that RSAMarginal can be enumerated as though dist.Categorical-like classes appear to do.

I've searched the forum pretty extensively and haven't seen any questions about this.

jmuchovej avatar Oct 04 '21 20:10 jmuchovej

@jmuchovej sorry, I'll try to look at it this week.

eb8680 avatar Oct 12 '21 12:10 eb8680

Also, a side note, I've actually replicated everything (that is, the models we're looking to build) in WebPPL. The only way to get a reasonable run-time was to use of their cache(..., N) function. Just running the model (no data fitting) takes <5min as a result, compared to hours without cache(..., N). (I didn't let the model run without cache run to completion.)

There's a @memoize(...) decorator in the existing HashingMarginal, which I would assume should be analogous to WebPPL's cache function, but it doesn't seem to be... (If memory serves – I've tried with maxsize=100000 and no dice.)

I'm not exactly sure how WebPPL's cache works under the hood (in terms of relating it to Pyro nomenclature), but I would imagine it works similarly to mem. mem uses the arguments to compute the hash for the lookup table. (This seems very similar to what _dist_and_values does, yet there seems to be little to no performance uplift from memoizing _dist_and_values.)

WebPPL's docs on cache and on mem.

(Later this week, or early next week, I can try replicating the results I recall if that would be helpful.)

jmuchovej avatar Oct 12 '21 15:10 jmuchovej