keras
keras copied to clipboard
DRAFT: Add custom variable updater.
Allows customization for how variables are updated by the optimizer. The base optimizer simply defers to the update handler to do the update, allowing full customization.
Can replace the existing overwrite_with_gradient attribute on variables, which currently is very application-specific.
Eliminates creation of optimizer variables that have custom updaters (including overwrite_with_gradient), since those variables are never used and may be wasteful.
This is an alternative to #21196. It would allow us to add special-handling for large embedding tables, where we do not want to pass around large gradients for tables that might span multiple devices. Instead the tables are updated in-place using a custom update rule.
Codecov Report
Attention: Patch coverage is 94.44444% with 3 lines in your changes missing coverage. Please review.
Project coverage is 82.62%. Comparing base (
37eacb0) to head (a53a152). Report is 91 commits behind head on master.
| Files with missing lines | Patch % | Lines |
|---|---|---|
| keras/src/backend/common/variables.py | 83.33% | 1 Missing and 1 partial :warning: |
| keras/src/optimizers/optimizer.py | 96.00% | 1 Missing :warning: |
Additional details and impacted files
@@ Coverage Diff @@
## master #21225 +/- ##
=======================================
Coverage 82.61% 82.62%
=======================================
Files 564 564
Lines 54476 54514 +38
Branches 8470 8475 +5
=======================================
+ Hits 45005 45040 +35
- Misses 7395 7397 +2
- Partials 2076 2077 +1
| Flag | Coverage Δ | |
|---|---|---|
| keras | 82.43% <94.44%> (+<0.01%) |
:arrow_up: |
| keras-jax | 63.73% <94.44%> (+0.02%) |
:arrow_up: |
| keras-numpy | 58.85% <85.18%> (+0.01%) |
:arrow_up: |
| keras-openvino | 32.99% <35.18%> (+<0.01%) |
:arrow_up: |
| keras-tensorflow | 64.13% <94.44%> (+0.01%) |
:arrow_up: |
| keras-torch | 63.81% <94.44%> (+0.02%) |
:arrow_up: |
Flags with carried forward coverage won't be shown. Click here to find out more.
:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.
:rocket: New features to boost your workflow:
- :snowflake: Test Analytics: Detect flaky tests, report on failures, and find test suite problems.
@hertschuh @fchollet
Potential replacement for #21196. With this approach, layering is mostly preserved (although in our case the custom updater would still actually contain an optimizer - but at least it doesn't need to in the general case).
This change can also generalizes and allows us to remove the overwrite_with_gradient attribute entirely, instead using a custom updater. The existing overwrite_with_variable attribute is very specific to scale factors in fp8 quantization, using a max when accumulating gradients.
/gemini review
[!WARNING] You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!