Implement `mapslices` without scalar iteration
When I use mapslices(f,a,dims) to manipulate CuArray, a warning appears. It reminds me that using scalar operations on the GPU is inefficient.
a=CUDA.rand(3,4,5)
b=CUDA.rand(2,3)
mapslices(a,dims=[1,2])do t
b*t
end
I had to use additional code to perform the operation.
c=map(eachslice(a,dims=3)) do t
b*t
end
cat(c...,dims=3)
In neural networks or machine learning, mini-batch is often used. When a sample is not a vector or matrix, the input of the model will have multiple dimensions each time, such as size(x)==(100,3,4,batch_size). In this case, mapslices() seems very convenient.
However, when the model and input are both CuArray, the GPU will be very inefficient due to too many scalar operations. Can the internal operations of mapslices() be optimized to make it more efficient?
Describe the solution you'd like
Is there a more elegant way to implement mapslices(f,a,dims) that enables it to use vectorized operations instead of scalar operations.
There isn't an easy way to implement mapslices efficiently, as the inner function is still vectorized (i.e. it uses array operations that are only defined for CuArray) so we can't generate a kernel from it. And the alternative, launching a kernel for each slice, is often prohibitively expensive too.
One workaround that comes to mind is to compile the inner sequence of operations using the CUDA Graph APIs to lower the cost of launching it for each slice, but that's a lot of work.
Note that the particular function here is NNlib.batched_mul, which goes to a CUDA kernel. And in fact also TensorCore.boxdot, which just reshapes & calls *. One may be faster than the other:
julia> using TensorCore, NNlib
julia> c ≈ b ⊠ a ≈ b ⊡ a
true
Hi. Related to this issue, is there currently a way to compute norm over a certain dimension?
For example, for (3, N) array computing norm over second dimension, resulting in (1, N) array.
For norm you can just write it yourself:
julia> A = randn(3,5);
julia> sqrt.(sum(abs2, A; dims=1)) ≈ mapslices(norm, A, dims=1)
true
julia> sum(abs, A; dims=1) ≈ mapslices(x -> norm(x,1), A, dims=1)
true