Zygote.jl
Zygote.jl copied to clipboard
Different RNG used during first gradient call on Julia 1.9
I've run into a very strange bug on Julia 1.9, where Zygote is giving a different result (for both primal and gradient) when taking the gradient of a random function from Distributions.jl
for the first time. (The same seed is used in all cases). This bug led to an error when I tried to run a package's test suite with 1.9.
Here is the MWE:
julia> using Zygote, Distributions, Random
julia> function f(x)
out = rand(Normal(x, x))
@show out
return out
end
f (generic function with 1 method)
julia> Random.seed!(123)
TaskLocalRNG()
julia> Zygote.gradient(f, .5) # Take the gradient for the first time
out = 0.7461228432625914
(1.4922456865251827,)
julia> Random.seed!(123)
TaskLocalRNG()
julia> Zygote.gradient(f, .5) # Take the gradient subsequent times
out = 0.17713466394801164
(0.3542693278960233,)
julia> Random.seed!(123)
TaskLocalRNG()
julia> Zygote.gradient(f, .5)
out = 0.17713466394801164
(0.3542693278960233,)
The version info where the above occured is:
julia> versioninfo()
Julia Version 1.9.0-beta2
Commit 7daffeecb8c (2022-12-29 07:45 UTC)
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 8 × 11th Gen Intel(R) Core(TM) i7-1165G7 @ 2.80GHz
WORD_SIZE: 64
LIBM: libopenlibm
LLVM: libLLVM-14.0.6 (ORCJIT, tigerlake)
Threads: 1 on 8 virtual cores
and my environment was
[31c24e10] Distributions v0.25.79
[e88e6eb3] Zygote v0.6.52
This does not occur on Julia 1.8.1 for me, with the same environment. I also haven't been able to reproduce it without Distributions.jl
, which muddles the pot further.
I assume the change to the primal does not occur outside of a gradient
call (i.e. just f(0.5)
)?
That's right. The first call to the primal, as well as subsequent calls, all agree with Zygote's result for Zygote's subsequent calls (i.e. out = 0.177...
)
But if I make five calls to the primal without reseting the seed, I get what Zygote's first call gives us!
julia> using Distributions, Random
julia> function f(x)
out = rand(Normal(x, x))
@show out
return out
end
f (generic function with 1 method)
julia> Random.seed!(123)
TaskLocalRNG()
julia> f(.5)
out = 0.17713466394801164
0.17713466394801164
julia> f(.5)
out = -0.23162568944446071
-0.23162568944446071
julia> f(.5)
out = -0.3118018727930403
-0.3118018727930403
julia> f(.5)
out = 0.3911674466082269
0.3911674466082269
julia> f(.5)
out = 0.7461228432625914
0.7461228432625914
I am able to reproduce similar behaviour without Distributions
:
julia> using Zygote, Random
julia> function f(x)
#=
The sqrt actually makes this not equivalent to the previous sampling procedure.
But with it, we can reproduce the same bug.
=#
out = sqrt(x) * randn() + x
@show out
return out
end
f (generic function with 1 method)
julia> Random.seed!(123)
TaskLocalRNG()
julia> Zygote.gradient(f, .5) # Take the gradient for the first time
out = 1.193657477360091
(1.693657477360091,)
julia> Random.seed!(123)
TaskLocalRNG()
julia> Zygote.gradient(f, .5) # Take the gradient subsequent times
out = 0.04339946293513103
(0.5433994629351311,)
julia> Random.seed!(123)
TaskLocalRNG()
julia> Zygote.gradient(f, .5)
out = 0.04339946293513103
(0.5433994629351311,)
In this case, the RNG state of Zygote when evaluating the gradient for the first time seems to be equivalent to what the RNG state on the sixth primal call would be without resetting seed.
Is the compiler possibly pulling values from the RNG on 1.9? I know there were some sorting algorithm changes made internally recently.
Good guess! I ran a bisect and the bug is caused by https://github.com/JuliaLang/julia/pull/45222.