ChainRulesCore.jl icon indicating copy to clipboard operation
ChainRulesCore.jl copied to clipboard

Add 3rd argument to InplaceableThunk

Open mcabbott opened this issue 4 years ago • 4 comments
trafficstars

An approach to https://github.com/JuliaDiff/ChainRulesCore.jl/issues/411

This lets you specify when it will be safe to apply the in-place rule. is_inplaceable_destination is not sufficient, because it must accept Vector{Float64}, but I can't write complex numbers into that.

The safer rule would then look like this. If A is reals & B is complex, and dA is real, then this ought to reject mul!:

function rrule(
    ::typeof(*),
    A::AbstractVecOrMat{<:CommutativeMulNumber},
    B::AbstractVecOrMat{<:CommutativeMulNumber},
)
    Y = A * B
    TY = eltype(Y)
    function times_pullback(ȳ)
        Ȳ = unthunk(ȳ)
        return (
            NoTangent(),
            InplaceableThunk(
                @thunk(Ȳ * B'),
                X̄ -> mul!(X̄, Ȳ, B', true, true),
                StridedArray{TY}
            ),
            InplaceableThunk(
                @thunk(A' * Ȳ),
                X̄ -> mul!(X̄, A', Ȳ, true, true),
                StridedArray{TY}
            )
        )
    end
    return Y, times_pullback
end

mcabbott avatar Jul 09 '21 20:07 mcabbott

I think we can take some time to experiment with this and try different things. I don't think we want to rush for this.

ChainRulesCore 1.0 will have the current InplacableThunk (with the arguments flipped). With it's conditions that you better be sure it isn't going to error if you use it.


This so feels like a problem for multiple dispatch. Though I suspect to make it nice we would need a macro. That something like the fallback to out of place as a extra method.

idk I want to try and bunch of things and see how the all look

oxinabox avatar Jul 10 '21 15:07 oxinabox

Yes there may be other solutions. This one is non-breaking now, and could equally well be an optional 2nd argument if you flip the order.

It's more or less dispatch, except writing x isa S && seemed slightly less code than explicitly adding dispatch to add!!. This seems a less mysterious level to dispatch at, than to duplicate this functionality inside a generated version of the add! argument.

What might be good is to see if we can invent examples for which this does not easily capture what you want. Maybe mul! is misleadingly simple.

mcabbott avatar Jul 10 '21 16:07 mcabbott

Nice thing about this is one can pass in a union of types.

A possible future, i can imagine with a macro is something like

@inplacable_thunk((Ȳ * B')) begin
    (X̄ :: StridedArray{TY})-> mul!(X̄, Ȳ, B', true, true),
    (X̄ :: Diagonal{TY})-> # Specialized method that only computes the diagonal of Ȳ * B'
end

whoch would then scoop up Union{StridedArray{TY}, Diagonal{TY}} into the type argument?

Related to that example: something this does't handle still is how to integrate with projectors. I wonder if we can use the types? But even without worrying about projecting, providing the chance to easily write methods that specialize on us only needing to compute some of the output. (the thing that ProjectTo tends to waste, computing excess things that we then throw away)

In general I want to spend some time thinking about this holistically, along side all our other thoughts about InplaceableThunks.

oxinabox avatar Jul 14 '21 19:07 oxinabox

Might be worth writing out the whole rrule for this 2-option case you envisage. I think the not-inplace thunk may also want to look pretty different, rather than rely on projection afterwards, to be efficient. At which point you may have two whole rrules.

For many rules which end up broadcasting dx -> dx .+= ..., when projection is only about eltype it would be nice to insert it. I guess you can just write dx -> dx .+= projector.element.(...) by hand.

mcabbott avatar Jul 15 '21 20:07 mcabbott