flax icon indicating copy to clipboard operation
flax copied to clipboard

Update checkpoints.py to add a `pool_size` param to `checkpoints.restore_checkpoint`

Open mjsML opened this issue 4 years ago • 3 comments

What does this PR do?

added pool_size as an optional param to restore_checkpoint. This gives control to the user on the number of threads being used (this is helpful in the GCS use case, specially with the new TPU-VMs).

cpus=multiprocessing.cpu_count()
pool_size=int(cpus*0.2) # on a 96 core TPU-VM that ends up being 19 threads
cProfile.run('state=checkpoints.restore_checkpoint(gcs_path,state,parallel=True,pool_size=pool_size)')

yields:

Parallel GCS with pool 19 ==============================
         5363 function calls (4608 primitive calls) in 0.563 seconds

and

cpus=multiprocessing.cpu_count()
pool_size=int(cpus*0.5) # on a 96 core TPU-VM that ends up being 48 threads
cProfile.run('state=checkpoints.restore_checkpoint(gcs_path,state,parallel=True,pool_size=pool_size)')

yields:

Parallel GCS with pool 48 ==============================
         5363 function calls (4608 primitive calls) in 0.574 seconds

and

cpus=multiprocessing.cpu_count()
pool_size=cpus # on a 96 core TPU-VM that ends up being 96 threads
cProfile.run('state=checkpoints.restore_checkpoint(gcs_path,state,parallel=True,pool_size=pool_size)')

yields:

Parallel GCS with pool 96 ==============================
         8625 function calls (6460 primitive calls) in 0.651 seconds

While

cProfile.run('state=checkpoints.restore_checkpoint(gcs_path,state,parallel=True)')

yields:

Parallel GCS with pool 32 ==============================
         8625 function calls (6460 primitive calls) in 0.734 seconds

Checklist

  • [ ] This PR fixes a minor issue (e.g.: typo or small bug) or improves the docs (you can dismiss the other checks if that's the case).
  • [ ] This change is discussed in a Github issue/ discussion (please add a link).
  • [ ] The documentation and docstrings adhere to the documentation guidelines.
  • [x] This change includes necessary high-coverage tests. (No quality testing = no merge!)

mjsML avatar Sep 09 '21 23:09 mjsML

Codecov Report

Merging #1533 (ef894eb) into main (e0a618a) will increase coverage by 0.01%. The diff coverage is n/a.

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #1533      +/-   ##
==========================================
+ Coverage   82.61%   82.63%   +0.01%     
==========================================
  Files          66       66              
  Lines        5505     5522      +17     
==========================================
+ Hits         4548     4563      +15     
- Misses        957      959       +2     
Impacted Files Coverage Δ
flax/training/checkpoints.py 97.87% <ø> (-0.02%) :arrow_down:
flax/linen/module.py 92.98% <0.00%> (-0.34%) :arrow_down:
flax/core/lift.py 95.76% <0.00%> (+0.09%) :arrow_up:
flax/linen/transforms.py 94.15% <0.00%> (+0.10%) :arrow_up:

Continue to review full report at Codecov.

Legend - Click here to learn more Δ = absolute <relative> (impact), ø = not affected, ? = missing data Powered by Codecov. Last update e0a618a...ef894eb. Read the comment docs.

codecov-commenter avatar Sep 11 '21 08:09 codecov-commenter

@avital @levskaya anything needed from my side to merge this?

mjsML avatar May 07 '22 10:05 mjsML

Assigning this to @levskaya who added the pull ready label so he probably knows most about this PR.

marcvanzee avatar May 07 '22 17:05 marcvanzee