CLIP icon indicating copy to clipboard operation
CLIP copied to clipboard

clip.model.build_model does not work if device is cpu

Open Toekan opened this issue 3 years ago • 4 comments

Hi,

Thanks for providing this really convenient package to use the CLIP model!

I've come across a problem with build_model when trying to reconstruct the model from a state_dict on my local computer without GPU.

Code to reproduce

First I download one of the built-in models and save the state_dict:

model, preprocess = clip.load("ViT-B/32", jit=False, device="cpu") torch.save(model.state_dict(), 'clip_off_the_shelve.pt')

Then I load the model using your function and try to use it to infer a text embedding:

model = clip.model.build_model(torch.load('clip_off_the_shelve.pt')) text_tokens = clip.tokenize(["door"]) with torch.no_grad(): text_features = model.encode_text(text_tokens).float() text_features /= text_features.norm(dim=-1, keepdim=True)

This unfortunately results in:

~/Envs/test_env/lib/python3.8/site-packages/torch/nn/functional.py in softmax(input, dim, _stacklevel, dtype)
   1678         dim = _get_softmax_dim("softmax", input.dim(), _stacklevel)
   1679     if dtype is None:
-> 1680         ret = input.softmax(dim)
   1681     else:
   1682         ret = input.softmax(dim, dtype=dtype)

RuntimeError: "softmax_lastdim_kernel_impl" not implemented for 'Half'

Potential fix (?)

From reading around, it seems like the culprit is convert_weights in build_model, which converts weights to fp16, regardless off the device being used. Pytorch doesn't support fp16 on "cpu" which seems to create the above error. Would it be possible to make build_model conditional on the device?

Thanks!

Toekan avatar Jan 17 '22 17:01 Toekan

The conditional code can be found here in clip.load()

https://github.com/openai/CLIP/blob/3482bb6ed319f70542094d1ed224c0db0b88c3a5/clip/clip.py#L138-L141

and clip.load("clip_off_the_shelve.pt") should work; please let me know if it doesn't.

jongwook avatar Apr 11 '22 02:04 jongwook

Facing the same error, and @jongwook clip.load("clip_off_the_shelve.pt") doesnt work as well

w1redch4d avatar Aug 05 '22 10:08 w1redch4d

By clip_off_the_shelve.pt I meant the models downloaded under ~/.cache/clip. Let me know what the stacktrace looks like if you see an error loading those models with clip.load().

jongwook avatar Aug 05 '22 22:08 jongwook

The stackrace:

Traceback (most recent call last):
  File "D:\Projects\Python\Github\VQ-Diffusion\run.py", line 109, in <module>
    Predictor().predict()
  File "D:\Projects\Python\Github\VQ-Diffusion\run.py", line 25, in predict
    images = VQ_Diffusion_model.generate_sample_with_condition(
  File "D:\Projects\Python\Github\VQ-Diffusion\run.py", line 90, in generate_sample_with_condition
    model_out = self.model.generate_content(
  File "D:\Projects\Python\Github\VQ-Diffusion\venv\lib\site-packages\torch\autograd\grad_mode.py", line 27, in decorate_context
    return func(*args, **kwargs)
  File "D:\Projects\Python\Github\VQ-Diffusion\image_synthesis\modeling\models\dalle.py", line 215, in generate_content
    trans_out = self.transformer.sample(condition_token=condition['condition_token'],
  File "D:\Projects\Python\Github\VQ-Diffusion\image_synthesis\modeling\transformers\diffusion_transformer.py", line 578, in sample
    cond_emb = self.condition_emb(input['condition_token']) # B x Ld x D   #256*1024
  File "D:\Projects\Python\Github\VQ-Diffusion\venv\lib\site-packages\torch\nn\modules\module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "D:\Projects\Python\Github\VQ-Diffusion\image_synthesis\modeling\embeddings\clip_text_embedding.py", line 72, in forward
    text_feature = self.encode_text(index)
  File "D:\Projects\Python\Github\VQ-Diffusion\image_synthesis\modeling\embeddings\clip_text_embedding.py", line 52, in encode_text
    x = self.transformer(x)
  File "D:\Projects\Python\Github\VQ-Diffusion\venv\lib\site-packages\torch\nn\modules\module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "D:\Projects\Python\Github\VQ-Diffusion\image_synthesis\modeling\modules\clip\model.py", line 198, in forward
    return self.resblocks(x)
  File "D:\Projects\Python\Github\VQ-Diffusion\venv\lib\site-packages\torch\nn\modules\module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "D:\Projects\Python\Github\VQ-Diffusion\venv\lib\site-packages\torch\nn\modules\container.py", line 141, in forward
    input = module(input)
  File "D:\Projects\Python\Github\VQ-Diffusion\venv\lib\site-packages\torch\nn\modules\module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "D:\Projects\Python\Github\VQ-Diffusion\image_synthesis\modeling\modules\clip\model.py", line 185, in forward
    x = x + self.attention(self.ln_1(x))
  File "D:\Projects\Python\Github\VQ-Diffusion\image_synthesis\modeling\modules\clip\model.py", line 182, in attention
    return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
  File "D:\Projects\Python\Github\VQ-Diffusion\venv\lib\site-packages\torch\nn\modules\module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "D:\Projects\Python\Github\VQ-Diffusion\venv\lib\site-packages\torch\nn\modules\activation.py", line 1038, in forward
    attn_output, attn_output_weights = F.multi_head_attention_forward(
  File "D:\Projects\Python\Github\VQ-Diffusion\venv\lib\site-packages\torch\nn\functional.py", line 5358, in multi_head_attention_forward
    attn_output, attn_output_weights = _scaled_dot_product_attention(q, k, v, attn_mask, dropout_p)
  File "D:\Projects\Python\Github\VQ-Diffusion\venv\lib\site-packages\torch\nn\functional.py", line 5037, in _scaled_dot_product_attention
    attn = softmax(attn, dim=-1)
  File "D:\Projects\Python\Github\VQ-Diffusion\venv\lib\site-packages\torch\nn\functional.py", line 1818, in softmax
    ret = input.softmax(dim)
RuntimeError: "softmax_lastdim_kernel_impl" not implemented for 'Half'

any how the error goes away implementing a simple check in the build_model functionality:

def build_model(state_dict: dict, device: str):
    .....
    if str(device) != "cpu":
        convert_weights(model)

    model.load_state_dict(state_dict)
    return model.eval()

w1redch4d avatar Aug 07 '22 13:08 w1redch4d