ChainRulesCore.jl
ChainRulesCore.jl copied to clipboard
Add 3rd argument to InplaceableThunk
An approach to https://github.com/JuliaDiff/ChainRulesCore.jl/issues/411
This lets you specify when it will be safe to apply the in-place rule. is_inplaceable_destination is not sufficient, because it must accept Vector{Float64}, but I can't write complex numbers into that.
The safer rule would then look like this. If A is reals & B is complex, and dA is real, then this ought to reject mul!:
function rrule(
::typeof(*),
A::AbstractVecOrMat{<:CommutativeMulNumber},
B::AbstractVecOrMat{<:CommutativeMulNumber},
)
Y = A * B
TY = eltype(Y)
function times_pullback(ȳ)
Ȳ = unthunk(ȳ)
return (
NoTangent(),
InplaceableThunk(
@thunk(Ȳ * B'),
X̄ -> mul!(X̄, Ȳ, B', true, true),
StridedArray{TY}
),
InplaceableThunk(
@thunk(A' * Ȳ),
X̄ -> mul!(X̄, A', Ȳ, true, true),
StridedArray{TY}
)
)
end
return Y, times_pullback
end
I think we can take some time to experiment with this and try different things. I don't think we want to rush for this.
ChainRulesCore 1.0 will have the current InplacableThunk (with the arguments flipped). With it's conditions that you better be sure it isn't going to error if you use it.
This so feels like a problem for multiple dispatch. Though I suspect to make it nice we would need a macro. That something like the fallback to out of place as a extra method.
idk I want to try and bunch of things and see how the all look
Yes there may be other solutions. This one is non-breaking now, and could equally well be an optional 2nd argument if you flip the order.
It's more or less dispatch, except writing x isa S && seemed slightly less code than explicitly adding dispatch to add!!. This seems a less mysterious level to dispatch at, than to duplicate this functionality inside a generated version of the add! argument.
What might be good is to see if we can invent examples for which this does not easily capture what you want. Maybe mul! is misleadingly simple.
Nice thing about this is one can pass in a union of types.
A possible future, i can imagine with a macro is something like
@inplacable_thunk((Ȳ * B')) begin
(X̄ :: StridedArray{TY})-> mul!(X̄, Ȳ, B', true, true),
(X̄ :: Diagonal{TY})-> # Specialized method that only computes the diagonal of Ȳ * B'
end
whoch would then scoop up Union{StridedArray{TY}, Diagonal{TY}} into the type argument?
Related to that example: something this does't handle still is how to integrate with projectors. I wonder if we can use the types? But even without worrying about projecting, providing the chance to easily write methods that specialize on us only needing to compute some of the output. (the thing that ProjectTo tends to waste, computing excess things that we then throw away)
In general I want to spend some time thinking about this holistically, along side all our other thoughts about InplaceableThunks.
Might be worth writing out the whole rrule for this 2-option case you envisage. I think the not-inplace thunk may also want to look pretty different, rather than rely on projection afterwards, to be efficient. At which point you may have two whole rrules.
For many rules which end up broadcasting dx -> dx .+= ..., when projection is only about eltype it would be nice to insert it. I guess you can just write dx -> dx .+= projector.element.(...) by hand.