flax
flax copied to clipboard
fix: restore replicated tensors
What does this PR do?
Fixes restoring checkpoints where only some of the parameters are fully replicated.
I was not able to create a test (requires sharding over multiple devices and no similar example) but I created a simple reproduction here.
Checklist
- [x] This PR fixes a minor issue (e.g.: typo or small bug) or improves the docs (you can dismiss the other checks if that's the case).
- [ ] This change is discussed in a Github issue/ discussion (please add a link).
- [x] The documentation and docstrings adhere to the documentation guidelines.
- [ ] This change includes necessary high-coverage tests. (No quality testing = no merge!)
Codecov Report
Merging #3217 (535ed6f) into main (c8bb930) will not change coverage. The diff coverage is
0.00%.
@@ Coverage Diff @@
## main #3217 +/- ##
=======================================
Coverage 82.32% 82.32%
=======================================
Files 54 54
Lines 6071 6071
=======================================
Hits 4998 4998
Misses 1073 1073
| Impacted Files | Coverage Δ | |
|---|---|---|
| flax/training/orbax_utils.py | 69.44% <0.00%> (ø) |
https://github.com/google/flax/pull/3229 includes your change and fixed another issue that appeared with your change. Feel free to give it a try!
Thanks, I tested the main branch and my bug example works no