tensordict
tensordict copied to clipboard
[BUG] Tensorclass.key() doesn't list non-tensor data.
Describe the bug
Tensorclasses (in contrast to tensordicts) don't list non-tensor data when iterating through it with .keys()
or .items()
. This also affects e.g. apply()
.
To Reproduce
from tensordict import TensorDict
from tensordict.prototype import tensorclass
import torch
tensordict = TensorDict({"tensor": torch.ones(3), "string": "string"})
print(sorted(tensordict.keys()))
# Output: ['string', 'tensor']
@tensorclass
class MyTensorClass:
tensor: torch.Tensor
string: str
my_tensor_class = MyTensorClass(tensor=torch.ones(3), string="string")
print(sorted(my_tensor_class.keys()))
# Output: ['tensor'], i.e. it ignores non-tensor data.
print(my_tensor_class.apply(lambda x: print(x)))
# Output: tensor([1., 1., 1.])
Expected behavior
For keys()
and items()
I'd expect it to iterate over non-tensor data.
For apply()
probably as well, but whether or not it iterates over non-tensor data could also be dependent on an input flag (as one might explicitly want to only apply something to tensors).
Screenshots
If applicable, add screenshots to help explain your problem.
System info
tensordict version: '0.4.0+b4c91e8'
Checklist
- [x] I have checked that there is no similar issue in the repo (required)
- [x] I have read the documentation (required)
- [x] I have provided a minimal working example to reproduce the bug (required)