ChainRules.jl icon indicating copy to clipboard operation
ChainRules.jl copied to clipboard

generic frule for constructors?

Open oxinabox opened this issue 5 years ago • 1 comments

Like for https://github.com/JuliaDiff/ChainRules.jl/issues/153 there is a generic frule for default constructors. It relies on the arguments being the same as the fields so would have to detect that.

Its basically:

function frule(::Type{P}, args..., _, dargs...) where P
     y  = P(args...)
     dy = Composite{P}(; zip(fieldnames(P, dargs)...)
end

the more serious implementation might be:

function frule(::Type{P}, all_args...) where P
     nargs = fieldcount(P)
     length(all_args) == 2nargs + 1 || return nothing
     args = @inbounds all_args[1:nargs]
     all(typeof.(args) .== fieldtypes(P)) || return nothing
     dargs = @inbounds all_args[end-nargs : end]  # skip dself as constructors never functors
     y  = P(args...)
     dy = Composite{P}(; zip(fieldnames(P, dargs)...)
end

But this may well need to be written as a generated function if it doesn't constant fold good.

oxinabox avatar Jan 18 '20 12:01 oxinabox

Related:

  1. We need to have mutable composite types. https://github.com/JuliaDiff/ChainRulesCore.jl/issues/105
  2. We need a way to work out the differential type for a given primal type, so we can declare the differential arrays etc (this can default for Any) https://github.com/JuliaDiff/ChainRulesCore.jl/issues/106

oxinabox avatar Jan 19 '20 12:01 oxinabox