SafeTensors serialization for PyTorch models
Adds SafeTensors-based serialization for PyTorch models (addresses #2532) and implements metadata-driven loading to integrate cleanly with the materializer workflow (per @bcdurak's feedback).
Changes
- β
Add
safetensorsoptional extra inpyproject.toml - β
Save
state_dictto.safetensorswhen available; fallback to.ptwith warning - β
Write minimal
metadata.json(class_path,serialization_format) - β
Use
TemporaryDirectory+copy_dir()for remote stores - β
load()always returnsnn.Module - β
Backward compat: supports
weights.pt,checkpoint.pt, and legacyentire_model.pt
New artifact layout
artifact_uri/
ββ weights.safetensors # or weights.pt on fallback
ββ metadata.json # class_path + format
Metadata
{
"class_path": "my_package.models.MyModel",
"serialization_format": "safetensors",
"init_args": [],
"init_kwargs": {},
"factory_path": null
}
Why SafeTensors?
- Security: Avoids pickle-based code execution risks
- Performance: Faster, memory-mapped weight loads
- Compatibility: Works with S3/GCS/Azure via artifact stores
Tests
Local run:
pytest tests/unit/integrations/pytorch/materializers/test_pytorch_module_materializer.py -v
# 4 passed in 1.88s
Coverage:
- Round-trip with safetensors
- Pickle fallback path
- Metadata-driven load
- Legacy formats (
weights.pt,checkpoint.pt,entire_model.pt) - Clear error when safetensors extra is missing at load
Known limitations (Phase 1)
-
Zero-argument
__init__()requirement: Models needing config should use a factory method (planned for Phase 2) -
Legacy artifacts without metadata (
weights.pt/checkpoint.pt) require:
model = materializer.load(data_type=MyModel)
- Legacy
entire_model.ptis loaded and returned as aModuledirectly (nodata_typeneeded)
Documentation
Happy to add a short guide covering why/how/limits/troubleshooting. Which file should I update?
docs/book/component-guide/materializers/pytorch.md(materializer behavior)?docs/book/integration-guide/pytorch.md(integration landing)?
Or would you prefer a new section?
Future work (separate PRs)
- Phase 2: Support
init_args/init_kwargs/ factory functions - Phase 3: PyTorch Lightning materializer
- Phase 4: HuggingFace Transformers support
Checklist
- [x] Tests pass locally
- [x] Code formatted (
ruff check --fix+ruff format) - [x] Also ran project scripts: bash scripts/format.sh and bash scripts/lint.sh
- [x] Type hints added (mypy clean)
- [x] Backward compatibility maintained
- [x] Rebased on
develop - [ ] Documentation updated (pending guidance on location)
- [x] CLA signed
Check out this pull request onΒ ![]()
See visual diffs & provide feedback on Jupyter Notebooks.
Powered by ReviewNB
Thank you for your submission! We really appreciate it. Like many open source projects, we ask that you sign our Contributor License Agreement before we can accept your contribution.
yusuke kunimitsu seems not to be a GitHub user. You need a GitHub account to be able to sign the CLA. If you have already a GitHub account, please add the email address used for this commit to your account.
You have signed the CLA already but the status is still pending? Let us recheck it.
Hey @kunigori, thanks for the PR! Can you please base your changes on the develop branch and then also change the target of this PR.
β οΈ This PR has been inactive for 2 weeks and has been marked as stale. Timeline:
- Week 2 (now): First reminder - PR marked as stale
- Week 4: PR will be automatically closed if no activity Please update this PR or leave a comment to keep it active. Any activity will reset the timer and remove the stale label.
Thanks for the update. I'm still actively working on this PR and will push revisions soon.
β οΈ This PR has been inactive for 2 weeks and has been marked as stale. Timeline:
- Week 2 (now): First reminder - PR marked as stale
- Week 4: PR will be automatically closed if no activity Please update this PR or leave a comment to keep it active. Any activity will reset the timer and remove the stale label.