DeCLIP
DeCLIP copied to clipboard
worked (simple) example of loading model and transforms?
Thank you for this exciting repository. Can you provide a simple example of how I might be able to load the models you provide in your model zoo?
Something along the lines of what is provided by the timm (pytorch-image-models) model repository:
import timm
model_name = 'ghostnet_100'
model = timm.create_model(model_name, pretrained=True)
model.eval()
from timm.data.transforms_factory import create_transform
from timm.data import resolve_data_config
config = resolve_data_config({}, model = model_name)
transform = create_transform(**config)
Ideally, this would allow us to use the models in a jupyter notebook or other interactive context.
Thanks in advance!
By way of example, here's a little script I worked out. If this looks incorrect, let me know!
import os, sys, torch
from PIL import Image
from torchvision import transforms
if not os.path.exists('DeCLIP'):
!git clone https://github.com/Sense-GVT/DeCLIP/
sys.path.append('DeCLIP')
sample_image = Image.open('dog.jpg')
from prototype.utils.misc import parse_config
config_path = 'DeCLIP/experiments/declip_experiments/declip88m/declip88m_r50_declip/config.yaml'
config = parse_config(config_path)
from prototype.model.declip import declip_res50
bpe_path = 'DeCLIP/prototype/text_info/bpe_simple_vocab_16e6.txt.gz'
config['model']['kwargs']['text_encode']['bpe_path'] = bpe_path
config['model']['kwargs']['clip']['text_mask_type'] = None
weights = torch.load('DeCLIP/weights/declip_88m/r50.pth.tar')['model']
weights = {k.replace('module.',''):v for k,v in weights.items()}
weights['logit_scale'] = weights['logit_scale'].unsqueeze(0)
model = declip_res50(**config['model']['kwargs'])
model.load_state_dict(weights, strict = False)
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
preprocess = transforms.Compose([transforms.Resize(256), transforms.ToTensor(), normalize])
inputs = preprocess(sample_image).unsqueeze(0)
model.visual(inputs)