Fast-LLM
Fast-LLM copied to clipboard
Elastic training
🎯 Goal (What & Why)
Enable elastic training in Fast-LLM:
Allow long-running training jobs to dynamically adjust to changing cluster resources (i.e., number of nodes) without manual reconfiguration or restart.
Why:
Clusters are often underutilized, and current training jobs assume static resource allocations. With elastic training, we could run background, preemptible jobs (e.g., ongoing pretraining) that scale up or down automatically depending on node availability. This supports:
- Better resource utilization during idle periods.
- Transparent fallback workloads during low cluster usage.
- A “cloud-native” development model with always-on, flexible training.
🚀 Execution Plan (Updated)
Elastic training is hard. It implies rethinking how model state, batch scheduling, and parallelization behave in the presence of shifting resources. There are multiple axes of complexity to address.
Step 1: What is the smallest working version?
The minimal viable product (MVP) for elastic training in Fast-LLM is a restart-based elasticity system, built around these principles:
- Training restarts automatically when cluster membership changes (node lost or added).
- Preemption handling:
- Catch
SIGTERMsignals sent by the batch system (K8s, SLURM, NGC, etc.). - Save a checkpoint within the ~30 second timeout window before
SIGKILL.- If checkpointing cannot complete synchronously, implement asynchronous checkpointing (e.g., spawn a background thread to flush shards) to maximize survival chances.
- Cleanly exit to trigger
torchrunelastic re-spawn (--nnodes=min:max).
- Catch
- On re-spawn:
- Load the latest checkpoint.
- Inspect the new world size.
- Rebuild the distributed setup (FSDP, TP, PP as needed).
- Adapt the effective batch size:
- Prefer keeping the global batch size roughly constant by adjusting gradient accumulation steps.
- Tolerate minor deviations within a configurable tolerance band (default: ±10%).
- Ensure memory constraints are respected:
- Adjust micro-batch size and ZeRO stage if necessary to stay within per-GPU memory budget.
- Prefer lower ZeRO stages to avoid overhead.
This setup is simple, compatible with existing batch systems, and robust against dynamic resource changes. It does not attempt live reconfiguration during training, accepts some downtime between preemption and full restart, and there's a possible slight training efficiency loss due to adapting the global batch size to the changed world size imperfectly.
Step 2: What additional optimizations are possible (but optional)?
(1) Batch size stabilization and scaling mitigations
Even after adjusting gradient accumulation steps, small deviations in global batch size (±10%) are inevitable. To absorb these variations without disrupting training dynamics:
- Stay within the critical batch size (CBS): LLM training follows scaling laws where efficiency is preserved up to a data-dependent critical batch size. Staying below this threshold should ensure that convergence is unaffected by batch size changes.
- Apply learning rate scaling: If the batch size increases significantly but remains under CBS, linearly or square-root scaling the learning rate helps maintain stable update magnitudes. For small deviations (±10%), learning rate changes may not be needed. This needs to be confirmed, though.
- Use adaptive optimizers: Optimizers like Adam naturally handle moderate batch size shifts by adapting step sizes internally.
- Smooth transitions: When batch size changes sharply, gradually ramp learning rates or reset momentum buffers to prevent instability.
- Extend training if needed: If a batch size increase pushes beyond CBS, convergence can still be maintained by training longer (and feeding in proportionally more data).
These mitigations are easy to implement and help preserve convergence behaviour across node changes. They should be treated as the next logical step after the MVP.
(2) Adaptive parallelization strategies
Currently, the MVP assumes a fixed parallel strategy (e.g., FSDP with a chosen ZeRO stage).
Longer term, it would be beneficial to select the parallelization strategy dynamically and holistically at restart based on available resources to optimize throughput.
- Small world sizes: prefer pure data parallelism or FSDP with ZeRO-1 to minimize sharding overhead.
- Medium world sizes: use FSDP with ZeRO-2 or ZeRO-2 to balance memory and communication cost.
- Large world sizes: combine TP, PP, etc. with full FSDP to maintain throughput and avoid stragglers.
This would require decision logic to pick the best strategy with maximum throughput at load time.
Adaptive strategy selection improves training efficiency, especially under wide resource variability, but also introduces complexity.
(3) Live reconfiguration
Instead of restarting, the system would detect node loss or addition and reconfigure process groups and model state in place:
- Tear down and rebuild distributed groups.
- Repartition model shards and optimizer states.
- Synchronize all ranks to a consistent view of world size changes.
This approach would minimize downtime but is likely harder to implement. Partial failures during reconfiguration (e.g., node drops mid-transition) are difficult to recover from cleanly.
Live reconfiguration should only be pursued after restart-based elasticity is fully stable.
(4) Streaming DiLoCo (Distributed Local Copy Training)
In this model, nodes train independently on different data shards, periodically averaging weights. There is no need to synchronize node membership changes mid-run.
While highly resilient to node churn, Streaming DiLoCo has major drawbacks:
- Reduced statistical efficiency due to delayed weight updates.
- Risk of model divergence if synchronization intervals are too long.
This method is mainly attractive for training on highly volatile, cloud preemptible resources. It is not recommended for stable, managed clusters unless absolutely necessary.
📌 Acceptance Criteria (Must-Haves for Completion)
- MVP supports restart-based elasticity:
- Detects new node count, resumes from checkpoint, reconfigures distributed config and batch settings.
- Dynamic world size -> batch size adaptation is implemented and validated.
- A worked-out plan exists for live reconfiguration (even if not in MVP).
- The tradeoffs of the strategy are documented.
- Tests confirm model convergence under elastic training.
- Training metrics include world size, batch size, and node churn for auditability.
- Documentation added to Fast-LLM user guide and tutorials.
🛠️ Project Management
- [x] Assign the project to the Fast-LLM project.
- [ ] Set the
Estimatefield (in days) in the GitHub project. - [x] Use the
Sizefield to categorize the PR size (Large). - [x] Assign an owner when opening the issue.
Scaling the batch size efficiently if a big challenge because it's difficult to get consistent results. We'd have to scale much more than the learning rate, and opinions vary on how to scale things. We're better off keeping the batch size constant. A reasonable MVP would vary only the number of gradient accumulation steps, tolerating small differences (say 10%) to account for lack of exact divisibility. We can get away with just this if there is already 2-3 grad accumulation steps in the max cluster size.
Another thing to consider is the frequency of reconfiguration vs downtime and lost progress. Restarting takes time, especially with distributed reconfigurations, and a naive implementation would just end up wasting most of the compute. At the very least we want to:
- Auto-save on shutdown (preemption) so we don't lose progress.
- Be careful when scaling up to ensure we aren't constantly scaling up and down. This could take the form of guarantees on the preemption service (ex. min runtime), cooldown before scaling up, etc.
(1) Live Reconfiguration:
This might not be that hard, but might not save that much time because we still need to re-create process groups and reorganize the training state, and added nodes will need to start from scratch either way.
(2) Streaming DiLoCo:
These kinds of things may be the most flexible come with a huge efficiency cost and are only useful for scaling up to lots of remote servers. I'd stay clear unless we have a really good reason for it.
Appreciate you raising these concerns clearly, @jlamypoirier! The complexities and tradeoffs you're highlighting are worth keeping in mind as we implement. That said, I want to clarify our mindset here:
We're not looking for reasons not to do elastic training. Instead, we're looking for the simplest, most straightforward path to getting something useful up and running, even if initially limited, inefficient, or imperfect.
The MVP as scoped (restart-based elasticity, adaptive gradient accumulation with tolerance for minor variance, documented tradeoffs, etc.) is deliberately modest but already highly valuable. We fully accept upfront that it won't perfectly solve every scenario and might incur efficiency costs initially. That's entirely fine and expected. Right now, the priority is forward momentum and a solid foundation we can incrementally refine.
So yes, let's keep these risks documented and visible, but let's also clearly commit to making progress here. We won't let the perfect be the enemy of shipping a useful capability. Thanks.
Please re-read my comment. I understand the need perfectly well, and agree with it. But the MVP as proposed has little (if any) chance of working at all, especially on our own infrastructure.
I carefully re-read your comment. You are suggesting that instead of tolerating global batch size changes, we should aim to keep it constant by varying gradient accumulation steps, staying within a defined tolerance band. That's a fair point. In fact, it was already listed as option (3) in the original ticket.
To validate how viable this is, I ran a concrete simulation for a realistic scenario, Apriel-5B:
Model : Apriel-5B
Nodes : 60 (480 GPUs total)
Micro-batch per GPU : 6
Gradient accumulation : 1
Global batch size : 2,880
Memory constraint : 120 GiB (H200 with headroom)
Tolerance band : ±10 % (i.e., 2,592 - 3,168)
I simulated dynamic reconfiguration across 1–100 nodes (8–800 GPUs), trying to maintain global batch size inside the tolerance band by adjusting gradient accumulation steps, the micro-batch size, and (if needed) the ZeRO stage, without exceeding the memory constraint.
The reconfiguration strategy works cleanly across almost all tested node counts. Scaling down to very few GPUs is feasible. Scaling up to 100 nodes also mostly works; some minor exceptions at specific points (80 and 100 nodes). ZeRO-1 suffices almost everywhere without needing to escalate to higher stages. In short, adaptive reconfiguration with ±10% tolerance is viable in practice for Apriel-5B.
I've attached the full simulation output and the code used (at the bottom).
Based on this, I am proposing the following MVP for restart-based elasticity:
-
Restart-based elasticity: On restart, detect new world size and reconfigure gradient accumulation steps, micro-batch size, and ZeRO stage to keep global batch size within a tolerance band and the memory footprint within the hardware limit.
-
Tolerance: Allow deviations within a configurable tolerance band (e.g., ±10%) for the global batch size.
-
Auto-save on shutdown (preemption): Training must save latest state cleanly when a node changes are detected.
The main goal is to move forward pragmatically: shipping restart-based elastic training that works across a wide range of practical scenarios. We should try to simulate more scenarios based on the attached code, but I'm now feeling confident that this is a good MVP.
Simulation output:
| Nodes | GPUs | ZeRO | Micro-batch | Grad. accum. | Global batch | Δ global batch (%) | Memory (total) GiB | Memory (params/grad/opt) GiB | Memory (activations) GiB |
|---|---|---|---|---|---|---|---|---|---|
| 1 | 8 | ZeRO-1 | 6 | 60 | 2880 | +0.0 | 115.5 | 31.5 | 84.0 |
| 2 | 16 | ZeRO-1 | 6 | 30 | 2880 | +0.0 | 113.3 | 29.3 | 84.0 |
| 3 | 24 | ZeRO-1 | 6 | 20 | 2880 | +0.0 | 112.5 | 28.5 | 84.0 |
| 4 | 32 | ZeRO-1 | 6 | 15 | 2880 | +0.0 | 112.1 | 28.1 | 84.0 |
| 5 | 40 | ZeRO-1 | 6 | 12 | 2880 | +0.0 | 111.9 | 27.9 | 84.0 |
| 6 | 48 | ZeRO-1 | 6 | 10 | 2880 | +0.0 | 111.8 | 27.8 | 84.0 |
| 7 | 56 | ZeRO-1 | 3 | 17 | 2856 | -0.8 | 69.6 | 27.6 | 42.0 |
| 8 | 64 | ZeRO-1 | 5 | 9 | 2880 | +0.0 | 97.6 | 27.6 | 70.0 |
| 9 | 72 | ZeRO-1 | 5 | 8 | 2880 | +0.0 | 97.5 | 27.5 | 70.0 |
| 10 | 80 | ZeRO-1 | 6 | 6 | 2880 | +0.0 | 111.5 | 27.5 | 84.0 |
| 11 | 88 | ZeRO-1 | 3 | 11 | 2904 | +0.8 | 69.4 | 27.4 | 42.0 |
| 12 | 96 | ZeRO-1 | 6 | 5 | 2880 | +0.0 | 111.4 | 27.4 | 84.0 |
| 13 | 104 | ZeRO-1 | 4 | 7 | 2912 | +1.1 | 83.3 | 27.3 | 56.0 |
| 14 | 112 | ZeRO-1 | 2 | 13 | 2912 | +1.1 | 55.3 | 27.3 | 28.0 |
| 15 | 120 | ZeRO-1 | 6 | 4 | 2880 | +0.0 | 111.3 | 27.3 | 84.0 |
| 16 | 128 | ZeRO-1 | 2 | 11 | 2816 | -2.2 | 55.3 | 27.3 | 28.0 |
| 17 | 136 | ZeRO-1 | 3 | 7 | 2856 | -0.8 | 69.3 | 27.3 | 42.0 |
| 18 | 144 | ZeRO-1 | 5 | 4 | 2880 | +0.0 | 97.3 | 27.3 | 70.0 |
| 19 | 152 | ZeRO-1 | 1 | 19 | 2888 | +0.3 | 41.2 | 27.2 | 14.0 |
| 20 | 160 | ZeRO-1 | 6 | 3 | 2880 | +0.0 | 111.2 | 27.2 | 84.0 |
| 21 | 168 | ZeRO-1 | 1 | 17 | 2856 | -0.8 | 41.2 | 27.2 | 14.0 |
| 22 | 176 | ZeRO-1 | 4 | 4 | 2816 | -2.2 | 83.2 | 27.2 | 56.0 |
| 23 | 184 | ZeRO-1 | 4 | 4 | 2944 | +2.2 | 83.2 | 27.2 | 56.0 |
| 24 | 192 | ZeRO-1 | 5 | 3 | 2880 | +0.0 | 97.2 | 27.2 | 70.0 |
| 25 | 200 | ZeRO-1 | 2 | 7 | 2800 | -2.8 | 55.2 | 27.2 | 28.0 |
| 26 | 208 | ZeRO-1 | 2 | 7 | 2912 | +1.1 | 55.2 | 27.2 | 28.0 |
| 27 | 216 | ZeRO-1 | 1 | 13 | 2808 | -2.5 | 41.2 | 27.2 | 14.0 |
| 28 | 224 | ZeRO-1 | 1 | 13 | 2912 | +1.1 | 41.2 | 27.2 | 14.0 |
| 29 | 232 | ZeRO-1 | 6 | 2 | 2784 | -3.3 | 111.2 | 27.2 | 84.0 |
| 30 | 240 | ZeRO-1 | 6 | 2 | 2880 | +0.0 | 111.2 | 27.2 | 84.0 |
| 31 | 248 | ZeRO-1 | 6 | 2 | 2976 | +3.3 | 111.1 | 27.1 | 84.0 |
| 32 | 256 | ZeRO-1 | 1 | 11 | 2816 | -2.2 | 41.1 | 27.1 | 14.0 |
| 33 | 264 | ZeRO-1 | 1 | 11 | 2904 | +0.8 | 41.1 | 27.1 | 14.0 |
| 34 | 272 | ZeRO-1 | 1 | 11 | 2992 | +3.9 | 41.1 | 27.1 | 14.0 |
| 35 | 280 | ZeRO-1 | 5 | 2 | 2800 | -2.8 | 97.1 | 27.1 | 70.0 |
| 36 | 288 | ZeRO-1 | 5 | 2 | 2880 | +0.0 | 97.1 | 27.1 | 70.0 |
| 37 | 296 | ZeRO-1 | 5 | 2 | 2960 | +2.8 | 97.1 | 27.1 | 70.0 |
| 38 | 304 | ZeRO-1 | 3 | 3 | 2736 | -5.0 | 69.1 | 27.1 | 42.0 |
| 39 | 312 | ZeRO-1 | 3 | 3 | 2808 | -2.5 | 69.1 | 27.1 | 42.0 |
| 40 | 320 | ZeRO-1 | 3 | 3 | 2880 | +0.0 | 69.1 | 27.1 | 42.0 |
| 41 | 328 | ZeRO-1 | 3 | 3 | 2952 | +2.5 | 69.1 | 27.1 | 42.0 |
| 42 | 336 | ZeRO-1 | 3 | 3 | 3024 | +5.0 | 69.1 | 27.1 | 42.0 |
| 43 | 344 | ZeRO-1 | 4 | 2 | 2752 | -4.4 | 83.1 | 27.1 | 56.0 |
| 44 | 352 | ZeRO-1 | 4 | 2 | 2816 | -2.2 | 83.1 | 27.1 | 56.0 |
| 45 | 360 | ZeRO-1 | 4 | 2 | 2880 | +0.0 | 83.1 | 27.1 | 56.0 |
| 46 | 368 | ZeRO-1 | 4 | 2 | 2944 | +2.2 | 83.1 | 27.1 | 56.0 |
| 47 | 376 | ZeRO-1 | 4 | 2 | 3008 | +4.4 | 83.1 | 27.1 | 56.0 |
| 48 | 384 | ZeRO-1 | 4 | 2 | 3072 | +6.7 | 83.1 | 27.1 | 56.0 |
| 49 | 392 | ZeRO-1 | 1 | 7 | 2744 | -4.7 | 41.1 | 27.1 | 14.0 |
| 50 | 400 | ZeRO-1 | 1 | 7 | 2800 | -2.8 | 41.1 | 27.1 | 14.0 |
| 51 | 408 | ZeRO-1 | 1 | 7 | 2856 | -0.8 | 41.1 | 27.1 | 14.0 |
| 52 | 416 | ZeRO-1 | 1 | 7 | 2912 | +1.1 | 41.1 | 27.1 | 14.0 |
| 53 | 424 | ZeRO-1 | 1 | 7 | 2968 | +3.1 | 41.1 | 27.1 | 14.0 |
| 54 | 432 | ZeRO-1 | 1 | 7 | 3024 | +5.0 | 41.1 | 27.1 | 14.0 |
| 55 | 440 | ZeRO-1 | 1 | 7 | 3080 | +6.9 | 41.1 | 27.1 | 14.0 |
| 56 | 448 | ZeRO-1 | 6 | 1 | 2688 | -6.7 | 111.1 | 27.1 | 84.0 |
| 57 | 456 | ZeRO-1 | 6 | 1 | 2736 | -5.0 | 111.1 | 27.1 | 84.0 |
| 58 | 464 | ZeRO-1 | 6 | 1 | 2784 | -3.3 | 111.1 | 27.1 | 84.0 |
| 59 | 472 | ZeRO-1 | 6 | 1 | 2832 | -1.7 | 111.1 | 27.1 | 84.0 |
| 60 | 480 | ZeRO-1 | 6 | 1 | 2880 | +0.0 | 111.1 | 27.1 | 84.0 |
| 61 | 488 | ZeRO-1 | 6 | 1 | 2928 | +1.7 | 111.1 | 27.1 | 84.0 |
| 62 | 496 | ZeRO-1 | 6 | 1 | 2976 | +3.3 | 111.1 | 27.1 | 84.0 |
| 63 | 504 | ZeRO-1 | 6 | 1 | 3024 | +5.0 | 111.1 | 27.1 | 84.0 |
| 64 | 512 | ZeRO-1 | 6 | 1 | 3072 | +6.7 | 111.1 | 27.1 | 84.0 |
| 65 | 520 | ZeRO-1 | 6 | 1 | 3120 | +8.3 | 111.1 | 27.1 | 84.0 |
| 66 | 528 | ZeRO-1 | 5 | 1 | 2640 | -8.3 | 97.1 | 27.1 | 70.0 |
| 67 | 536 | ZeRO-1 | 5 | 1 | 2680 | -6.9 | 97.1 | 27.1 | 70.0 |
| 68 | 544 | ZeRO-1 | 5 | 1 | 2720 | -5.6 | 97.1 | 27.1 | 70.0 |
| 69 | 552 | ZeRO-1 | 5 | 1 | 2760 | -4.2 | 97.1 | 27.1 | 70.0 |
| 70 | 560 | ZeRO-1 | 5 | 1 | 2800 | -2.8 | 97.1 | 27.1 | 70.0 |
| 71 | 568 | ZeRO-1 | 5 | 1 | 2840 | -1.4 | 97.1 | 27.1 | 70.0 |
| 72 | 576 | ZeRO-1 | 5 | 1 | 2880 | +0.0 | 97.1 | 27.1 | 70.0 |
| 73 | 584 | ZeRO-1 | 5 | 1 | 2920 | +1.4 | 97.1 | 27.1 | 70.0 |
| 74 | 592 | ZeRO-1 | 5 | 1 | 2960 | +2.8 | 97.1 | 27.1 | 70.0 |
| 75 | 600 | ZeRO-1 | 5 | 1 | 3000 | +4.2 | 97.1 | 27.1 | 70.0 |
| 76 | 608 | ZeRO-1 | 5 | 1 | 3040 | +5.6 | 97.1 | 27.1 | 70.0 |
| 77 | 616 | ZeRO-1 | 5 | 1 | 3080 | +6.9 | 97.1 | 27.1 | 70.0 |
| 78 | 624 | ZeRO-1 | 5 | 1 | 3120 | +8.3 | 97.1 | 27.1 | 70.0 |
| 79 | 632 | ZeRO-1 | 5 | 1 | 3160 | +9.7 | 97.1 | 27.1 | 70.0 |
| 80 | 640 | - | - | - | - | - | - | - | - |
| 81 | 648 | ZeRO-1 | 4 | 1 | 2592 | -10.0 | 83.1 | 27.1 | 56.0 |
| 82 | 656 | ZeRO-1 | 4 | 1 | 2624 | -8.9 | 83.1 | 27.1 | 56.0 |
| 83 | 664 | ZeRO-1 | 4 | 1 | 2656 | -7.8 | 83.1 | 27.1 | 56.0 |
| 84 | 672 | ZeRO-1 | 4 | 1 | 2688 | -6.7 | 83.1 | 27.1 | 56.0 |
| 85 | 680 | ZeRO-1 | 4 | 1 | 2720 | -5.6 | 83.1 | 27.1 | 56.0 |
| 86 | 688 | ZeRO-1 | 4 | 1 | 2752 | -4.4 | 83.1 | 27.1 | 56.0 |
| 87 | 696 | ZeRO-1 | 4 | 1 | 2784 | -3.3 | 83.1 | 27.1 | 56.0 |
| 88 | 704 | ZeRO-1 | 4 | 1 | 2816 | -2.2 | 83.1 | 27.1 | 56.0 |
| 89 | 712 | ZeRO-1 | 4 | 1 | 2848 | -1.1 | 83.1 | 27.1 | 56.0 |
| 90 | 720 | ZeRO-1 | 4 | 1 | 2880 | +0.0 | 83.1 | 27.1 | 56.0 |
| 91 | 728 | ZeRO-1 | 4 | 1 | 2912 | +1.1 | 83.1 | 27.1 | 56.0 |
| 92 | 736 | ZeRO-1 | 4 | 1 | 2944 | +2.2 | 83.1 | 27.1 | 56.0 |
| 93 | 744 | ZeRO-1 | 4 | 1 | 2976 | +3.3 | 83.0 | 27.0 | 56.0 |
| 94 | 752 | ZeRO-1 | 4 | 1 | 3008 | +4.4 | 83.0 | 27.0 | 56.0 |
| 95 | 760 | ZeRO-1 | 4 | 1 | 3040 | +5.6 | 83.0 | 27.0 | 56.0 |
| 96 | 768 | ZeRO-1 | 4 | 1 | 3072 | +6.7 | 83.0 | 27.0 | 56.0 |
| 97 | 776 | ZeRO-1 | 4 | 1 | 3104 | +7.8 | 83.0 | 27.0 | 56.0 |
| 98 | 784 | ZeRO-1 | 4 | 1 | 3136 | +8.9 | 83.0 | 27.0 | 56.0 |
| 99 | 792 | ZeRO-1 | 4 | 1 | 3168 | +10.0 | 83.0 | 27.0 | 56.0 |
| 100 | 800 | - | - | - | - | - | - | - | - |
Caveats:
- Tolerance band: The simulation uses a ±10 % tolerance band for the global batch size, but we don't know if this is acceptable in practice. It may be too wide or too narrow, depending on the model and training setup.
- Performance: The simulation doesn't directly account for training throughput implications of changing gradient accumulation steps or micro-batch size. It favours smaller ZeRO stages because they tend to be faster (especially with gradient accumulation), but this doesn't mean that this is the fastest configuration.
- Memory budget: The simulation assumes a fixed memory budget per GPU (e.g., 120 GiB for H200). This is realistic in our case, however, because our clusters are homogeneous.
- Other parallelization strategies: The simulation only considers ZeRO-style training. It may be possible to achieve better results with other parallelization strategies (e.g., pipeline parallelism) or a combination of strategies.
Python simulation code:
"""
Memory-aware batch-reconfiguration for ZeRO-style training.
Use this **re-target a training run** when the number of GPUs changes:
* Estimate per-GPU memory usage (parameters, gradients, optimiser state,
activations) for a given *ModelSpec* + *TrainingSpec*.
* Search the space
`(ZeRO stage, micro-batch, grad-accum steps)`
for a configuration that
1. keeps the global batch size within *± tolerance* of an original run,
2. fits under a user-supplied **memory budget**,
3. prefers lower ZeRO stages when several configs tie.
"""
import dataclasses
import math
import typing
import pandas
_BYTES_IN_GIB = 1024**3
@dataclasses.dataclass
class ModelSpec:
"""
Minimal set of numbers needed for memory arithmetic.
"""
hidden_size: int
num_hidden_layers: int
num_parameters: int # (includes embeddings, biases, etc.)
dtype_bytes: int = 2 # 2 for bf16, 4 for fp32
@dataclasses.dataclass
class TrainingSpec:
zero_stage: typing.Literal[0, 1, 2, 3]
world_size: int # total data-parallel replicas (w)
micro_batch_size: int
gradient_accumulation_steps: int
sequence_length: int
# byte widths
param_bytes: int = 2 # bf16 weights
grad_bytes: int = 4 # fp32 grads
optim_bytes: int = 4 # fp32 Adam slots
optim_state_slots: int = 2 # Adam
activation_constant: float = 16.0 # empirical constant for activation footprint
@property
def global_batch_size(self) -> int:
return self.world_size * self.micro_batch_size * self.gradient_accumulation_steps
def _shard_factor(self, kind: typing.Literal["param", "grad", "optim"]) -> float:
"""
Scaling induced by ZeRO sharding.
"""
zs, w = self.zero_stage, self.world_size
if kind == "param": # Mp
return 1.0 if zs < 3 else 1 / w
if kind == "grad": # Mg
return 1.0 if zs < 2 else 1 / w
if kind == "optim": # Mo
return 1.0 if zs < 1 else 1 / w
raise ValueError(f"Unknown kind: {kind}. Must be one of 'param', 'grad', or 'optim'.")
@dataclasses.dataclass
class MemoryFootprint:
state_gib: float
activations_gib: float
@property
def total_gib(self) -> float:
"""
Returns the total memory footprint of parameters, gradients, optimizer state, and activations.
"""
return self.state_gib + self.activations_gib
@classmethod
def from_specs(
cls,
model_spec: ModelSpec,
training_spec: TrainingSpec,
) -> "MemoryFootprint":
return cls(
state_gib=cls._compute_state_gib(model_spec, training_spec),
activations_gib=cls._compute_activations_gib(model_spec, training_spec),
)
@staticmethod
def _compute_state_gib(model_spec: ModelSpec, training_spec: TrainingSpec) -> float:
"""
Returns the memory footprint of parameters, gradients, and optimizer state.
"""
Mp = model_spec.num_parameters * training_spec.param_bytes * training_spec._shard_factor("param")
Mg = model_spec.num_parameters * training_spec.grad_bytes * training_spec._shard_factor("grad")
Mo = (
model_spec.num_parameters
* training_spec.optim_bytes
* training_spec.optim_state_slots
* training_spec._shard_factor("optim")
)
return (Mp + Mg + Mo) / _BYTES_IN_GIB
@staticmethod
def _compute_activations_gib(model_spec: ModelSpec, training_spec: TrainingSpec) -> float:
"""
Returns the memory footprint of forward/backward activations.
This is a rough estimate based on the number of layers, hidden size, and sequence length.
The formula is based on the assumption that each layer has a constant number of activations
and that the activations are stored in a contiguous block of memory.
"""
bytes_act = (
training_spec.activation_constant
* training_spec.micro_batch_size
* training_spec.sequence_length
* model_spec.hidden_size
* model_spec.num_hidden_layers
* model_spec.dtype_bytes
)
return bytes_act / _BYTES_IN_GIB
def reconfigure(
*,
model_spec: ModelSpec,
old_training_spec: TrainingSpec,
new_world_size: int,
tolerance: float,
mem_limit_gib: float,
max_micro_batch_size: int,
zero_stage_candidates: typing.Iterable[int] = (1, 2, 3),
) -> tuple[TrainingSpec, MemoryFootprint] | None:
"""
Return the feasible (spec, mem) with
• minimal |Δ batch|
• tie-break by lower ZeRO stage
or None if no solution exists.
"""
target = old_training_spec.global_batch_size
lower = math.floor(target * (1 - tolerance))
upper = math.ceil(target * (1 + tolerance))
best_key: tuple[int, int] | None = None # (abs_dev, zero_stage)
best_pair: tuple[TrainingSpec, MemoryFootprint] | None = None
max_micro = min(max_micro_batch_size, upper // new_world_size)
for zs in zero_stage_candidates: # 1 → 2 → 3
for micro in range(max_micro, 0, -1): # big → small
min_accum = math.ceil(lower / (new_world_size * micro))
max_accum = math.floor(upper / (new_world_size * micro))
if min_accum > max_accum or max_accum == 0:
continue
for accum in range(min_accum, max_accum + 1):
spec = TrainingSpec(
zero_stage=zs,
world_size=new_world_size,
micro_batch_size=micro,
gradient_accumulation_steps=accum,
sequence_length=old_training_spec.sequence_length,
)
mem = MemoryFootprint.from_specs(model_spec, spec)
if mem.total_gib > mem_limit_gib:
continue
dev = abs(spec.global_batch_size - target)
key = (dev, zs) # **comparable tuple**
if best_key is None or key < best_key:
best_key = key
best_pair = (spec, mem)
return best_pair
apriel_model_spec = ModelSpec(
hidden_size=4096,
num_hidden_layers=28,
num_parameters=4_832_071_680,
)
apriel_training_spec = TrainingSpec(
zero_stage=1,
world_size=480,
micro_batch_size=6,
gradient_accumulation_steps=1,
sequence_length=4096,
)
# mem = MemoryFootprint.from_specs(apriel_model_spec, apriel_training_spec)
# print(
# f"ZeRO-1 (w=480): "
# f"total {mem.total_gib:6.2f} GiB | "
# f"{mem.state_gib:6.2f} GiB params/grad/opt | "
# f"{mem.activations_gib:6.2f} GiB activations"
# )
tolerance = 0.10 # ±10 % global batch deviation allowed compared to the original
mem_limit_gib = 120.0 # budget per H200 (141 GiB total, leave some headroom)
max_micro_batch_size = 32 # search upper bound for micro‑batch
records = []
for nodes in range(1, 101):
new_world_size = nodes * 8
res = reconfigure(
model_spec=apriel_model_spec,
old_training_spec=apriel_training_spec,
new_world_size=new_world_size,
tolerance=tolerance,
mem_limit_gib=mem_limit_gib,
max_micro_batch_size=max_micro_batch_size,
)
if res:
new_training_spec, mem = res
deviation = (
(new_training_spec.global_batch_size - apriel_training_spec.global_batch_size)
* 100
/ apriel_training_spec.global_batch_size
)
records.append(
{
"Nodes": nodes,
"GPUs": new_world_size,
"ZeRO": f"ZeRO-{new_training_spec.zero_stage}",
"Micro-batch": new_training_spec.micro_batch_size,
"Grad. accum.": new_training_spec.gradient_accumulation_steps,
"Global batch": new_training_spec.global_batch_size,
"Δ global batch (%)": f"{deviation:+.1f}",
"Memory (total) GiB": f"{mem.total_gib:.1f}",
"Memory (params/grad/opt) GiB": f"{mem.state_gib:.1f}",
"Memory (activations) GiB": f"{mem.activations_gib:.1f}",
}
)
else:
records.append(
{
"Nodes": nodes,
"GPUs": new_world_size,
"ZeRO": "-",
"Micro-batch": "-",
"Grad. accum.": "-",
"Global batch": "-",
"Δ global batch (%)": "-",
"Memory (total) GiB": "-",
"Memory (params/grad/opt) GiB": "-",
"Memory (activations) GiB": "-",
}
)
df = pandas.DataFrame(records)
with pandas.option_context("display.max_rows", None, "display.max_columns", None):
print(df.to_markdown(index=False))
I've also checked how we can handle pre-emption cleanly. Good news: every batch system I know (K8s, SLURM, NGC, etc.) sends a SIGTERM into the container before it reclaims the node. torchrun just forwards that signal to every rank, then escalates to SIGKILL after ~30 s. (Timeout for process termination upon receiving sigterm should be a ...). That's enough time to flush a checkpoint if we react immediately.
For this to work we need to launch in elastic mode:
torchrun \
--nnodes=1:100 # min:max nodes
--nproc-per-node=gpu \
--rdzv-backend=c10d \
--rdzv-endpoint=$MASTER_ADDR:29400 \
--max-restarts=1000 \
fast-llm ...
torchrun will kill-and-respawn the whole worker set whenever membership changes (node lost or added), see torchrun (Elastic Launch).
Inside our code, we do something like this:
import signal, sys, torch.distributed as dist
def _handler(*_):
save_checkpoint() # must finish <30 s
dist.destroy_process_group()
sys.exit(0) # clean exit -> torchrun restarts
signal.signal(signal.SIGTERM, _handler)
And at start-up, we:
- load latest checkpoint,
- rebuild everything with new world size,
- resume training.
If 30 s is tight, we will need asynchronous checkpointing (spawn a background thread that streams shards). That's something we had on the roadmap anyway at some point.
That's really all that's required for a restart-based elastic MVP: catch SIGTERM, dump state fast, rely on torchrun --nnodes=min:max to relaunch on the new set of nodes.
As a quick follow-up: I went back to #26 and also reviewed the Apriel-5B training logs to check checkpointing speeds under realistic conditions. During Apriel-5B training, all ranks were saving to a shared cluster filesystem. The checkpoint payload per rank (weights + optimizer shard) was about 9.4 GiB. Saving the checkpoint completed in about 2 seconds across all ranks, implying an effective write bandwidth well within practical limits. Even assuming a much lower bandwidth of ~1 GiB/s per GPU, each shard would flush in under 10 seconds. The average step time was ~1.8 seconds, so even accounting for finishing the current step before checkpointing, we stay well within the ~30 second grace window after receiving SIGTERM. Even with gradient accumulation, we can safely trigger checkpointing after any completed forward-backward pass (i.e., after any microbatch), without waiting for a full optimizer step. Thus, longer accumulation (e.g., across many microbatches) does not block or meaningfully delay checkpointing. In short, checkpointing performance is not a blocker for the elastic MVP. Serialization overhead exists (as noted in #26), but in practice, Fast-LLM already saves fast enough to safely tolerate preemptions without requiring asynchronous checkpointing. We can proceed with synchronous checkpointing for the restart-based elasticity MVP. Async checkpointing remains a worthwhile future enhancement for larger models or slower filesystems, but is not a requirement for the MVP.
Thanks for the detailed analysis, it elaborates my suggestions in great details. Some comments:
- We can easily do a detailed memory analysis for specific models, but it's harder in the general case. To keep things simple I suggest relying on a few manually set parameters, a minimum node count (1 for small models) and max micro-batch size (6 in the example above) should be enough for most cases, then Fast-LLM can take it from there.
- The 10% threshold should be ok as long as we keep
micro_batch_size * grad_acc_steps >= some_threshold(akabatch_size/max_micro_batch_size/max_nodes >= some_threshold, around 5 according to the analysis above (or I guess(2 x tolerance)**-1), assuming we only adjust these two parameters . Below that we might have to ignore the threshold or keep some nodes idle to stay within range. In following steps we could also use things like sequence-data-parallelism or micro-sequences to increase the effectivemicro_batch_size. - I knew about the sigkill on our infrastructure (I think it gives 1 minute), good to know we have it elsewhere. 30s should indeed almost always be enough, the only potential issue is with very large models trained with few nodes, but that's a mostly useless case anyway because bigger models need more nodes for other reasons.
- This all looks relatively easy to implement from the Fast-LLM side (torch elastic handles most of it), the reason we haven't done it already is the lack of support from the gpu allocator mechanism that prevents us from making use of it. Did anything change on that side?
Thanks @jlamypoirier, agreed. We will encode the logic with the proposed threshold-based tolerance, default micro-batch and gradient-accumulation-step caps, and fallback strategies if limits are exceeded. We will also make sure this degrades gracefully even under awkward configs. I'd like to keep SDP and other extensions are out of scope for now, though. Elastic support with the current allocator mechanism remains limited, but we are exploring other runtime contexts where the proposed feature will be directly usable. Since torchrun handles the orchestration and SIGTERM checkpointing is easy to implement, we can go ahead and prototype this end-to-end.