ForwardDiffChainRules.jl icon indicating copy to clipboard operation
ForwardDiffChainRules.jl copied to clipboard

`frule_via_ad` should accept several arguments

Open gdalle opened this issue 2 years ago • 5 comments

According to the official API specification of ChainRulesCore.jl, frule_via_ad should accept all the arguments of the function as a destructured tuple: https://juliadiff.org/ChainRulesCore.jl/stable/api.html#ChainRulesCore.frule_via_ad

However, it seems that ForwardDiffChainRules.jl only accepts one argument: https://github.com/ThummeTo/ForwardDiffChainRules.jl/blob/609201f3ebaaaffd341037132aaff8ab744f92f1/src/ForwardDiffChainRules.jl#L46-L52 I think this is the reason for a bug in my code. Do you think it is fixable?

gdalle avatar Mar 13 '23 09:03 gdalle

@mohamed82008 any idea?

gdalle avatar Mar 14 '23 07:03 gdalle

Not a fan of supporting multiple arguments with ForwardDiff. ForwardDiff's API does not support that.

mohdibntarek avatar Mar 15 '23 18:03 mohdibntarek

hmm on second thought, I changed my mind

mohdibntarek avatar Mar 15 '23 18:03 mohdibntarek

Given the metaprogramming-heavy nature of this package, I'm not sure how to make a PR implementing this

gdalle avatar Mar 16 '23 08:03 gdalle

NonconvexUtils already supports this https://github.com/JuliaNonconvex/NonconvexUtils.jl/blob/main/src/forwarddiff_frule.jl#L1. You can just assume the inputs are all real/arrays and not recursive containers and simplify the implementation. I use flatten/unflatten but that's because I wanted to be too generic. No need to be that generic.

mohdibntarek avatar Mar 16 '23 09:03 mohdibntarek