Add --megascale_abort_on_hangs flag for multi-slice TPU jobs
- Introduce flag to terminate jobs on MegaScale Runtime Errors
- Enable auto-restart of jax process when errors occur
- Prevent silent hangs in multi-slice TPU configurations
- Reduce time to recovery for failed jobs
- ref: https://github.com/apple/axlearn/pull/716
- co-authored by Nick Stogner [email protected]
- Kyle [email protected]
@Ethanlm @markblee PTAL
Please don't merge yet. Kyle is helping us testing this.
Please don't merge yet. Kyle is helping us testing this.
Tested in internal environment by scheduling a multi slice v5p job in the internal environment test area. Job was able to make progress and the flag was set for the job using Isaack's branch for axlearn.
Also do we know if there is there a list of libtpu-only (non-xla) flags, maybe with some brief description about what they do?
BTW, thanks a lot for working on this! Getting the hanging situation improved is super valuable.
Based on recent discussion, https://chat.google.com/room/AAAAE7IGW88/3qZf4tP48RU/m5MskM43z4o?cls=10 https://chat.google.com/room/AAAAE7IGW88/3qZf4tP48RU/AXtt8F5CztM?cls=10 looks like we should not enable this flag.
With jax 0.4.33 we build some error aggregation in the coordinator to help on identify bad TPU host for example. but it only works with megascale_abort_on_hangs=false because we need the all workers to report the error to the coordinators
Replaced by https://github.com/apple/axlearn/pull/1010 according to latest recommendations
I closed this PR since https://github.com/apple/axlearn/pull/1010 has been merged