NNlib.jl
NNlib.jl copied to clipboard
add sparsemax
Initial sparsemax implementation #354
TODO
- [x] add support for vectors (similar to softmax)
- [ ] add comprehensive testing
- [x] compare gradients with PyTorch implementations
I think we can avoid defining a mutating version sparsemax!
and just go with sparsemax
, since the implementation of sparsemax!
does a lot of allocations in any case.
Let's also make sure that ∇sparsemax
does not perform mutation internally, so that we can differentiate sparsemax
twice
Here's an attempt at a faster version:
function sm5(x::AbstractArray; dims::Integer=1)
z = if x isa AbstractVector
dims == 1 || return float(x) # do-nothing case, same return type
sort(float(x); rev=true) # `sort([3,2,1]; dims=1)` is an error
else
# `float` is usually free, except on integers etc, needed to make re-use safe.
sort(float(x); dims, rev=true)
end
mask = _sm4mask(z, dims)
tausum = sum(z .= z .* mask; dims) # no longer need `z`, re-use to save allocations
kay = sum(mask; dims)
z .= _relu.(x .- (tausum .- 1) ./ kay)
end
function _sm4mask(z::AbstractArray, dim::Integer)
acc = cumsum(z; dims=dim)
if dim == 1
# Treat common cases specially for speed -- this is a factor 3 in @btime sm5($x) below
acc .= 1 .+ axes(z,1) .* z .> acc
elseif dim == 2
acc .= 1 .+ axes(z,2)' .* z .> acc
else
# This isn't type-stable. Writing into `acc` ensures the whole function still is:
cnt = reshape(axes(x, dim), ntuple(_->1, dim-1)..., :)
# cnt = reshape(collect(axes(x, dim)), ntuple(d -> d==dim ? (:) : 1, ndims(z)))
acc .= 1 .+ cnt .* z .> acc
end
acc
end
_relu(x) = _ifelse(x>0, x, false) # different gradient at zero
_ifelse(p, x, y) = ifelse(p, promote(x, y)...)
function ∇sm5(dy::AbstractArray, y::AbstractArray; dims::Integer=1)
vee = sum(dy .* (y .> 0); dims)
kay = count(>(0), y; dims) # could also keep from forward pass?
_ifelse.(y .> 0, dy .- vee ./ kay, 0)
end
x = rand(1:5, 5,10) ./ 5
sparsemax(x)
sm5(x)
sparsemax(x) ≈ sm5(x)
all(≈(1), sum(sm5(x), dims=1))
julia> @btime sparsemax($x);
min 4.470 μs, mean 4.970 μs (63 allocations, 7.11 KiB. GC mean 5.49%) # a9f9d7f, when I wrote this
min 2.176 μs, mean 2.438 μs (22 allocations, 2.83 KiB. GC mean 6.53%) # cde15de, PR update
min 3.271 μs, mean 3.610 μs (30 allocations, 3.08 KiB. GC mean 5.28%) # 95ee6e4
julia> @btime sm1($x); # simple version which is purely slices & loops, no broadcast
min 1.188 μs, mean 1.311 μs (25 allocations, 2.06 KiB. GC mean 4.23%)
julia> @btime sm5($x);
min 526.260 ns, mean 604.508 ns (6 allocations, 1.34 KiB. GC mean 10.30%)
For the gradient, nonzeros = x[x.!=0.0]
makes a vector. Plausible variations of that don't seem to lead to correct answers.
ForwardDiff.gradient(x -> sum(sparsemax(x)), x) # all zero!
ForwardDiff.jacobian(sparsemax, x) # not zero, obviously
delta = 0 .* x .+ randn.();
ForwardDiff.gradient(x -> sum(delta .* sparsemax(x)), x)
∇sparsemax!(delta, x, sparsemax(x))
∇sm5(delta, sm5(x))
Current version has the correct implementation
Trying to check the gradient numerically:
julia> x = rand(1:5, 5,10) ./ 5;
julia> delta = 0 .* x .+ randn.();
julia> ∇sparsemax(delta, sparsemax(x)) # cde15de
5×10 Matrix{Float64}:
-0.0307726 -0.0315741 -0.197077 -0.0307726 … 0.0445302 -0.0307726 -0.0307726
-0.498682 0.390107 -0.0632101 -0.532938 -0.0307726 -0.0307726 -0.456715
-0.0307726 -0.0307726 1.13294 -0.00631409 -0.109703 0.194811 -0.534492
0.275757 -0.0307726 0.132314 0.438519 -0.166756 0.0616505 -0.0307726
-0.191945 -0.0307726 -0.0307726 -0.0307726 -0.0984743 0.0760821 -0.0307726
julia> ∇sm5(delta, sm5(x))
5×10 Matrix{Float64}:
0.0 -0.526767 -1.57823 0.0 … 0.0 1.35172 0.0 0.0
-0.488485 0.526767 -0.685784 -1.17488 0.0 0.0 0.0 0.0777773
0.0 0.0 1.6463 0.268066 0.579856 -0.470059 -0.0775806 -0.0777773
1.30843 0.0 0.61771 0.906813 -0.579856 -0.45652 -0.22328 0.0
-0.819949 0.0 0.0 0.0 0.0 -0.425142 0.300861 0.0
julia> using ForwardDiff
julia> ForwardDiff.gradient(x -> sum(delta .* sparsemax(x)), x)
5×10 Matrix{Float64}:
0.0 -0.526767 -1.57823 0.0 … -0.0475105 1.35172 0.0 0.0
-0.488485 0.526767 -0.685784 -1.17488 0.242721 0.0 0.0 0.0777773
0.0 0.0 1.6463 0.268066 0.967818 -0.470059 -0.0775806 -0.0777773
1.30843 0.0 0.61771 0.906813 -0.191893 -0.45652 -0.22328 0.0
-0.819949 0.0 0.0 0.0 -0.971135 -0.425142 0.300861 0.0
julia> using FiniteDifferences, ZChop
julia> zchop(grad(central_fdm(5, 1), x -> sum(delta .* sm4(x)), x)[1], 1e-10)
5×10 Matrix{Float64}:
0.0 -0.16144 -1.57823 0.0 … -0.145158 1.35172 0.0 0.0
-0.488485 0.892093 -0.685784 -1.17488 -0.0484138 0.0 0.0 0.0777773
0.0 0.0 1.6463 0.268066 0.773837 -0.470059 -0.0775806 -0.0777773
1.30843 -0.730653 0.61771 0.906813 -0.385874 -0.45652 -0.22328 0.0
-0.819949 0.0 0.0 0.0 -0.453033 -0.425142 0.300861 0.0
Is it clear that the pytorch code gets this right? It's not so obvious to me what it's doing. If you are set up to do so, can you run them all on the same (tricky) input & compare?
The mathematics is simple but, as mentioned above, there will quite often be issues of which subgradient to choose. My function _relu
tries to ensure that ForwardDiff picks the same one that is easy to implement using y .> 0
, without storing its arguments.
Finite differencing will tend to average out sub-gradients, but need not produce something which is a valid choice at all at such points. I'm not sure which class the 2nd column in this example is in. The other columns are something of a sanity-check. This is what test_rrule
will do, which is called by gradtest
, but because of such issues it may not be usable for this function.
Perhaps the test should just be that Zygote agrees with ForwardDiff. For such a test to be safe, I think it's best to have a function like _relu
(it doesn't matter what it's called) so that this won't break if conventions change elsewhere. (No relu
, no max
. Also, max
is weirdly slow.)
The current commit also omits many .=
from sm5
above, which means that it is not type-stable, and allocates twice as much memory. Why? And now removes much of the detail of the mask function, which wasn't there by accident -- see annotations now added above, and benchmarks.
julia> @code_warntype sparsemax(x)
Body::Any
1 ─ %1 = Main.:(var"#sparsemax#33")(1, #self#, x)::Any
└── return %1
julia> @code_warntype sm5(x)
Body::Matrix{Float32}
1 ─ %1 = Main.:(var"#sm4#25")(1, #self#, x)::Matrix{Float32}
└── return %1
julia> @test @inferred(sparsemax([1,2,3])) == [0,0,1] # some of these can be in tests
ERROR: return type Vector{Float64} does not match inferred return type Any
julia> @test sparsemax([1,2,3]; dims=3) == [1,2,3]
ERROR: UndefVarError: z not defined
I think it's best to have a function like _relu (it doesn't matter what it's called) so that this won't break if conventions change elsewhere. (No relu, no max. Also, max is weirdly slow.) My function _relu tries to ensure that ForwardDiff picks the same one that is easy to implement using y .> 0, without storing its arguments.
Great to know, thank you!
The current commit also omits many .= from sm5 above, which means that it is not type-stable, and allocates twice as much memory. Why?
Reverted this, I am not sure why I removed it.
Is it clear that the pytorch code gets this right? It's not so obvious to me what it's doing. If you are set up to do so, can you run them all on the same (tricky) input & compare?
I will test it vs the python results.
I think the gradient is working. I also went through the paper again, and we are calculating it exactly how the authors did. I have attached an image of the paper's derived jacobian. We are using Δ
instead of "p" for the sparsemax outputs, nonzeros
instead of "s" for the nonzero subset of "p". Additionally, we are using out
instead of "v", andtotal
instead of "$\hat{v}$".
No, Δ
is the backward gradient we receive, y
is the sparsemax output.
I'm not sure I ever decoded the paper's notation. But the question of whether the gradient is correct doesn't require that -- the paper may have clues but this is simple enough that deriving it might be quicker than decoding them. Or the paper may have mistakes. The python code might too.
We can check numerically. The easy case is small numbers, like rand(3,4)/10
, where there should be no zeros in the output. I thought that large numbers (hence zeros) would be harder but am not sure today.
I am surprised that tests pass here. Why doesn't it call test_rrule
, and fail?
julia> using ChainRulesTestUtils
julia> test_rrule(sparsemax, rand(3,4))
nonzeros = Bool[1 1 1 1; 1 1 1 1; 1 1 1 1]
nonzeros = Bool[1 1 1 1; 1 1 1 1; 1 1 1 1]
test_rrule: sparsemax on Matrix{Float64}: Test Failed at /Users/me/.julia/packages/ChainRulesTestUtils/73Y9Q/src/check_result.jl:24
Expression: isapprox(actual, expected; kwargs...)
Evaluated: isapprox([-0.612288804854263 -1.3615411120031258 0.7652273269279669 -3.179013619758306; 0.11318318165635466 2.061264365438702 1.309074111596146 -0.9704853614354765; 0.3360509543675956 0.1273032261697271 -0.5543590617133406 1.9655847936080213], [-1.590000000000055 -5.723333333333301 -2.1000000000001666 -6.1266666666666945; 1.2299999999999827 4.456666666666624 2.0999999999998784 -1.6066666666666587; 0.36000000000001586 1.2666666666663948 -3.3757265396302482e-15 7.733333333333414]; rtol = 1.0e-9, atol = 1.0e-9)