NNlib.jl
NNlib.jl copied to clipboard
add topk
Initial steps to fix #352
TODO
- [ ] give more thought to the interface (which outputs do we expect?)
- [ ] add rrule or write it in an AD friendly way
- [ ] gpu support
I think the obvious way to make this AD-able would be just rely on the gradient for getindex, which will store the indices for the backwards pass. Maybe it's as simple as this:
topk(x::AbstractArray, k::Integer; kw...) = x[topkind(x, k; kw...)]
topkind(x::AbstractVector, k::Integer; dims::Integer=1, rev=true, kw...) = (@assert dims==1; partialsortperm(x, 1:k; rev=rev, kw...))
topkind(x::AbstractArray, k::Integer; dims::Integer=1, rev=true, kw...) = mapslices(y -> topkind(y, k; rev=rev, kw...), x; dims=dims)
@nograd topkperm
Unlike the PR this doesn't return the permutation, but do you need it for anything else? It also defaults to first dimension, and won't accept dims=:
.
But to make this GPU-friendly... There should be no aliasing issues with the gradient. There is partialsort!(::CuVector, ...)
but no sortperm and no mapslices. There is a sort!(::CuMatrix; dims)
. Both of these call the same quicksort!
so perhaps something can be built on that.
For GPU compat, TF's XLA implementation should be using all out of place ops: https://github.com/tensorflow/tensorflow/blob/8d72537c6abf5a44103b57b9c2e22c14f5f49698/tensorflow/core/tpu/kernels/topk_ops.cc. Of course we can't rely on any of the optimizations in XLA, so a more ideal implementation would probably look the PyTorch one here: https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cuda/TensorTopK.cu.
What I hoped might work, re-using what CUDA.jl already has, doesn't seem to -- can this easily be fixed?
julia> sort(tuple.(cu(rand(10)), 1:10), by=first)
ERROR: InvalidIRError: compiling kernel qsort_kernel(CuDeviceVector{Tuple{Float32, Int64}, 1}, Int64, Int64, Bool, Val{true}, Int64, Nothing, typeof(isless), typeof(first), Val{1}) resulted in invalid LLVM IR
Reason: unsupported dynamic function invocation (call to zero)
Stacktrace:
[1] bitonic_median
@ ~/.julia/packages/CUDA/9T5Sq/src/sorting.jl:217
[2] qsort_kernel
@ ~/.julia/packages/CUDA/9T5Sq/src/sorting.jl:405
[3] qsort_kernel
@ ~/.julia/packages/CUDA/9T5Sq/src/sorting.jl:378
Reason: unsupported dynamic function invocation (call to zero)
Stacktrace:
[1] bubble_sort
@ ~/.julia/packages/CUDA/9T5Sq/src/sorting.jl:274
[2] qsort_kernel
@ ~/.julia/packages/CUDA/9T5Sq/src/sorting.jl:401
[3] qsort_kernel
@ ~/.julia/packages/CUDA/9T5Sq/src/sorting.jl:378
Stacktrace:
[1] check_ir(job::GPUCompiler.CompilerJob{GPUCompiler.PTXCompilerTarget, CUDA.CUDACompilerParams, GPUCompiler.FunctionSpec{typeof(CUDA.Quicksort.qsort_kernel), Tuple{CuDeviceVector{Tuple{Float32, Int64}, 1}, Int64, Int64, Bool, Val{true}, Int64, Nothing, typeof(isless), typeof(first), Val{1}}}}, args::LLVM.Module)
@ GPUCompiler ~/.julia/packages/GPUCompiler/fG3xK/src/validation.jl:111
[2] macro expansion
@ ~/.julia/packages/GPUCompiler/fG3xK/src/driver.jl:319 [inlined]
OK, this seems to work:
function topkperm(x::CuArray, k::Integer; dims::Integer=1, rev=true, lt=isless, by=identity)
tups = tuple.(x, reshape(axes(x,dims), fill(1, dims-1)..., :))
CUDA.quicksort!(tups; lt=(rev ? !lt : lt), by=by∘first, dims=dims, partial_k=1:k)
tv = view(tups, ntuple(d -> d==dims ? (1:k) : (:), ndims(x))...)
broadcast(tv, CartesianIndices(ntuple(d -> d==dims ? Base.OneTo(1) : axes(x,d), ndims(x)))) do (_,i), J
CartesianIndex(ntuple(d -> d==dims ? i : J[d], ndims(x)))
end
end
# piracy to make e.g. this work: sort(tuple.(cu(rand(10)), 1:10), by=first)
@inline Base.zero(::Type{T}) where {T<:Tuple{Vararg{Any,N}}} where {N} = ntuple(i -> zero(T.parameters[i]), N)
@inline Base.one(::Type{T}) where {T<:Tuple{Vararg{Any,N}}} where {N} = ntuple(i -> one(T.parameters[i]), N)
julia> x = 100randn(5,6)
5×6 Matrix{Float64}:
-123.513 -106.66 25.3997 41.3127 105.062 -20.8767
161.838 49.7304 -44.2289 -44.0282 -227.478 -62.3863
-99.9103 90.3985 -203.78 22.0575 -14.5563 -242.797
50.9009 120.479 -213.53 53.8734 -33.2207 -118.205
50.3431 54.3659 49.8969 204.863 -103.487 -41.3977
julia> topk(cu(x), 2)
2×6 CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}:
161.838 120.479 49.8969 204.863 105.062 -20.8767
50.9009 90.3985 25.3997 53.8734 -14.5563 -41.3977
My CPU version above is not correct, since it uses indices along one slice as linear indices in the whole array. It would need to also make CartesianIndices everywhere. After which it's not just one line... is there a tidier way?
function topkperm(x::AbstractArray, k::Integer; dims::Integer=1, rev=true, kw...)
out = similar(CartesianIndices(ntuple(d -> d==dims ? Base.OneTo(k) : axes(x,d), ndims(x))))
iters = ntuple(d -> d==dims ? (Colon(),) : axes(x,d), ndims(x))
for J in Iterators.product(iters...)
p = partialsortperm(view(x, J...), 1:k; rev=rev, kw...)
for i in 1:k
I = ntuple(d -> d==dims ? i : J[d], ndims(x))
PI = ntuple(d -> d==dims ? p[i] : J[d], ndims(x))
out[I...] = CartesianIndex(PI)
end
end
out
end
@mcabbott thanks for all these suggestions, feel free to push on this branch.
I like the topk
- topkperm
decoupling, so that it mimics base's sort functions.
Representation of the (partial) permutation as an array of CartesianIndex it's redundant, since it would be enough to return an array of integers
topkout[i1, ..,iN, k] = x[i1, ...,iN, perm[i1, ..., iN, k]]
We could return a custom type for perm
storing the integer permutation and dims
. getindex
could be overloaded so that we can do x[perm]
. But maybe this adds complication that turns out to be not so useful in the end, so I would be totally fine with returning cartesian indexes and revisiting later if needed.
Yes I don't like the redundancy of returning CartesianIndices, but I do like the simplicity of getindex. Maybe returning an array of linear indices instead would be nicer?