NNlib.jl
NNlib.jl copied to clipboard
`batched_transpose` with multiple batch dimensions
Motivation and description
There exists a method for batched_mul that reshapes arrays to allow for an arbitrary number of batch dimensions:
function batched_mul(x::AbstractArray{T1,N}, y::AbstractArray{T2,N}) where {T1,T2,N}
batch_size = size(x)[3:end]
@assert batch_size == size(y)[3:end] "batch size has to be the same for the two arrays."
x2 = reshape(x, size(x, 1), size(x, 2), :)
y2 = reshape(y, size(y, 1), size(y, 2), :)
z = batched_mul(x2, y2)
return reshape(z, size(z, 1), size(z, 2), batch_size...)
end
It would be useful to have support for this with batched_transpose and batched_adjoint as well.
Possible Implementation
The existing code is quite sophisticated and "lazy", so something like this wouldn't fly:
batched_transpose(A::AbstractArray{T, N}) where {T <: Real, N} = permutedims(A, (2, 1, 3:N))
I imagine it would be possible to generalize the code beyond three dimensions though. Indexing methods are currently hard-coded. Things like the strides would also need to be generalized:
function Base.strides(A::Union{BatchedTranspose, BatchedAdjoint{<:Real}})
sp = strides(A.parent)
(sp[2], sp[1], sp[3:end]...)
end
Is it better to just use PermutedDimsArray?
After some thinking and tinkering, I've concluded that PermutedDimsArray works fine.
For my use case however, where I use it to define a custom chain rule, I needed to use the inner constructor with all the type parameters like so:
# permutation needs to be passed as type parameters directly so the type can be inferred
function _batched_transpose(A::AbstractArray{T, N}) where {T, N}
perm = (2, 1, 3:N...)
PermutedDimsArray{T, N, perm, perm, typeof(A)}(A)
end
or else I would get an error:
function _batched_transpose(A::AbstractArray{T, N}) where {T, N}
perm = (2, 1, 3:N...)
PermutedDimsArray(A, perm)
end
using Test
@inferred _batched_transpose(rand(4, 5, 6))
# output:
ERROR: return type PermutedDimsArray{Float64, 3, (2, 1, 3), (2, 1, 3), Array{Float64, 3}} does not match inferred return type PermutedDimsArray{Float64, 3, _A, _B, Array{Float64, 3}} where {_A, _B}
I suspect this is because of the splat in the regular constructor:
function PermutedDimsArray(data::AbstractArray{T,N}, perm) where {T,N}
length(perm) == N || throw(ArgumentError(string(perm, " is not a valid permutation of dimensions 1:", N)))
iperm = invperm(perm)
PermutedDimsArray{T,N,(perm...,),(iperm...,),typeof(data)}(data)
end
This isn't really related to the issue, but I figured I'd include it for documentation purposes.😄
EDIT: it's probably not the splatting itself, but the fact that the permutation is derived from the type parameter N, so it's essentially a constant.
EDIT 2: somewhat expectedly, CUDA doesn't like this, as it ends up wanting to do scalar indexing.
When an array with multiple batch dimensions needs to be transposed for use in batched_mul, I found this to work alright:
function batched_mul_transpose1(x::AbstractArray{T1,N}, y::AbstractArray{T2,N}) where {T1,T2,N}
batch_size = size(x)[3:end]
@assert batch_size == size(y)[3:end] "batch size has to be the same for the two arrays."
x2 = reshape(x, size(x, 1), size(x, 2), :) |> batched_transpose # call batched_transpose after flattening batch dimensions
y2 = reshape(y, size(y, 1), size(y, 2), :)
z = batched_mul(x2, y2)
return reshape(z, size(z, 1), size(z, 2), batch_size...)
end
This would be the same as batched_mul(batched_transpose(x), y).
It's tricky. Perhaps there need to me methods of batched_mul accepting these >3 dimension BatchedAdjoint types, so that the reshape affects the wrapped Array (or CuArray) rather than composing another wrapper (which CUDA doesn't like, as you saw).
Or perhaps the reshaping to 3D should be done by a utility function which knows about BatchedAdjoint, not just reshape.
Xref https://github.com/FluxML/NNlib.jl/issues/391 about other questions about batched_mul accepting >3 dimensions.
(Also, some regret that we didn't go with an interface like batched_mul(A, adjoint, B), instead of array wrappers!)