pyro icon indicating copy to clipboard operation
pyro copied to clipboard

Feature request: add a parameter that allows Predictive to propagate gradients

Open pwsiegel opened this issue 4 months ago • 0 comments

Issue Description

When generating samples using pyro.infer.predictive.Predictive, gradients are dropped. It looks like this is an intentional design choice from the code, but I'm not sure why - if there's a good reason then ignore me.

Code Snippet

Create a model and fit a guide such that model(x).requires_grad and guide(x)['some_site'].requires_grad both return True when x has gradients enabled. Then do:

predictive = Predictive(model, guide=guide, num_samples=100)
posterior_samples = predictive(x)

Then posterior_samples['some_site'].requires_grad is False.

pwsiegel avatar Oct 18 '24 15:10 pwsiegel