NeMo icon indicating copy to clipboard operation
NeMo copied to clipboard

Add safetensor option when saving and restoring models

Open stevehuang52 opened this issue 1 year ago • 19 comments

What does this PR do ?

Re-do changes in a previously closed PR https://github.com/NVIDIA/NeMo/pull/7812 to add safetensor options when saving and restoring models. as customers requested.

CC @galv

Collection: [core,asr,nlp,common]

Changelog

  • Re-do changes in https://github.com/NVIDIA/NeMo/pull/7812 for current main branch
  • add nemo.utils.secure.torch_load and nemo.utils.secure.torch_save
  • torch.load and torch.save removed from save_restore_connector and alternates of it
  • secure.torch_load and torch_save use safetensors when available, however if safe=False torch.load and torch.save are used to allow for backwards compatability
  • safe is a named parameter and passed through from save_to and restore_from. The default is safe=False to preserve backwards compatability
  • added unit tests to double check backwards compatability and secure functions work correctly.

Usage

Following usage is copied from https://github.com/NVIDIA/NeMo/pull/7812

When using nemo with untrusted .nemo files this will greatly reduce the potential for the user to be attacked.

# Existing usage should be preserved. although the result will include a .safetensor inside the .nemo now as well
model.save_to(save_path=model_save_path)
model.save_to(save_path=model_save_path, safe=False)

# the .nemo will only include the .safetensor.
model.save_to(save_path=model_save_path, safe=True)

# Existing usage should be preserved. although if the .nemo contains a .safetensor, it will be used instead of the pytorch version
model.restore_from(model_save_path)
model.restore_from(save_path=model_save_path, safe=False)

# The restore will fail if a .safetensor isn't available to prevent an untrusted .nemo file from resulting in an exploit
model.restore_from(save_path=model_save_path, safe=True)

PR Type:

  • [x] New Feature
  • [ ] Bugfix
  • [ ] Documentation

stevehuang52 avatar Dec 11 '24 19:12 stevehuang52

I just wanted to chime in that if there is any way to get this pull request merged into the main branch and a release branch that would be hugely helpful. I can hobble along using my clone of this branch, but that's obviously not ideal. Thank you guys for you great work and for this really helpful contribution in particular!

FredSRichardson avatar Dec 13 '24 22:12 FredSRichardson

This PR is stale because it has been open for 14 days with no activity. Remove stale label or comment or update or this will be closed in 7 days.

github-actions[bot] avatar Dec 28 '24 01:12 github-actions[bot]

This PR was closed because it has been inactive for 7 days since being marked as stale.

github-actions[bot] avatar Jan 04 '25 01:01 github-actions[bot]

Reopened. It got stale over the holidays.

galv avatar Jan 08 '25 20:01 galv

I just wanted to say I fully endorse this pull request. I tested it and it works well for my use case. It will greatly enable my work if some form of this safetensor PR can make it into the main branch. Thank you!!!

FredSRichardson avatar Jan 09 '25 22:01 FredSRichardson

This PR is stale because it has been open for 14 days with no activity. Remove stale label or comment or update or this will be closed in 7 days.

github-actions[bot] avatar Jan 25 '25 01:01 github-actions[bot]

Hello! Any chance of resurrecting this pull request? Thank you!!! -Fred

FredSRichardson avatar Jan 31 '25 16:01 FredSRichardson

beep boop 🤖: 🙏 The following files have warnings. In case you are familiar with these, please try helping us to improve the code base.


Your code was analyzed with PyLint. The following annotations have been identified:

************* Module nemo.collections.asr.models.clustering_diarizer
nemo/collections/asr/models/clustering_diarizer.py:238:0: C0301: Line too long (135/119) (line-too-long)
nemo/collections/asr/models/clustering_diarizer.py:242:0: C0301: Line too long (164/119) (line-too-long)
nemo/collections/asr/models/clustering_diarizer.py:328:0: C0301: Line too long (121/119) (line-too-long)
nemo/collections/asr/models/clustering_diarizer.py:549:4: C0116: Missing function or method docstring (missing-function-docstring)
************* Module nemo.collections.asr.modules.audio_preprocessing
nemo/collections/asr/modules/audio_preprocessing.py:98:0: C0301: Line too long (667/119) (line-too-long)
nemo/collections/asr/modules/audio_preprocessing.py:95:4: C0116: Missing function or method docstring (missing-function-docstring)
nemo/collections/asr/modules/audio_preprocessing.py:106:4: C0116: Missing function or method docstring (missing-function-docstring)
nemo/collections/asr/modules/audio_preprocessing.py:304:4: C0116: Missing function or method docstring (missing-function-docstring)
nemo/collections/asr/modules/audio_preprocessing.py:545:4: C0116: Missing function or method docstring (missing-function-docstring)
nemo/collections/asr/modules/audio_preprocessing.py:616:4: C0116: Missing function or method docstring (missing-function-docstring)
nemo/collections/asr/modules/audio_preprocessing.py:667:4: C0116: Missing function or method docstring (missing-function-docstring)
nemo/collections/asr/modules/audio_preprocessing.py:724:0: C0115: Missing class docstring (missing-class-docstring)
nemo/collections/asr/modules/audio_preprocessing.py:757:0: C0115: Missing class docstring (missing-class-docstring)
nemo/collections/asr/modules/audio_preprocessing.py:776:0: C0115: Missing class docstring (missing-class-docstring)
nemo/collections/asr/modules/audio_preprocessing.py:792:0: C0115: Missing class docstring (missing-class-docstring)
nemo/collections/asr/modules/audio_preprocessing.py:798:0: C0115: Missing class docstring (missing-class-docstring)
************* Module nemo.collections.asr.modules.conv_asr
nemo/collections/asr/modules/conv_asr.py:197:4: C0116: Missing function or method docstring (missing-function-docstring)
nemo/collections/asr/modules/conv_asr.py:239:4: C0116: Missing function or method docstring (missing-function-docstring)
nemo/collections/asr/modules/conv_asr.py:399:4: C0116: Missing function or method docstring (missing-function-docstring)
nemo/collections/asr/modules/conv_asr.py:459:4: C0116: Missing function or method docstring (missing-function-docstring)
nemo/collections/asr/modules/conv_asr.py:503:4: C0116: Missing function or method docstring (missing-function-docstring)
nemo/collections/asr/modules/conv_asr.py:507:4: C0116: Missing function or method docstring (missing-function-docstring)
nemo/collections/asr/modules/conv_asr.py:603:4: C0116: Missing function or method docstring (missing-function-docstring)
nemo/collections/asr/modules/conv_asr.py:677:4: C0116: Missing function or method docstring (missing-function-docstring)
nemo/collections/asr/modules/conv_asr.py:689:4: C0116: Missing function or method docstring (missing-function-docstring)
nemo/collections/asr/modules/conv_asr.py:758:4: C0116: Missing function or method docstring (missing-function-docstring)
nemo/collections/asr/modules/conv_asr.py:858:4: C0116: Missing function or method docstring (missing-function-docstring)
nemo/collections/asr/modules/conv_asr.py:881:4: C0116: Missing function or method docstring (missing-function-docstring)
nemo/collections/asr/modules/conv_asr.py:900:0: C0115: Missing class docstring (missing-class-docstring)
nemo/collections/asr/modules/conv_asr.py:945:0: C0115: Missing class docstring (missing-class-docstring)
nemo/collections/asr/modules/conv_asr.py:969:0: C0115: Missing class docstring (missing-class-docstring)
nemo/collections/asr/modules/conv_asr.py:983:0: C0115: Missing class docstring (missing-class-docstring)
nemo/collections/asr/modules/conv_asr.py:992:0: C0115: Missing class docstring (missing-class-docstring)
************* Module nemo.collections.nlp.models.nlp_model
nemo/collections/nlp/models/nlp_model.py:187:0: C0301: Line too long (121/119) (line-too-long)
nemo/collections/nlp/models/nlp_model.py:191:0: C0301: Line too long (120/119) (line-too-long)
nemo/collections/nlp/models/nlp_model.py:196:0: C0301: Line too long (135/119) (line-too-long)
nemo/collections/nlp/models/nlp_model.py:286:0: C0301: Line too long (132/119) (line-too-long)
nemo/collections/nlp/models/nlp_model.py:423:0: C0301: Line too long (133/119) (line-too-long)
nemo/collections/nlp/models/nlp_model.py:306:4: C0116: Missing function or method docstring (missing-function-docstring)
nemo/collections/nlp/models/nlp_model.py:454:4: C0116: Missing function or method docstring (missing-function-docstring)
************* Module nemo.collections.nlp.parts.nlp_overrides
nemo/collections/nlp/parts/nlp_overrides.py:219:0: C0301: Line too long (140/119) (line-too-long)
nemo/collections/nlp/parts/nlp_overrides.py:224:0: C0301: Line too long (149/119) (line-too-long)
nemo/collections/nlp/parts/nlp_overrides.py:240:0: C0301: Line too long (123/119) (line-too-long)
nemo/collections/nlp/parts/nlp_overrides.py:431:0: C0301: Line too long (136/119) (line-too-long)
nemo/collections/nlp/parts/nlp_overrides.py:517:0: C0301: Line too long (152/119) (line-too-long)
nemo/collections/nlp/parts/nlp_overrides.py:741:0: C0301: Line too long (140/119) (line-too-long)
nemo/collections/nlp/parts/nlp_overrides.py:746:0: C0301: Line too long (149/119) (line-too-long)
nemo/collections/nlp/parts/nlp_overrides.py:1012:0: C0301: Line too long (128/119) (line-too-long)
nemo/collections/nlp/parts/nlp_overrides.py:1016:0: C0301: Line too long (141/119) (line-too-long)
nemo/collections/nlp/parts/nlp_overrides.py:1020:0: C0301: Line too long (149/119) (line-too-long)
nemo/collections/nlp/parts/nlp_overrides.py:1077:0: C0301: Line too long (135/119) (line-too-long)
nemo/collections/nlp/parts/nlp_overrides.py:1204:0: C0301: Line too long (121/119) (line-too-long)
nemo/collections/nlp/parts/nlp_overrides.py:1789:0: C0301: Line too long (152/119) (line-too-long)
nemo/collections/nlp/parts/nlp_overrides.py:255:4: C0116: Missing function or method docstring (missing-function-docstring)
nemo/collections/nlp/parts/nlp_overrides.py:388:4: C0116: Missing function or method docstring (missing-function-docstring)
nemo/collections/nlp/parts/nlp_overrides.py:432:4: C0116: Missing function or method docstring (missing-function-docstring)
nemo/collections/nlp/parts/nlp_overrides.py:613:4: C0116: Missing function or method docstring (missing-function-docstring)
nemo/collections/nlp/parts/nlp_overrides.py:629:4: C0116: Missing function or method docstring (missing-function-docstring)
nemo/collections/nlp/parts/nlp_overrides.py:648:4: C0116: Missing function or method docstring (missing-function-docstring)
nemo/collections/nlp/parts/nlp_overrides.py:889:4: C0116: Missing function or method docstring (missing-function-docstring)
nemo/collections/nlp/parts/nlp_overrides.py:1008:0: C0115: Missing class docstring (missing-class-docstring)
nemo/collections/nlp/parts/nlp_overrides.py:1700:4: C0116: Missing function or method docstring (missing-function-docstring)
nemo/collections/nlp/parts/nlp_overrides.py:1788:4: C0116: Missing function or method docstring (missing-function-docstring)
nemo/collections/nlp/parts/nlp_overrides.py:18:0: W0611: Unused import re (unused-import)
nemo/collections/nlp/parts/nlp_overrides.py:107:4: W0611: Unused tensorstore imported from megatron.core.dist_checkpointing.strategies (unused-import)
************* Module nemo.core.classes.common
nemo/core/classes/common.py:693:0: C0301: Line too long (120/119) (line-too-long)
nemo/core/classes/common.py:819:0: C0301: Line too long (124/119) (line-too-long)
nemo/core/classes/common.py:926:0: C0301: Line too long (120/119) (line-too-long)
nemo/core/classes/common.py:471:0: C0115: Missing class docstring (missing-class-docstring)
nemo/core/classes/common.py:567:0: C0115: Missing class docstring (missing-class-docstring)
nemo/core/classes/common.py:647:0: C0115: Missing class docstring (missing-class-docstring)
nemo/core/classes/common.py:1026:4: C0116: Missing function or method docstring (missing-function-docstring)
nemo/core/classes/common.py:1141:4: C0116: Missing function or method docstring (missing-function-docstring)
nemo/core/classes/common.py:29:0: W0611: Unused Iterable imported from typing (unused-import)
************* Module nemo.core.classes.modelPT
nemo/core/classes/modelPT.py:82:0: C0301: Line too long (130/119) (line-too-long)
nemo/core/classes/modelPT.py:177:0: C0301: Line too long (121/119) (line-too-long)
nemo/core/classes/modelPT.py:184:0: C0301: Line too long (154/119) (line-too-long)
nemo/core/classes/modelPT.py:258:0: C0301: Line too long (131/119) (line-too-long)
nemo/core/classes/modelPT.py:262:0: C0301: Line too long (132/119) (line-too-long)
nemo/core/classes/modelPT.py:310:0: C0301: Line too long (160/119) (line-too-long)
nemo/core/classes/modelPT.py:390:0: C0301: Line too long (135/119) (line-too-long)
nemo/core/classes/modelPT.py:1250:0: C0301: Line too long (123/119) (line-too-long)
nemo/core/classes/modelPT.py:1487:0: C0301: Line too long (140/119) (line-too-long)
nemo/core/classes/modelPT.py:1698:0: C0301: Line too long (128/119) (line-too-long)
nemo/core/classes/modelPT.py:1717:0: C0301: Line too long (135/119) (line-too-long)
nemo/core/classes/modelPT.py:1727:0: C0301: Line too long (122/119) (line-too-long)
nemo/core/classes/modelPT.py:1853:0: C0301: Line too long (166/119) (line-too-long)
nemo/core/classes/modelPT.py:1922:0: C0301: Line too long (120/119) (line-too-long)
nemo/core/classes/modelPT.py:2076:0: C0301: Line too long (151/119) (line-too-long)
nemo/core/classes/modelPT.py:223:4: C0116: Missing function or method docstring (missing-function-docstring)
nemo/core/classes/modelPT.py:878:4: C0116: Missing function or method docstring (missing-function-docstring)
nemo/core/classes/modelPT.py:944:4: C0116: Missing function or method docstring (missing-function-docstring)
nemo/core/classes/modelPT.py:948:4: C0116: Missing function or method docstring (missing-function-docstring)
nemo/core/classes/modelPT.py:955:4: C0116: Missing function or method docstring (missing-function-docstring)
nemo/core/classes/modelPT.py:1223:4: C0116: Missing function or method docstring (missing-function-docstring)
nemo/core/classes/modelPT.py:1647:4: C0116: Missing function or method docstring (missing-function-docstring)
nemo/core/classes/modelPT.py:1765:4: C0116: Missing function or method docstring (missing-function-docstring)
************* Module nemo.core.connectors.save_restore_connector
nemo/core/connectors/save_restore_connector.py:51:0: C0301: Line too long (135/119) (line-too-long)
nemo/core/connectors/save_restore_connector.py:208:0: C0301: Line too long (141/119) (line-too-long)
nemo/core/connectors/save_restore_connector.py:315:0: C0301: Line too long (140/119) (line-too-long)
nemo/core/connectors/save_restore_connector.py:430:0: C0301: Line too long (141/119) (line-too-long)
nemo/core/connectors/save_restore_connector.py:435:0: C0301: Line too long (141/119) (line-too-long)
nemo/core/connectors/save_restore_connector.py:38:0: C0115: Missing class docstring (missing-class-docstring)
nemo/core/connectors/save_restore_connector.py:704:4: C0116: Missing function or method docstring (missing-function-docstring)
nemo/core/connectors/save_restore_connector.py:712:4: C0116: Missing function or method docstring (missing-function-docstring)
nemo/core/connectors/save_restore_connector.py:720:4: C0116: Missing function or method docstring (missing-function-docstring)
nemo/core/connectors/save_restore_connector.py:728:4: C0116: Missing function or method docstring (missing-function-docstring)

-----------------------------------
Your code has been rated at 9.75/10

Mitigation guide:

  • Add sensible and useful docstrings to functions and methods
  • For trivial methods like getter/setters, consider adding # pylint: disable=C0116 inside the function itself
  • To disable multiple functions/methods at once, put a # pylint: disable=C0116 before the first and a # pylint: enable=C0116 after the last.

By applying these rules, we reduce the occurance of this message in future.

Thank you for improving NeMo's documentation!

github-actions[bot] avatar Jan 31 '25 17:01 github-actions[bot]

This PR is stale because it has been open for 14 days with no activity. Remove stale label or comment or update or this will be closed in 7 days.

github-actions[bot] avatar Feb 15 '25 01:02 github-actions[bot]

This PR was closed because it has been inactive for 7 days since being marked as stale.

github-actions[bot] avatar Feb 23 '25 02:02 github-actions[bot]

This PR is stale because it has been open for 14 days with no activity. Remove stale label or comment or update or this will be closed in 7 days.

github-actions[bot] avatar Mar 11 '25 02:03 github-actions[bot]

This PR was closed because it has been inactive for 7 days since being marked as stale.

github-actions[bot] avatar Mar 18 '25 02:03 github-actions[bot]

Can this pull request be resurrected? It would be great if this could make it into the released code some time in the near future.

FredSRichardson avatar Mar 28 '25 16:03 FredSRichardson

It would be great if this could make it into the released code some time in the near future.

Thanks for your patience on this matter. This PR still needs a lot of modifications to put it into the future releases, while we currently don't have enough bandwidth to work on it yet. It's on our list for sure, and we will update once we have made progress.

stevehuang52 avatar Mar 29 '25 19:03 stevehuang52

This PR is stale because it has been open for 14 days with no activity. Remove stale label or comment or update or this will be closed in 7 days.

github-actions[bot] avatar Apr 13 '25 02:04 github-actions[bot]

This PR is stale because it has been open for 14 days with no activity. Remove stale label or comment or update or this will be closed in 7 days.

github-actions[bot] avatar Apr 30 '25 02:04 github-actions[bot]

This PR is stale because it has been open for 14 days with no activity. Remove stale label or comment or update or this will be closed in 7 days.

github-actions[bot] avatar May 20 '25 02:05 github-actions[bot]

This PR was closed because it has been inactive for 7 days since being marked as stale.

github-actions[bot] avatar May 28 '25 02:05 github-actions[bot]

This PR is stale because it has been open for 14 days with no activity. Remove stale label or comment or update or this will be closed in 7 days.

github-actions[bot] avatar Jun 12 '25 02:06 github-actions[bot]

This PR is stale because it has been open for 14 days with no activity. Remove stale label or comment or update or this will be closed in 7 days.

github-actions[bot] avatar Jun 27 '25 02:06 github-actions[bot]

This PR is stale because it has been open for 14 days with no activity. Remove stale label or comment or update or this will be closed in 7 days.

github-actions[bot] avatar Jul 12 '25 02:07 github-actions[bot]

Could you provide some feedback on when this pull request will get merged? We are at a point where we need to import NeMo with safetensor support into our client's environment. Thank you!

FredSRichardson avatar Jul 12 '25 15:07 FredSRichardson

This PR is stale because it has been open for 14 days with no activity. Remove stale label or comment or update or this will be closed in 7 days.

github-actions[bot] avatar Jul 28 '25 02:07 github-actions[bot]

@nithinraok Do you think someone else can take over and finish the PR? I think the most important part is to switch to only using safetensor if the user desires to enable it.

stevehuang52 avatar Jul 28 '25 17:07 stevehuang52

@nithinraok @stevehuang52 - I'm a little confused by something. The HuggingFace Canary 1b Flash model at https://huggingface.co/nvidia/canary-1b-flash appears to be a safetensor model. How is that supported by the NeMo toolkit without this patch?

FredSRichardson avatar Jul 30 '25 18:07 FredSRichardson

I think this needs to be clarified.

Current safetensors in https://huggingface.co/nvidia/canary-1b-flash are for hf transformers support through https://github.com/huggingface/transformers/pull/39062

@stevehuang52 what is the plan with this PR? How should we proceed with now addition of FC support to transformers?

nithinraok avatar Jul 30 '25 18:07 nithinraok

@nithinraok I think the current PR needs some improvement, but the most important one is to only store weights as safetensor if user specifies safe=true, while current way is to store both torch ckpt and safetensor therefore doubling the checkpoint size.

I think the minimum function is to allow users to convert existing nemo checkpoint to use safetensor, then load safetensor checkpoints when doing inference.

stevehuang52 avatar Jul 30 '25 19:07 stevehuang52

what is the use of it when performing inference with NeMo with added safetensors support with safe=true?

nithinraok avatar Jul 30 '25 19:07 nithinraok

As if we are inferring with NeMo toolkit why not just use torch ckpt as we are doing? If for security reasons then we should drop torch ckpt and just save safetensor in .nemo and not both

nithinraok avatar Jul 30 '25 19:07 nithinraok

As if we are inferring with NeMo toolkit why not just use torch ckpt as we are doing? If for security reasons then we should drop torch ckpt and just save safetensor in .nemo and not both

I'm afraid I'm not the right person to answer.

@FredSRichardson could you please explain your use case?

stevehuang52 avatar Jul 30 '25 19:07 stevehuang52