axlearn
axlearn copied to clipboard
[JAX API] Updating `TransferToMemoryKind` and `jax.experimental.pallas.triton`
@matthew-e-hopkins Hey people, this is a huge update, to allow us to use JAX > 0.5.3 (we're currently testing AXLearn with JAX 0.7.2). I've implemented the following changes:
- I've created a back compatibility option
_JAX_MEMORY_SPACE_SUPPORTso that all these changes can work with different versions of JAX - In
utils.pyJAX'from jax._src.sharding_impls import TransferToMemoryKindhas been substituted with its correspondent version for JAX 0.7 (jax.memory.Space.*). I am preserving the previous option by checking the jax version:
if _JAX_MEMORY_SPACE_SUPPORT:
MemoryKind = [jax.memory.Space.Device, jax.memory.Space.Host]
DEVICE_MEMORY = jax.memory.Space.Device
HOST_MEMORY = jax.memory.Space.Host
def transfer_to_memory_kind(tensor: Tensor, memory_kind: MemoryKind) -> Tensor:
return jax.device_put(tensor, memory_kind)
else:
from jax._src.sharding_impls import TransferToMemoryKind # pylint: disable=ungrouped-imports
MemoryKind = Literal["device", "pinned_host"]
DEVICE_MEMORY = "device"
HOST_MEMORY = "pinned_host"
- These changes have been propagated to
optimizers_test.pyandoptimizers.py -
jax.experimental.pallas.triton.TritonCompilerParamshas now changed in.CompilerParams, sogpu_attention.py,gpu_decoding.py,gpu_paged_attention.pyandpaged_kv_cache_gpu_kernel.pyhave been changed accordingly. Again, as before, I'm importing_JAX_MEMORY_SPACE_SUPPORTto check the JAX version and preserving the previous code.
I've tested the changes with Fuji models, it would be great to find an optimal solution for this, as we'd like to support AXLearn in JAX-Toolbox again newer JAX versions. Please, let me know if you want some changes. Thank you