vision icon indicating copy to clipboard operation
vision copied to clipboard

TensorDict X TransformsV2

Open NicolasHug opened this issue 2 years ago • 6 comments

https://github.com/pytorch-labs/tensordict https://pytorch.org/rl/tensordict/index.html

Some random notes after a chat I had with @vmoens


TensorsDicts don't really work with our V2 transforms right now: they don't error, but they get passed-through without being transformed:

img = torch.rand(3, 10, 10)
bbox1 = datapoints.BoundingBox(torch.rand(3, 4), format="XYXY", spatial_size=(10, 10))
bbox2 = datapoints.BoundingBox(torch.rand(12, 4), format="XYXY", spatial_size=(10, 10))

td1 = TensorDict({"img": img, "bbox": bbox1}, batch_size=[])
out = v2.Resize(20)(td1)
assert out["img"] is out["img"]  # passed-through :'(

It's because pytree.tree_flatten(TensorDict) returns [TensorDict] and so our transforms just pass it through as per our convention.


Some interesting property of TensorDicts is that they could potentially be able to stack() tensors with different shapes which is particularly relevant for BBoxes:

td2 = TensorDict({"img": img, "bbox": bbox2}, batch_size=[])
batch = torch.stack([td1, td2])

gives:

LazyStackedTensorDict(
    fields={
        bbox: BoundingBox(shape=torch.Size([2, -1, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        img: Tensor(shape=torch.Size([2, 3, 10, 10]), device=cpu, dtype=torch.float32, is_shared=False)},
    exclusive_fields={
    },
    batch_size=torch.Size([2]),
    device=None,
    is_shared=False,
    stack_dim=0)

note the -1 in the BBox dim which replaces 3 and 12.


class MyDataset:
    def __getitem__(self, idx):
        img = torch.rand(3, 10, 10)
        num_bboxes = idx + 1
        bbox = datapoints.BoundingBox(torch.rand(num_bboxes, 4), format="XYXY", spatial_size=(10, 10))
        return TensorDict({"img": img, "bbox": bbox}, [])

    def __len__(self):
        return 100

from torch.utils.data import DataLoader

ds = MyDataset()

dl = DataLoader(ds, batch_size=4, collate_fn=torch.stack)  # This will work fine
dl = DataLoader(ds, batch_size=4)  # This fails

I suppose the default behaviour (i.e. not passing a custom collate_fn) could be supported by tweaking default_collate_fn_map https://github.com/pytorch/pytorch/blob/21ede4547aa6873971c990d527c4511bcebf390d/torch/utils/data/_utils/collate.py#L190, but it's private (CC @vmoens )

NicolasHug avatar Jul 26 '23 17:07 NicolasHug

It's because pytree.tree_flatten(TensorDict) returns [TensorDict] and so our transforms just pass it through as per our convention.

Does it make sense to open an issue in core about this? First time I hear about TensorDict, but supporting it in pytree sounds reasonable.

Apart from that, we could always monkeypatch it to support our needs. Registering a new type is straightforward if we are comfortable using private APIs.

pmeier avatar Jul 26 '23 20:07 pmeier

pytree support should be added soon: https://github.com/pytorch-labs/tensordict/pull/501

NicolasHug avatar Jul 27 '23 08:07 NicolasHug

This works fine with v0.2.1

import torchvision
from torchvision.tv_tensors import BoundingBoxes, Image
from torchvision.transforms.v2 import Compose, Resize, Grayscale
from tensordict import TensorDict
import torch

image = Image(torch.randint(255, (3, 64, 64), dtype=torch.uint8))
box = BoundingBoxes(torch.randint(0, 64, size=(5, 4)), format="XYXY", canvas_size=(64, 64))
label = torch.randint(10, ())

td = TensorDict({"image": image, "label": label, "meta": {"box": box}}, [])

t = Compose([Resize((32, 32)), Grayscale()])

t(td)

which gives

TensorDict(
    fields={
        image: Image(shape=torch.Size([1, 32, 32]), device=cpu, dtype=torch.uint8, is_shared=False),
        label: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False),
        meta: TensorDict(
            fields={
                box: BoundingBoxes(shape=torch.Size([5, 4]), device=cpu, dtype=torch.int64, is_shared=False)},
            batch_size=torch.Size([]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)

Happy to write a tuto in tensordict (or here) to show how that can be used!

vmoens avatar Nov 20 '23 14:11 vmoens

If I can brag a bit Check out this multiprocessed transform:

import torchvision
from torchvision.tv_tensors import BoundingBoxes, Image
from torchvision.transforms.v2 import Compose, Resize, Grayscale
from tensordict import TensorDict
import torch

if __name__ == "__main__":
    image = Image(torch.randint(255, (5, 3, 64, 64), dtype=torch.uint8))
    box = BoundingBoxes(torch.randint(0, 64, size=(5, 4)), format="XYXY", canvas_size=(64, 64))
    label = torch.randint(10, ())
    
    td = TensorDict({"image": image, "label": label, "meta": {"box": box}}, [], device="cpu")
    
    t = Compose([Resize((32, 32)), Grayscale()])
    
    tdt = t(td)
    print(tdt)
    # Makes a lazy stack of the tensordicts
    td = torch.stack([td] * 100)
    # Map the transform over all items on 2 separate procs
    tdt = td.map(t, dim=0, num_workers=2, chunksize=1)
    print(tdt)

This prints the first td (like in the previous comment) but also this

TensorDict(
    fields={
        image: Tensor(shape=torch.Size([100, 5, 1, 32, 32]), device=cpu, dtype=torch.uint8, is_shared=False),
        label: Tensor(shape=torch.Size([100]), device=cpu, dtype=torch.int64, is_shared=False),
        meta: TensorDict(
            fields={
                box: Tensor(shape=torch.Size([100, 5, 4]), device=cpu, dtype=torch.int64, is_shared=False)},
            batch_size=torch.Size([100]),
            device=cpu,
            is_shared=False)},
    batch_size=torch.Size([100]),
    device=cpu,
    is_shared=False)

The print shows that all items are tensors which means that the type is lost somewhere, let me check where. But it's pretty cool to see that this works almost oob!

vmoens avatar Nov 20 '23 14:11 vmoens

This PR https://github.com/pytorch/tensordict/pull/589 will allow you to keep the tensor type after a call to TensorDict.map provided that you work with a lazy stack:

import torchvision
from torchvision.tv_tensors import BoundingBoxes, Image
from torchvision.transforms.v2 import Compose, Resize, Grayscale
from tensordict import TensorDict
import torch

if __name__ == "__main__":
    image = Image(torch.randint(255, (5, 3, 64, 64), dtype=torch.uint8))
    box = BoundingBoxes(
        torch.randint(0, 64, size=(5, 4)),
        format="XYXY",
        canvas_size=(64, 64)
        )
    label = torch.randint(10, ())

    td = TensorDict(
        {"image": image, "label": label, "meta": {"box": box}},
        [],
        device="cpu"
        )

    t = Compose([Resize((32, 32)), Grayscale()])

    tdt = t(td)
    # Makes a lazy stack of the tensordicts
    td = torch.stack([td.clone() for _ in range(100)])
    # Map the transform over all items on 2 separate procs
    print('calling map on', td)
    tdt = td.map(t, dim=0, num_workers=2, chunksize=0)
    print(tdt[0]) # the first tensordict of the lazy stack contains the original types!

This prints at TD with original types

TensorDict(
    fields={
        image: Image(shape=torch.Size([5, 1, 32, 32]), device=cpu, dtype=torch.uint8, is_shared=True),
        label: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=True),
        meta: TensorDict(
            fields={
                box: BoundingBoxes(shape=torch.Size([5, 4]), device=cpu, dtype=torch.int64, is_shared=True)},
            batch_size=torch.Size([]),
            device=cpu,
            is_shared=False)},
    batch_size=torch.Size([]),
    device=cpu,
    is_shared=False)

vmoens avatar Dec 04 '23 12:12 vmoens

+1: The following does not work:

from torchvision.io import read_image
from torchvision.transforms.v2.functional import to_pil_image, to_image
from torchvision.transforms.v2 import RandomAffine
from tensordict import TensorDict

img = to_image(read_image("./astronaut.jpg"))
transform = RandomAffine(degrees=45)
out = transform(img) # This does work on a torchvision.tv_tensors.Image

td = TensorDict({"image1": img, "image2": img}, [])
out = transform(td)

TypeError: No image, video, mask or bounding box was found in the sample While the following does work correctly:

out = TensorDict.from_dict(transform(td.to_dict()))

Mxbonn avatar Mar 08 '24 14:03 Mxbonn