axlearn icon indicating copy to clipboard operation
axlearn copied to clipboard

[JAX API] Updating `TransferToMemoryKind` and `jax.experimental.pallas.triton`

Open Steboss opened this issue 4 months ago • 0 comments

@matthew-e-hopkins Hey people, this is a huge update, to allow us to use JAX > 0.5.3 (we're currently testing AXLearn with JAX 0.7.2). I've implemented the following changes:

  • I've created a back compatibility option _JAX_MEMORY_SPACE_SUPPORT so that all these changes can work with different versions of JAX
  • In utils.py JAX' from jax._src.sharding_impls import TransferToMemoryKind has been substituted with its correspondent version for JAX 0.7 (jax.memory.Space.*). I am preserving the previous option by checking the jax version:
if _JAX_MEMORY_SPACE_SUPPORT:
    MemoryKind = [jax.memory.Space.Device, jax.memory.Space.Host]
    DEVICE_MEMORY = jax.memory.Space.Device
    HOST_MEMORY = jax.memory.Space.Host

    def transfer_to_memory_kind(tensor: Tensor, memory_kind: MemoryKind) -> Tensor:
        return jax.device_put(tensor, memory_kind)

else:
    from jax._src.sharding_impls import TransferToMemoryKind  # pylint: disable=ungrouped-imports

    MemoryKind = Literal["device", "pinned_host"]
    DEVICE_MEMORY = "device"
    HOST_MEMORY = "pinned_host" 
  • These changes have been propagated to optimizers_test.py and optimizers.py
  • jax.experimental.pallas.triton.TritonCompilerParams has now changed in .CompilerParams, so gpu_attention.py, gpu_decoding.py, gpu_paged_attention.py and paged_kv_cache_gpu_kernel.py have been changed accordingly. Again, as before, I'm importing _JAX_MEMORY_SPACE_SUPPORT to check the JAX version and preserving the previous code.

I've tested the changes with Fuji models, it would be great to find an optimal solution for this, as we'd like to support AXLearn in JAX-Toolbox again newer JAX versions. Please, let me know if you want some changes. Thank you

Steboss avatar Sep 10 '25 14:09 Steboss