ATen
ATen copied to clipboard
Record in Declarations.yaml which return tensors are differentiable
max_pool2d
returns an output
tensor and an indices
tensor; only the output
tensor is differentiable. prelu_backward
, on the other hand, returns a grad_input
and grad_weight
; both are differentiable.
We need to know which outputs are differentiable, because those are the only ones we should set_flags
on. However, at the moment, Declarations.yaml
does not record this information. In fact, I am not sure we have this information at all, except perhaps implicitly, in the sense that no non-NN ATen method returns two differentiable outputs (this is hardcoded into gen_variable_type.py
). I'm not exactly sure how we should add the information to the source cwrap/yaml files.
Edit. It might be OK to just assume that Tensor
is differentiable but IndexTensor
is not.