Zygote.jl
Zygote.jl copied to clipboard
Changing Distances adjoints to ChainRules syntax
Using @adjoint disable the overloading of rrule see : https://github.com/JuliaDiff/ChainRulesCore.jl/issues/239#issuecomment-713560336
It makes it impossible to define new rules for different metrics.
I am aware that the long-term plan is to move out the rules to the respective packages but in the mean-time it would be nice for it to be solved (it blocks our development in https://github.com/JuliaGaussianProcesses/KernelFunctions.jl/pull/208).
Lgtm for the most part.
@adjoint does take precedence over rrule, so you could also define it via the macro to change it later. Although I would ask what the use case for that would be?
Sorry, I made a mistake: trying JuliaGaussianProcesses/KernelFunctions.jl#208 with this branch does not seem to solve the problem.
Also precompiling this patch returns a surprising warning:
┌ Warning: Error requiring `Distances` from `Zygote`
│ exception =
│ LoadError: error in method definition: function ChainRulesCore.rrule must be explicitly imported to be extended
...
Yeah, so we usually move the adjoints over, but since ChainRules is a dep, you could add the import
@theogf I know you've resolved this on your end now, but would you be up for pushing this PR through anyway? Seems like something that we should be doing throughout Zygote when the opportunity arises.
It seems the adjoint definitions have to be updated to the ChainRules-syntax? In particular, I guess nothing should be replaced with the ChainRules-equivalents.
~~I officially don't know how to import rrule correctly via Requires...~~
I am not sure the failing checks are related to this PR...
Doesn't look related to me. @theogf I reckon this is basically good to go after a patch bump. I'll bors it after that unless @DhairyaLGandhi has any objections?
There was a merge conflict, otherwise lgtm
needs a rebase, otherwise seems a sensible change
I wonder if possibly maintainers of Distances who didn't want to add a dependency on ChainRulesCore might be fine with adding it as a weak dependency?
CI is complaining about a \Delta not being defined, but it seems like it is?