ray icon indicating copy to clipboard operation
ray copied to clipboard

[AIR] Add `TorchVisionPreprocessor`

Open bveeramani opened this issue 2 years ago • 10 comments

Signed-off-by: Balaji [email protected]

Depends on:

  • [x] #30448
  • [x] #30514

Why are these changes needed?

tl;dr: Users can't figure out how to apply TorchVision transforms. This PR introduces an abstraction that makes it easy.

To apply a TorchVision transform, you need to do something like this:

def preprocess(batch: dict[str, np.ndarray]) -> dict[str, np.ndarray]:
    return {"image": np.array(transform(image).numpy() for image in batch["image"])}

BatchMapper(preprocess, batch_format="numpy")

There're many unobvious things:

  • You need to set batch_format="numpy" (batch_format is a required argument, but it's not obvious if you should use "numpy" or "pandas").
  • You need to individually apply the transformation to each image in the batch.
  • You need to convert transformed images back to ndarrays.
  • You need to wrap the batch of transformed images in an ndarray.
  • You need to return a dictionary.

This PR introduces a preprocessor TorchVisionPreprocessor that abstracts away the complexity:

TorchVisionPreprocessor(columns=["image"], transform=transform)

Related issue number

Closes #30403

Checks

  • [ ] I've signed off every commit(by using the -s flag, i.e., git commit -s) in this PR.
  • [ ] I've run scripts/format.sh to lint the changes in this PR.
  • [ ] I've included any doc changes needed for https://docs.ray.io/en/master/.
  • [ ] I've made sure the tests are passing. Note that there might be a few flaky tests, see the recent failures at https://flakey-tests.ray.io/
  • Testing Strategy
    • [ ] Unit tests
    • [ ] Release tests
    • [ ] This PR is not tested :(

bveeramani avatar Nov 22 '22 08:11 bveeramani

@bveeramani What are the next steps for this PR? I think this was blocked by a tensor extension bug, but the fix should be merged into master.

clarkzinzow avatar Dec 05 '22 21:12 clarkzinzow

@bveeramani What are the next steps for this PR? I think this was blocked by a tensor extension bug, but the fix should be merged into master.

@clarkzinzow I think I just need to resolve merge conflicts and maybe add a couple more tests.

bveeramani avatar Dec 05 '22 22:12 bveeramani

Some torchvision transforms work on a batched input. Should we allow users to also specify a separate batched_transform for better performance?

See our batch prediction benchmarks for an example

amogkam avatar Dec 07 '22 21:12 amogkam

Some torchvision transforms work on a batched input. Should we allow users to also specify a separate batched_transform for better performance?

See our batch prediction benchmarks for an example

How do you imagine this would work?

In pure Torch, you'd use typically use torchvision.transforms.Compose. So, I imagined people would do something like:

transform = transforms.Compose(
    ToTensor(),
    CenterCrop(224),
    Normalize(mean=..., std=...)
)

preprocessor = TorchVisionPreprocessor(transform)

To support batched transforms, would you need to do something like:

preprocessor = Chain(
    TorchVisionPreprocessor(ToTensor()),
    TorchVisionPreprocessor(CenterCrop(224), batched=True),
    TorchVisionPreprocessor(Normalize(mean=..., std=...), batched=True)
)

bveeramani avatar Dec 07 '22 22:12 bveeramani

@bveeramani Why couldn't you do a single transforms.Composed batch transform given to a single TorchVisionPreprocessor?

clarkzinzow avatar Dec 07 '22 22:12 clarkzinzow

Another note: if you're using TorchVision datasets, you pass transforms to the dataset constructor. In this case, transforms aren't batched.

See https://github.com/pytorch/vision/blob/23d3f78aeea9329a8257e17b90c37f6f2016c171/torchvision/datasets/folder.py#L230-L233

bveeramani avatar Dec 07 '22 22:12 bveeramani

@bveeramani Why couldn't you do a single transforms.Composed batch transform given to a single TorchVisionPreprocessor?

Not all transforms support batches. Like, you could compose some batch-supporting transforms with transforms that don't support batches.

How would we implement this? If we straightforwardly pass a batch of data to Compose, the program would error if the Compose has something like ToTensor.

bveeramani avatar Dec 07 '22 22:12 bveeramani

I was thinking that we could make it opt-in at the TorchVisionPreprocessor constructor, something like

transform = transforms.Compose(
    CenterCrop(224, batched=True),
    Normalize(mean=..., std=..., batched=True)
)
preprocessor = TorchVisionPreprocessor(transform, batched=True)

clarkzinzow avatar Dec 07 '22 22:12 clarkzinzow

API could be something like this:

single_transform = transforms.ToTensor()
batch_transform = transforms.Compose([CenterCrop(...), Normalize(...)])
preprocessor = TorchVisionPreprocessor(transform=single_transform, batch_transform=batch_transform)

Single preprocessor that can accept both options. batch_transform arg is Optional.

amogkam avatar Dec 07 '22 22:12 amogkam

Failing tests are because of FileNotFoundError: Couldn't find file at https://www.dropbox.com/s/1pzkadrvffbqw6o/train.txt?dl=1 in test_huggingface

bveeramani avatar Dec 21 '22 21:12 bveeramani