Metalhead.jl
Metalhead.jl copied to clipboard
Issues with `ViT` on the GPU
With some experimentation, issues pop up when I use ViT on the GPU. Documenting these so that they can be tracked down and solved:
-
[x] Class tokens don't work on the GPU for now because
filldoesn't automatically allocate on CPU/GPU as per the model (should be solved by #166 ) -
[ ]
MLUtils.chunkisn't GPU-friendly yet. (A hotfix landed with #169 so ViTs should work on GPUs for now but a long-term fix is pending) -
[x] A scalar indexing warning comes up when I run the model (regression introduced in #162 because of
selectdim- indexing is unavoidable since data is not contiguous and soselectdimreturns a view) (should be solved by #166)
For the reshape issue, have you looked at what is the output of chunk. I suspect it isn't a CuArray but a view instead. permutedims just turns it back into a CuArray. So then the issue will be to make chunk GPU friendly.
Oh, you're right! That is indeed the issue. I'll see how I can make chunk work on the GPU
Typically, CUDA.jl will return CuArrays instead of SubArrays for views, so long as the values are contiguous. The use of selectdim in chunk will try to take a view. If the "chunk" you are selecting corresponds to multiple indices (e.g. 2:4) in a middle dimension, then this will be a non-contiguous access and CUDA.jl will fallback to returning a SubArray. At this point almost any downstream operation on the GPU will complain. Basically this is to say that the CPU and GPU code path have to diverge cause you cannot make views work in this case.
Not sure exactly how we want to fix this. I feel that any change will be slower for the CPU path, and MLUtils.jl currently doesn't depend on CUDA.jl (we want to keep it this way).
cc @CarloLucibello and @ToucheSir for some more eyes
If I'm not mistaken, https://github.com/FluxML/Metalhead.jl/blob/edf83e0932a8ebf8c642db35dbf8dcbcb293ad38/src/layers/attention.jl#L47 accesses the last dimension and not a middle one?
Oops so then selectdim just always returns a SubArray? Taking a contiguous view directly should return CuArray if I'm not mistaken.
Looks like it does for CuArrays. My understanding from testing with Cthulhu is that https://github.com/JuliaLang/julia/blob/v1.7.3/base/abstractarraymath.jl#L136 transforms what should be a : or or a UnitRange{Int64} for the non-selected dimensions into a Slice{OneTo{Int64}} . view on CuArrays can recognize view(x, :) is contiguous, but not view(x, Base.Slice(axes(x, 1))).
For our purposes, I think replacing the selectdim on https://github.com/JuliaML/MLUtils.jl/blob/348dbdd04065c650cf4a8fb2979bdaa4aad0b84a/src/utils.jl#L165 with an equivalent function that fills in non-chunked dims with Colon would suffice. It might also be nice to support dims as a Val for type stability, but as long as we can avoid triggering https://github.com/JuliaGPU/CUDA.jl/blob/0a46a13467f8522cc84ca490bc4948e25bf95fd3/src/array.jl#L631 we should be fine.
can this be closed?
I think so, we can re-open if anything was missed.