burn icon indicating copy to clipboard operation
burn copied to clipboard

[WIP] Add support for importing pytorch `.pt` files using `burn-import`

Open antimora opened this issue 1 year ago • 3 comments

Submitting a WIP PR for initial review. I still need to do:

  1. Update Readme and Book with the new addition
  2. Add document strings to methods.
  3. Add permute feature - that lets swap dimensions during format translation.

Pull Request Template

Checklist

  • [ ] Confirmed that run-checks all script has been executed.
  • [ ] Made sure the book is up to date with changes in this PR.

Related Issues/PRs

Provide links to relevant issues and dependent PRs.

Changes

Summarize the problem being addressed and your solution.

Testing

Describe how these changes have been tested.

antimora avatar Dec 19 '23 20:12 antimora

@nathanielsimard @louisfd

Regarding the linear conversion: I'm unsure we should "force" our modules to store the weights the same way as PyTorch.

Yes, I agree it's not ideal, as my solution would not scale to other formats and would constrain our design choices going forward.

What I'm sure about is that it would be very cool to have a way to load records of different versions and apply migration. I'm going to propose something soon regarding that.

I agree, that would indeed be very cool. It would be even cooler if we can still keep build conversion.

I have two possible solutions (A and B):

"A" solution:

  1. A .pt file (PyTorch model file) is converted during the build and it does not have to be aware of the target model - only during loading. This design choice will allow for build time translation or using a CLI tool. This is currently accomplished in my current implementation. Pre-converting as opposed to converting on the fly a) allows for doing some work in advance, b) eliminates a .pt runtime dependency, c) allows loading a subset of weights (e.g., load only encoder and not decoder).
  2. During record loading, there is additional (minimal) translation because we can match the name & location of tensors. This is when we have the opportunity to know the target modules. This would mean we have to implement a custom load function. Something like the following:
let record = NamedMpkFileRecorder::<FullPrecisionSettings>::default()
    .load_pytorch_translated(file_path)
    .expect("Failed to decode state");

The caveat of the minimal translation is that there is still run-time conversion (e.g., transposing).

"B" solution:

A model's record is loaded from a .pt file but could be re-saved to Burn's file format. This would allow knowing what parts go to target modules (e.g., Linear). However, I am not sure if this would be achievable using build.rs because of the circular dependency. A CLI tool becomes out of the question because the target model info will be missing.

antimora avatar Dec 20 '23 02:12 antimora

@nathanielsimard and I had an offline conversation, and here is the revised summary:

  1. We agreed that the primary goal is to effectively integrate PyTorch weights into the Burn framework while maintaining independence from PyTorch's structural constraints. This involves developing mechanisms for importing, patching, and handling weights and module structures in a way that aligns with Burn's unique architecture.

  2. The generated "Record" struct will provide essential information about the target module, including its hierarchical position, name, and module type (e.g., Linear or BatchNorm).

  3. For PyTorch integration, we will use a PyTorchFileRecorder that functions as follows:

    let record: MyModelRecord = PyTorchFileRecorder::<FullPrecisionSettings>::default()
        .load(file_path)
        .expect("Failed to load .pt file");
    let model = MyModel::<Backend>::new_with(record);
    
  4. MyModelRecord is a record type that can be saved in various Burn formats using the existing recorders, such as NamedMpkFileRecorder, PrettyJsonFileRecorder, BinFileRecorder, etc.

This solution achieves decoupling and offers the following advantages:

  1. It enables dynamic or build-time conversion.
  2. It can be implemented by others in addition to PyTorch.
  3. It enhances accuracy as users are not required to tag module types manually.
  4. It remains flexible, allowing for changes in module names from the source.

antimora avatar Dec 20 '23 21:12 antimora

Just to update everyone. I have a solution that will accomplish what @nathanielsimard and I discussed. I researched serde extensively and it's possible to achieve only through a custom deserializer. No code change in the core or derived required.

antimora avatar Dec 26 '23 16:12 antimora

Codecov Report

Attention: 482 lines in your changes are missing coverage. Please review.

Comparison is base (3b7d9fe) 85.97% compared to head (008775d) 84.64%. Report is 2 commits behind head on main.

Files Patch % Lines
burn-core/src/record/serde/de.rs 39.45% 287 Missing :warning:
burn-core/src/record/serde/ser.rs 46.66% 96 Missing :warning:
burn-import/pytorch-tests/tests/boolean/mod.rs 0.00% 27 Missing :warning:
burn-core/src/record/serde/data.rs 79.83% 24 Missing :warning:
burn-import/src/pytorch/recorder.rs 61.53% 15 Missing :warning:
burn-core/src/record/serde/adapter.rs 70.00% 12 Missing :warning:
burn-core/src/record/serde/error.rs 0.00% 10 Missing :warning:
burn-import/src/pytorch/reader.rs 91.22% 5 Missing :warning:
burn-import/src/pytorch/error.rs 0.00% 4 Missing :warning:
burn-core/src/record/primitive.rs 96.66% 2 Missing :warning:
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1085      +/-   ##
==========================================
- Coverage   85.97%   84.64%   -1.33%     
==========================================
  Files         522      545      +23     
  Lines       59179    61166    +1987     
==========================================
+ Hits        50879    51774     +895     
- Misses       8300     9392    +1092     

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.

codecov[bot] avatar Jan 22 '24 23:01 codecov[bot]

Fine for me too! Thanks a lot for your hard work! 😃

Just a final question, the TODO in the first message will be covered in a next PR or is this an intermediate review?

The documentation and filing TODOs will be done next (new PR). The TODO comment that you had found regarding Conv group testing is removed because I am testing similar aspects with kernel_size > 1.

antimora avatar Jan 25 '24 16:01 antimora