TensorCore.jl icon indicating copy to clipboard operation
TensorCore.jl copied to clipboard

Support `boxdot` with n neighboring indices

Open KeitaNakamura opened this issue 1 year ago • 2 comments
trafficstars

Hi, @mcabbott. I have implemented boxdot with n neighboring indices, as I mentioned on Discourse. I'm not sure if this aligns with your intended function, but I would appreciate any feedback you may have.

julia> using TensorCore

julia> A = rand(3,3,3);

julia> B = rand(3,3,3);

julia> A ⊡₂ B ≈ reshape(A, 3,:) * reshape(B, :,3)
true

julia> boxdot(A, B, Val(3)) ≈ transpose(vec(A)) * vec(B)
true

KeitaNakamura avatar Oct 25 '24 06:10 KeitaNakamura

At first glance looks good!

I'd like to look closely at how adjoint vectors get handled, as that was the tricky case before.

I wonder whether boxdot!(C, A, B) can just infer Val(2) when necessary? It should be known from the types.

mcabbott avatar Oct 25 '24 12:10 mcabbott

I'd like to look closely at how adjoint vectors get handled, as that was the tricky case before.

Yes. I kept your implementation of single contraction for handling adjoint vectors, using Val{1} specialization.

I wonder whether boxdot!(C, A, B) can just infer Val(2) when necessary? It should be known from the types.

I’m not sure if I fully understand your suggestion, but currently, the implementation does not check the size of the C tensor, so any order of contraction works as long as C has the correct length. For example:

julia> boxdot!(similar(A, 81), A, B, Val(1)); # works

julia> boxdot!(similar(A, 9,9), A, B, Val(1)); # works

julia> boxdot!(similar(A, 9), A, B, Val(2)); # works

julia> boxdot!(similar(A, 1), A, B, Val(3)); # works

Are you suggesting we check the length of C to automatically apply the appropriate contraction? Or should we instead check the size of C (or perhaps just ndims) and select the contraction accordingly? But in both cases, I'm concerned that if someone mistakenly provides an incorrect C tensor or just forgets to put Val(N), the function might not throw an error.

(Edit) From the above examples, I think we should at least check the ndims of the tensor C. I actually prefer to check the size of C strictly, though.

KeitaNakamura avatar Oct 27 '24 01:10 KeitaNakamura