NNlib.jl icon indicating copy to clipboard operation
NNlib.jl copied to clipboard

add sparsemax

Open tylerjthomas9 opened this issue 3 years ago • 6 comments

Initial sparsemax implementation #354

TODO

  • [x] add support for vectors (similar to softmax)
  • [ ] add comprehensive testing
  • [x] compare gradients with PyTorch implementations

tylerjthomas9 avatar Nov 19 '21 16:11 tylerjthomas9

I think we can avoid defining a mutating version sparsemax! and just go with sparsemax, since the implementation of sparsemax! does a lot of allocations in any case.

Let's also make sure that ∇sparsemax does not perform mutation internally, so that we can differentiate sparsemax twice

CarloLucibello avatar Nov 21 '21 11:11 CarloLucibello

Here's an attempt at a faster version:

function sm5(x::AbstractArray; dims::Integer=1)
    z = if x isa AbstractVector
        dims == 1 || return float(x)  # do-nothing case, same return type
        sort(float(x); rev=true)  # `sort([3,2,1]; dims=1)` is an error
    else
        # `float` is usually free, except on integers etc, needed to make re-use safe.
        sort(float(x); dims, rev=true)
    end
    mask = _sm4mask(z, dims)
    tausum = sum(z .= z .* mask; dims)  # no longer need `z`, re-use to save allocations
    kay = sum(mask; dims)
    z .= _relu.(x .- (tausum .- 1) ./ kay)
end

function _sm4mask(z::AbstractArray, dim::Integer)
    acc = cumsum(z; dims=dim)
    if dim == 1  
        # Treat common cases specially for speed -- this is a factor 3 in @btime sm5($x) below
        acc .= 1 .+ axes(z,1) .* z .> acc
    elseif dim == 2
        acc .= 1 .+ axes(z,2)' .* z .> acc
    else
        # This isn't type-stable. Writing into `acc` ensures the whole function still is:
        cnt = reshape(axes(x, dim), ntuple(_->1, dim-1)..., :)
        # cnt = reshape(collect(axes(x, dim)), ntuple(d -> d==dim ? (:) : 1, ndims(z)))
        acc .= 1 .+ cnt .* z .> acc
    end
    acc
end

_relu(x) = _ifelse(x>0, x, false)  # different gradient at zero

_ifelse(p, x, y) = ifelse(p, promote(x, y)...)

function ∇sm5(dy::AbstractArray, y::AbstractArray; dims::Integer=1)
    vee = sum(dy .* (y .> 0); dims)
    kay = count(>(0), y; dims)  # could also keep from forward pass?
    _ifelse.(y .> 0, dy .- vee ./ kay, 0)
end

x = rand(1:5, 5,10) ./ 5
sparsemax(x)
sm5(x)

sparsemax(x) ≈ sm5(x)
all(≈(1), sum(sm5(x), dims=1))
julia> @btime sparsemax($x);
  min 4.470 μs, mean 4.970 μs (63 allocations, 7.11 KiB. GC mean 5.49%)  # a9f9d7f, when I wrote this
  min 2.176 μs, mean 2.438 μs (22 allocations, 2.83 KiB. GC mean 6.53%)  # cde15de, PR update
  min 3.271 μs, mean 3.610 μs (30 allocations, 3.08 KiB. GC mean 5.28%)  # 95ee6e4
 
julia> @btime sm1($x);  # simple version which is purely slices & loops, no broadcast
  min 1.188 μs, mean 1.311 μs (25 allocations, 2.06 KiB. GC mean 4.23%)

julia> @btime sm5($x);
  min 526.260 ns, mean 604.508 ns (6 allocations, 1.34 KiB. GC mean 10.30%)

For the gradient, nonzeros = x[x.!=0.0] makes a vector. Plausible variations of that don't seem to lead to correct answers.

ForwardDiff.gradient(x -> sum(sparsemax(x)), x)  # all zero! 
ForwardDiff.jacobian(sparsemax, x)  # not zero, obviously

delta = 0 .* x .+ randn.();
ForwardDiff.gradient(x -> sum(delta .* sparsemax(x)), x)
∇sparsemax!(delta, x, sparsemax(x))
∇sm5(delta, sm5(x))

mcabbott avatar Nov 21 '21 15:11 mcabbott

Current version has the correct implementation

Trying to check the gradient numerically:

julia> x = rand(1:5, 5,10) ./ 5;

julia> delta = 0 .* x .+ randn.();

julia> ∇sparsemax(delta, sparsemax(x))  # cde15de
5×10 Matrix{Float64}:
 -0.0307726  -0.0315741  -0.197077   -0.0307726   …   0.0445302  -0.0307726  -0.0307726
 -0.498682    0.390107   -0.0632101  -0.532938       -0.0307726  -0.0307726  -0.456715
 -0.0307726  -0.0307726   1.13294    -0.00631409     -0.109703    0.194811   -0.534492
  0.275757   -0.0307726   0.132314    0.438519       -0.166756    0.0616505  -0.0307726
 -0.191945   -0.0307726  -0.0307726  -0.0307726      -0.0984743   0.0760821  -0.0307726

julia> ∇sm5(delta, sm5(x))
5×10 Matrix{Float64}:
  0.0       -0.526767  -1.57823    0.0       …   0.0        1.35172    0.0         0.0
 -0.488485   0.526767  -0.685784  -1.17488       0.0        0.0        0.0         0.0777773
  0.0        0.0        1.6463     0.268066      0.579856  -0.470059  -0.0775806  -0.0777773
  1.30843    0.0        0.61771    0.906813     -0.579856  -0.45652   -0.22328     0.0
 -0.819949   0.0        0.0        0.0           0.0       -0.425142   0.300861    0.0

julia> using ForwardDiff

julia> ForwardDiff.gradient(x -> sum(delta .* sparsemax(x)), x)
5×10 Matrix{Float64}:
  0.0       -0.526767  -1.57823    0.0       …  -0.0475105   1.35172    0.0         0.0
 -0.488485   0.526767  -0.685784  -1.17488       0.242721    0.0        0.0         0.0777773
  0.0        0.0        1.6463     0.268066      0.967818   -0.470059  -0.0775806  -0.0777773
  1.30843    0.0        0.61771    0.906813     -0.191893   -0.45652   -0.22328     0.0
 -0.819949   0.0        0.0        0.0          -0.971135   -0.425142   0.300861    0.0

julia> using FiniteDifferences, ZChop

julia> zchop(grad(central_fdm(5, 1), x -> sum(delta .* sm4(x)), x)[1], 1e-10)
5×10 Matrix{Float64}:
  0.0       -0.16144   -1.57823    0.0       …  -0.145158    1.35172    0.0         0.0
 -0.488485   0.892093  -0.685784  -1.17488      -0.0484138   0.0        0.0         0.0777773
  0.0        0.0        1.6463     0.268066      0.773837   -0.470059  -0.0775806  -0.0777773
  1.30843   -0.730653   0.61771    0.906813     -0.385874   -0.45652   -0.22328     0.0
 -0.819949   0.0        0.0        0.0          -0.453033   -0.425142   0.300861    0.0

Is it clear that the pytorch code gets this right? It's not so obvious to me what it's doing. If you are set up to do so, can you run them all on the same (tricky) input & compare?

The mathematics is simple but, as mentioned above, there will quite often be issues of which subgradient to choose. My function _relu tries to ensure that ForwardDiff picks the same one that is easy to implement using y .> 0, without storing its arguments.

Finite differencing will tend to average out sub-gradients, but need not produce something which is a valid choice at all at such points. I'm not sure which class the 2nd column in this example is in. The other columns are something of a sanity-check. This is what test_rrule will do, which is called by gradtest, but because of such issues it may not be usable for this function.

Perhaps the test should just be that Zygote agrees with ForwardDiff. For such a test to be safe, I think it's best to have a function like _relu (it doesn't matter what it's called) so that this won't break if conventions change elsewhere. (No relu, no max. Also, max is weirdly slow.)

The current commit also omits many .= from sm5 above, which means that it is not type-stable, and allocates twice as much memory. Why? And now removes much of the detail of the mask function, which wasn't there by accident -- see annotations now added above, and benchmarks.

julia> @code_warntype sparsemax(x)
Body::Any
1 ─ %1 = Main.:(var"#sparsemax#33")(1, #self#, x)::Any
└──      return %1

julia> @code_warntype sm5(x)
Body::Matrix{Float32}
1 ─ %1 = Main.:(var"#sm4#25")(1, #self#, x)::Matrix{Float32}
└──      return %1

julia> @test @inferred(sparsemax([1,2,3])) == [0,0,1]  # some of these can be in tests
ERROR: return type Vector{Float64} does not match inferred return type Any

julia> @test sparsemax([1,2,3]; dims=3) == [1,2,3]
ERROR: UndefVarError: z not defined

mcabbott avatar Nov 23 '21 01:11 mcabbott

I think it's best to have a function like _relu (it doesn't matter what it's called) so that this won't break if conventions change elsewhere. (No relu, no max. Also, max is weirdly slow.) My function _relu tries to ensure that ForwardDiff picks the same one that is easy to implement using y .> 0, without storing its arguments.

Great to know, thank you!

The current commit also omits many .= from sm5 above, which means that it is not type-stable, and allocates twice as much memory. Why?

Reverted this, I am not sure why I removed it.

Is it clear that the pytorch code gets this right? It's not so obvious to me what it's doing. If you are set up to do so, can you run them all on the same (tricky) input & compare?

I will test it vs the python results.

tylerjthomas9 avatar Nov 23 '21 01:11 tylerjthomas9

I think the gradient is working. I also went through the paper again, and we are calculating it exactly how the authors did. I have attached an image of the paper's derived jacobian. We are using Δ instead of "p" for the sparsemax outputs, nonzeros instead of "s" for the nonzero subset of "p". Additionally, we are using out instead of "v", andtotal instead of "$\hat{v}$". image

tylerjthomas9 avatar Nov 27 '21 18:11 tylerjthomas9

No, Δ is the backward gradient we receive, y is the sparsemax output.

I'm not sure I ever decoded the paper's notation. But the question of whether the gradient is correct doesn't require that -- the paper may have clues but this is simple enough that deriving it might be quicker than decoding them. Or the paper may have mistakes. The python code might too.

We can check numerically. The easy case is small numbers, like rand(3,4)/10, where there should be no zeros in the output. I thought that large numbers (hence zeros) would be harder but am not sure today.

I am surprised that tests pass here. Why doesn't it call test_rrule, and fail?

julia> using ChainRulesTestUtils

julia> test_rrule(sparsemax, rand(3,4))
nonzeros = Bool[1 1 1 1; 1 1 1 1; 1 1 1 1]
nonzeros = Bool[1 1 1 1; 1 1 1 1; 1 1 1 1]
test_rrule: sparsemax on Matrix{Float64}: Test Failed at /Users/me/.julia/packages/ChainRulesTestUtils/73Y9Q/src/check_result.jl:24
  Expression: isapprox(actual, expected; kwargs...)
   Evaluated: isapprox([-0.612288804854263 -1.3615411120031258 0.7652273269279669 -3.179013619758306; 0.11318318165635466 2.061264365438702 1.309074111596146 -0.9704853614354765; 0.3360509543675956 0.1273032261697271 -0.5543590617133406 1.9655847936080213], [-1.590000000000055 -5.723333333333301 -2.1000000000001666 -6.1266666666666945; 1.2299999999999827 4.456666666666624 2.0999999999998784 -1.6066666666666587; 0.36000000000001586 1.2666666666663948 -3.3757265396302482e-15 7.733333333333414]; rtol = 1.0e-9, atol = 1.0e-9)

mcabbott avatar Nov 27 '21 23:11 mcabbott