Zygote.jl
Zygote.jl copied to clipboard
Zygote gradients different from ForwardDiff/ReverseDiff on Julia 1.10-rc2
Hi, hope this suffices as an MWE:
using Pkg
packages = [
Pkg.PackageSpec(;name="ForwardDiff", version="0.10.36")
Pkg.PackageSpec(;name="ReverseDiff", version="1.15.1"),
Pkg.PackageSpec(;name="Zygote", version="0.6.67")
Pkg.PackageSpec(;name="KernelFunctions", version="0.10.60")
Pkg.PackageSpec(;name="Distributions", version="0.25.104")
]
Pkg.add(packages)
using ForwardDiff, ReverseDiff, Zygote, Distributions, KernelFunctions
#Define kernel function (periodic + white noise)
kernel(l,s) = with_lengthscale(SqExponentialKernel(), l^2) ∘ PeriodicTransform(1/365) + ScaledKernel(WhiteKernel(),s^2)
#Create data deterministically
m = collect(-0.9:0.1:1)
#Differentiate likelihood for data sample with respect to kernel hyperparameters
#ForwardDiff
println(ForwardDiff.gradient(x->logpdf(MvNormal(zeros(20),kernelmatrix(kernel(x[1],x[2]),collect(1:18:360))),m),[1.,0.1]))
#ReverseDiff
println(ReverseDiff.gradient(x->logpdf(MvNormal(zeros(20),kernelmatrix(kernel(x[1],x[2]),collect(1:18:360))),m),[1.,0.1]))
#Zygote
println(Zygote.gradient(x->logpdf(MvNormal(zeros(20),kernelmatrix(kernel(x[1],x[2]),collect(1:18:360))),m),[1.,0.1]))
Outputs are as follows on my machine, using Julia 1.10-rc2
:
ForwardDiff: [-52.862449903127434, 403.6043237529404]
ReverseDiff: [-52.86244990312515, 403.6043237529402]
Zygote: ([23.812852743170346, -114.3874772277085],)
Let me know if you need anything else.
My understanding is that KernelFunctions defines their own ChainRules, so a lot of the code in question will be on that side. Have you raised this issue with them? They may be able to offer a more informed opinion on what's going on.
Thank you, that would probably explain it. My understanding was that a ChainRule is applied equivalently in all autodiff packages, but that must have been wrong then.
Do you recommend closing this issue then, or shall I leave it open?
Out of the 3 ADs you tested, it's likely only Zygote is using any ChainRules here. Feel free to leave this open and link to it when you're creating issues in other repos so we have a trail of breadcrumbs to follow.