tch-rs icon indicating copy to clipboard operation
tch-rs copied to clipboard

Adding variables to a VarStore panics on vs.load().

Open GaryBoone opened this issue 1 year ago • 4 comments

I want to save some metadata in a model. Below is a complete example of adding tensors to a VarStore using var_copy(), then saving/restoring them.

use tch::{nn::VarStore, Device, Tensor};

fn main() {
    let vs = VarStore::new(Device::cuda_if_available());

    let epochs = Tensor::from(10).to_kind(tch::Kind::Int);
    println!("epochs: {}", epochs);
    let learning_rate = Tensor::from(0.001);
    println!("learning_rate: {}", learning_rate);

    // Assign the tensors to the VarStore.
    let e = vs.root().var_copy("epochs", &epochs);
    println!("e: {}", e);
    let lr = vs.root().var_copy("learning_rate", &learning_rate);
    println!("lr: {}", lr);

    // Save the VarStore to a file.
    vs.save("model.pt").unwrap();

    // Load the VarStore from the saved file.
    let mut loaded_var_store = VarStore::new(Device::cuda_if_available());
    loaded_var_store.load("model.pt").unwrap();

    // Access the loaded tensors.
    let loaded_epochs_var = loaded_var_store.root().get("epochs").unwrap();
    let loaded_learning_rate_var = loaded_var_store.root().get("learning_rate").unwrap();

    println!("Loaded epochs: {}", loaded_epochs_var);
    println!("Loaded learning_rate: {}", loaded_learning_rate_var);
}

And here's the output:

epochs: [10]
Tensor[[], Int]
learning_rate: [0.0010]
Tensor[[], Double]
e: [10.]
Tensor[[], Float]
lr: [0.0010]
Tensor[[], Float]
thread 'main' panicked at 'called `Result::unwrap()` on an `Err` value: Torch("isGenericDict() INTERNAL ASSERT FAILED at \"/Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/core/ivalue_inl.h\":2046, please report a bug to PyTorch. Expected GenericDict but got Object
...

So, notes/questions/issues:

  1. Note that while epochs is an Int tensor, the output for e shows that it becomes a Float when loaded into the VarStore. The learning_rate is similarly changed from a Double tensor to Float. That seems like a bug: The VarStore shouldn't change the type when copying in a tensor.
  2. These values are actually just scalars. Is there/could we have a way to store scalars in the VarStore without having to load them into tensors?
  3. It panics on the loaded_var_store.load("model.pt").unwrap(); line. It's an internal panic, not a panic on the unwrap().

GaryBoone avatar Apr 26 '23 22:04 GaryBoone

Re 1, all the float like variables in a var store use the same type (by default a float32), this simplifies casting to lower precision at the expense of having finer control on the actual type. You can find methods to convert your store to use double/... in the documentation. Re 3, I don't have a computer at hand to try it out but there is some magic being applied to files with the .pt extension (this magic assumes that they were written from the Python PyTorch API), could you try using another extension? The typical convention for this crate is .ot.

LaurentMazare avatar Apr 27 '23 06:04 LaurentMazare

Re 3, success. Changing the model extension from .pt to .ot solved the problem. Maybe the code could issue a warning for incorrect extensions?

GaryBoone avatar Apr 27 '23 23:04 GaryBoone

I wouldn't be super keen on having warnings being emitted as they are likely to clutter the process output (I think for now this crate doesn't emit any such warnings). Hopefully people running into this issue can google the error message and find this issue.

LaurentMazare avatar May 13 '23 08:05 LaurentMazare

I exactly had a same issue - used .pt extension and got this error. Maybe we can document this behaviour?

AcrylicShrimp avatar Nov 03 '23 09:11 AcrylicShrimp