Zygote.jl icon indicating copy to clipboard operation
Zygote.jl copied to clipboard

How to speed up pullbacks when iterating over arrays?

Open sethaxen opened this issue 5 years ago • 15 comments

Consider the following example from base Julia:

evalpoly(x, p::AbstractVector) = _evalpoly(x, p)

function _evalpoly(x, p)
    N = length(p)
    ex = p[end]
    for i in N-1:-1:1
        ex = muladd(x, ex, p[i])
    end
    ex
end

This function is simple and looks at each element of p once. Ideally, its gradient would be fast. However:

julia> using Zygote, BenchmarkTools

julia> x, p = rand(), randn(10000);

julia> @btime evalpoly(x, p);
  17.323 μs (1 allocation: 16 bytes)

julia> @btime Zygote.gradient(evalpoly, x, p);
  679.433 ms (490088 allocations: 1.51 GiB)

To work around this, I'm adding a custom rule for evalpoly to ChainRules.jl (https://github.com/JuliaDiff/ChainRules.jl/pull/190), that will speed things up dramatically once https://github.com/FluxML/Zygote.jl/pull/366 is merged:

julia> @btime Zygote.gradient(evalpoly, x, p); #10,000x faster than before!!
  43.338 μs (21 allocations: 156.81 KiB)

But this is a band-aid. How can we improve this in the general case? Is the main problem likely the adjoint for getindex (e.g. https://github.com/FluxML/Zygote.jl/issues/365)? But even with a great adjoint for getindex, Zygote won't know that each element is only used once and therefore it could do something fast like allocate a single gradient vector and fill it efficiently.

For reference, jax is able to get comparable performance to the custom rule without a rule (see https://github.com/google/jax/issues/3047#issuecomment-627764538).

sethaxen avatar May 13 '20 07:05 sethaxen

JAX solves this problem in two ways:

  • The XLA compiler can convert functionally pure accumulation at a single index into in-place updates (assuming the original array values are not reused).
  • JAX has a special autodiff primitive lax.scan for loops with a statically known number of iterations. This simplifies memory allocation for the backwards pass. It's also required for XLA, since XLA needs to able to allocate all memory before running anything.

shoyer avatar May 13 '20 17:05 shoyer

Could this be done as some foldl type operation? That's one way to convey that you are iterating once; I don't know whether that information is at present used well by Zygote but it could surely be. Is this what lax.scan is doing?

Edit, with code. The gradient of this is actually much slower, right now:

julia> _evalpoly3(x,p) = foldr((a,b) -> muladd(x,b,a), p);

julia> @btime _evalpoly($x, $p)
  11.156 μs (0 allocations: 0 bytes)
-0.5897460380339452

julia> @btime _evalpoly3($x, $p)
  11.162 μs (0 allocations: 0 bytes)
-0.5897460380339452

Scalar getindex is (as I'm sure you know) making a complete zero(p) at every iteration, writing one number, and adding them up. Which is crazy here. Making some kind of sparse array instead would be great.

mcabbott avatar May 14 '20 10:05 mcabbott

S4tf uses a sparse array for this.

AriMKatz avatar May 14 '20 11:05 AriMKatz

Julia's evalpoly implementation is equivalent to one that uses foldr, but that actually has worse performance for some reason.

julia> _evalpoly(x, p) = foldr((pi, y) -> muladd(x, y, pi), p; init = zero(x))

julia> _evalpoly(x, p) === evalpoly(x, p)
true

julia> @btime Zygote.gradient(_evalpoly, x, p)
  1.081 s (1390414 allocations: 1.54 GiB)

Internally, foldr calls foldl, which in turn calls mapfoldl, but Zygote doesn't currently have a custom adjoint for these functions. But I'm more concerned about the general case. Loops like these are really common in Julia packages.

Scalar getindex is (as I'm sure you know) making a complete zero(p) at every iteration, writing one number, and adding them up. Which is crazy here. Making some kind of sparse array instead would be great.

Yeah it's really bad. I think I implemented an adjoint for getindex once using SparseVector, and it took 100ms instead of 600ms, which is still not great. I think that's because you avoid allocating the zeros, but you still need to allocate a new sparse vector at each iteration. As @willtebbutt noted in the linked issue, the right way to handle this is probably in-place accumulation. I don't know what it would take to support that though.

S4tf uses a sparse array for this.

Do you have a link to how they do it? Do they use sparse arrays for everything? Or, how do they signal conversion from a sparse array to a dense array?

sethaxen avatar May 14 '20 14:05 sethaxen

Yes, I hadn't tried but individual sparse vectors still seem like a lot of overhead, especially if (as is common) you are covering the whole array. Did you literally use spzeros or some special struct?

julia> @btime spzeros(10000)
  63.615 ns (2 allocations: 160 bytes)

It would be worth fixing mapfoldl regardless, but would be interested to hear if there's a more general solution. Could you do something as crude as detecting for loops which contain indexing?

mcabbott avatar May 14 '20 16:05 mcabbott

JAX's lax.scan is a generalization of mapfoldl that also has optional direct outputs. Here's the Python pseudocode version:

def scan(f, init, xs):
  carry = init
  ys = []
  for x in xs:
    carry, y = f(carry, x)
    ys.append(y)
  return carry, np.stack(ys)

shoyer avatar May 14 '20 16:05 shoyer

Yes, I hadn't tried but individual sparse vectors still seem like a lot of overhead, especially if (as is common) you are covering the whole array. Did you literally use spzeros or some special struct?

Just SparseVector (all benchmarks are running slower on my machine right now, but this shows the speed-up):

julia> @btime Zygote.gradient(evalpoly, x, p);
  1.235 s (490088 allocations: 1.51 GiB)

julia> using SparseArrays

julia> Zygote.∇getindex(x::AbstractVector, i::Tuple{T}) where {T<:Integer} = dy -> begin
         dx = SparseArrays.SparseVector{eltype(dy),T}(length(x), [i[1]], [dy])
         return (dx, nothing)
       end

julia> @btime Zygote.gradient(evalpoly, x, p);
  378.754 ms (510086 allocations: 120.61 MiB)

A truly one-hot would be nice, but after one accumulation has happened, a fully sparse vector is already needed I think. If no elements of p are set in the loop, and it's instead only read (which is all Zygote supports anyways for AbstractVector p), then perhaps there's a way to delay accumulation until the end. Then you would end up with something like Zygote.accum(grads::OneHot...), then we could allocate the vector and call setindex! internally and get a lot of efficiency.

It would be worth fixing mapfoldl regardless, but would be interested to hear if there's a more general solution. Could you do something as crude as detecting for loops which contain indexing?

Could we? I don't see how we could do that with a custom adjoint, but perhaps something in Zygote's internals could (I have no idea how Zygote's internals work)?

JAX's lax.scan is a generalization of mapfoldl that also has optional direct outputs. Here's the Python pseudocode version:

def scan(f, init, xs):
  carry = init
  ys = []
  for x in xs:
    carry, y = f(carry, x)
    ys.append(y)
  return carry, np.stack(ys)

Nice! This is very similar to an implementation I have been testing. Note that the function and the pullback use the same pattern as jax.scan, namely, accumulate and store some intermediate.

myfoldl(op, init, itr) = foldl(op, itr; init = init)

myfoldr(op, init, itr) = myfoldl((a, b) -> op(b, a), init, reverse(itr))

Zygote.@adjoint function myfoldl(op, init, itr)
    y, back = Zygote._pullback(op, init, itr[1])
    backs = similar(itr, typeof(back))
    backs[1] = back
    for i in eachindex(itr)[2:end]
        y, backs[i] = Zygote._pullback(op, y, itr[i])
    end
    function myfoldl_pullback(Δy)
        (∂op, ∂y, ∂itri) = backs[end](Δy)
        ∂itr = similar(itr, typeof(∂itri))
        ∂itr[end] = ∂itri
        for i in eachindex(backs)[(end - 1):-1:1]
            ∂opi, ∂y, ∂itr[i] = backs[i](∂y)
            ∂op = Zygote.accum(∂op, ∂opi)
        end
        return ∂op, ∂y, ∂itr
    end
    return y, myfoldl_pullback
end

_evalpoly(x, p) = myfoldr((pi, y) -> muladd(x, y, pi), zero(x), p)

julia> @btime Zygote.gradient(_evalpoly, x, p);
  963.264 μs (90029 allocations: 2.29 MiB)

Which gets us below a millisecond.

sethaxen avatar May 14 '20 16:05 sethaxen

I had a go... delaying when to make the sparse vector does help a little:

julia> Zygote.∇getindex(x::AbstractVector, i::Tuple{T}) where {T<:Integer} = dy -> (GotOne(x, i[1], dy), nothing)
julia> struct GotOne{T,I,V}
       arr::T
       ind::I
       val::V
       end
julia> Zygote.accum(dx::GotOne, dy::GotOne) = SparseVector(length(dx.arr), [dx.ind, dy.ind], [dx.val, dy.val])
julia> Base.:+(s::SparseVector, g::GotOne) = s[g.ind] += g.val
julia> Base.:+(n::Number, g::GotOne) = GotOne(g.arr, g.ind, g.val + n)

julia> @btime Zygote.gradient(_evalpoly, $x, $p);
  52.410 ms (357409 allocations: 15.52 MiB)

Compared to about 100ms for the sparse solution, on my computer, and 118.672 μs for the foldl one! I'm not precisely sure when this gets accumulated, and whether this could be delayed further?

but perhaps something in Zygote's internals could (I have no idea how Zygote's internals work)?

I'm not certain, but I think this IR tracing stuff can see all the code which is executed inside gradient(). Whether it is still in a form which has for loops at that point I don't actually know. But if it is, it seems conceivable that you could detect getindex, inside a loop, which varies its index, and give that whole block one _zero(p)?

It sounds like such analysis is not what Jax is doing, as the definition was changed to explicitly call lax.scan:

@jit
 def _polyval(p, x):
...
   y, _ = lax.scan(f, y, p_main.reshape(loop_steps, unrolled_steps))

mcabbott avatar May 14 '20 17:05 mcabbott

The original fixed version of JAX's polyval is probably most informative:

@jit
def polyval(p, x):
  shape = lax.broadcast_shapes(p.shape[1:], x.shape)
  dtype = result_type(p, x)
  y = lax.full_like(x, 0, shape=shape, dtype=dtype)
  y, _ = lax.scan(lambda y, p: (y * x + p, None), y, p)
  return y

This has good performance on CPUs.

(We had to add some manual loop unrolling to improve performance on GPUs)

shoyer avatar May 14 '20 18:05 shoyer

S4tf uses a sparse array for this.

Do you have a link to how they do it? Do they use sparse arrays for everything? Or, how do they signal conversion from a sparse array to a dense array?

@sethaxen see below. @oxinabox might also be interested

https://groups.google.com/a/tensorflow.org/d/msg/swift/az0-tV15_rA/kNBBDOOXAAAJ

https://docs.google.com/document/d/1epIhG-3znJZtBhF9nYpTfXXmumlqa7Nc-2SQ3sJPKgg/edit#

https://groups.google.com/a/tensorflow.org/forum/#!searchin/swift/subscript|sort:date/swift/GqoxzKajbYg/L5sBLznQBgAJ

I'm not exactly sure how it's implemented, but I haven't dug through all the available material.

AriMKatz avatar May 14 '20 18:05 AriMKatz

Compared to about 100ms for the sparse solution, on my computer, and 118.672 μs for the foldl one! I'm not precisely sure when this gets accumulated, and whether this could be delayed further?

Which one is giving you 118.672 μs?

sethaxen avatar May 15 '20 03:05 sethaxen

I'm curious why/if Buffer doesn't do the job here. It seems like it should solve the main problem – huge number of array allocations – quite neatly, even if there's still some dispatch overhead.

The other option is just to wait for our own XLA support, which we know will be able to handle this case. We could actually pull some neat tricks here, e.g. having the compiler decide whether a loop is worth unrolling or not. That's probably a way off, but even without that it would be possible to implement our own scan and customise how it gets differentiated and lowered to XLA.

MikeInnes avatar May 19 '20 14:05 MikeInnes

I'm curious why/if Buffer doesn't do the job here. It seems like it should solve the main problem – huge number of array allocations – quite neatly, even if there's still some dispatch overhead.

I benchmarked it. It's 1 order of magnitude faster, but it's still 3 orders of magnitude slower than what can be achieved by hardcoding the same derivatives.

julia> using Zygote, BenchmarkTools

julia> Base.lastindex(b::Zygote.Buffer) = lastindex(b.data) # this definition is missing

julia> x, p = rand(), randn(10000);

julia> b = Zygote.bufferfrom(p);

julia> @btime Zygote.gradient(Base.Math._evalpoly, x, p); # evalpoly requires a vector
  581.570 ms (490085 allocations: 1.51 GiB)
(2.481946129508864, [1.0, 0.8956556841020316, 0.8021991044642781, 0.7184941876949901, 0.6435234032032898, 0.57637539393171, 0.5162338977514837, 0.4623678248472633, 0.4141223704703439, 0.3709110550255708  …  2.0e-323, 2.0e-323, 2.0e-323, 2.0e-323, 2.0e-323, 2.0e-323, 2.0e-323, 2.0e-323, 2.0e-323, 2.0e-323])

julia> @btime Zygote.gradient(Base.Math._evalpoly, x, b)
  73.611 ms (479072 allocations: 16.98 MiB)
(2.481946129508864, nothing)

But as you an see, its adjoints don't return the pullback (instead accessing it using grad_mut). I could only get them out by creating my own Context:

julia> ctx = Zygote.Context()
Zygote.Context(nothing)

julia> Zygote._pullback(ctx, Base.Math._evalpoly, x, b)[2](1.0)
(nothing, 2.481946129508864, nothing)

julia> Zygote.grad_mut(ctx, b)
10000-element Array{Float64,1}:
 1.0
 0.8956556841020316
 0.8021991044642781
 ⋮
 2.0e-323
 2.0e-323

I think adding an adjoint for bufferfrom that returns the current adjoint on Buffer would fix this, then one could call Zygote.gradient((x, p) -> Base.Math._evalpoly(x, bufferfrom(p)), x, p)

sethaxen avatar May 19 '20 17:05 sethaxen

I'm curious why/if Buffer doesn't do the job here. It seems like it should solve the main problem – huge number of array allocations – quite neatly, even if there's still some dispatch overhead.

Maybe #652 is the problem. Because if there's a loop at all, even if it literally does nothing, performance is the same:
julia> function foo(x, p)
          N = length(p)
          ex = p[end]
          for i in N-1:-1:1
              ex = muladd(x, ex, one(x))
          end
          ex
       end

julia> @btime Zygote.gradient(foo, x, b);
  70.541 ms (410101 allocations: 15.77 MiB)

julia> function foo_nothing(x, p)
          N = length(p)
          ex = p[end]
          for i in N-1:-1:1
              nothing
          end
          ex
       end
foo (generic function with 1 method)

julia> @btime Zygote.gradient(foo_nothing, x, b);
  66.355 ms (320077 allocations: 13.13 MiB)

julia> function foo_noloop(x, p)
          N = length(p)
          ex = p[end]
          ex
       end
foo_noloop (generic function with 1 method)

julia> @btime Zygote.gradient(foo_noloop, x, b);
  10.875 μs (36 allocations: 79.23 KiB)

sethaxen avatar May 19 '20 20:05 sethaxen

Yeah, that's likely to be the issue if the work inside the body of the loop is small. Type inference in the presence of branches is effectively disabled pending compiler improvements, but one option might be to create something like scan just to get type stability.

Can you try the simplest possible example with XLA.jl and open an issue if/when it breaks? We might be able to get it up and running fairly quickly.

MikeInnes avatar May 21 '20 13:05 MikeInnes