torch
torch copied to clipboard
Cannot save any torch tensor or `nn_module`
Ever since I updated torch
to the latest version, I have not been able to save any torch
object using the default serialization method. I have instead had to fall back on options(torch.serialization_version = 2)
to be able to save any tensors or model objects. Here is a reprex:
library(torch)
x <- torch_rand(100, 1000)
torch_save(x, "test.pth")
#> Error in `FUN()`:
#> ! `metadata` must be a named list of scalar characters.
#> Backtrace:
#> ▆
#> 1. ├─torch::torch_save(x, "test.pth")
#> 2. └─torch:::torch_save.torch_tensor(x, "test.pth")
#> 3. └─torch:::torch_save_to_file(...)
#> 4. └─safetensors::safe_save_file(state_dict, path = con, metadata = metadata)
#> 5. └─safetensors:::write_safe(tensors, metadata, con)
#> 6. └─safetensors:::make_meta(tensors, metadata)
#> 7. └─safetensors:::validate_metadata(metadata)
#> 8. └─base::lapply(...)
#> 9. └─safetensors (local) FUN(X[[i]], ...)
#> 10. └─cli::cli_abort("{.arg metadata} must be a named list of scalar characters.")
#> 11. └─rlang::abort(...)
## trying to use safetensors directly gives a different error:
safetensors::safe_save_file(x, "test.pth")
#> Error in for (tensor in tensors) {: invalid for() loop sequence
## This works (but seems very slow for some reason)
options(torch.serialization_version = 2)
torch_save(x, "test.pth")
Created on 2023-10-06 with reprex v2.0.2
Any ideas what is going wrong here?
Hi @rdinnager,
Thanks for reporting.
That's weird, I'd assume this is a mismatch between the torch version and the safetensors versions, as at some point I think I saw some similar issue.
Can you try updating your safetensors
package. I just tried lates commit from torch + (CRAN or latest commit) safetensors and they seem to work correctly.
I am using the latest version of both torch
and safetensors
from CRAN:
packageVersion("safetensors")
#> [1] '0.1.2'
packageVersion("torch")
#> [1] '0.11.0'
I'm thinking now that I actually need the development version of torch
to work properly, after looking through the recent commit history. I will try that!
ohhh, I think that might be the case. You are right, you might need to downgrade safetensors or use the dev version of torch. I'm going to make a new torch release soon.
Yes, I decided to wait until the new release and just use options(torch.serialization_version = 2)
in the mean time. I find the new precompiled cuda binary method of installation so convenient I just don't want to bother trying to install from source at the moment, which would require me installing the compatible CUDA locally (and I'm not up for that right now ;) ).