jax
jax copied to clipboard
Explicit managing the output buffer of `jax.jit` function
I found that for the common training pattern in Jax:
new_state, other_output = jitted_train_step_fn(old_state, other_input)
Current XLA runtime may assign different backing device memory buffer for old_state and new_state.
This behavior is strange, as user may observe that their model's state is frequently changing the memory locations. It is also not perf friendly, as it will lead to command buffer update cost (because memory pointer has changed across command buffer launches), and some cache misses.
Sergei Lebedev from google also mentioned that Implicit output buffer allocation is also an issue for people using jax.pure_callback. There are a few bug reports where people wanted the callback to be zero copy.
I think the reason for this behavior is because jax.jit API only has buffer donation parameters, and not having input/output aliasing parameter.
The best is that user can specify that old_state and new_state is aliased through jax.jit API parameter, and XLA buffer allocation just assign the buffer of new_state to buffer of old_state, then it is more perf friendly and semantically natural.
You may be able to do what you want via the donate_argnums/donate_argnames parameter of jax.jit; see the jax.jit documentation for a description.
For example:
from functools import partial
import jax
import jax.numpy as jnp
state = (jnp.zeros(4), jnp.arange(4))
@partial(jax.jit, donate_argnums=0)
def f(state):
x, y = state
x += 1
y *= 2
return (x, y)
pointer0 = state[0].unsafe_buffer_pointer()
pointer1 = state[1].unsafe_buffer_pointer()
state = f(state)
assert state[0].unsafe_buffer_pointer() == pointer0 # same memory
assert state[1].unsafe_buffer_pointer() == pointer1 # same memory
I think the reason for this behavior is because jax.jit API only has buffer donation parameters, and not having input/output aliasing parameter.
In my experience (not tested on latest jax though) jit with donate_argnums does re-use the input state variable as output state if the arrays are large enough, and doesn't if they are small.
There is a new experimental ArrayRef object that may serve the purposes of the question here: https://docs.jax.dev/en/latest/array_refs.html
This should be part of the v0.7.1 release.
Closing becuase I don't think there's any other action to take here. Thanks for raising the issue!