NNlib.jl icon indicating copy to clipboard operation
NNlib.jl copied to clipboard

add Sparsemax activation

Open tylerjthomas9 opened this issue 3 years ago • 2 comments

Source paper: http://arxiv.org/abs/1602.02068

PyTorch implementation: https://github.com/Qwicen/node/blob/master/lib/nn_utils.py

I started working on implementing sparsemax in Julia for TabNet. I thought that it would best fit in NNlib.jl. It should have the exact same functionality as softmax.

tylerjthomas9 avatar Oct 01 '21 19:10 tylerjthomas9

If you have an implementation, then a PR would be welcome! We can iterate the design there.

darsnack avatar Oct 01 '21 20:10 darsnack

Still trying to get the jacobian to work, but I have the initial forward pass

using NNlib
using LinearAlgebra
using Zygote

sparsemax(x; dims = 1) = sparsemax!(similar(x, (float ∘ eltype)(x)), x; dims = dims)

sparsemax!(x; dims = 1) = sparsemax!(x, x; dims = dims)

function sparsemax!(out::AbstractArray{T}, x::AbstractArray; dims = 1) where {T}
     # only 2D tensors are supported
     @assert dims in (1, 2)
    
     max_ = maximum(x; dims=dims)
     x .-= max_
 
     # make ix like
     d = size(x, dims)
     if dims==1
         rhos = reshape(collect(1:d), d, 1) |> typeof(x)
     elseif dims == 2 
         rhos = reshape(collect(1:d), 1, d) |> typeof(x)
     end
         
 
     # compute threshold and support
     x_sorted = sort(x; dims=dims, rev=true)
     x_cumsum = cumsum(x_sorted; dims=dims) .- 1.0
     support =  rhos .* x_sorted .> x_cumsum
     support_size = vec(sum(support; dims=dims)) |> Vector{Int64}
     if dims == 1
         tau = diag(NNlib.gather(transpose(x_cumsum), support_size))
     elseif dims == 2
         tau = diag(NNlib.gather(x_cumsum, support_size))
     end
     tau ./= support_size
     
     if dims == 1
         out = clamp.(x .- transpose(tau), 0, Inf)
     elseif dims == 2
         out =  clamp.(x .- tau, 0, Inf)
     end
end

x = [0.3367 -0.1863; 0.1288 2.2082; 0.2345 -0.638; 0.2303 0.4617; -1.1229 0.2674] 
println("Sparsemax probabilities")
sparsemax(x; dims=1)
Sparsemax probabilities

2×5 Matrix{Float64}:
 0.7615  0.0  0.93625  0.3843  0.0
 0.2385  1.0  0.06375  0.6157  1.0

tylerjthomas9 avatar Oct 04 '21 21:10 tylerjthomas9