jax
jax copied to clipboard
Offloading in grad(scan(remat(fn, policy=offload))) results in XlaRuntimeError
Description
The following code
import jax
import jax.ad_checkpoint
from jax import numpy as jnp
@jax.jit
def apply(params, x):
def step(y, i):
y = jnp.sin(y)
y = jax.ad_checkpoint.checkpoint_name(y, 'save_remat')
y = jnp.sin(y)
return y, ()
step = jax.remat(
step,
policy=jax.checkpoint_policies.save_and_offload_only_these_names(
names_which_can_be_saved=(),
names_which_can_be_offloaded=('save_remat',),
offload_src='device',
offload_dst='pinned_host',
),
)
y = jax.grad(lambda p: jax.lax.scan(step, p, x)[0].sum())(params)
return y
params = jnp.ones([3])
x = jnp.ones([2, 3])
apply(params, x)
reproduces the error
XlaRuntimeError: UNIMPLEMENTED: Performing sub-chunk copy is not supported in async dynamic slice yet.
Error encountered while compiling %dynamic-slice-start = ((f32[2,3]{1,0:T(2,128)S(5)}, s32[]{:T(256)}, s32[]{:T(256)}), f32[1,3]{1,0:T(2,128)}, u32[]{:S(2)}, s32[]) dynamic-slice-start(f32[2,3]{1,0:T(2,128)S(5)} %get-tuple-element.237, s32[]{:T(256)} %select.6, s32[]{:T(256)} %constant.4..sunk.1), dynamic_slice_sizes={1,3}, metadata={op_name="jit(apply)/jit(main)/transpose(jvp(while))/body/dynamic_slice" source_file="<ipython-input-1-a56f72d33e52>" source_line=27}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"0","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}]},"used_scoped_memory_configs":[]}.
Error encountered while compiling %while.7 = (s32[]{:T(256)}, f32[3]{0:T(256)}, f32[2,3]{1,0:T(2,128)S(5)}, f32[2,3]{1,0:T(2,128)}, s32[]{:T(256)}, /*index=5*/s32[]{:T(256)}, s32[]{:T(256)}) while((s32[]{:T(256)}, f32[3]{0:T(256)}, f32[2,3]{1,0:T(2,128)S(5)}, f32[2,3]{1,0:T(2,128)}, s32[]{:T(256)}, /*index=5*/s32[]{:T(256)}, s32[]{:T(256)}) %tuple.50), condition=%wide.wide.wide.wide.region_3.91.clone.clone.clone, body=%wide.wide.wide.wide.region_2.67.clone.clone.clone.
System info (python version, jaxlib version, accelerator, etc.)
jax: 0.4.34
jaxlib: 0.4.34
numpy: 2.0.1
python: 3.11.8 (stable, redacted, redacted) [Clang google3-trunk (94c024adedcb53059c29d7c2d62982053b60e86a)]
jax.devices (8 total, 8 local): [TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0) TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1) ... TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0) TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]
process_count: 1
platform: uname_result(system='Linux', ..., release='5.10.0-smp-1104.53.0.0', version='#1 [v5.10.0-1104.53.0.0] SMP @1727505643', machine='x86_64')