algorithmic-efficiency icon indicating copy to clipboard operation
algorithmic-efficiency copied to clipboard

[WIP] Migrate JAX workloads from pmap to jit

Open priyakasimbeg opened this issue 8 months ago • 1 comments

Purpose

The goal of this PR is to allow model parameter and optimizer state sharding, and also to migrate the JAX code from using jax.pmap to using jax.jit.

TODOs:

  • [ ] Migrate reference optimizers to use jax.jit
    • [X] Nesterov
    • [X] AdamW
    • [ ] Others
  • [ ] Migrate workloads to use jax.jit
    • [X] (Test workload) MNIST
    • [X] (Test workload) CIFAR
    • [x] WMT
    • [x] Criteo1TB
    • [x] FastMRI
    • [ ] Librispeech
    • [ ] OGBG
    • [x] ImageNet

Changelog

  • Added some sharding utilities to handle data distributed
  • Replaced pmap code for CIFAR/MNIST with jit
  • Modified AdamW and Nesterov accordingly
  • Updated checkpoint and data_utils to support the new approach (mostly removing explicit jax_utils.replicate calls).

Issues

  • Prefetching functionality in CIFAR is temporarily disabled (marked with FIXME), not sure how to best support it here.
  • I haven't edited any of the PyTorch code, we will need to make sure they still do comparably..

priyakasimbeg avatar Mar 06 '25 21:03 priyakasimbeg

MLCommons CLA bot All contributors have signed the MLCommons CLA ✍️ ✅

github-actions[bot] avatar Mar 06 '25 21:03 github-actions[bot]

We didn't migrate all the optimizers. Should we look into migrating the following?

  • adafactor
  • lamb
  • sam
  • shampoo

rka97 avatar Aug 20 '25 05:08 rka97

We didn't migrate all the optimizers. Should we look into migrating the following?

  • adafactor
  • lamb
  • sam
  • shampoo

The above algorithms are buggy and we have no plans to fix them, so instead of maintaining them we should delete them

priyakasimbeg avatar Aug 21 '25 00:08 priyakasimbeg