NNlib.jl icon indicating copy to clipboard operation
NNlib.jl copied to clipboard

`batched_transpose` with multiple batch dimensions

Open AntonOresten opened this issue 1 year ago • 3 comments
trafficstars

Motivation and description

There exists a method for batched_mul that reshapes arrays to allow for an arbitrary number of batch dimensions:

function batched_mul(x::AbstractArray{T1,N}, y::AbstractArray{T2,N}) where {T1,T2,N}
    batch_size = size(x)[3:end]
    @assert batch_size == size(y)[3:end] "batch size has to be the same for the two arrays."
    x2 = reshape(x, size(x, 1), size(x, 2), :)
    y2 = reshape(y, size(y, 1), size(y, 2), :)
    z = batched_mul(x2, y2)
    return reshape(z, size(z, 1), size(z, 2), batch_size...)
  end

It would be useful to have support for this with batched_transpose and batched_adjoint as well.

Possible Implementation

The existing code is quite sophisticated and "lazy", so something like this wouldn't fly:

batched_transpose(A::AbstractArray{T, N}) where {T <: Real, N} = permutedims(A, (2, 1, 3:N))

I imagine it would be possible to generalize the code beyond three dimensions though. Indexing methods are currently hard-coded. Things like the strides would also need to be generalized:

function Base.strides(A::Union{BatchedTranspose, BatchedAdjoint{<:Real}})
    sp = strides(A.parent)
    (sp[2], sp[1], sp[3:end]...)
end

Is it better to just use PermutedDimsArray?

AntonOresten avatar May 27 '24 14:05 AntonOresten

After some thinking and tinkering, I've concluded that PermutedDimsArray works fine.

For my use case however, where I use it to define a custom chain rule, I needed to use the inner constructor with all the type parameters like so:

# permutation needs to be passed as type parameters directly so the type can be inferred
function _batched_transpose(A::AbstractArray{T, N}) where {T, N}
    perm = (2, 1, 3:N...)
    PermutedDimsArray{T, N, perm, perm, typeof(A)}(A)
end

or else I would get an error:

function _batched_transpose(A::AbstractArray{T, N}) where {T, N}
    perm = (2, 1, 3:N...)
    PermutedDimsArray(A, perm)
end

using Test
@inferred _batched_transpose(rand(4, 5, 6))

# output:
ERROR: return type PermutedDimsArray{Float64, 3, (2, 1, 3), (2, 1, 3), Array{Float64, 3}} does not match inferred return type PermutedDimsArray{Float64, 3, _A, _B, Array{Float64, 3}} where {_A, _B}

I suspect this is because of the splat in the regular constructor:

function PermutedDimsArray(data::AbstractArray{T,N}, perm) where {T,N}
    length(perm) == N || throw(ArgumentError(string(perm, " is not a valid permutation of dimensions 1:", N)))
    iperm = invperm(perm)
    PermutedDimsArray{T,N,(perm...,),(iperm...,),typeof(data)}(data)
end

This isn't really related to the issue, but I figured I'd include it for documentation purposes.😄

EDIT: it's probably not the splatting itself, but the fact that the permutation is derived from the type parameter N, so it's essentially a constant.

EDIT 2: somewhat expectedly, CUDA doesn't like this, as it ends up wanting to do scalar indexing.

AntonOresten avatar May 28 '24 23:05 AntonOresten

When an array with multiple batch dimensions needs to be transposed for use in batched_mul, I found this to work alright:

function batched_mul_transpose1(x::AbstractArray{T1,N}, y::AbstractArray{T2,N}) where {T1,T2,N}
    batch_size = size(x)[3:end]
    @assert batch_size == size(y)[3:end] "batch size has to be the same for the two arrays."
    x2 = reshape(x, size(x, 1), size(x, 2), :) |> batched_transpose # call batched_transpose after flattening batch dimensions
    y2 = reshape(y, size(y, 1), size(y, 2), :)
    z = batched_mul(x2, y2)
    return reshape(z, size(z, 1), size(z, 2), batch_size...)
end

This would be the same as batched_mul(batched_transpose(x), y).

AntonOresten avatar Jun 07 '24 15:06 AntonOresten

It's tricky. Perhaps there need to me methods of batched_mul accepting these >3 dimension BatchedAdjoint types, so that the reshape affects the wrapped Array (or CuArray) rather than composing another wrapper (which CUDA doesn't like, as you saw).

Or perhaps the reshaping to 3D should be done by a utility function which knows about BatchedAdjoint, not just reshape.

Xref https://github.com/FluxML/NNlib.jl/issues/391 about other questions about batched_mul accepting >3 dimensions.

(Also, some regret that we didn't go with an interface like batched_mul(A, adjoint, B), instead of array wrappers!)

mcabbott avatar Jun 11 '24 18:06 mcabbott