tensordict icon indicating copy to clipboard operation
tensordict copied to clipboard

[BUG] Tensorclass.key() doesn't list non-tensor data.

Open maximilianigl opened this issue 11 months ago • 2 comments

Describe the bug

Tensorclasses (in contrast to tensordicts) don't list non-tensor data when iterating through it with .keys() or .items(). This also affects e.g. apply().

To Reproduce

from tensordict import TensorDict
from tensordict.prototype import tensorclass
import torch

tensordict = TensorDict({"tensor": torch.ones(3), "string": "string"})

print(sorted(tensordict.keys()))
# Output: ['string', 'tensor']

@tensorclass
class MyTensorClass:
    tensor: torch.Tensor
    string: str


my_tensor_class = MyTensorClass(tensor=torch.ones(3), string="string")

print(sorted(my_tensor_class.keys()))
# Output: ['tensor'], i.e. it ignores non-tensor data.

print(my_tensor_class.apply(lambda x: print(x)))
# Output: tensor([1., 1., 1.])

Expected behavior

For keys() and items() I'd expect it to iterate over non-tensor data. For apply() probably as well, but whether or not it iterates over non-tensor data could also be dependent on an input flag (as one might explicitly want to only apply something to tensors).

Screenshots

If applicable, add screenshots to help explain your problem.

System info

tensordict version: '0.4.0+b4c91e8'

Checklist

  • [x] I have checked that there is no similar issue in the repo (required)
  • [x] I have read the documentation (required)
  • [x] I have provided a minimal working example to reproduce the bug (required)

maximilianigl avatar Mar 21 '24 14:03 maximilianigl