jax
jax copied to clipboard
Update sparse.grad() to support re-packing gradients into PyTrees
Fixes #16582
This PR modifies the postprocessing step of sparse.grad()
to reconstruct the input PyTrees that were indexed for autodiff. Previously, only the gradient corresponding to the first element of a PyTree would be returned.
This PR includes several test cases to verify the behavior of sparse.grad()
matches that of jax.grad()
when gradients are taken with respect to pytrees.
@jakevdp This PR is ready for review if you get the chance.
@jakevdp Finally getting back around to this, apologies for the glacial pace. I adopted your suggestions with some minor fixes.
Hi - we're seeing a couple test failures here that have been fixed on the main branch. Can you rebase against the most recent main branch commit so we can try again?