pytorch-image-models
pytorch-image-models copied to clipboard
[BUG] Transform not properly working with batching and Grayscale Images in ViT
Describe the bug If you use the out of the box image transforms for TinyViT, they do not really work if you want to use a grayscale images, because they expect 2/3 channel images (see code below). It would also be nice if you could use batching right away since it increases the interoperability with other workflows and frameworks (lightning for example).
To Reproduce Steps to reproduce the behavior:
model = timm.create_model('tiny_vit_21m_512.dist_in22k_ft_in1k', pretrained=True, in_chans=1, num_classes=4)
data_config = timm.data.resolve_model_data_config(model)
train_transform = timm.data.create_transform(**data_config, is_training=True)
train_transform.transforms.insert(0, v2.ToPILImage())
x = torch.randn(4, 1, 512, 512)
model(x)
Expected behavior Expected behaviour would be to automatically detect a 4D tensor and implement batching. Also for grayscale images the channel could just be copied of all 3 channels before transformation for a low effort implementation. There is also the PIL Images dependency, maybe it makes sense to drop that in favor of torch.tensor.
Screenshots If applicable, add screenshots to help explain your problem.
Desktop (please complete the following information):
- OS: Ubuntu 22.04
- Python 3.10
- Cuda 12.5