MART icon indicating copy to clipboard operation
MART copied to clipboard

Add support for extended transforms in Lambda and SplitLambda

Open dxoigmn opened this issue 2 years ago • 0 comments

Right now it is not possible to use extended transforms with Lambda and SplitLambda. For example, it would be useful to do something like:

_target_: mart.transforms.SplitLambda
lambd:
  _target_: mart.transforms.Compose
  transforms:
    - _target_: mart.transforms.RandomHorizontalFlip
split_size_or_sections: 3
lambd_section: -1
dim: 0

Here is a failing example in python:

import torch
import mart

transform = mart.transforms.SplitLambda(lambd=mart.transforms.Compose(transforms=[mart.transforms.RandomHorizontalFlip()]),
                                        split_size_or_sections=3,
                                        lambd_section=-1,
                                        dim=0)
transform(tensor=torch.zeros((6, 320, 240)), target={})

If you don't pass target={}, it works as expected

It would also be nice to support original torchvision transforms:

_target_: mart.transforms.SplitLambda
lambd:
  _target_: mart.transforms.Compose
  transforms:
    - _target_: torchvision.transforms.Normalize
      mean: 0
      std: 255
split_size_or_sections: 3
lambd_section: -1
dim: 0

Here's a failing example in python:

import torch
import mart
import torchvision

transform = mart.transforms.SplitLambda(lambd=mart.transforms.Compose(transforms=[torchvision.transforms.Normalize(mean=0, std=255)]),
                                        split_size_or_sections=3,
                                        lambd_section=-1,
                                        dim=0)
transform(tensor=torch.zeros((6, 320, 240)), target={})

dxoigmn avatar Jan 24 '23 19:01 dxoigmn