jax icon indicating copy to clipboard operation
jax copied to clipboard

dynamic config scope under `jit` doesn't change partitionable threefry behavior

Open froystig opened this issue 9 months ago • 0 comments

Description

import jax

def f(x):
  return x + jax.random.randint(jax.random.key(72), (), 0, 10)

def g(x):
  with jax.threefry_partitionable(True):  # False by default
    return x + jax.random.randint(jax.random.key(72), (), 0, 10)

h = jax.jit(g)
print('f', f(1))
print('g', g(1))
print('h', h(1))

prints:

f 4
g 8
h 4

System info (python version, jaxlib version, accelerator, etc.)

0.4.27.dev with 0.4.26 jaxlib.

froystig avatar May 03 '24 19:05 froystig