algorithmic-efficiency
algorithmic-efficiency copied to clipboard
[WIP] Migrate JAX workloads from pmap to jit
Purpose
The goal of this PR is to allow model parameter and optimizer state sharding, and also to migrate the JAX code from using jax.pmap to using jax.jit.
TODOs:
- [ ] Migrate reference optimizers to use jax.jit
- [X] Nesterov
- [X] AdamW
- [ ] Others
- [ ] Migrate workloads to use jax.jit
- [X] (Test workload) MNIST
- [X] (Test workload) CIFAR
- [x] WMT
- [x] Criteo1TB
- [x] FastMRI
- [ ] Librispeech
- [ ] OGBG
- [x] ImageNet
Changelog
- Added some sharding utilities to handle data distributed
- Replaced pmap code for CIFAR/MNIST with jit
- Modified AdamW and Nesterov accordingly
- Updated checkpoint and data_utils to support the new approach (mostly removing explicit jax_utils.replicate calls).
Issues
- Prefetching functionality in CIFAR is temporarily disabled (marked with FIXME), not sure how to best support it here.
- I haven't edited any of the PyTorch code, we will need to make sure they still do comparably..
MLCommons CLA bot All contributors have signed the MLCommons CLA ✍️ ✅
We didn't migrate all the optimizers. Should we look into migrating the following?
- adafactor
- lamb
- sam
- shampoo
We didn't migrate all the optimizers. Should we look into migrating the following?
- adafactor
- lamb
- sam
- shampoo
The above algorithms are buggy and we have no plans to fix them, so instead of maintaining them we should delete them