flaxmodels
flaxmodels copied to clipboard
Use `safetensors` to store tensors instead of `pickle`
Hi @matthias-wright, I've been playing around for a couple days with your project and it's so cool, thanks for building some pure flax
models here 👍🏻
Don't know if you're aware, but @huggingface developed a new format for storing tensors named safetensors
as most of the serialized models from PyTorch use pickle
to store the tensors, which seems to be not super efficient plus it has some known security issues. So I want to know whether you're considering to port the current tensors to use safetensors
instead.
I've recently built safejax
so as to easily do that, which means that the storage is optimal and more safe! If this is something you could consider to improve flaxmodels
please let me know and I can try to help if applicable!
P.S. Did you consider publishing the Python package to PyPI tracking it through GitHub Release so that it attracts more users due to the ease of installation through pip
from PyPI instead of from source as in the README.md
?
Hi @alvarobartt, thank you for reaching out! I'm glad the project was of use to you :)
safetensors
sound very interesting. I love that it's written in Rust.
I'm interested in incorporating this into flaxmodels
as long as it doesn't complicate things.
I'm very busy right now but if you want to make a PR, I'm willing to support you.
flaxmodels
is published on PyPi, see here: https://pypi.org/project/flaxmodels/. Although currently it is not done through Github Actions.
Cool @matthias-wright! I'm also busy these days (also some bank holidays in Spain next week) but I'll try to submit a PR in case you're interested, your project is indeed the most similar to pytorch-image-models
aka. timm
I could find for Flax 👍🏻 I'm also working on ResNet v2
in Flax, as well as porting the weights from PyTorch.
I'll keep you updated anyway!
Also regarding how "complicated" that could be in terms of code, I'd say that it's just unpickling the weights, converting to safetensors
, uploading file to S3/Dropbox or whatever storage you prefer (could also be GitHub Releases but we have a quota-restriction there), and then just replace the call/s to utils.download
as just use safejax.deserialize
. Anyway, I'll submit a PR and we can continue the discussion there over the written code which may be easier :hugs: