MART
MART copied to clipboard
Add support for extended transforms in Lambda and SplitLambda
Right now it is not possible to use extended transforms with Lambda
and SplitLambda
. For example, it would be useful to do something like:
_target_: mart.transforms.SplitLambda
lambd:
_target_: mart.transforms.Compose
transforms:
- _target_: mart.transforms.RandomHorizontalFlip
split_size_or_sections: 3
lambd_section: -1
dim: 0
Here is a failing example in python:
import torch
import mart
transform = mart.transforms.SplitLambda(lambd=mart.transforms.Compose(transforms=[mart.transforms.RandomHorizontalFlip()]),
split_size_or_sections=3,
lambd_section=-1,
dim=0)
transform(tensor=torch.zeros((6, 320, 240)), target={})
If you don't pass target={}
, it works as expected
It would also be nice to support original torchvision transforms:
_target_: mart.transforms.SplitLambda
lambd:
_target_: mart.transforms.Compose
transforms:
- _target_: torchvision.transforms.Normalize
mean: 0
std: 255
split_size_or_sections: 3
lambd_section: -1
dim: 0
Here's a failing example in python:
import torch
import mart
import torchvision
transform = mart.transforms.SplitLambda(lambd=mart.transforms.Compose(transforms=[torchvision.transforms.Normalize(mean=0, std=255)]),
split_size_or_sections=3,
lambd_section=-1,
dim=0)
transform(tensor=torch.zeros((6, 320, 240)), target={})