NNlib.jl icon indicating copy to clipboard operation
NNlib.jl copied to clipboard

add topk

Open CarloLucibello opened this issue 3 years ago • 6 comments

Initial steps to fix #352

TODO

  • [ ] give more thought to the interface (which outputs do we expect?)
  • [ ] add rrule or write it in an AD friendly way
  • [ ] gpu support

CarloLucibello avatar Sep 30 '21 08:09 CarloLucibello

I think the obvious way to make this AD-able would be just rely on the gradient for getindex, which will store the indices for the backwards pass. Maybe it's as simple as this:

topk(x::AbstractArray, k::Integer; kw...) = x[topkind(x, k; kw...)]

topkind(x::AbstractVector, k::Integer; dims::Integer=1, rev=true, kw...) = (@assert dims==1; partialsortperm(x, 1:k; rev=rev, kw...))
topkind(x::AbstractArray, k::Integer; dims::Integer=1, rev=true, kw...) = mapslices(y -> topkind(y, k; rev=rev, kw...), x; dims=dims)
@nograd topkperm

Unlike the PR this doesn't return the permutation, but do you need it for anything else? It also defaults to first dimension, and won't accept dims=:.

But to make this GPU-friendly... There should be no aliasing issues with the gradient. There is partialsort!(::CuVector, ...) but no sortperm and no mapslices. There is a sort!(::CuMatrix; dims). Both of these call the same quicksort! so perhaps something can be built on that.

mcabbott avatar Oct 10 '21 21:10 mcabbott

For GPU compat, TF's XLA implementation should be using all out of place ops: https://github.com/tensorflow/tensorflow/blob/8d72537c6abf5a44103b57b9c2e22c14f5f49698/tensorflow/core/tpu/kernels/topk_ops.cc. Of course we can't rely on any of the optimizations in XLA, so a more ideal implementation would probably look the PyTorch one here: https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cuda/TensorTopK.cu.

ToucheSir avatar Oct 10 '21 21:10 ToucheSir

What I hoped might work, re-using what CUDA.jl already has, doesn't seem to -- can this easily be fixed?

julia> sort(tuple.(cu(rand(10)), 1:10), by=first)
ERROR: InvalidIRError: compiling kernel qsort_kernel(CuDeviceVector{Tuple{Float32, Int64}, 1}, Int64, Int64, Bool, Val{true}, Int64, Nothing, typeof(isless), typeof(first), Val{1}) resulted in invalid LLVM IR
Reason: unsupported dynamic function invocation (call to zero)
Stacktrace:
 [1] bitonic_median
   @ ~/.julia/packages/CUDA/9T5Sq/src/sorting.jl:217
 [2] qsort_kernel
   @ ~/.julia/packages/CUDA/9T5Sq/src/sorting.jl:405
 [3] qsort_kernel
   @ ~/.julia/packages/CUDA/9T5Sq/src/sorting.jl:378
Reason: unsupported dynamic function invocation (call to zero)
Stacktrace:
 [1] bubble_sort
   @ ~/.julia/packages/CUDA/9T5Sq/src/sorting.jl:274
 [2] qsort_kernel
   @ ~/.julia/packages/CUDA/9T5Sq/src/sorting.jl:401
 [3] qsort_kernel
   @ ~/.julia/packages/CUDA/9T5Sq/src/sorting.jl:378
Stacktrace:
  [1] check_ir(job::GPUCompiler.CompilerJob{GPUCompiler.PTXCompilerTarget, CUDA.CUDACompilerParams, GPUCompiler.FunctionSpec{typeof(CUDA.Quicksort.qsort_kernel), Tuple{CuDeviceVector{Tuple{Float32, Int64}, 1}, Int64, Int64, Bool, Val{true}, Int64, Nothing, typeof(isless), typeof(first), Val{1}}}}, args::LLVM.Module)
    @ GPUCompiler ~/.julia/packages/GPUCompiler/fG3xK/src/validation.jl:111
  [2] macro expansion
    @ ~/.julia/packages/GPUCompiler/fG3xK/src/driver.jl:319 [inlined]

mcabbott avatar Oct 10 '21 22:10 mcabbott

OK, this seems to work:

function topkperm(x::CuArray, k::Integer; dims::Integer=1, rev=true, lt=isless, by=identity)
    tups = tuple.(x, reshape(axes(x,dims), fill(1, dims-1)..., :))
    CUDA.quicksort!(tups; lt=(rev ? !lt : lt), by=by∘first, dims=dims, partial_k=1:k)
    tv = view(tups, ntuple(d -> d==dims ? (1:k) : (:), ndims(x))...)
    broadcast(tv, CartesianIndices(ntuple(d -> d==dims ? Base.OneTo(1) : axes(x,d), ndims(x)))) do (_,i), J
        CartesianIndex(ntuple(d -> d==dims ? i : J[d], ndims(x)))
    end
end

# piracy to make e.g. this work:  sort(tuple.(cu(rand(10)), 1:10), by=first)
@inline Base.zero(::Type{T}) where {T<:Tuple{Vararg{Any,N}}} where {N} = ntuple(i -> zero(T.parameters[i]), N)
@inline Base.one(::Type{T}) where {T<:Tuple{Vararg{Any,N}}} where {N} = ntuple(i -> one(T.parameters[i]), N)
julia> x = 100randn(5,6)
5×6 Matrix{Float64}:
 -123.513   -106.66      25.3997   41.3127   105.062    -20.8767
  161.838     49.7304   -44.2289  -44.0282  -227.478    -62.3863
  -99.9103    90.3985  -203.78     22.0575   -14.5563  -242.797
   50.9009   120.479   -213.53     53.8734   -33.2207  -118.205
   50.3431    54.3659    49.8969  204.863   -103.487    -41.3977

julia> topk(cu(x), 2)
2×6 CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}:
 161.838   120.479   49.8969  204.863   105.062   -20.8767
  50.9009   90.3985  25.3997   53.8734  -14.5563  -41.3977

My CPU version above is not correct, since it uses indices along one slice as linear indices in the whole array. It would need to also make CartesianIndices everywhere. After which it's not just one line... is there a tidier way?

function topkperm(x::AbstractArray, k::Integer; dims::Integer=1, rev=true, kw...)
    out = similar(CartesianIndices(ntuple(d -> d==dims ? Base.OneTo(k) : axes(x,d), ndims(x))))
    iters = ntuple(d -> d==dims ? (Colon(),) : axes(x,d), ndims(x))
    for J in Iterators.product(iters...)
        p = partialsortperm(view(x, J...), 1:k; rev=rev, kw...)
        for i in 1:k
            I = ntuple(d -> d==dims ? i : J[d], ndims(x))
            PI = ntuple(d -> d==dims ? p[i] : J[d], ndims(x))
            out[I...] = CartesianIndex(PI)
        end
    end
    out
end

mcabbott avatar Oct 10 '21 22:10 mcabbott

@mcabbott thanks for all these suggestions, feel free to push on this branch.

I like the topk - topkperm decoupling, so that it mimics base's sort functions.

Representation of the (partial) permutation as an array of CartesianIndex it's redundant, since it would be enough to return an array of integers

topkout[i1, ..,iN, k] = x[i1, ...,iN, perm[i1, ..., iN, k]]

We could return a custom type for perm storing the integer permutation and dims. getindex could be overloaded so that we can do x[perm]. But maybe this adds complication that turns out to be not so useful in the end, so I would be totally fine with returning cartesian indexes and revisiting later if needed.

CarloLucibello avatar Oct 11 '21 07:10 CarloLucibello

Yes I don't like the redundancy of returning CartesianIndices, but I do like the simplicity of getindex. Maybe returning an array of linear indices instead would be nicer?

mcabbott avatar Oct 11 '21 13:10 mcabbott