Zygote.jl
Zygote.jl copied to clipboard
[hv]cat not compatible with non-number types
Hey, consider the following example
struct Point
x
y
end
function tst(p)
pp = [p[1] p[2]; p[3] p[4]]
return pp
end
pointar = [Point(1,2), Point(3,4), Point(5,6), Point(7,8)]
back_x = x -> Zygote.pullback(tst, pointar)[2](x)[1]
display(back_x(pointar))
which throws an error stating that Mutating arrays is not supported -- called setindex!(::Matrix{point}, _...). The problem can be pulled back to the call of hvcat to build the righthand site in the tst function. After some research I found out that this is due to the definition of the adjoint of the hvcat function in array.jl Line number 98 - 100:
@adjoint function hvcat(rows::Tuple{Vararg{Int}}, xs::Number...)
hvcat(rows, xs...), ȳ -> (nothing, permutedims(ȳ)...)
end
It states that the xs can only be of type number. However, this seems to be restrictive, and comparing to the defnition of the function in Base/abstractarray.jl:
hvcat(rows::Tuple{Vararg{Int}}, xs::AbstractVecOrMat...) = ...
which seems to except abstract Vectors or Matrices the solution could be simply allowing for the same types.
There are slightly more general rules for these in CR, but still restricted to numbers or arrays: https://github.com/JuliaDiff/ChainRules.jl/blob/main/src/rulesets/Base/array.jl#L387
They could be widened to accept anything. Maybe you'd just have to replace size (which isn't sure to be defined) with some size_if_array which gives () not just on numbers, but on unknown types.
And then Zygote's rules should be deleted -- #1217.
I've folded https://github.com/FluxML/Zygote.jl/issues/1223 into this issue and updated the title accordingly. There's no reason for us not to track all these *cat methods in one place.