Update checkpoints.py to add a `pool_size` param to `checkpoints.restore_checkpoint`
What does this PR do?
added pool_size as an optional param to restore_checkpoint. This gives control to the user on the number of threads being used (this is helpful in the GCS use case, specially with the new TPU-VMs).
cpus=multiprocessing.cpu_count()
pool_size=int(cpus*0.2) # on a 96 core TPU-VM that ends up being 19 threads
cProfile.run('state=checkpoints.restore_checkpoint(gcs_path,state,parallel=True,pool_size=pool_size)')
yields:
Parallel GCS with pool 19 ==============================
5363 function calls (4608 primitive calls) in 0.563 seconds
and
cpus=multiprocessing.cpu_count()
pool_size=int(cpus*0.5) # on a 96 core TPU-VM that ends up being 48 threads
cProfile.run('state=checkpoints.restore_checkpoint(gcs_path,state,parallel=True,pool_size=pool_size)')
yields:
Parallel GCS with pool 48 ==============================
5363 function calls (4608 primitive calls) in 0.574 seconds
and
cpus=multiprocessing.cpu_count()
pool_size=cpus # on a 96 core TPU-VM that ends up being 96 threads
cProfile.run('state=checkpoints.restore_checkpoint(gcs_path,state,parallel=True,pool_size=pool_size)')
yields:
Parallel GCS with pool 96 ==============================
8625 function calls (6460 primitive calls) in 0.651 seconds
While
cProfile.run('state=checkpoints.restore_checkpoint(gcs_path,state,parallel=True)')
yields:
Parallel GCS with pool 32 ==============================
8625 function calls (6460 primitive calls) in 0.734 seconds
Checklist
- [ ] This PR fixes a minor issue (e.g.: typo or small bug) or improves the docs (you can dismiss the other checks if that's the case).
- [ ] This change is discussed in a Github issue/ discussion (please add a link).
- [ ] The documentation and docstrings adhere to the documentation guidelines.
- [x] This change includes necessary high-coverage tests. (No quality testing = no merge!)
Codecov Report
Merging #1533 (ef894eb) into main (e0a618a) will increase coverage by
0.01%. The diff coverage isn/a.
@@ Coverage Diff @@
## main #1533 +/- ##
==========================================
+ Coverage 82.61% 82.63% +0.01%
==========================================
Files 66 66
Lines 5505 5522 +17
==========================================
+ Hits 4548 4563 +15
- Misses 957 959 +2
| Impacted Files | Coverage Δ | |
|---|---|---|
| flax/training/checkpoints.py | 97.87% <ø> (-0.02%) |
:arrow_down: |
| flax/linen/module.py | 92.98% <0.00%> (-0.34%) |
:arrow_down: |
| flax/core/lift.py | 95.76% <0.00%> (+0.09%) |
:arrow_up: |
| flax/linen/transforms.py | 94.15% <0.00%> (+0.10%) |
:arrow_up: |
Continue to review full report at Codecov.
Legend - Click here to learn more
Δ = absolute <relative> (impact),ø = not affected,? = missing dataPowered by Codecov. Last update e0a618a...ef894eb. Read the comment docs.
@avital @levskaya anything needed from my side to merge this?
Assigning this to @levskaya who added the pull ready label so he probably knows most about this PR.