ChainRulesCore.jl icon indicating copy to clipboard operation
ChainRulesCore.jl copied to clipboard

`rrules` do not support chunked mode

Open oscardssmith opened this issue 3 years ago • 2 comments

This is a general issue, but for a specific incarnation, https://github.com/JuliaDiff/ChainRules.jl/blob/8073c7c4638bdd46f4e822d2ab72423c051c5e4b/src/rulesets/Base/array.jl#L40

function rrule(::typeof(Base.vect), X::Vararg{T, N}) where {T, N}
    vect_pullback(ȳ) = (NoTangent(), NTuple{N}(ȳ)...)
    return Base.vect(X...), vect_pullback
end

This rule implicitly assumes that is a Vector, but if you are taking a jacobian, it will be a Matrix in which case, it should be

function rrule(::typeof(Base.vect), X::Vararg{T, N}) where {T, N}
    vect_pullback(ȳ) = (NoTangent(), ȳ...)
    return Base.vect(X...), vect_pullback
end

Similar problems also exist for the getindex rrules, and I'm sure there are a bunch of other similar cases. Is there a good general solution to this?

oscardssmith avatar Jul 18 '22 20:07 oscardssmith

I think you're asking whether there's a scheme for chunked reverse mode. There is not: at present (co)tangents match the size of the primal. https://github.com/JuliaDiff/ChainRulesCore.jl/issues/92 has some discussion, see also https://github.com/JuliaDiff/Diffractor.jl/pull/54.

Edit: most rules will enforce this via projection:

julia> x = [1,2,3];  # vector primal

julia> ProjectTo(x)([4;5;6;;])  # allows 1-column matrix, converts to vector
3-element Vector{Float64}:
 4.0
 5.0
 6.0

julia> ProjectTo(x)([4 5 6])  # does not allow worse shapes
ERROR: DimensionMismatch: variable with size(x) == (3,) cannot have a gradient with size(dx) == (1, 3)

mcabbott avatar Jul 19 '22 13:07 mcabbott

For now the current status should be clearly documented, perhaps at these pages:

https://juliadiff.org/ChainRulesCore.jl/dev/rule_author/tangents.html

https://juliadiff.org/ChainRulesCore.jl/dev/maths/propagators.html

mcabbott avatar Jul 19 '22 17:07 mcabbott