KeyError: 'channels'
If I just want to use it in other field, what should I do?
import torch model = torch.hub.load('insitro/ChannelViT', 'cpjump_cellpaint_channelvit_small_p8_with_hcs_supervised', pretrained=True) model.eval() images = torch.randn(5, 3, 224, 224) out = model(images)
KeyError: 'channels'
What param shoud I input the extra_tokens?
extra_tokens["channels"] should contain channel indices per batch and should be of shape batch_size x n_channels.
For example, in the ImageNet dataset, we return a dictionary containing channels per sample which is collated using pytorch default_collate function. default_collate collates Mapping[K, V_i] -> Mapping[K, default_collate([V_1, V_2, …])] resulting in extra_tokens['channels'] of shape batch_size x n_channels.
also discussed in https://github.com/insitro/ChannelViT/issues/3#issuecomment-2027716674