jax icon indicating copy to clipboard operation
jax copied to clipboard

Update sparse.grad() to support re-packing gradients into PyTrees

Open Blair-Johnson opened this issue 1 year ago • 1 comments

Fixes #16582

This PR modifies the postprocessing step of sparse.grad() to reconstruct the input PyTrees that were indexed for autodiff. Previously, only the gradient corresponding to the first element of a PyTree would be returned.

This PR includes several test cases to verify the behavior of sparse.grad() matches that of jax.grad() when gradients are taken with respect to pytrees.

Blair-Johnson avatar Feb 12 '24 21:02 Blair-Johnson

@jakevdp This PR is ready for review if you get the chance.

Blair-Johnson avatar Apr 16 '24 18:04 Blair-Johnson

@jakevdp Finally getting back around to this, apologies for the glacial pace. I adopted your suggestions with some minor fixes.

Blair-Johnson avatar Jul 10 '24 22:07 Blair-Johnson

Hi - we're seeing a couple test failures here that have been fixed on the main branch. Can you rebase against the most recent main branch commit so we can try again?

jakevdp avatar Jul 29 '24 17:07 jakevdp