Zygote.jl icon indicating copy to clipboard operation
Zygote.jl copied to clipboard

[hv]cat not compatible with non-number types

Open ntausend opened this issue 3 years ago • 2 comments

Hey, consider the following example

struct Point
    x
    y
end

function tst(p)
   pp = [p[1] p[2]; p[3] p[4]]
   return pp
end

pointar = [Point(1,2), Point(3,4), Point(5,6), Point(7,8)]

back_x = x -> Zygote.pullback(tst, pointar)[2](x)[1] 

display(back_x(pointar))

which throws an error stating that Mutating arrays is not supported -- called setindex!(::Matrix{point}, _...). The problem can be pulled back to the call of hvcat to build the righthand site in the tst function. After some research I found out that this is due to the definition of the adjoint of the hvcat function in array.jl Line number 98 - 100:

@adjoint function hvcat(rows::Tuple{Vararg{Int}}, xs::Number...)
  hvcat(rows, xs...), ȳ -> (nothing, permutedims(ȳ)...)
end

It states that the xs can only be of type number. However, this seems to be restrictive, and comparing to the defnition of the function in Base/abstractarray.jl:

hvcat(rows::Tuple{Vararg{Int}}, xs::AbstractVecOrMat...) = ...

which seems to except abstract Vectors or Matrices the solution could be simply allowing for the same types.

ntausend avatar Apr 11 '22 15:04 ntausend

There are slightly more general rules for these in CR, but still restricted to numbers or arrays: https://github.com/JuliaDiff/ChainRules.jl/blob/main/src/rulesets/Base/array.jl#L387

They could be widened to accept anything. Maybe you'd just have to replace size (which isn't sure to be defined) with some size_if_array which gives () not just on numbers, but on unknown types.

And then Zygote's rules should be deleted -- #1217.

mcabbott avatar Apr 12 '22 03:04 mcabbott

I've folded https://github.com/FluxML/Zygote.jl/issues/1223 into this issue and updated the title accordingly. There's no reason for us not to track all these *cat methods in one place.

ToucheSir avatar May 10 '22 21:05 ToucheSir