burn icon indicating copy to clipboard operation
burn copied to clipboard

Fix record nested value de/serialization

Open laggui opened this issue 1 year ago • 2 comments

While working on the Llama-3 implementation I stumbled upon a memory issue when importing pytorch weights with PyTorchFileRecorder.

When I profiled the memory usage for ResNet-152 (checkpoint is 252MB on disk), I saw a huge peak memory usage for what is supposed to be a relatively small model. Up to ~5GB as pointed out by the heaptrack trace below.

Before image

After image

Checklist

  • [x] Confirmed that run-checks all script has been executed.

Changes

Added U16s and F32s variants for NestedValue so weights can be parsed as a vector of primitive types instead of Vec<NestedValue>. For example, a vec of f32s is now represented as Vec[v, v, v, ...] instead of Vec[NestedValue::F32(v), NestedValue::F32(v), ...]. The NestedValue enum has a size of 56 bytes so it can grow very rapidly (just imagine for a very large number of parameters like in LLama 8B 🤯 ).

  • Handle different vec types in Serializer based on the input element type
  • Make VecSeqAccess's iter generic and add concrete implementations for vec of NestedValue, u16 and f32

Testing

All unit tests pass, including half precision record tests in burn-import/pytorch-tests.

laggui avatar May 09 '24 19:05 laggui

Was running the checks locally and test record::serde::ser::tests::test_param_serde just failed.. will investigate & fix.

/edit:

Previously the fmt::Debug captured by the test had the vector as Vec([F32(1.0), F32(1.0), F32(1.0), ...] len=4) but now the values are no longer encapsulated as a NestedValue::F32 so it is just Vec([1.0, 1.0, 1.0, ...] len=4) instead.

That means the characters F32() x 3 are excluded (15 characters total), which comes down to a new length 149 - 15 = 134 ✅

laggui avatar May 09 '24 19:05 laggui

Codecov Report

Attention: Patch coverage is 87.61905% with 13 lines in your changes are missing coverage. Please review.

Project coverage is 86.61%. Comparing base (5bbc5ea) to head (40b3c71). Report is 1 commits behind head on main.

Files Patch % Lines
crates/burn-core/src/record/serde/de.rs 88.73% 8 Missing :warning:
crates/burn-core/src/record/serde/data.rs 83.33% 3 Missing :warning:
crates/burn-core/src/record/serde/ser.rs 87.50% 2 Missing :warning:
Additional details and impacted files
@@           Coverage Diff           @@
##             main    #1751   +/-   ##
=======================================
  Coverage   86.61%   86.61%           
=======================================
  Files         700      700           
  Lines       83427    83509   +82     
=======================================
+ Hits        72257    72329   +72     
- Misses      11170    11180   +10     

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.

codecov[bot] avatar May 09 '24 19:05 codecov[bot]

Before merging, please file uncompleted refactor or fixes.

antimora avatar May 16 '24 03:05 antimora

To close this PR I'll handle the other tensor element types, but I've opened a new issue regarding the other improvements suggested in previous discussions.

laggui avatar May 16 '24 16:05 laggui

If we do #1773 alone, then we can deprecate serialization. So you don't need to do other accumulating item types.

antimora avatar May 16 '24 20:05 antimora

If we do #1773 alone, then we can deprecate serialization. So you don't need to do other accumulating item types.

Sure, we can limit the current PR to Vec<u16> (for f16 and bf16) and Vec<f32> (for pretty much all other parameter weights).

The linked issue should capture all element types when we tackle it.

laggui avatar May 21 '24 12:05 laggui