jax
jax copied to clipboard
dynamic config scope under `jit` doesn't change partitionable threefry behavior
Description
import jax
def f(x):
return x + jax.random.randint(jax.random.key(72), (), 0, 10)
def g(x):
with jax.threefry_partitionable(True): # False by default
return x + jax.random.randint(jax.random.key(72), (), 0, 10)
h = jax.jit(g)
print('f', f(1))
print('g', g(1))
print('h', h(1))
prints:
f 4
g 8
h 4
System info (python version, jaxlib version, accelerator, etc.)
0.4.27.dev
with 0.4.26
jaxlib.