`rrules` do not support chunked mode
This is a general issue, but for a specific incarnation, https://github.com/JuliaDiff/ChainRules.jl/blob/8073c7c4638bdd46f4e822d2ab72423c051c5e4b/src/rulesets/Base/array.jl#L40
function rrule(::typeof(Base.vect), X::Vararg{T, N}) where {T, N}
vect_pullback(ȳ) = (NoTangent(), NTuple{N}(ȳ)...)
return Base.vect(X...), vect_pullback
end
This rule implicitly assumes that ȳ is a Vector, but if you are taking a jacobian, it will be a Matrix in which case, it should be
function rrule(::typeof(Base.vect), X::Vararg{T, N}) where {T, N}
vect_pullback(ȳ) = (NoTangent(), ȳ...)
return Base.vect(X...), vect_pullback
end
Similar problems also exist for the getindex rrules, and I'm sure there are a bunch of other similar cases.
Is there a good general solution to this?
I think you're asking whether there's a scheme for chunked reverse mode. There is not: at present (co)tangents match the size of the primal. https://github.com/JuliaDiff/ChainRulesCore.jl/issues/92 has some discussion, see also https://github.com/JuliaDiff/Diffractor.jl/pull/54.
Edit: most rules will enforce this via projection:
julia> x = [1,2,3]; # vector primal
julia> ProjectTo(x)([4;5;6;;]) # allows 1-column matrix, converts to vector
3-element Vector{Float64}:
4.0
5.0
6.0
julia> ProjectTo(x)([4 5 6]) # does not allow worse shapes
ERROR: DimensionMismatch: variable with size(x) == (3,) cannot have a gradient with size(dx) == (1, 3)
For now the current status should be clearly documented, perhaps at these pages:
https://juliadiff.org/ChainRulesCore.jl/dev/rule_author/tangents.html
https://juliadiff.org/ChainRulesCore.jl/dev/maths/propagators.html