burn
burn copied to clipboard
add support for safetensors in pytorch reader
Pull Request Template
Checklist
- [x] Confirmed that
run-checks allscript has been executed. - [x] Made sure the book is up to date with changes in this PR.
Related Issues/PRs
https://github.com/tracel-ai/burn/issues/626
Changes
Simple addition to the already implemented reader.rs, supporting safetsensors format using candle with CPU device import.
Testing
in the examples/pytorch-import directory, there is a mnist.safetensors file that is successfully imported.
Codecov Report
Attention: Patch coverage is 80.56680% with 48 lines in your changes missing coverage. Please review.
Project coverage is 81.36%. Comparing base (
1f92ec1) to head (a22aa59). Report is 1 commits behind head on main.
Additional details and impacted files
@@ Coverage Diff @@
## main #2721 +/- ##
==========================================
- Coverage 81.37% 81.36% -0.02%
==========================================
Files 818 821 +3
Lines 117643 117791 +148
==========================================
+ Hits 95736 95835 +99
- Misses 21907 21956 +49
:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.
:rocket: New features to boost your workflow:
- :snowflake: Test Analytics: Detect flaky tests, report on failures, and find test suite problems.
IMO maybe we can have something like pub mod safetensors; under a new feature gate in crate burn-import so users could build exactly what they need, since safetensors does not seem to be a format strongly related to PyTorch.
IMO maybe we can have something like
pub mod safetensors;under a new feature gate in crateburn-importso users could build exactly what they need, since safetensors does not seem to be a format strongly related to PyTorch.
I think this is a good point, and it also builds the scaffolding for potentially rewriting it to remove the Candle dependency.
IMO maybe we can have something like
pub mod safetensors;under a new feature gate in crateburn-importso users could build exactly what they need, since safetensors does not seem to be a format strongly related to PyTorch.
I agree that the format is not strongly related to pytorch, but I think most models available in safetensor format are pytorch models 😅
Unless you mean supporting the safetensor format as another recorder to load and save modules. In this case, not sure that this is a meaningful addition.
I suggest creating a dedicated SafeTensorFileRecorder to handle SafeTensor files independently from PyTorch's .pt files. This approach ensures a clear separation between different file formats and supports framework-specific transformations during the import process.
Additionally, I propose providing configurable options (via LoadArgs, similar to PyTorch's recorder) within the recorder to specify the appropriate transformation adapter. By default, this could use the PyTorchAdapter but allow customization for other frameworks, such as TensorFlow. This design enhances flexibility and decouples the handling of different tensor file formats. Moreover, it might be beneficial to support passing a user-defined implementation of BurnModuleAdapter when needed.
Lastly, we should replicate PyTorch import tests to ensure comprehensive coverage. Over time, we can expand these tests to include SafeTensor files exported from TensorFlow.
One more thing: we should introduce a new feature flag, safetensors.
Hi all, I went through and essentially copied over the implementation for pytorch recorder, and created the safetensors recorder. It's a lot of new files that are essentially copied code but with little adjustments. I think this gives a good base for the future when we'd like to remove the candle dependency, and to add further support for safetsensors in the future.
I would prioritize Rust code de-duplication first. We can leave the example and test duplicated for now (because it will take time).
Agreed for the bold part.
This PR has been marked as stale because it has not been updated for over a month
Just a feature request, can ignore this if you think it is unnessesary😂: since safetensors now supports no_std, could we also have a SafeTensorBytesRecorder?
Just a feature request, can ignore this if you think it is unnessesary😂: since safetensors now supports
no_std, could we also have aSafeTensorBytesRecorder?
A bit hesitant on adding new record formats tbh, not sure of the value. To import safetensor format (like the initial target of this PR), sure. But maybe you could provide a bit more info to justify? 🙂
Just a feature request, can ignore this if you think it is unnessesary😂: since safetensors now supports
no_std, could we also have aSafeTensorBytesRecorder?A bit hesitant on adding new record formats tbh, not sure of the value. To import safetensor format (like the initial target of this PR), sure. But maybe you could provide a bit more info to justify? 🙂
We don't need a new type. We can provide with an arg option.
This PR has been marked as stale because it has not been updated for over a month
@laggui This has been a lot of work. What can be done to complete the review?
@laggui This has been a lot of work. What can be done to complete the review?
Still the PR needs to be refactored.
Do you use have a use case for this feature? We could complete this ourselves.
Didn't want to rush anything, just saw the 2 year old feature request and your bot adding the stale label. I don't fully understand the described limitations, but importing/integrating HuggingFace or Kaggle sounds like a very big feature to boost burn adoption 😅
Didn't want to rush anything, just saw the 2 year old feature request and your bot adding the stale label. I don't fully understand the described limitations, but importing/integrating HuggingFace or Kaggle sounds like a very big feature to boost burn adoption 😅
It's all about priorities. It's been requested but there wasn't real demand in use cases =)
Not difficult to implement because we can rely on PytorchRecorder's foundation.
Hi all apologies for the wait on this one, got tied up with other projects and work. I can finalize this soon as it's relatively simple.
pushed a new commit, I fixed the adapter as per requested, I wanted to remove the duplicate code, but frankly feel a little icky depending so largely on the pytorch feature implementation. Do we want to abstract away some of the logic in the pytorch feature into more generic part of burn import?
pushed a new commit, I fixed the adapter as per requested, I wanted to remove the duplicate code, but frankly feel a little icky depending so largely on the pytorch feature implementation. Do we want to abstract away some of the logic in the pytorch feature into more generic part of burn import?
If you're okay, I can make these changes directly in this PR. It might be faster turn around considering how many files in this PR. It's mainly about organizing files. I think you got the functionality working.
pushed a new commit, I fixed the adapter as per requested, I wanted to remove the duplicate code, but frankly feel a little icky depending so largely on the pytorch feature implementation. Do we want to abstract away some of the logic in the pytorch feature into more generic part of burn import?
If you're okay, I can make these changes directly in this PR. It might be faster turn around considering how many files in this PR. It's mainly about organizing files. I think you got the functionality working.
Yes! 100% okay with this. I know safetensors is a highly sought out feature.
Refactored to remove duplication, updated import example, and made the adapter type as an option.
Still remaining: new section in the book and updating root's README.md to mention safetensors.
All changes are in place. Waiting for @laggui.
If you have a chance, @wandbrandon, please review as well.
All changes are in place. Waiting for @laggui.
If you have a chance, @wandbrandon, please review as well.
Looks great, appreciate the fast turnaround!
@laggui Done updating. Hopefully no issues.
Thanks for reviewing the long PR. Takes up lots of context switching.
Thanks for reviewing the long PR. Takes up lots of context switching.
Of course! Longer PRs take more time to review carefully so time-to-merge is usually longer as well 😅
There has been an uptick in activity since the 0.17 release too so I have to balance it out. Sorry if your other PRs are not entirely reviewed yet, should come soon!