TensorDict X TransformsV2
https://github.com/pytorch-labs/tensordict https://pytorch.org/rl/tensordict/index.html
Some random notes after a chat I had with @vmoens
TensorsDicts don't really work with our V2 transforms right now: they don't error, but they get passed-through without being transformed:
img = torch.rand(3, 10, 10)
bbox1 = datapoints.BoundingBox(torch.rand(3, 4), format="XYXY", spatial_size=(10, 10))
bbox2 = datapoints.BoundingBox(torch.rand(12, 4), format="XYXY", spatial_size=(10, 10))
td1 = TensorDict({"img": img, "bbox": bbox1}, batch_size=[])
out = v2.Resize(20)(td1)
assert out["img"] is out["img"] # passed-through :'(
It's because pytree.tree_flatten(TensorDict) returns [TensorDict] and so our transforms just pass it through as per our convention.
Some interesting property of TensorDicts is that they could potentially be able to stack() tensors with different shapes which is particularly relevant for BBoxes:
td2 = TensorDict({"img": img, "bbox": bbox2}, batch_size=[])
batch = torch.stack([td1, td2])
gives:
LazyStackedTensorDict(
fields={
bbox: BoundingBox(shape=torch.Size([2, -1, 4]), device=cpu, dtype=torch.float32, is_shared=False),
img: Tensor(shape=torch.Size([2, 3, 10, 10]), device=cpu, dtype=torch.float32, is_shared=False)},
exclusive_fields={
},
batch_size=torch.Size([2]),
device=None,
is_shared=False,
stack_dim=0)
note the -1 in the BBox dim which replaces 3 and 12.
class MyDataset:
def __getitem__(self, idx):
img = torch.rand(3, 10, 10)
num_bboxes = idx + 1
bbox = datapoints.BoundingBox(torch.rand(num_bboxes, 4), format="XYXY", spatial_size=(10, 10))
return TensorDict({"img": img, "bbox": bbox}, [])
def __len__(self):
return 100
from torch.utils.data import DataLoader
ds = MyDataset()
dl = DataLoader(ds, batch_size=4, collate_fn=torch.stack) # This will work fine
dl = DataLoader(ds, batch_size=4) # This fails
I suppose the default behaviour (i.e. not passing a custom collate_fn) could be supported by tweaking default_collate_fn_map https://github.com/pytorch/pytorch/blob/21ede4547aa6873971c990d527c4511bcebf390d/torch/utils/data/_utils/collate.py#L190, but it's private (CC @vmoens )
It's because
pytree.tree_flatten(TensorDict)returns[TensorDict]and so our transforms just pass it through as per our convention.
Does it make sense to open an issue in core about this? First time I hear about TensorDict, but supporting it in pytree sounds reasonable.
Apart from that, we could always monkeypatch it to support our needs. Registering a new type is straightforward if we are comfortable using private APIs.
pytree support should be added soon: https://github.com/pytorch-labs/tensordict/pull/501
This works fine with v0.2.1
import torchvision
from torchvision.tv_tensors import BoundingBoxes, Image
from torchvision.transforms.v2 import Compose, Resize, Grayscale
from tensordict import TensorDict
import torch
image = Image(torch.randint(255, (3, 64, 64), dtype=torch.uint8))
box = BoundingBoxes(torch.randint(0, 64, size=(5, 4)), format="XYXY", canvas_size=(64, 64))
label = torch.randint(10, ())
td = TensorDict({"image": image, "label": label, "meta": {"box": box}}, [])
t = Compose([Resize((32, 32)), Grayscale()])
t(td)
which gives
TensorDict(
fields={
image: Image(shape=torch.Size([1, 32, 32]), device=cpu, dtype=torch.uint8, is_shared=False),
label: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False),
meta: TensorDict(
fields={
box: BoundingBoxes(shape=torch.Size([5, 4]), device=cpu, dtype=torch.int64, is_shared=False)},
batch_size=torch.Size([]),
device=None,
is_shared=False)},
batch_size=torch.Size([]),
device=None,
is_shared=False)
Happy to write a tuto in tensordict (or here) to show how that can be used!
If I can brag a bit Check out this multiprocessed transform:
import torchvision
from torchvision.tv_tensors import BoundingBoxes, Image
from torchvision.transforms.v2 import Compose, Resize, Grayscale
from tensordict import TensorDict
import torch
if __name__ == "__main__":
image = Image(torch.randint(255, (5, 3, 64, 64), dtype=torch.uint8))
box = BoundingBoxes(torch.randint(0, 64, size=(5, 4)), format="XYXY", canvas_size=(64, 64))
label = torch.randint(10, ())
td = TensorDict({"image": image, "label": label, "meta": {"box": box}}, [], device="cpu")
t = Compose([Resize((32, 32)), Grayscale()])
tdt = t(td)
print(tdt)
# Makes a lazy stack of the tensordicts
td = torch.stack([td] * 100)
# Map the transform over all items on 2 separate procs
tdt = td.map(t, dim=0, num_workers=2, chunksize=1)
print(tdt)
This prints the first td (like in the previous comment) but also this
TensorDict(
fields={
image: Tensor(shape=torch.Size([100, 5, 1, 32, 32]), device=cpu, dtype=torch.uint8, is_shared=False),
label: Tensor(shape=torch.Size([100]), device=cpu, dtype=torch.int64, is_shared=False),
meta: TensorDict(
fields={
box: Tensor(shape=torch.Size([100, 5, 4]), device=cpu, dtype=torch.int64, is_shared=False)},
batch_size=torch.Size([100]),
device=cpu,
is_shared=False)},
batch_size=torch.Size([100]),
device=cpu,
is_shared=False)
The print shows that all items are tensors which means that the type is lost somewhere, let me check where. But it's pretty cool to see that this works almost oob!
This PR https://github.com/pytorch/tensordict/pull/589 will allow you to keep the tensor type after a call to TensorDict.map provided that you work with a lazy stack:
import torchvision
from torchvision.tv_tensors import BoundingBoxes, Image
from torchvision.transforms.v2 import Compose, Resize, Grayscale
from tensordict import TensorDict
import torch
if __name__ == "__main__":
image = Image(torch.randint(255, (5, 3, 64, 64), dtype=torch.uint8))
box = BoundingBoxes(
torch.randint(0, 64, size=(5, 4)),
format="XYXY",
canvas_size=(64, 64)
)
label = torch.randint(10, ())
td = TensorDict(
{"image": image, "label": label, "meta": {"box": box}},
[],
device="cpu"
)
t = Compose([Resize((32, 32)), Grayscale()])
tdt = t(td)
# Makes a lazy stack of the tensordicts
td = torch.stack([td.clone() for _ in range(100)])
# Map the transform over all items on 2 separate procs
print('calling map on', td)
tdt = td.map(t, dim=0, num_workers=2, chunksize=0)
print(tdt[0]) # the first tensordict of the lazy stack contains the original types!
This prints at TD with original types
TensorDict(
fields={
image: Image(shape=torch.Size([5, 1, 32, 32]), device=cpu, dtype=torch.uint8, is_shared=True),
label: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=True),
meta: TensorDict(
fields={
box: BoundingBoxes(shape=torch.Size([5, 4]), device=cpu, dtype=torch.int64, is_shared=True)},
batch_size=torch.Size([]),
device=cpu,
is_shared=False)},
batch_size=torch.Size([]),
device=cpu,
is_shared=False)
+1: The following does not work:
from torchvision.io import read_image
from torchvision.transforms.v2.functional import to_pil_image, to_image
from torchvision.transforms.v2 import RandomAffine
from tensordict import TensorDict
img = to_image(read_image("./astronaut.jpg"))
transform = RandomAffine(degrees=45)
out = transform(img) # This does work on a torchvision.tv_tensors.Image
td = TensorDict({"image1": img, "image2": img}, [])
out = transform(td)
TypeError: No image, video, mask or bounding box was found in the sample
While the following does work correctly:
out = TensorDict.from_dict(transform(td.to_dict()))