pytorch-image-models icon indicating copy to clipboard operation
pytorch-image-models copied to clipboard

[BUG] Transform not properly working with batching and Grayscale Images in ViT

Open asusdisciple opened this issue 7 months ago • 0 comments

Describe the bug If you use the out of the box image transforms for TinyViT, they do not really work if you want to use a grayscale images, because they expect 2/3 channel images (see code below). It would also be nice if you could use batching right away since it increases the interoperability with other workflows and frameworks (lightning for example).

To Reproduce Steps to reproduce the behavior:

model = timm.create_model('tiny_vit_21m_512.dist_in22k_ft_in1k', pretrained=True, in_chans=1, num_classes=4)
data_config = timm.data.resolve_model_data_config(model)
    train_transform = timm.data.create_transform(**data_config, is_training=True)
 train_transform.transforms.insert(0, v2.ToPILImage())
x = torch.randn(4, 1, 512, 512)
model(x)

Expected behavior Expected behaviour would be to automatically detect a 4D tensor and implement batching. Also for grayscale images the channel could just be copied of all 3 channels before transformation for a low effort implementation. There is also the PIL Images dependency, maybe it makes sense to drop that in favor of torch.tensor.

Screenshots If applicable, add screenshots to help explain your problem.

Desktop (please complete the following information):

  • OS: Ubuntu 22.04
  • Python 3.10
  • Cuda 12.5

asusdisciple avatar Jul 01 '24 08:07 asusdisciple