NNlib.jl icon indicating copy to clipboard operation
NNlib.jl copied to clipboard

Batched dot(x, A, y)

Open 3f6a opened this issue 5 months ago • 4 comments

Motivation and description

Say I have the arrays x[i,b], y[j,b] and A[i,j,b]. Is there an efficient way to do the following "batched dot" operation:

[sum(x[i,b] * A[i,j,b] * y[j,b] for i = axes(A,1) for j = axes(A,2)) for b = ...]

where b traverses the batch dimension. As usual, we could have size(x,2) == 1, size(A,3)==1, ..., which would mean the corresponding missing dimension is broadcasted.

Apologies if there is already a way to do this (efficiently) with existing functions in NNlib, I could not figure it out.

Possible Implementation

No response

3f6a avatar Jul 12 '25 18:07 3f6a