jax icon indicating copy to clipboard operation
jax copied to clipboard

Offloading in grad(scan(remat(fn, policy=offload))) results in XlaRuntimeError

Open hbq1 opened this issue 4 months ago • 2 comments

Description

The following code

import jax
import jax.ad_checkpoint
from jax import numpy as jnp


@jax.jit
def apply(params, x):

  def step(y, i):
    y = jnp.sin(y)

    y = jax.ad_checkpoint.checkpoint_name(y, 'save_remat')

    y = jnp.sin(y)
    return y, ()

  step = jax.remat(
      step,
      policy=jax.checkpoint_policies.save_and_offload_only_these_names(
          names_which_can_be_saved=(),
          names_which_can_be_offloaded=('save_remat',),
          offload_src='device',
          offload_dst='pinned_host',
      ),
  )

  y = jax.grad(lambda p: jax.lax.scan(step, p, x)[0].sum())(params)
  return y


params = jnp.ones([3])
x = jnp.ones([2, 3])

apply(params, x)

reproduces the error

XlaRuntimeError: UNIMPLEMENTED: Performing sub-chunk copy is not supported in async dynamic slice yet.
	

Error encountered while compiling %dynamic-slice-start = ((f32[2,3]{1,0:T(2,128)S(5)}, s32[]{:T(256)}, s32[]{:T(256)}), f32[1,3]{1,0:T(2,128)}, u32[]{:S(2)}, s32[]) dynamic-slice-start(f32[2,3]{1,0:T(2,128)S(5)} %get-tuple-element.237, s32[]{:T(256)} %select.6, s32[]{:T(256)} %constant.4..sunk.1), dynamic_slice_sizes={1,3}, metadata={op_name="jit(apply)/jit(main)/transpose(jvp(while))/body/dynamic_slice" source_file="<ipython-input-1-a56f72d33e52>" source_line=27}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"0","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}]},"used_scoped_memory_configs":[]}.
	

Error encountered while compiling %while.7 = (s32[]{:T(256)}, f32[3]{0:T(256)}, f32[2,3]{1,0:T(2,128)S(5)}, f32[2,3]{1,0:T(2,128)}, s32[]{:T(256)}, /*index=5*/s32[]{:T(256)}, s32[]{:T(256)}) while((s32[]{:T(256)}, f32[3]{0:T(256)}, f32[2,3]{1,0:T(2,128)S(5)}, f32[2,3]{1,0:T(2,128)}, s32[]{:T(256)}, /*index=5*/s32[]{:T(256)}, s32[]{:T(256)}) %tuple.50), condition=%wide.wide.wide.wide.region_3.91.clone.clone.clone, body=%wide.wide.wide.wide.region_2.67.clone.clone.clone.

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.34
jaxlib: 0.4.34
numpy:  2.0.1
python: 3.11.8 (stable, redacted, redacted) [Clang google3-trunk (94c024adedcb53059c29d7c2d62982053b60e86a)]
jax.devices (8 total, 8 local): [TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0) TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1) ... TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0) TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]
process_count: 1
platform: uname_result(system='Linux', ..., release='5.10.0-smp-1104.53.0.0', version='#1 [v5.10.0-1104.53.0.0] SMP @1727505643', machine='x86_64')

hbq1 avatar Oct 04 '24 10:10 hbq1