burn
burn copied to clipboard
[WIP] Add support for importing pytorch `.pt` files using `burn-import`
Submitting a WIP PR for initial review. I still need to do:
- Update Readme and Book with the new addition
- Add document strings to methods.
- Add permute feature - that lets swap dimensions during format translation.
Pull Request Template
Checklist
- [ ] Confirmed that
run-checks allscript has been executed. - [ ] Made sure the book is up to date with changes in this PR.
Related Issues/PRs
Provide links to relevant issues and dependent PRs.
Changes
Summarize the problem being addressed and your solution.
Testing
Describe how these changes have been tested.
@nathanielsimard @louisfd
Regarding the linear conversion: I'm unsure we should "force" our modules to store the weights the same way as PyTorch.
Yes, I agree it's not ideal, as my solution would not scale to other formats and would constrain our design choices going forward.
What I'm sure about is that it would be very cool to have a way to load records of different versions and apply migration. I'm going to propose something soon regarding that.
I agree, that would indeed be very cool. It would be even cooler if we can still keep build conversion.
I have two possible solutions (A and B):
"A" solution:
- A
.ptfile (PyTorch model file) is converted during the build and it does not have to be aware of the target model - only during loading. This design choice will allow for build time translation or using a CLI tool. This is currently accomplished in my current implementation. Pre-converting as opposed to converting on the fly a) allows for doing some work in advance, b) eliminates a.ptruntime dependency, c) allows loading a subset of weights (e.g., load only encoder and not decoder). - During record loading, there is additional (minimal) translation because we can match the name & location of tensors. This is when we have the opportunity to know the target modules. This would mean we have to implement a custom load function. Something like the following:
let record = NamedMpkFileRecorder::<FullPrecisionSettings>::default()
.load_pytorch_translated(file_path)
.expect("Failed to decode state");
The caveat of the minimal translation is that there is still run-time conversion (e.g., transposing).
"B" solution:
A model's record is loaded from a .pt file but could be re-saved to Burn's file format. This would allow knowing what parts go to target modules (e.g., Linear). However, I am not sure if this would be achievable using build.rs because of the circular dependency. A CLI tool becomes out of the question because the target model info will be missing.
@nathanielsimard and I had an offline conversation, and here is the revised summary:
-
We agreed that the primary goal is to effectively integrate PyTorch weights into the Burn framework while maintaining independence from PyTorch's structural constraints. This involves developing mechanisms for importing, patching, and handling weights and module structures in a way that aligns with Burn's unique architecture.
-
The generated "Record" struct will provide essential information about the target module, including its hierarchical position, name, and module type (e.g., Linear or BatchNorm).
-
For PyTorch integration, we will use a
PyTorchFileRecorderthat functions as follows:let record: MyModelRecord = PyTorchFileRecorder::<FullPrecisionSettings>::default() .load(file_path) .expect("Failed to load .pt file"); let model = MyModel::<Backend>::new_with(record); -
MyModelRecordis a record type that can be saved in various Burn formats using the existing recorders, such asNamedMpkFileRecorder,PrettyJsonFileRecorder,BinFileRecorder, etc.
This solution achieves decoupling and offers the following advantages:
- It enables dynamic or build-time conversion.
- It can be implemented by others in addition to PyTorch.
- It enhances accuracy as users are not required to tag module types manually.
- It remains flexible, allowing for changes in module names from the source.
Just to update everyone. I have a solution that will accomplish what @nathanielsimard and I discussed. I researched serde extensively and it's possible to achieve only through a custom deserializer. No code change in the core or derived required.
Codecov Report
Attention: 482 lines in your changes are missing coverage. Please review.
Comparison is base (
3b7d9fe) 85.97% compared to head (008775d) 84.64%. Report is 2 commits behind head on main.
Additional details and impacted files
@@ Coverage Diff @@
## main #1085 +/- ##
==========================================
- Coverage 85.97% 84.64% -1.33%
==========================================
Files 522 545 +23
Lines 59179 61166 +1987
==========================================
+ Hits 50879 51774 +895
- Misses 8300 9392 +1092
:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.
Fine for me too! Thanks a lot for your hard work! 😃
Just a final question, the
TODOin the first message will be covered in a next PR or is this an intermediate review?
The documentation and filing TODOs will be done next (new PR). The TODO comment that you had found regarding Conv group testing is removed because I am testing similar aspects with kernel_size > 1.