TensorCore.jl
TensorCore.jl copied to clipboard
Support `boxdot` with n neighboring indices
Hi, @mcabbott.
I have implemented boxdot with n neighboring indices, as I mentioned on Discourse. I'm not sure if this aligns with your intended function, but I would appreciate any feedback you may have.
julia> using TensorCore
julia> A = rand(3,3,3);
julia> B = rand(3,3,3);
julia> A ⊡₂ B ≈ reshape(A, 3,:) * reshape(B, :,3)
true
julia> boxdot(A, B, Val(3)) ≈ transpose(vec(A)) * vec(B)
true
At first glance looks good!
I'd like to look closely at how adjoint vectors get handled, as that was the tricky case before.
I wonder whether boxdot!(C, A, B) can just infer Val(2) when necessary? It should be known from the types.
I'd like to look closely at how adjoint vectors get handled, as that was the tricky case before.
Yes. I kept your implementation of single contraction for handling adjoint vectors, using Val{1} specialization.
I wonder whether
boxdot!(C, A, B)can just inferVal(2)when necessary? It should be known from the types.
I’m not sure if I fully understand your suggestion, but currently, the implementation does not check the size of the C tensor, so any order of contraction works as long as C has the correct length. For example:
julia> boxdot!(similar(A, 81), A, B, Val(1)); # works
julia> boxdot!(similar(A, 9,9), A, B, Val(1)); # works
julia> boxdot!(similar(A, 9), A, B, Val(2)); # works
julia> boxdot!(similar(A, 1), A, B, Val(3)); # works
Are you suggesting we check the length of C to automatically apply the appropriate contraction? Or should we instead check the size of C (or perhaps just ndims) and select the contraction accordingly? But in both cases, I'm concerned that if someone mistakenly provides an incorrect C tensor or just forgets to put Val(N), the function might not throw an error.
(Edit)
From the above examples, I think we should at least check the ndims of the tensor C. I actually prefer to check the size of C strictly, though.