pytorch-lightning
pytorch-lightning copied to clipboard
[Fabric Lightning] Named barriers
Description & Motivation
To prevent ranks losing alignment due to user error -- it would be beneficial to have named barriers with lightning allowing nodes to move forward only if same barrier name is met.
Pitch
For example:
if fabric.global_rank == 0:
fabric.barrier("rank_0")
else:
fabric.barrier("not_rank_0")
will fail in this case, and upon timeout each rank will raise an error with the barrier at which it is held up.
This is as opposed to potential user error where due to incorrect logic the various ranks might go different paths, reach some other barrier which in turn enables the whole flow to continue.
An issue that will likely repeat itself is with fabric.save
. It is not obvious to new users (that don't dig into the documentation) that this should be called in all nodes, as it implements its own internal barrier call.
A typical mistake would be to construct
if fabric.global_rank == 0:
fabric.save(...)
fabric.barrier()
do_training_stuff
fabric.barrier()
In this case, rank 0 will start to lag behind as it performs an additional barrier call.
If fabric.save
would implement fabric.barrier("save")
then the above program would exit printing that there is an alignment issue.
Alternatives
No response
Additional context
https://github.com/Lightning-AI/pytorch-lightning/issues/19780
cc @borda @awaelchli