torchsnapshot
torchsnapshot copied to clipboard
[Feature] Load partially instantiated state-dict
Summary:
Allows to load state_dict
s from disk when the saved copy contains more elements that the pre-populated copy.
Test plan: TODO
Fixes #{issue number} Closes #101
Codecov Report
Merging #103 (5501d83) into main (a11364c) will decrease coverage by
6.18%
. The diff coverage is80.85%
.
@@ Coverage Diff @@
## main #103 +/- ##
==========================================
- Coverage 90.40% 84.22% -6.19%
==========================================
Files 25 25
Lines 2346 2384 +38
==========================================
- Hits 2121 2008 -113
- Misses 225 376 +151
Impacted Files | Coverage Δ | |
---|---|---|
torchsnapshot/flatten.py | 86.73% <66.66%> (-3.63%) |
:arrow_down: |
torchsnapshot/io_preparer.py | 92.55% <72.72%> (-1.07%) |
:arrow_down: |
torchsnapshot/snapshot.py | 92.75% <95.23%> (-3.03%) |
:arrow_down: |
torchsnapshot/storage_plugins/gcs.py | 0.00% <0.00%> (-78.90%) |
:arrow_down: |
torchsnapshot/storage_plugins/s3.py | 25.00% <0.00%> (-65.00%) |
:arrow_down: |
torchsnapshot/storage_plugin.py | 43.75% <0.00%> (-21.88%) |
:arrow_down: |
torchsnapshot/memoryview_stream.py | 63.33% <0.00%> (-6.67%) |
:arrow_down: |
torchsnapshot/io_types.py | 79.68% <0.00%> (-3.13%) |
:arrow_down: |
torchsnapshot/scheduler.py | 91.11% <0.00%> (-2.23%) |
:arrow_down: |
... and 2 more |
:mega: We’re building smart automated test selection to slash your CI/CD build times. Learn more
Hi @vmoens, thank you so much for raising the issue and drafting a proposal!
Previously, in order to reduce the memory footprint on load, we would find the tensors in the target module's state dict, in-place restore them, and use them to build a state dict for .load_state_dict()
.
Your proposal made us realize that this is not the right approach. Instead, we should always rebuild the original state dict and reuse the tensor in the target's module's state dict in addition. This way we can achieve the same memory saving while making the behavior much more general.