ott icon indicating copy to clipboard operation
ott copied to clipboard

Add an OTT `check_grads` function

Open Daniel-Packer opened this issue 2 years ago • 0 comments

Is your feature request related to a problem? Please describe. When implementing new features, we provide tests of the gradients to demonstrate that the new feature is differentiable by autograd. In order to test all the gradients of a function, I end up duplicating a lot of code blocks. Even if I write a function to check the gradients to avoid that duplication, I will need to duplicate that function definition across feature implementations, which is pretty much the same issue. We could just group these into a single function in ott.tools that can be used universally.

Describe the solution you'd like We could implement a version of check_grads in ott.tools as done in the internal jax utilities here: google/jax#2648.

Describe alternatives you've considered We could also just use the jax internal check_grads function directly. I don't know if there are compatibility issues there.

Additional context The docstring for check_grads has been added too, which makes it sound like we could just use it, but I'm not sure: google/jax#2656.

Daniel-Packer avatar Nov 06 '23 13:11 Daniel-Packer