torch_ecg
torch_ecg copied to clipboard
Safetensors
This PR typically changes the default behavior of the save method of the CkptMixin class. Now it uses the save_file method from safetensors instead of torch.save by default. See the comparison of the model saving mechanisms. The save method now has the following signature
def save(
self,
path: Union[str, bytes, os.PathLike],
train_config: CFG,
extra_items: Optional[dict] = None,
use_safetensors: bool = True,
safetensors_single_file: bool = True,
) -> None:
"""Save the model to disk.
.. note::
`safetensors` is used by default to save the model.
If one wants to save the models in `.pth` or `.pt` format,
he/she must explicitly set ``use_safetensors=False``.
Parameters
----------
path : `path-like`
Path to save the model.
train_config : CFG
Config for training the model,
used when one restores the model.
extra_items : dict, optional
Extra items to save along with the model.
The values should be serializable: can be saved as a json file,
or is a dict of torch tensors.
.. versionadded:: 0.0.32
use_safetensors : bool, default True
Whether to use `safetensors` to save the model.
This will be overridden by the suffix of `path`:
if it is `.safetensors`, then `use_safetensors` is set to True;
if it is `.pth` or `.pt`, then if `use_safetensors` is True,
the suffix is changed to `.safetensors`, otherwise it is unchanged.
.. versionadded:: 0.0.32
safetensors_single_file : bool, default True
Whether to save the metadata along with the state dict into one file.
.. versionadded:: 0.0.32
Returns
-------
None
"""
...
This change is backward compatible. One is also able to save the models in pth/pt format like previously, by explicitly setting use_safetensors=False. The load method is able to load pth/pt format models correctly.
:x: 1 Tests Failed:
| Tests completed | Failed | Passed | Skipped |
|---|---|---|---|
| 443 | 1 | 442 | 33 |
View the top 1 failed test(s) by shortest run time
test/test_utils/test_utils_nn.py::test_mixin_classesStack Traces | 0.005s run time
def test_mixin_classes(): model_1d = Model1D(12, CFG(out_channels=128)) assert isinstance(model_1d.module_size, int) assert model_1d.module_size > 0 assert isinstance(model_1d.sizeof, int) assert model_1d.sizeof > model_1d.module_size assert isinstance(model_1d.module_size_, str) assert isinstance(model_1d.sizeof_, str) assert isinstance(model_1d.dtype_, str) assert isinstance(model_1d.device_, str) # test pth/pt file save_path = Path(__file__).resolve().parents[1] / "tmp" / "test_mixin.pth" # convert save_path to bytes to cover bytes path handling code save_path = str(save_path).encode() model_1d.save(save_path, CFG(dict(n_leads=12)), extra_items={"xxx": {"ones": torch.ones((2, 2))}}, use_safetensors=False) > assert save_path.is_file() ^^^^^^^^^^^^^^^^^ E AttributeError: 'bytes' object has no attribute 'is_file' test/test_utils/test_utils_nn.py:524: AttributeError
To view more test analytics, go to the Test Analytics Dashboard 📋 Got 3 mins? Take this short survey to help us improve Test Analytics.