pyro
pyro copied to clipboard
Feature request: add a parameter that allows Predictive to propagate gradients
Issue Description
When generating samples using pyro.infer.predictive.Predictive, gradients are dropped. It looks like this is an intentional design choice from the code, but I'm not sure why - if there's a good reason then ignore me.
Code Snippet
Create a model and fit a guide such that model(x).requires_grad
and guide(x)['some_site'].requires_grad
both return True when x
has gradients enabled. Then do:
predictive = Predictive(model, guide=guide, num_samples=100)
posterior_samples = predictive(x)
Then posterior_samples['some_site'].requires_grad
is False.