tch-rs icon indicating copy to clipboard operation
tch-rs copied to clipboard

Mutate the values of a Tensor in-place?

Open iahuang opened this issue 2 years ago • 3 comments

In Pytorch, I am able to set the elements of a tensor without creating a new tensor.

import torch

tens = torch.Tensor([1, 2, 3, 4, 5])
tens[0] = 5

Is there a way of doing something similar in tch (safely or not)?

iahuang avatar Apr 05 '23 21:04 iahuang

I have this ugly code:

    let mut mask = Tensor::ones(&[3,3], (tch::Kind::Float, device));
    mask.index_put_(
        &[Some(Tensor::of_slice(&[1])), Some(Tensor::of_slice(&[2]))],
        &Tensor::zeros(&[1], (tch::Kind::Float, device)),
        false,
    );

Which works as expected:

Mask:
 1  1  1
 1  1  0
 1  1  1

I hope there is a shorter way to do that.

Butanium avatar May 04 '23 23:05 Butanium

Actually you can do

use tch::IndexOp;
mask.i((0,1)).fill_(-1.);

Butanium avatar May 04 '23 23:05 Butanium

some additional detail on indexing https://docs.rs/tch/latest/tch/index/index.html

Nic-Gould avatar Jun 16 '23 01:06 Nic-Gould