burn
burn copied to clipboard
Fix record nested value de/serialization
While working on the Llama-3 implementation I stumbled upon a memory issue when importing pytorch weights with PyTorchFileRecorder.
When I profiled the memory usage for ResNet-152 (checkpoint is 252MB on disk), I saw a huge peak memory usage for what is supposed to be a relatively small model. Up to ~5GB as pointed out by the heaptrack trace below.
Before
After
Checklist
- [x] Confirmed that
run-checks allscript has been executed.
Changes
Added U16s and F32s variants for NestedValue so weights can be parsed as a vector of primitive types instead of Vec<NestedValue>. For example, a vec of f32s is now represented as Vec[v, v, v, ...] instead of Vec[NestedValue::F32(v), NestedValue::F32(v), ...]. The NestedValue enum has a size of 56 bytes so it can grow very rapidly (just imagine for a very large number of parameters like in LLama 8B 🤯 ).
- Handle different vec types in
Serializerbased on the input element type - Make
VecSeqAccess'sitergeneric and add concrete implementations for vec ofNestedValue,u16andf32
Testing
All unit tests pass, including half precision record tests in burn-import/pytorch-tests.
Was running the checks locally and test record::serde::ser::tests::test_param_serde just failed.. will investigate & fix.
/edit:
Previously the fmt::Debug captured by the test had the vector as Vec([F32(1.0), F32(1.0), F32(1.0), ...] len=4) but now the values are no longer encapsulated as a NestedValue::F32 so it is just Vec([1.0, 1.0, 1.0, ...] len=4) instead.
That means the characters F32() x 3 are excluded (15 characters total), which comes down to a new length 149 - 15 = 134 ✅
Codecov Report
Attention: Patch coverage is 87.61905% with 13 lines in your changes are missing coverage. Please review.
Project coverage is 86.61%. Comparing base (
5bbc5ea) to head (40b3c71). Report is 1 commits behind head on main.
Additional details and impacted files
@@ Coverage Diff @@
## main #1751 +/- ##
=======================================
Coverage 86.61% 86.61%
=======================================
Files 700 700
Lines 83427 83509 +82
=======================================
+ Hits 72257 72329 +72
- Misses 11170 11180 +10
:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.
Before merging, please file uncompleted refactor or fixes.
To close this PR I'll handle the other tensor element types, but I've opened a new issue regarding the other improvements suggested in previous discussions.
If we do #1773 alone, then we can deprecate serialization. So you don't need to do other accumulating item types.
If we do #1773 alone, then we can deprecate serialization. So you don't need to do other accumulating item types.
Sure, we can limit the current PR to Vec<u16> (for f16 and bf16) and Vec<f32> (for pretty much all other parameter weights).
The linked issue should capture all element types when we tackle it.