optax
optax copied to clipboard
Fix: #1509 Merge duplicate DoG implementations and add layer-wise support
Fix Issue: #1509 This PR merges the duplicate Distance over Gradients (DoG) implementations found in optax/contrib/_dog.py and optax/_src/transform.py into a single, unified implementation in optax/_src/dog.py. Created optax/_src/dog.py which consolidates DoG and DoWG. The new scale_by_dognow supports a layer_wise argument. Re-implemented scale_by_distance_over_gradients in optax/_src/transform.py to use the new scale_by_dog with layer_wise=True. Deprecated scale_by_distance_over_gradients in favor of scale_by_dog. Updated optax/contrib/_dog.py to be a compatibility shim importing from optax/_src/dog.py. Added optax/_src/dog_test.py to verify both global and layer-wise behaviors, as well as legacy compatibility.
Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).
View this failed invocation of the CLA check for more information.
For the most up to date status, view the checks section at the bottom of the pull request.
@mtthss @rdyro @vroulet Can you review this.