pytorch-image-models icon indicating copy to clipboard operation
pytorch-image-models copied to clipboard

[FEATURE] Add TinyCLIP

Open seefun opened this issue 2 years ago • 1 comments

Is your feature request related to a problem? Please describe. In paper “TinyCLIP: CLIP Distillation via Affinity Mimicking and Weight Inheritance” : from microsoft They provide a novel cross-modal distillation method named TinyCLIP to unleashes the capacity of small CLIP models. It seems like a good job. The small CLIP model will be more practical on the mobile devices or local deploy compared to the large model.

repo: https://github.com/microsoft/Cream/tree/main/TinyCLIP

image

Describe the solution you'd like The code is based the great work: open-clip, so I feel that it might be not difficult to add them to the official implementation of timm and open-clip.

seefun avatar Nov 09 '23 03:11 seefun

Hey @seefun, I'm very interested to add this feature.

Do you have any specific requirements or guidelines that I should follow? Also, I am open to collaborating with other contributors if there's already a team working on this.

engichang1467 avatar Nov 21 '23 03:11 engichang1467

This seems to work for me

from timm.models.vision_transformer import _convert_openai_clip, VisionTransformer
import torch
from torch import nn


url = "https://github.com/wkcn/TinyCLIP-model-zoo/releases/download/checkpoints/TinyCLIP-ViT-8M-16-Text-3M-YFCC15M.pt"

model = VisionTransformer(num_classes=512, embed_dim=256, depth=10, num_heads=4, pre_norm=True, norm_layer=nn.LayerNorm)

state_dict = torch.hub.load_state_dict_from_url(url, map_location="cpu")["state_dict"]
state_dict = _convert_openai_clip(state_dict, model, prefix="_image_encoder.module.visual.")

model.load_state_dict(state_dict)

Checking against official HF release (https://huggingface.co/wkcn/TinyCLIP-ViT-8M-16-Text-3M-YFCC15M), I obtain identical results

from transformers import CLIPModel

model_ref = CLIPModel.from_pretrained("wkcn/TinyCLIP-ViT-8M-16-Text-3M-YFCC15M")

x = torch.randn(1, 3, 224, 224)
model.eval()
model_ref.eval()

model_ref.get_image_features(x)[0] - model(x)

@rwightman Do you want me to open a PR to add these TinyViT models? From what I see, since TinyViT uses non-standard variants, I will need to create separate "model function". Something like this

@register_model
def vit_8m_patch16_clip_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
    model_args = dict(embed_dim=256, depth=10, num_heads=4, pre_norm=True, norm_layer=nn.LayerNorm)
    model = _create_vision_transformer(
        'vit_8m_patch16_clip_224', pretrained=pretrained, **dict(model_args, **kwargs))
    return model

An alternative is to convert the model weights and host on HF. With #2039, model config can be specified in HF.

gau-nernst avatar Mar 18 '24 03:03 gau-nernst

@gau-nernst yeah, if it's just standard clip vit arch but w/ some custom widths/depths/heads that approach will work fine. If there's a PR with these added that points to url weights, I can push to HF hub from there and change the pretraind_cfgs to reflect that...

rwightman avatar Mar 18 '24 20:03 rwightman