ForwardDiffChainRules.jl
ForwardDiffChainRules.jl copied to clipboard
`frule_via_ad` should accept several arguments
According to the official API specification of ChainRulesCore.jl, frule_via_ad should accept all the arguments of the function as a destructured tuple: https://juliadiff.org/ChainRulesCore.jl/stable/api.html#ChainRulesCore.frule_via_ad
However, it seems that ForwardDiffChainRules.jl only accepts one argument: https://github.com/ThummeTo/ForwardDiffChainRules.jl/blob/609201f3ebaaaffd341037132aaff8ab744f92f1/src/ForwardDiffChainRules.jl#L46-L52 I think this is the reason for a bug in my code. Do you think it is fixable?
@mohamed82008 any idea?
Not a fan of supporting multiple arguments with ForwardDiff. ForwardDiff's API does not support that.
hmm on second thought, I changed my mind
Given the metaprogramming-heavy nature of this package, I'm not sure how to make a PR implementing this
NonconvexUtils already supports this https://github.com/JuliaNonconvex/NonconvexUtils.jl/blob/main/src/forwarddiff_frule.jl#L1. You can just assume the inputs are all real/arrays and not recursive containers and simplify the implementation. I use flatten/unflatten but that's because I wanted to be too generic. No need to be that generic.