tch-rs
tch-rs copied to clipboard
Mutate the values of a Tensor in-place?
In Pytorch, I am able to set the elements of a tensor without creating a new tensor.
import torch
tens = torch.Tensor([1, 2, 3, 4, 5])
tens[0] = 5
Is there a way of doing something similar in tch (safely or not)?
I have this ugly code:
let mut mask = Tensor::ones(&[3,3], (tch::Kind::Float, device));
mask.index_put_(
&[Some(Tensor::of_slice(&[1])), Some(Tensor::of_slice(&[2]))],
&Tensor::zeros(&[1], (tch::Kind::Float, device)),
false,
);
Which works as expected:
Mask:
1 1 1
1 1 0
1 1 1
I hope there is a shorter way to do that.
Actually you can do
use tch::IndexOp;
mask.i((0,1)).fill_(-1.);
some additional detail on indexing https://docs.rs/tch/latest/tch/index/index.html