CUDA array scalar getindex error
The following:
using TensorCast, CUDA; CUDA.allowscalar(false)
C = cu(ones(10,2))
L = cu(ones(10,3))
@reduce D[m,a] := sum(p) C[p,a] + L[p,m]
gives a scalar getindex is disallowed error. But using @cast as an intermediate step or re-ordering indices both work fine:
@cast T[p,m,a] := C[p,a] + L[p,m]
D = reshape(sum(T, dims=1), (3,2))
or
C = cu(ones(2,10))
L = cu(ones(3,10))
@reduce D[m,a] := sum(p) C[a,p] + L[m,p]
both produce
3×2 CuArray{Float32,2,CuArray{Float32,3,Nothing}}:
20.0 20.0
20.0 20.0
20.0 20.0
Question was initially raised here: TensorCast & CUDA
It's pretty odd that this fails while the example of https://github.com/mcabbott/TensorCast.jl/pull/10#issue-359041016 does not. Both use orient to reshape a transposed matrix. But doing this twice seems to cause problems:
reshape(C,1,2,10) .+ reshape(L', 3,1,10) # ok
reshape(C',1,2,10) .+ reshape(L, 3,1,10) # ok
reshape(C',1,2,10) .+ reshape(L', 3,1,10) # ERROR: scalar getindex is disallowed
reshape(C',1,2,10) |> typeof
# Base.ReshapedArray{Float32,3,LinearAlgebra.Adjoint{Float32,CuArray{Float32,2,Nothing}},Tuple{Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}}}
This seems the same on CuArrays v1.7.2, CuArrays v2.2.1, and CUDA v0.1.0.
Here is one way to work around this, forcing the broadcast to be a CUDA one:
trick = cu(fill(false))
@reduce D[m,a] := sum(p) C[p,a] + L[p,m] + trick
I think this can be closed as fixed by #31, current behaviour is:
julia> @reduce D[m,a] := sum(p) C[p,a] + L[p,m]
3×2 CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}:
20.0 20.0
20.0 20.0
20.0 20.0
julia> @pretty @reduce D[m,a] := sum(p) C[p,a] + L[p,m]
begin
@boundscheck ndims(C) == 2 || throw(ArgumentError("expected a 2-tensor C[p, a]"))
@boundscheck axes(C, 1) == axes(L, 1) || throw(DimensionMismatch("range of index p must agree"))
@boundscheck ndims(L) == 2 || throw(ArgumentError("expected a 2-tensor L[p, m]"))
local fox = transmute(C, Val((nothing, 2, 1)))
local tiger = transmute(L, Val((2, nothing, 1)))
D = dropdims(sum(@__dot__(fox + tiger), dims = 3), dims = 3)
end