torch icon indicating copy to clipboard operation
torch copied to clipboard

Cannot save any torch tensor or `nn_module`

Open rdinnager opened this issue 1 year ago • 4 comments

Ever since I updated torch to the latest version, I have not been able to save any torch object using the default serialization method. I have instead had to fall back on options(torch.serialization_version = 2) to be able to save any tensors or model objects. Here is a reprex:

library(torch)

x <- torch_rand(100, 1000)
torch_save(x, "test.pth")
#> Error in `FUN()`:
#> ! `metadata` must be a named list of scalar characters.
#> Backtrace:
#>      ▆
#>   1. ├─torch::torch_save(x, "test.pth")
#>   2. └─torch:::torch_save.torch_tensor(x, "test.pth")
#>   3.   └─torch:::torch_save_to_file(...)
#>   4.     └─safetensors::safe_save_file(state_dict, path = con, metadata = metadata)
#>   5.       └─safetensors:::write_safe(tensors, metadata, con)
#>   6.         └─safetensors:::make_meta(tensors, metadata)
#>   7.           └─safetensors:::validate_metadata(metadata)
#>   8.             └─base::lapply(...)
#>   9.               └─safetensors (local) FUN(X[[i]], ...)
#>  10.                 └─cli::cli_abort("{.arg metadata} must be a named list of scalar characters.")
#>  11.                   └─rlang::abort(...)

## trying to use safetensors directly gives a different error:
safetensors::safe_save_file(x, "test.pth")
#> Error in for (tensor in tensors) {: invalid for() loop sequence

## This works (but seems very slow for some reason)
options(torch.serialization_version = 2)
torch_save(x, "test.pth")

Created on 2023-10-06 with reprex v2.0.2

Any ideas what is going wrong here?

rdinnager avatar Oct 06 '23 15:10 rdinnager

Hi @rdinnager,

Thanks for reporting.

That's weird, I'd assume this is a mismatch between the torch version and the safetensors versions, as at some point I think I saw some similar issue.

Can you try updating your safetensors package. I just tried lates commit from torch + (CRAN or latest commit) safetensors and they seem to work correctly.

dfalbel avatar Oct 06 '23 17:10 dfalbel

I am using the latest version of both torch and safetensors from CRAN:

packageVersion("safetensors")
#> [1] '0.1.2'
packageVersion("torch")
#> [1] '0.11.0'

I'm thinking now that I actually need the development version of torch to work properly, after looking through the recent commit history. I will try that!

rdinnager avatar Oct 06 '23 18:10 rdinnager

ohhh, I think that might be the case. You are right, you might need to downgrade safetensors or use the dev version of torch. I'm going to make a new torch release soon.

dfalbel avatar Oct 10 '23 14:10 dfalbel

Yes, I decided to wait until the new release and just use options(torch.serialization_version = 2) in the mean time. I find the new precompiled cuda binary method of installation so convenient I just don't want to bother trying to install from source at the moment, which would require me installing the compatible CUDA locally (and I'm not up for that right now ;) ).

rdinnager avatar Dec 12 '23 21:12 rdinnager