ChainRules.jl
ChainRules.jl copied to clipboard
generic frule for constructors?
Like for https://github.com/JuliaDiff/ChainRules.jl/issues/153 there is a generic frule for default constructors. It relies on the arguments being the same as the fields so would have to detect that.
Its basically:
function frule(::Type{P}, args..., _, dargs...) where P
y = P(args...)
dy = Composite{P}(; zip(fieldnames(P, dargs)...)
end
the more serious implementation might be:
function frule(::Type{P}, all_args...) where P
nargs = fieldcount(P)
length(all_args) == 2nargs + 1 || return nothing
args = @inbounds all_args[1:nargs]
all(typeof.(args) .== fieldtypes(P)) || return nothing
dargs = @inbounds all_args[end-nargs : end] # skip dself as constructors never functors
y = P(args...)
dy = Composite{P}(; zip(fieldnames(P, dargs)...)
end
But this may well need to be written as a generated function if it doesn't constant fold good.
Related:
- We need to have mutable composite types. https://github.com/JuliaDiff/ChainRulesCore.jl/issues/105
- We need a way to work out the differential type for a given primal type, so we can declare the differential arrays etc (this can default for Any) https://github.com/JuliaDiff/ChainRulesCore.jl/issues/106