Zygote.jl icon indicating copy to clipboard operation
Zygote.jl copied to clipboard

accum_sum for AbstractArray{<:AbstractArray{<:Number}}

Open grero opened this issue 1 year ago • 5 comments

The following line seems a bit puzzling:

https://github.com/FluxML/Zygote.jl/blob/e2d7839a476fc1cda407500d9a2ca125bd6e2aa5/src/lib/broadcast.jl#L48

using Zygote
x = [[0.1, 0.3, 0.4], [0.2,0.4]]
Zygote.accum_sum(x)
Zygote.accum_sum(x)
ERROR: DimensionMismatch: dimensions must match: a has dims (Base.OneTo(3),), b has dims (Base.OneTo(2),), mismatch at 1
Stacktrace:
  [1] promote_shape
    @ ./indices.jl:178 [inlined]
  [2] promote_shape
    @ ./indices.jl:169 [inlined]
  [3] +(A::Vector{Float64}, Bs::Vector{Float64})
    @ Base ./arraymath.jl:14
  [4] add_sum
    @ ./reduce.jl:24 [inlined]
  [5] _mapreduce(f::typeof(identity), op::typeof(Base.add_sum), #unused#::IndexLinear, A::Vector{Vector{Float64}})
    @ Base ./reduce.jl:435
  [6] _mapreduce_dim
    @ ./reducedim.jl:365 [inlined]
  [7] #mapreduce#765
    @ ./reducedim.jl:357 [inlined]
  [8] mapreduce
    @ ./reducedim.jl:357 [inlined]
  [9] #_sum#775
    @ ./reducedim.jl:999 [inlined]
 [10] _sum
    @ ./reducedim.jl:999 [inlined]
 [11] #_sum#774
    @ ./reducedim.jl:998 [inlined]
 [12] _sum
    @ ./reducedim.jl:998 [inlined]
 [13] #sum#772
    @ ./reducedim.jl:994 [inlined]
 [14] #accum_sum#1175
    @ ~/.julia/packages/Zygote/HTsWj/src/lib/broadcast.jl:48 [inlined]
 [15] accum_sum(xs::Vector{Vector{Float64}})
    @ Zygote ~/.julia/packages/Zygote/HTsWj/src/lib/broadcast.jl:48
 [16] top-level scope
    @ REPL[5]:1

Wouldn't this be a better implementation?

Zygote.accum_sum(xs::AbstractVector{<:AbstractArray{<:Number}};dims=:) = sum(sum.(xs,dims=dims))

I'm not sure how to interpret the dims argument when the outer container is AbstractArray, though.

grero avatar May 30 '23 13:05 grero

The purpose is to reduce an array of arrays down to a single array, not a scalar as your suggestion does. Given this is 100% internal only functionality, can you elaborate on what you're trying to do with it?

ToucheSir avatar May 30 '23 13:05 ToucheSir

I'm trying to find the gradient of a kernel based on KernelFunctions. The kernel works on pairs of vectors where the lengths can be different. When I try to take the gradient, I get an error that I trace to the line I quoted above.

This is a minimum working example of what I'm trying to achieve

using Zygote
func(x::Vector{Float64} ,y::Vector{Float64},a) = sum(broadcast(-, x, permutedims(y)).^a)
func(x::Vector{Vector{Float64}},y::Vector{Vector{Float64}},a) = sum(func.(x,permutedims(y),a))

x = [[0.1, 0.3, 0.4], [0.2,0.4]]
y = [[0.3, 0.5, 0.7], [0.5,0.7]]

func(x[1], y[1], 2.0) # works
func(x,y, 2.0) # works

grads = only(Zygote.gradient(a->func(x,y,a), [1.0]))
ERROR: DimensionMismatch: dimensions must match: a has dims (Base.OneTo(3),), b has dims (Base.OneTo(2),), mismatch at 1
Stacktrace:
  [1] promote_shape
    @ ./indices.jl:178 [inlined]
  [2] promote_shape
    @ ./indices.jl:169 [inlined]
  [3] +(A::Vector{Float64}, Bs::Vector{Float64})
    @ Base ./arraymath.jl:14
  [4] add_sum
    @ ./reduce.jl:24 [inlined]
  [5] _mapreduce(f::typeof(identity), op::typeof(Base.add_sum), #unused#::IndexLinear, A::Matrix{Vector{Float64}})
    @ Base ./reduce.jl:435
  [6] _mapreduce_dim
    @ ./reducedim.jl:365 [inlined]
  [7] #mapreduce#765
    @ ./reducedim.jl:357 [inlined]
  [8] mapreduce
    @ ./reducedim.jl:357 [inlined]
  [9] #_sum#775
    @ ./reducedim.jl:999 [inlined]
 [10] _sum
    @ ./reducedim.jl:999 [inlined]
 [11] #sum#773
    @ ./reducedim.jl:995 [inlined]
 [12] sum
    @ ./reducedim.jl:995 [inlined]
 [13] _reducedim_init(f::typeof(identity), op::typeof(Base.add_sum), fv::typeof(zero), fop::typeof(sum), A::Matrix{Vector{Float64}}, region::Tuple{Int64, Int64})
    @ Base ./reducedim.jl:120
 [14] reducedim_init
    @ ./reducedim.jl:108 [inlined]
 [15] _mapreduce_dim
    @ ./reducedim.jl:371 [inlined]
 [16] #mapreduce#765
    @ ./reducedim.jl:357 [inlined]
 [17] #_sum#799
    @ ./reducedim.jl:1023 [inlined]
 [18] _sum
    @ ./reducedim.jl:1023 [inlined]
 [19] #_sum#798
    @ ./reducedim.jl:1022 [inlined]
 [20] _sum
    @ ./reducedim.jl:1022 [inlined]
 [21] #sum#772
    @ ./reducedim.jl:994 [inlined]
 [22] #accum_sum#1175
    @ ~/.julia/packages/Zygote/HTsWj/src/lib/broadcast.jl:48 [inlined]
 [23] unbroadcast(x::Vector{Vector{Float64}}, x̄::Matrix{Vector{Float64}})
    @ Zygote ~/.julia/packages/Zygote/HTsWj/src/lib/broadcast.jl:62
 [24] map(f::typeof(Zygote.unbroadcast), t::Tuple{Vector{Vector{Float64}}, Matrix{Vector{Float64}}, Vector{Float64}}, s::Tuple{Matrix{Vector{Float64}}, Matrix{Vector{Float64}}, Matrix{Float64}})

... longer stack trace.

grero avatar May 30 '23 13:05 grero

Some investigation... some chance this is a bug in Base?

julia> Zygote.accum_sum(xs::AbstractArray{<:AbstractArray{<:Number}}; dims = :) = sum(@show xs; dims = @show dims)

julia> Zygote.unbroadcast([1:2, 3:4], reshape([1:2, 3:4, 5:6, 7:8],2,2))  # similar signature to error case
xs = UnitRange{Int64}[1:2 5:6; 3:4 7:8]
dims = (3, 2)
2-element Vector{Vector{Float64}}:
 [6.0, 8.0]
 [10.0, 12.0]

julia> sum(reshape([1:2, 3:4, 5:6, 7:8],2,2); dims=2)  # dims=2 is what matters, 3 is a type-stability hack
2×1 Matrix{Vector{Int64}}:
 [6, 8]
 [10, 12]

julia> x = [[0.1, 0.3, 0.4, 0.4], [0.2,0.4,0.3]];  # changed size so that inner & outer sizes never agree

julia> y = [[0.3, 0.5, 0.7, 0.7], [0.5,0.7,0.6]];

julia> func(x,y,1.0) isa Real  # still runs with this size

julia> Zygote.gradient(a->func(x,y,a), 1.0)  # note that I leave a just a scalar
xs = FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}[Fill(4.0, 4) Fill(3.0, 4); Fill(4.0, 3) Fill(3.0, 3)]
dims = (3, 2)
ERROR: DimensionMismatch: dimensions must match: a has dims (Base.OneTo(4),), b has dims (Base.OneTo(3),), mismatch at 1
Stacktrace:
...
  [3] +(a::FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}, b::FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}})
    @ FillArrays ~/.julia/packages/FillArrays/KpKaA/src/fillalgebra.jl:255 [inlined]
...
 [24] accum_sum(xs::AbstractArray{<:AbstractArray{<:Number}}; dims::Any)
    @ Main ./REPL[48]:1 [inlined]
...
 [26] unbroadcast(x::Vector{Vector{Float64}}, x̄::Matrix{FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}})
    @ Zygote ~/.julia/packages/Zygote/HTsWj/src/lib/broadcast.jl:62
...

julia> xfill = reshape([Fill(4.0, 4), Fill(4.0, 3), Fill(3.0, 4), Fill(3.0, 3)], 2,2)  # reconstruct argument
2×2 Matrix{Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}}:
 Fill(4.0, 4)  Fill(3.0, 4)
 Fill(4.0, 3)  Fill(3.0, 3)

julia> @show xfill;  # printing of this is bad!
xfill = Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}[Fill(4.0, 4) Fill(3.0, 4); Fill(4.0, 3) Fill(3.0, 3)]

julia> Zygote.unbroadcast([1:4, 1:4], xfill)
xs = Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}[Fill(4.0, 4) Fill(3.0, 3); Fill(4.0, 4) Fill(3.0, 3)]
dims = (3, 2)
ERROR: DimensionMismatch: dimensions must match: a has dims (Base.OneTo(4),), b has dims (Base.OneTo(3),), mismatch at 1

julia> sum(xfill; dims=2)  # also fails, no Zygote
ERROR: DimensionMismatch: dimensions must match: a has dims (Base.OneTo(4),), b has dims (Base.OneTo(3),), mismatch at 1

julia> sum(collect.(xfill); dims=2)  # also fails, just Base
ERROR: DimensionMismatch: dimensions must match: a has dims (Base.OneTo(4),), b has dims (Base.OneTo(3),), mismatch at 1

julia> collect.(xfill)
2×2 Matrix{Vector{Float64}}:
 [4.0, 4.0, 4.0, 4.0]  [3.0, 3.0, 3.0, 3.0]
 [4.0, 4.0, 4.0]       [3.0, 3.0, 3.0]

julia> sum(length.(xfill); dims=2)
2×1 Matrix{Int64}:
 8
 6

Without Zygote at all:

julia> mat = reshape(vcat([1:5 for _ in 1:6]), 3, 2)
3×2 Matrix{UnitRange{Int64}}:
 1:5  1:5
 1:5  1:5
 1:5  1:5

julia> sum(mat; dims=1)
1×2 Matrix{Vector{Int64}}:
 [3, 6, 9, 12, 15]  [3, 6, 9, 12, 15]

julia> mat2 = reshape(vcat([1:5 for _ in 1:3], [1:7 for _ in 1:3]), 3, 2)
3×2 Matrix{UnitRange{Int64}}:
 1:5  1:7
 1:5  1:7
 1:5  1:7

julia> sum(mat2; dims=1)
ERROR: DimensionMismatch: argument dimensions must match: length of r1 is 5, length of r2 is 7
Stacktrace:
  [1] +(r1::StepRangeLen{Int64, Int64, Int64, Int64}, r2::StepRangeLen{Int64, Int64, Int64, Int64})
    @ Base ./range.jl:1459
  [2] +(r1::StepRangeLen{Int64, Int64, Int64, Int64}, r2::UnitRange{Int64})
    @ Base ./range.jl:1450
...

mcabbott avatar May 30 '23 22:05 mcabbott

This is not at all important, but I'm curious; what is the intended result of running sum(x;dims=1) when x is a Vector{Vector{Float64}} where the inner vectors are of different length?

grero avatar Jun 01 '23 04:06 grero

That one cannot work, you are asking to add things of different length.

But what AD is doing, and what my mat2 example does, should only add things of the same length.

julia> sum.(eachslice(mat; dims=2, drop=false))
1×2 Matrix{StepRangeLen{Int64, Int64, Int64, Int64}}:
 3:3:15  3:3:15

julia> collect.(ans)  # same as above
1×2 Matrix{Vector{Int64}}:
 [3, 6, 9, 12, 15]  [3, 6, 9, 12, 15]

julia> sum.(eachslice(mat2; dims=2, drop=false))
1×2 Matrix{StepRangeLen{Int64, Int64, Int64, Int64}}:
 3:3:15  3:3:21

julia> collect.(ans)  # error above
1×2 Matrix{Vector{Int64}}:
 [3, 6, 9, 12, 15]  [3, 6, 9, 12, 15, 18, 21]

mcabbott avatar Jun 13 '23 01:06 mcabbott