tch-rs icon indicating copy to clipboard operation
tch-rs copied to clipboard

getting gradient for intermediate tensors

Open bushuyev opened this issue 1 year ago • 0 comments

Hi, is it possible to get this in rust? in python:

import torch

x = torch.tensor(2., requires_grad=True)
a = x ** 2
a.retain_grad()

b = x ** 3
b.retain_grad()

q = a + b


q.backward()

print(f"x.grad={x.grad} a.grad={a.grad} b.grad={b.grad}")

it prints

x.grad=16.0 a.grad=1.0 b.grad=1.0

tried this:

let x = Tensor::from(2.0).set_requires_grad(true);
let a = x.pow_tensor_scalar(2);
let b = x.pow_tensor_scalar(3);
// "retain_grad"  is excluded?
let q = &a + &b;

q.backward();

println!("x.grad={} a.grad={} b.grad={}", x.grad(), a.grad(), b.grad());

but it prints x.grad=[16.]Tensor[[], Double] a.grad=Tensor[Undefined] b.grad=Tensor[Undefined]

thanks

bushuyev avatar Jun 11 '24 16:06 bushuyev