AbstractGPs.jl
AbstractGPs.jl copied to clipboard
rrule for MeanFunction types
Without rrule
AD will not pass for MeanFunction
types. #16 resolves #14 for ZeroMean
but it should be done for all types.
Any chance of writing a macro for all mean functions, or defining an abstract mean function so that an abstract rrule
can be defined?
This wouldn't be viable as a general solution without calling back into an AD, which isn't something that is currently supported by ChainRules
(you still have to use an AD-specific solution for that, although it's something that we're working on).
The right solution here is for us to fix https://github.com/FluxML/Zygote.jl/issues/646 so that this doesn't happen in the first place.
@willtebbutt I think that the main issue still comes from using map
everywhere (as it is the case for KernelFunctions
). Should we consider doing the same thing as in KernelFunctions and temporarily replace it by the ersatz _map
until the Zygote
issue you mentionned is solved?
Eurgh, yeah, you're probably right. This is very sad, but probably necessary for now.
I'll make a PR
This was resolved a while ago