Zygote.jl icon indicating copy to clipboard operation
Zygote.jl copied to clipboard

Zygote gradients different from ForwardDiff/ReverseDiff on Julia 1.10-rc2

Open SaremS opened this issue 1 year ago • 3 comments

Hi, hope this suffices as an MWE:

using Pkg
packages = [
    Pkg.PackageSpec(;name="ForwardDiff", version="0.10.36")
    Pkg.PackageSpec(;name="ReverseDiff", version="1.15.1"),
    Pkg.PackageSpec(;name="Zygote", version="0.6.67")
    Pkg.PackageSpec(;name="KernelFunctions", version="0.10.60")
    Pkg.PackageSpec(;name="Distributions", version="0.25.104")
]
Pkg.add(packages)


using ForwardDiff, ReverseDiff, Zygote, Distributions, KernelFunctions

#Define kernel function (periodic + white noise)
kernel(l,s) = with_lengthscale(SqExponentialKernel(), l^2) ∘ PeriodicTransform(1/365) + ScaledKernel(WhiteKernel(),s^2)

#Create data deterministically
m = collect(-0.9:0.1:1)

#Differentiate likelihood for data sample with respect to kernel hyperparameters
#ForwardDiff
println(ForwardDiff.gradient(x->logpdf(MvNormal(zeros(20),kernelmatrix(kernel(x[1],x[2]),collect(1:18:360))),m),[1.,0.1]))

#ReverseDiff
println(ReverseDiff.gradient(x->logpdf(MvNormal(zeros(20),kernelmatrix(kernel(x[1],x[2]),collect(1:18:360))),m),[1.,0.1]))

#Zygote
println(Zygote.gradient(x->logpdf(MvNormal(zeros(20),kernelmatrix(kernel(x[1],x[2]),collect(1:18:360))),m),[1.,0.1]))

Outputs are as follows on my machine, using Julia 1.10-rc2:

ForwardDiff: [-52.862449903127434, 403.6043237529404] ReverseDiff: [-52.86244990312515, 403.6043237529402]

Zygote: ([23.812852743170346, -114.3874772277085],)

Let me know if you need anything else.

SaremS avatar Dec 12 '23 08:12 SaremS

My understanding is that KernelFunctions defines their own ChainRules, so a lot of the code in question will be on that side. Have you raised this issue with them? They may be able to offer a more informed opinion on what's going on.

ToucheSir avatar Dec 15 '23 01:12 ToucheSir

Thank you, that would probably explain it. My understanding was that a ChainRule is applied equivalently in all autodiff packages, but that must have been wrong then.

Do you recommend closing this issue then, or shall I leave it open?

SaremS avatar Dec 15 '23 12:12 SaremS

Out of the 3 ADs you tested, it's likely only Zygote is using any ChainRules here. Feel free to leave this open and link to it when you're creating issues in other repos so we have a trail of breadcrumbs to follow.

ToucheSir avatar Dec 15 '23 14:12 ToucheSir