tensordict icon indicating copy to clipboard operation
tensordict copied to clipboard

[BUG] Tensorclass.key() doesn't list non-tensor data.

Open maximilianigl opened this issue 1 year ago • 2 comments

Describe the bug

Tensorclasses (in contrast to tensordicts) don't list non-tensor data when iterating through it with .keys() or .items(). This also affects e.g. apply().

To Reproduce

from tensordict import TensorDict
from tensordict.prototype import tensorclass
import torch

tensordict = TensorDict({"tensor": torch.ones(3), "string": "string"})

print(sorted(tensordict.keys()))
# Output: ['string', 'tensor']

@tensorclass
class MyTensorClass:
    tensor: torch.Tensor
    string: str


my_tensor_class = MyTensorClass(tensor=torch.ones(3), string="string")

print(sorted(my_tensor_class.keys()))
# Output: ['tensor'], i.e. it ignores non-tensor data.

print(my_tensor_class.apply(lambda x: print(x)))
# Output: tensor([1., 1., 1.])

Expected behavior

For keys() and items() I'd expect it to iterate over non-tensor data. For apply() probably as well, but whether or not it iterates over non-tensor data could also be dependent on an input flag (as one might explicitly want to only apply something to tensors).

Screenshots

If applicable, add screenshots to help explain your problem.

System info

tensordict version: '0.4.0+b4c91e8'

Checklist

  • [x] I have checked that there is no similar issue in the repo (required)
  • [x] I have read the documentation (required)
  • [x] I have provided a minimal working example to reproduce the bug (required)

maximilianigl avatar Mar 21 '24 14:03 maximilianigl

This used to be a design choice before we introduced NonTensorData

Context

We initially thought it could be interesting to let tensorclass carry non-tensor data, but thought it was better to exclude it from the keys since things like apply or even any other op that iterates through the keys (reshape, gather, ...) would have been meaningless with non-tensor data.

Then we introduces the NonTensorData which is a simple subclass of the a tensorclass that can only carry non-tensor data. If you call tensordict.keys(include_nested=True) you will have a NonTensorData node appearing as if it had no leaves, but in reality it is a leaf. If you call apply over NonTensorData it's ok because it does not access the data field (it's not part of the keys). But becaue we hack through __getitem__ and __setitem__ to access NonTensorData.data, the situation is less clear (one would imagine that the non-tensor data is now part of the keys since we use the key to set or access the value).

Solution

We introduced a is_leaf function in keys, items, values and apply to quickly check if a node is a leaf or not. That allows us to avoid things like reshape to access the non-tensor data thinking it's a leaf while still presenting non-tensor data as leaves to the user. We should now apply something similar with all tensorclasses: by default, keys will return all the data (tensor and non-tensor) but internally operations will only be applied to tensor-data.

vmoens avatar Mar 26 '24 08:03 vmoens

Sounds good to me! If I understand it correctly then 'real' leafs are tensors and non-tensors are not leafs because they're wrapped in NonTensorData, which is a dataclass-subclass? As a user, I'd find that slightly confusing, since to me non-tensor data does appear as leafs. But I guess the is_leaf is only an internal implementation details?

maximilianigl avatar Mar 26 '24 08:03 maximilianigl