flaxmodels icon indicating copy to clipboard operation
flaxmodels copied to clipboard

Use `safetensors` to store tensors instead of `pickle`

Open alvarobartt opened this issue 2 years ago • 3 comments

Hi @matthias-wright, I've been playing around for a couple days with your project and it's so cool, thanks for building some pure flax models here 👍🏻

Don't know if you're aware, but @huggingface developed a new format for storing tensors named safetensors as most of the serialized models from PyTorch use pickle to store the tensors, which seems to be not super efficient plus it has some known security issues. So I want to know whether you're considering to port the current tensors to use safetensors instead.

I've recently built safejax so as to easily do that, which means that the storage is optimal and more safe! If this is something you could consider to improve flaxmodels please let me know and I can try to help if applicable!

P.S. Did you consider publishing the Python package to PyPI tracking it through GitHub Release so that it attracts more users due to the ease of installation through pip from PyPI instead of from source as in the README.md?

alvarobartt avatar Dec 24 '22 13:12 alvarobartt

Hi @alvarobartt, thank you for reaching out! I'm glad the project was of use to you :) safetensors sound very interesting. I love that it's written in Rust. I'm interested in incorporating this into flaxmodels as long as it doesn't complicate things. I'm very busy right now but if you want to make a PR, I'm willing to support you.

flaxmodels is published on PyPi, see here: https://pypi.org/project/flaxmodels/. Although currently it is not done through Github Actions.

matthias-wright avatar Dec 27 '22 16:12 matthias-wright

Cool @matthias-wright! I'm also busy these days (also some bank holidays in Spain next week) but I'll try to submit a PR in case you're interested, your project is indeed the most similar to pytorch-image-models aka. timm I could find for Flax 👍🏻 I'm also working on ResNet v2 in Flax, as well as porting the weights from PyTorch.

I'll keep you updated anyway!

alvarobartt avatar Dec 28 '22 08:12 alvarobartt

Also regarding how "complicated" that could be in terms of code, I'd say that it's just unpickling the weights, converting to safetensors, uploading file to S3/Dropbox or whatever storage you prefer (could also be GitHub Releases but we have a quota-restriction there), and then just replace the call/s to utils.download as just use safejax.deserialize. Anyway, I'll submit a PR and we can continue the discussion there over the written code which may be easier :hugs:

alvarobartt avatar Dec 28 '22 08:12 alvarobartt