Zygote.jl icon indicating copy to clipboard operation
Zygote.jl copied to clipboard

Different RNG used during first gradient call on Julia 1.9

Open gaurav-arya opened this issue 2 years ago • 6 comments

I've run into a very strange bug on Julia 1.9, where Zygote is giving a different result (for both primal and gradient) when taking the gradient of a random function from Distributions.jl for the first time. (The same seed is used in all cases). This bug led to an error when I tried to run a package's test suite with 1.9.

Here is the MWE:

julia> using Zygote, Distributions, Random

julia> function f(x) 
           out = rand(Normal(x, x))
           @show out
           return out
       end
f (generic function with 1 method)

julia> Random.seed!(123)
TaskLocalRNG()

julia> Zygote.gradient(f, .5) # Take the gradient for the first time
out = 0.7461228432625914
(1.4922456865251827,)

julia> Random.seed!(123)
TaskLocalRNG()

julia> Zygote.gradient(f, .5) # Take the gradient subsequent times
out = 0.17713466394801164
(0.3542693278960233,)

julia> Random.seed!(123)
TaskLocalRNG()

julia> Zygote.gradient(f, .5)
out = 0.17713466394801164
(0.3542693278960233,)

The version info where the above occured is:

julia> versioninfo()
Julia Version 1.9.0-beta2
Commit 7daffeecb8c (2022-12-29 07:45 UTC)
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 8 × 11th Gen Intel(R) Core(TM) i7-1165G7 @ 2.80GHz
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-14.0.6 (ORCJIT, tigerlake)
  Threads: 1 on 8 virtual cores

and my environment was

  [31c24e10] Distributions v0.25.79
  [e88e6eb3] Zygote v0.6.52

This does not occur on Julia 1.8.1 for me, with the same environment. I also haven't been able to reproduce it without Distributions.jl, which muddles the pot further.

gaurav-arya avatar Jan 10 '23 11:01 gaurav-arya

I assume the change to the primal does not occur outside of a gradient call (i.e. just f(0.5))?

darsnack avatar Jan 10 '23 12:01 darsnack

That's right. The first call to the primal, as well as subsequent calls, all agree with Zygote's result for Zygote's subsequent calls (i.e. out = 0.177...)

gaurav-arya avatar Jan 10 '23 12:01 gaurav-arya

But if I make five calls to the primal without reseting the seed, I get what Zygote's first call gives us!

julia> using Distributions, Random

julia> function f(x) 
           out = rand(Normal(x, x))
           @show out
           return out
       end
f (generic function with 1 method)

julia> Random.seed!(123)
TaskLocalRNG()

julia> f(.5)
out = 0.17713466394801164
0.17713466394801164

julia> f(.5)
out = -0.23162568944446071
-0.23162568944446071

julia> f(.5)
out = -0.3118018727930403
-0.3118018727930403

julia> f(.5)
out = 0.3911674466082269
0.3911674466082269

julia> f(.5)
out = 0.7461228432625914
0.7461228432625914

gaurav-arya avatar Jan 10 '23 12:01 gaurav-arya

I am able to reproduce similar behaviour without Distributions:

julia> using Zygote, Random

julia> function f(x) 
           #=
           The sqrt actually makes this not equivalent to the previous sampling procedure. 
           But with it, we can reproduce the same bug.
           =#
           out = sqrt(x) * randn() + x
           @show out
           return out
       end
f (generic function with 1 method)

julia> Random.seed!(123)
TaskLocalRNG()

julia> Zygote.gradient(f, .5) # Take the gradient for the first time
out = 1.193657477360091
(1.693657477360091,)

julia> Random.seed!(123)
TaskLocalRNG()

julia> Zygote.gradient(f, .5) # Take the gradient subsequent times
out = 0.04339946293513103
(0.5433994629351311,)

julia> Random.seed!(123)
TaskLocalRNG()

julia> Zygote.gradient(f, .5)
out = 0.04339946293513103
(0.5433994629351311,)

In this case, the RNG state of Zygote when evaluating the gradient for the first time seems to be equivalent to what the RNG state on the sixth primal call would be without resetting seed.

gaurav-arya avatar Jan 10 '23 13:01 gaurav-arya

Is the compiler possibly pulling values from the RNG on 1.9? I know there were some sorting algorithm changes made internally recently.

ToucheSir avatar Jan 11 '23 03:01 ToucheSir

Good guess! I ran a bisect and the bug is caused by https://github.com/JuliaLang/julia/pull/45222.

gaurav-arya avatar Jan 11 '23 10:01 gaurav-arya