tensordict
tensordict copied to clipboard
[BUG] Tensorclass.key() doesn't list non-tensor data.
Describe the bug
Tensorclasses (in contrast to tensordicts) don't list non-tensor data when iterating through it with .keys() or .items(). This also affects e.g. apply().
To Reproduce
from tensordict import TensorDict
from tensordict.prototype import tensorclass
import torch
tensordict = TensorDict({"tensor": torch.ones(3), "string": "string"})
print(sorted(tensordict.keys()))
# Output: ['string', 'tensor']
@tensorclass
class MyTensorClass:
tensor: torch.Tensor
string: str
my_tensor_class = MyTensorClass(tensor=torch.ones(3), string="string")
print(sorted(my_tensor_class.keys()))
# Output: ['tensor'], i.e. it ignores non-tensor data.
print(my_tensor_class.apply(lambda x: print(x)))
# Output: tensor([1., 1., 1.])
Expected behavior
For keys() and items() I'd expect it to iterate over non-tensor data.
For apply() probably as well, but whether or not it iterates over non-tensor data could also be dependent on an input flag (as one might explicitly want to only apply something to tensors).
Screenshots
If applicable, add screenshots to help explain your problem.
System info
tensordict version: '0.4.0+b4c91e8'
Checklist
- [x] I have checked that there is no similar issue in the repo (required)
- [x] I have read the documentation (required)
- [x] I have provided a minimal working example to reproduce the bug (required)
This used to be a design choice before we introduced NonTensorData
Context
We initially thought it could be interesting to let tensorclass carry non-tensor data, but thought it was better to exclude it from the keys since things like apply or even any other op that iterates through the keys (reshape, gather, ...) would have been meaningless with non-tensor data.
Then we introduces the NonTensorData which is a simple subclass of the a tensorclass that can only carry non-tensor data. If you call tensordict.keys(include_nested=True) you will have a NonTensorData node appearing as if it had no leaves, but in reality it is a leaf. If you call apply over NonTensorData it's ok because it does not access the data field (it's not part of the keys). But becaue we hack through __getitem__ and __setitem__ to access NonTensorData.data, the situation is less clear (one would imagine that the non-tensor data is now part of the keys since we use the key to set or access the value).
Solution
We introduced a is_leaf function in keys, items, values and apply to quickly check if a node is a leaf or not. That allows us to avoid things like reshape to access the non-tensor data thinking it's a leaf while still presenting non-tensor data as leaves to the user.
We should now apply something similar with all tensorclasses: by default, keys will return all the data (tensor and non-tensor) but internally operations will only be applied to tensor-data.
Sounds good to me!
If I understand it correctly then 'real' leafs are tensors and non-tensors are not leafs because they're wrapped in NonTensorData, which is a dataclass-subclass?
As a user, I'd find that slightly confusing, since to me non-tensor data does appear as leafs. But I guess the is_leaf is only an internal implementation details?