NNlib.jl
NNlib.jl copied to clipboard
add Sparsemax activation
Source paper: http://arxiv.org/abs/1602.02068
PyTorch implementation: https://github.com/Qwicen/node/blob/master/lib/nn_utils.py
I started working on implementing sparsemax
in Julia for TabNet. I thought that it would best fit in NNlib.jl. It should have the exact same functionality as softmax
.
If you have an implementation, then a PR would be welcome! We can iterate the design there.
Still trying to get the jacobian to work, but I have the initial forward pass
using NNlib
using LinearAlgebra
using Zygote
sparsemax(x; dims = 1) = sparsemax!(similar(x, (float ∘ eltype)(x)), x; dims = dims)
sparsemax!(x; dims = 1) = sparsemax!(x, x; dims = dims)
function sparsemax!(out::AbstractArray{T}, x::AbstractArray; dims = 1) where {T}
# only 2D tensors are supported
@assert dims in (1, 2)
max_ = maximum(x; dims=dims)
x .-= max_
# make ix like
d = size(x, dims)
if dims==1
rhos = reshape(collect(1:d), d, 1) |> typeof(x)
elseif dims == 2
rhos = reshape(collect(1:d), 1, d) |> typeof(x)
end
# compute threshold and support
x_sorted = sort(x; dims=dims, rev=true)
x_cumsum = cumsum(x_sorted; dims=dims) .- 1.0
support = rhos .* x_sorted .> x_cumsum
support_size = vec(sum(support; dims=dims)) |> Vector{Int64}
if dims == 1
tau = diag(NNlib.gather(transpose(x_cumsum), support_size))
elseif dims == 2
tau = diag(NNlib.gather(x_cumsum, support_size))
end
tau ./= support_size
if dims == 1
out = clamp.(x .- transpose(tau), 0, Inf)
elseif dims == 2
out = clamp.(x .- tau, 0, Inf)
end
end
x = [0.3367 -0.1863; 0.1288 2.2082; 0.2345 -0.638; 0.2303 0.4617; -1.1229 0.2674]
println("Sparsemax probabilities")
sparsemax(x; dims=1)
Sparsemax probabilities
2×5 Matrix{Float64}:
0.7615 0.0 0.93625 0.3843 0.0
0.2385 1.0 0.06375 0.6157 1.0