pytorch-image-models
pytorch-image-models copied to clipboard
[FEATURE] Add TinyCLIP
Is your feature request related to a problem? Please describe. In paper “TinyCLIP: CLIP Distillation via Affinity Mimicking and Weight Inheritance” : from microsoft They provide a novel cross-modal distillation method named TinyCLIP to unleashes the capacity of small CLIP models. It seems like a good job. The small CLIP model will be more practical on the mobile devices or local deploy compared to the large model.
repo: https://github.com/microsoft/Cream/tree/main/TinyCLIP
Describe the solution you'd like The code is based the great work: open-clip, so I feel that it might be not difficult to add them to the official implementation of timm and open-clip.
Hey @seefun, I'm very interested to add this feature.
Do you have any specific requirements or guidelines that I should follow? Also, I am open to collaborating with other contributors if there's already a team working on this.
This seems to work for me
from timm.models.vision_transformer import _convert_openai_clip, VisionTransformer
import torch
from torch import nn
url = "https://github.com/wkcn/TinyCLIP-model-zoo/releases/download/checkpoints/TinyCLIP-ViT-8M-16-Text-3M-YFCC15M.pt"
model = VisionTransformer(num_classes=512, embed_dim=256, depth=10, num_heads=4, pre_norm=True, norm_layer=nn.LayerNorm)
state_dict = torch.hub.load_state_dict_from_url(url, map_location="cpu")["state_dict"]
state_dict = _convert_openai_clip(state_dict, model, prefix="_image_encoder.module.visual.")
model.load_state_dict(state_dict)
Checking against official HF release (https://huggingface.co/wkcn/TinyCLIP-ViT-8M-16-Text-3M-YFCC15M), I obtain identical results
from transformers import CLIPModel
model_ref = CLIPModel.from_pretrained("wkcn/TinyCLIP-ViT-8M-16-Text-3M-YFCC15M")
x = torch.randn(1, 3, 224, 224)
model.eval()
model_ref.eval()
model_ref.get_image_features(x)[0] - model(x)
@rwightman Do you want me to open a PR to add these TinyViT models? From what I see, since TinyViT uses non-standard variants, I will need to create separate "model function". Something like this
@register_model
def vit_8m_patch16_clip_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
model_args = dict(embed_dim=256, depth=10, num_heads=4, pre_norm=True, norm_layer=nn.LayerNorm)
model = _create_vision_transformer(
'vit_8m_patch16_clip_224', pretrained=pretrained, **dict(model_args, **kwargs))
return model
An alternative is to convert the model weights and host on HF. With #2039, model config can be specified in HF.
@gau-nernst yeah, if it's just standard clip vit arch but w/ some custom widths/depths/heads that approach will work fine. If there's a PR with these added that points to url weights, I can push to HF hub from there and change the pretraind_cfgs to reflect that...