tch-rs
tch-rs copied to clipboard
getting gradient for intermediate tensors
Hi, is it possible to get this in rust? in python:
import torch
x = torch.tensor(2., requires_grad=True)
a = x ** 2
a.retain_grad()
b = x ** 3
b.retain_grad()
q = a + b
q.backward()
print(f"x.grad={x.grad} a.grad={a.grad} b.grad={b.grad}")
it prints
x.grad=16.0 a.grad=1.0 b.grad=1.0
tried this:
let x = Tensor::from(2.0).set_requires_grad(true);
let a = x.pow_tensor_scalar(2);
let b = x.pow_tensor_scalar(3);
// "retain_grad" is excluded?
let q = &a + &b;
q.backward();
println!("x.grad={} a.grad={} b.grad={}", x.grad(), a.grad(), b.grad());
but it prints x.grad=[16.]Tensor[[], Double] a.grad=Tensor[Undefined] b.grad=Tensor[Undefined]
thanks