xla
xla copied to clipboard
jaxlib.xla_extension.XlaRuntimeError: UNIMPLEMENTED: Support for annotation groups with gaps doesn't exist yet
RIght now, users can't reuse a jitted method that includes scheduling ids more than once in a JAX program.
Here is a very stripped down code JAX example that showcases the issue.
import jax
import jax.numpy as jnp
from jax._src.xla_metadata import set_xla_metadata
from functools import partial
from jax.sharding import PartitionSpec as P
mesh = jax.make_mesh((4,), ("i",))
@jax.jit
def f(x):
c = jnp.zeros_like(x)
for i in range(4):
with set_xla_metadata(_scheduling_group_id=i):
c += x
x = jax.lax.pshuffle(x, "i", [3, 0, 1, 2])
return c
@jax.jit
@partial(jax.shard_map, mesh=mesh, in_specs=P("i"), out_specs=P("i"))
def main(x):
y = f(x)
y = jnp.sin(y)
return f(y)
print(main(jnp.ones((4, 4))))
Fails with this error
Traceback (most recent call last):
File "/tmp/scheduling_group_gap.py", line 31, in <module>
print(main(jnp.ones((4, 4))))
^^^^^^^^^^^^^^^^^^^^^^
jaxlib._jax.XlaRuntimeError: UNIMPLEMENTED: Support for annotation groups with gaps doesn't exist yet, annotation: 1, instr: collective-permute-start.4 has the same annotation in its operand tree but has gaps on the way from that operand to itself.
--------------------
The issue is that the inner f function is jitted, and so the scheduling groups are the same every time the function is called. When this gets lowered in XLA, we run into this assertion error.
One possible workaround for this to to remove the jit annotation on method and include some kind of global offset that can be added every time it is called.
offset = 0
# No jit!
def f(x):
global offset
c = jnp.zeros_like(x)
for i in range(4):
with set_xla_metadata(_scheduling_group_id=i+offset): # Add global offset here.
c += x
x = jax.lax.pshuffle(x, "i", [3, 0, 1, 2])
offset += 4
return c
But this work around is clunky given the global, and would require retracing what should be jitted logic several times. It also is an issue if you have several different methods that include scheduling annotations, as they would all need to have unique ids.