pyro
pyro copied to clipboard
`pyro.vectorized_markov` is slow for large time lengths because of `str(torch.tensor)` operation
In order to make pyro.markov
and pyro.vectorized_markov
syntaxes compatible and for readability of Markov variable names we introduced string manipulations in VectorizedMarkovMessenger
that replaces str(torch.tensor)
suffix with str(slice)
:
https://github.com/pyro-ppl/pyro/blob/5ecc9c0a2b8192b99ad386cc1d946e8a44f964ae/pyro/contrib/funsor/handlers/plate_messenger.py#L320
which allows to interchange pyro.markov
and pyro.vectorized_markov
:
-for i in pyro.markov(...):
+for i in pyro.vectorized_markov(...):
x_curr = pyro.sample(f"x_{i}", ...)
However, for long tensors calling str(tensor)
is slow and becomes one of the slowest operations in TraceMarkovEnum_ELBO
. Locally, I changed my code so that pyro.vectorized_markov
would return a tuple of torch.tensor
indices and slice
suffix:
for i, i_suffix in pyro.vectorized_markov(...):
# i - torch.tensor object
# i_suffix - slice object
x_curr = pyro.sample(f"x_{i_suffix}", ...)
This solves the speed problem, however, the syntax deviates from pyro.markov
's.
- Does anyone have an idea how to solve this issue without breaking code compatibility between
pyro.markov
andpyro.vectorized_markov
? - If not, should the speed be prioritized over aforementioned syntax compatibility?
What about just making the indices slice
s? Was there a reason we needed them to be tensors?
I guess we need tensors to do index arithmetic (e.g. x[t-1]
). It seems like a better long-term solution would be to manage naming in markov contexts more automatically and avoid the need for manual formatting in the first place, since it's rare and difficult for users to refer to these sites manually anyway. In the meantime,
should the speed be prioritized over aforementioned syntax compatibility?
No, let's just change the syntax, contrib.funsor
is already pretty slow