torchft icon indicating copy to clipboard operation
torchft copied to clipboard

Explain quorum

Open rualark opened this issue 7 months ago • 14 comments

There are a lot of explanations about quorum in the design doc and in the code comments. But I did not find a place where it is explained "quorum of what" is meant: I assume that this means the minimum number of healthy (that reply to lighthouse or reply with non-error code?) ranks?

I assume this means that min_replica_size is this minimum number and it has to be not higher than num_replicas (which is the initial number of ranks in the replication group)?

Quorum id seems to play significant role in torchft - IIUC it is used to detect changes and isolate recovered nodes of the previous quorum?

It may be beneficial to add answers to these questions to design doc or README.

rualark avatar Apr 15 '25 17:04 rualark

Hi, I attached a diagram that I made to help me understand the quorum. I write out the conditions for quorum in pseudocode there.

Image

In terms of quorum_id, here are the relevant code snippets:

In the lighthouse, the quorum_id is incremented when the quorum changes:

            // only increment quorum ID if something about the quorum
            // changed (members/addresses/etc)
            if state.prev_quorum.is_none()
                || quorum_changed(
                    &participants,
                    &state.prev_quorum.as_ref().unwrap().participants,
                )
            {
                state.quorum_id += 1;
                info!(
                    "Detected quorum change, bumping quorum_id to {}",
                    state.quorum_id
                );
            }

Note however that there is a subtlety of Quorum. There is a FastQuorum algorithm that returns the quorum immediately if all participants of the previous quorum joins this quorum. This decreases the quorum overhead, and prevents JoinTimeOut of waiting for processes that are heartbeating but not sending quorums (potentially because the heartbeat thread is alive but the quorum thread is dead). The second point was pointed out to me by @d4l3k.

Given that the quorum_id changes, in manager.py a new store–prefix is used to open up a new logical key space for the process groups to configure itself.

Thus, the quorum_id is used to signify whether process group reconfiguration is needed by the processes in the quorum.

        if quorum_id != self._quorum_id:
            store_prefixed_addr = f"{store_address}/torchft/{quorum_id}/{self._rank}"

            self._logger.info(f"reconfiguring for {quorum_id=} {store_prefixed_addr=}")
            # We use the replica rank and world as we want all replicas in the PG.
            # TODO: handle configure errors
            self._pg.configure(store_prefixed_addr, replica_rank, replica_world_size)
            self._quorum_id = quorum_id

WarrenZhu050413 avatar Apr 23 '25 13:04 WarrenZhu050413

Thank you Warren. I think quorum should be explained properly in the documentation (README or doc). Right now there are some pieces of explanation in different places and it is difficult to construct a full picture out of it.

rualark avatar Apr 24 '25 18:04 rualark

@d4l3k I would be happy to take a stab at this.

I think what may make the quorum protocol clearer (at least to a certain audience) is making explicit how it is related to quorum based commit protocol in distributed database.

My current best understanding of torchFT's distributed computing model is the following:

TorchFT's Commit Protocol

  1. We start a transaction at the beginning of the forward pass.

  2. We commit to the transaction if manager.allreduce succeeds and live recovery succeeds and should_commit() returns True.

  • We are guaranteed that if any participants have an error, then manager.allreduce will not succeed in their DDP group due to the following logic:
        if self.errored():
            fut = torch.futures.Future()  # pyre-fixme[29]: not a function
            fut.set_result(tensor)
            return fut
  • We are further guaranteed that others in their FSDP group will not succeed because should_commit() will not return true.
  1. We commit our transaction through optimizer.step().

Upon failure, we abort the transaction and restart our iteration (from the forward pass at 1).

Explaining Quorum

What took me a while to understand at the beginning is that the quorum with the lighthouse does not determine whether we are commiting or aborting the transaction. Instead, makes it possible to recover from faults, and also helps to guarantee the correctness of the allreduce result by doing:

Fault Recovery: a) Quorum reconfigures the process group to exclude potentially failed replica groups

Correctness: b) If we are healing, then quorum ensures that we do not contribute to the allreduce results (by doing tensor.zero_()). c) Normalizing the allreduced gradient correctly with the quorum size.

Other things that could make quorum easier to understand

  • It took me a while to realize that the result of should_commit() is not determined by the Lighthouse but by the ManagerServer. I wonder how this could be made clearer.
  • The fact that the recovery logic is folded into the quorum computation logic in the manager.rs code and the _async_quorum in manager.py also made it a bit difficult for me to parse the purpose of the quorum. The quorum seems to have two separate tasks (though they share a lot of common work, so it makes sense to do them together): Doing live checkpoint recovery, determining the process group members, and ensuring the correctness of all reduce.

WarrenZhu050413 avatar Apr 26 '25 01:04 WarrenZhu050413

Here is a potentially clarifying diagram that draws out the dependencies between the different streams (here I don't mean CUDA streams, but generic sequential computation process).

Image

WarrenZhu050413 avatar Apr 29 '25 01:04 WarrenZhu050413

@WarrenZhu050413 that explaination + diagram makes sense to me. The documentation has been a bit sparse so would be great to have some more indepth docs beyond just the design doc. It's a good callout that should_commit is local to the replica group where as quorum is cross group.

I have heard some feedback about the recovery logic, maybe we should move that into Python so it's easier for folks to find/discover to detangle quroum from recovery.

d4l3k avatar Apr 29 '25 23:04 d4l3k

@d4l3k

Sounds great. Should I add a /concepts folder in the docs? I am considering having one file explaining the different concurrent streams in torchFT, and another explaining the quorum, similar to the writeup above. Does this make sense?

I have two feedbacks about live checkpoint recovery:

  1. Is there any previous research applying similar techniques/comparisons between other techniques and live checkpoint recovery?

One question that I had first encountering torchFT is how live checkpoint recovery relates to previous work.

The one work I found is SWIFT: Expedited Failure Recovery for Large-scale DNN Training, which "exploits the replicas of the model state in data parallelism for failure recovery". However, instead of waiting for all-reduce to finish before executing optimizer.step(), it reverses the step.

  1. How to decouple quorum and checkpoint logic?

It seems difficult to move the recovery logic into Python since it could require sending the recovery assignments down to the individual clients. I wonder whether one could expand the quorum RPC to a more general controller RPC, within which the quorum logic and live checkpoint logic are implemented (with the live checkpoint logic waiting for the quorum logic).

WarrenZhu050413 avatar Apr 30 '25 02:04 WarrenZhu050413

+1 to this. Would be also good to document what kind of guarantees/invariants we're trying to ensure and some explanation on why the protocol is correct and ensures they are not violated.

tushar00jain avatar May 03 '25 19:05 tushar00jain

+1 to this. Would be also good to document what kind of guarantees/invariants we're trying to ensure and some explanation on why the protocol is correct and ensures they are not violated.

What would be some of these guarantees/invariants?

I listed out some below. Would love to hear what I missed.

Collective Communications Guarantee

  1. Only the max_step (i.e. "up to date") participants contributes to gradient all-reduce
  2. Gradient all-reduce is normalized correctly by the number of contributing participants

Commit guarantee

  1. All ranks in a replica group will do optimizer.step() together.

Dataloader/Global Batch Size guarantee

  1. torchFT currently does not guarantee global batch size, or loading all possible data. See #186.

Derivative Invariants:

  1. All ranks in a replica group should have the same step() (Currently TODO in the code: https://github.com/pytorch/torchft/blob/main/src/manager.rs#L270)
  2. All participants in gradient all-reduce have the same step()
  3. All training processes with the same group_rank and step() has the same model weights and optimizer state

WarrenZhu050413 avatar May 05 '25 08:05 WarrenZhu050413

What would be some of these guarantees/invariants?

I'll have to think about this a bit 😅 The main invariant we'd want to prove is that all replicas have the same parameters (or zero) after every step. Any other invariant that's necessary for the main invariant could be useful. What's also important is the assumptions that are required for the invariant to hold e.g. we need to have 1 replica common between 2 consecutive steps, otherwise 2 groups of replicas can flip flop their health and end up with different parameters.

tushar00jain avatar May 09 '25 05:05 tushar00jain

All replicas have the same parameters (or zero) after every step.

@tushar00jain I gave writing out the derivation a shot.

It seems to me that this invariant is derived from the following four guarantees:

  1. Synchronized Steps: All participating processes are at the same step before optimizer.step() is called.
  2. Synchronized State: All processes at the same step has the same model and optimizer states.
  3. Consistent Updates: Each optimizer.step() updates a process identically if they are at the same step.
  4. Identical Initial State: All processes begin with the same model state at step 0.

If these three guarantees hold, the main invariant is upheld: starting from an identical state (4), all participating processes reach the same step before an update (1), and have the same model states (2). Since the update operation affects each process identically (3), all participating processes will have the same updates from the same starting state, resulting in identical parameters.

Current Implementation of Guarantees in torchFT:

  • (1) Synchronized Steps: This is currently guaranteed by the quorum process. Within this process, all processes that have reached the max_step perform live checkpoint recovery for all other processes, bringing them to a consistent step (since step info. is contained in the transported state_dict).

  • (2) Synchronized State: This is guaranteed by the live checkpoint recovery process which transports the optimizer and model states.

  • (3) Consistent Updates: This is guaranteed because all participating processes are part of the same process group. They execute allreduce on the same gradient tensors, and optimizer.step() applies the same modifications given identical model and optimizer states (which are synchronized by guarantee (1) and the gradients).

  • (4) Identical Initial State: This is achieved by first initializing one replica group and then propagating its state to other groups as they join. Alternatively, a common initialization method with a shared seed can be employed.

Note:

Currently none of these guarantees hold if silent data corruption happen.

WarrenZhu050413 avatar May 09 '25 23:05 WarrenZhu050413

@WarrenZhu050413 That looks good. I was more curious about the assumptions e.g. this is seems like a strong statement

Currently none of these guarantees hold if silent data corruption happen.

Particularly on,

Synchronized Steps

This may not hold and we have to declare what we're assuming. Here's a simple example from @d4l3k. Assume 2 replicas, a and b.

  1. step 10: a, b commit
  2. a hangs
  3. step 11: b commits 4: b hangs, a recovers
  4. step 11: a commits

I think it works if we assume there's at least 1 common replica between each consecutive step. This will change the proof of Synchronized Steps. Or maybe there's a weaker assumption we could make, that may or may not involve changing the protocol.

tushar00jain avatar May 12 '25 19:05 tushar00jain

@tushar00jain thanks for the feedback!

Could you elaborate on changing the proof part? I didn't quite get how the example and the synchronized step interact.

WarrenZhu050413 avatar May 17 '25 12:05 WarrenZhu050413

@WarrenZhu050413

With the current implementation, the example I mentioned is a counter-example to Synchronized Steps i.e. Synchronized Steps doesn't hold. There are some replicas, not participating in quorum, that would not be brought to a consistent state.

So, we add the assumption -- there's at least 1 common replica between each consecutive step. We update the proof as follows,

This is currently guaranteed by the quorum process. Within this process, all processes that have reached the max_step perform live checkpoint recovery for all other participants, bringing them to a consistent step (since step info. is contained in the transported state_dict). If another membership were to be formed afterwards, it must include at least 1 participant from this set of participants. This common participant will be at the highest step and perform recovery for all participants in the new membership thereby also bringing them to a consistent state as the membership at the previous step.

Now, the example I mentioned is not valid since it doesn't adhere to the assumption.

tushar00jain avatar May 27 '25 22:05 tushar00jain

@tushar00jain

Yes! This is totally right. I haven't thought of this point before.

WarrenZhu050413 avatar Jun 01 '25 15:06 WarrenZhu050413