CLIP
CLIP copied to clipboard
CLIP compatibility with MPS/Arm64 (Mac)
I've been experimenting with CLIP on a Mac Studio with an M1 Ultra/48 core GPU now that there's a compatible Torch nightly. When running the sample in the README, changing line 5 to device = "CPU"
lets it run with the expected outcome but if I switch it to device = "mps"
it errors out:
Traceback (most recent call last): File "/Users/jasonhoffman/Documents/python/mlv/clip_test.py", line 13, in <module> text_features = model.encode_text(text) File "/Users/jasonhoffman/Documents/python/mlv/env/lib/python3.10/site-packages/clip/model.py", line 349, in encode_text x = self.transformer(x) File "/Users/jasonhoffman/Documents/python/mlv/env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1131, in _call_impl return forward_call(*input, **kwargs) File "/Users/jasonhoffman/Documents/python/mlv/env/lib/python3.10/site-packages/clip/model.py", line 204, in forward return self.resblocks(x) File "/Users/jasonhoffman/Documents/python/mlv/env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1131, in _call_impl return forward_call(*input, **kwargs) File "/Users/jasonhoffman/Documents/python/mlv/env/lib/python3.10/site-packages/torch/nn/modules/container.py", line 139, in forward input = module(input) File "/Users/jasonhoffman/Documents/python/mlv/env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1131, in _call_impl return forward_call(*input, **kwargs) File "/Users/jasonhoffman/Documents/python/mlv/env/lib/python3.10/site-packages/clip/model.py", line 191, in forward x = x + self.attention(self.ln_1(x)) File "/Users/jasonhoffman/Documents/python/mlv/env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1131, in _call_impl return forward_call(*input, **kwargs) File "/Users/jasonhoffman/Documents/python/mlv/env/lib/python3.10/site-packages/clip/model.py", line 163, in forward ret = super().forward(x.type(torch.float32)) File "/Users/jasonhoffman/Documents/python/mlv/env/lib/python3.10/site-packages/torch/nn/modules/normalization.py", line 189, in forward return F.layer_norm( File "/Users/jasonhoffman/Documents/python/mlv/env/lib/python3.10/site-packages/torch/nn/functional.py", line 2501, in layer_norm return torch.layer_norm(input, normalized_shape, weight, bias, eps, torch.backends.cudnn.enabled) RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.
I realize I might be doing something obscure here, but is there any reason why I should expect this error? At this point is it more likely to be with the nightly build of torch or the CLIP model trying to use it?
My understanding is MPS support is not totally there yet. Here's a general tracking issue on the pytorch repo that is covering all the various operations that are still being implemented (it's quite the list!):
https://github.com/pytorch/pytorch/issues/77764
hmm but this problem looks different to the category of problem tracked by the linked issue.
@Maxhirez I'd encourage you to submit this as an issue to pytorch, especially as a minimal repro. there's definitely people working on improving MPS support, but they can only fix problems that get reported.
is it more likely to be with the nightly build of torch or the CLIP model trying to use it?
mm my guess would be torch. the dream of swappable backends is "if you know torch, you don't need to know CUDA", so I think if non-CUDA backends are differently-constrained, that's probably a sign of missing functionality (basic usage implemented, advanced usage unsupported)
from the sounds of the error message, the CLIP model could work around the error by (pun intended) reshaping the code. but I think if the backend supported more functionality then a workaround wouldn't be needed.
hmm but this problem looks different to the category of problem tracked by the linked issue.
@Maxhirez I'd encourage you to submit this as an issue to pytorch, especially as a minimal repro. there's definitely people working on improving MPS support, but they can only fix problems that get reported.
Fair point. I guess all I meant was when they released an mps version of pytorch it wasn't exactly done yet. But they do seem to be making daily progress, at least!
There's been a lot of progress on the pytorch side... any hope for CLIP and M1 Apple Silicon?
https://github.com/pytorch/pytorch/issues/77764