torchft
torchft copied to clipboard
Added proactive heartbeat timeout failure propagation (#164) (#188)
Overview
This PR improves failure detection speed of torchFT through proactive failure recovery. The Manager now listens to Lighthouse failure notifications and aborts hanging collectives immediately instead of waiting for NCCL/Gloo time-outs.
Basic demonstration
You can experiment with proactive failure recovery mode by:
export TORCHFT_PROACTIVE_RECOVERY=1
With this enabled, the manager will listen to the Lighthouse server for heartbeat failures of other replica groups and break from a hanging allreduce.
You can test this out by running train_ddp_proactive.py
On shell 1 (one replica groups starts initial training):
export REPLICA_GROUP_ID=0
export NUM_REPLICA_GROUPS=2
export TORCHFT_PROACTIVE_RECOVERY=1
CUDA_VISIBLE_DEVICES=0 TORCHFT_LIGHTHOUSE=http://localhost:29510 torchrun --master_port=29600 --nnodes=1 --nproc_per_node=1 -- train_ddp_proactive.py
On shell 2 (a second replica group joins):
export REPLICA_GROUP_ID=1
export NUM_REPLICA_GROUPS=2
export TORCHFT_PROACTIVE_RECOVERY=1
CUDA_VISIBLE_DEVICES=1 TORCHFT_LIGHTHOUSE=http://localhost:29510 torchrun --master_port=29601 --nnodes=1 --nproc_per_node=1 -- train_ddp_proactive.py
You should observe that the process with replica group id 1 will exit early, and the process with replica group id 0 will quickly resume training. If the same script is ran with after setting export TORCHFT_PROACTIVE_RECOVERY=0, you should observe that the process with replica group id 1 will hang for dozens of seconds before continuing.
INFO:torchft.manager:[train_ddp_0:81a52ce4-d803-4f22-a0c3-54f3b4a88c89/0 - step 10] Setting error processor thread stop event
INFO:torchft.manager:[train_ddp_0:81a52ce4-d803-4f22-a0c3-54f3b4a88c89/0 - step 10] Waiting for error processor thread to complete
INFO:torchft.manager:[train_ddp_0:81a52ce4-d803-4f22-a0c3-54f3b4a88c89/0 - step 10] Error processor thread shutdown completed.
INFO:torchft.manager:[train_ddp_0:81a52ce4-d803-4f22-a0c3-54f3b4a88c89/0 - step 10] Setting failure listener stop event for process
INFO:torchft.manager:[train_ddp_0:81a52ce4-d803-4f22-a0c3-54f3b4a88c89/0 - step 10] Waiting for failure listener process to complete
INFO:torchft.manager:[train_ddp_0:81a52ce4-d803-4f22-a0c3-54f3b4a88c89/0 - step 10] Failure listener process shutdown completed
And in the Lighthouse you will observe:
2025-05-20T22:29:30.029 [INFO] [torchft::lighthouse] - Replica train_ddp_1:a581dae2-1ebc-4f93-b882-6477832fef6b timed out (last heartbeat: Instant { tv_sec: 5200692, tv_nsec: 955240591 }), sending failure notification.
2025-05-20T22:29:30.029 [INFO] [torchft::lighthouse] - Removed replica train_ddp_1:a581dae2-1ebc-4f93-b882-6477832fef6b from heartbeats and participants due to timeout.
2025-05-20T22:29:30.029 [INFO] [torchft::lighthouse] - New failure detected, resetting all participants for quorum formation.
2025-05-20T22:29:30.029 [INFO] [torchft::lighthouse] - Healthy replicas received failure notification for train_ddp_1:a581dae2-1ebc-4f93-b882-6477832fef6b
Implementation
Implementation Details
Implementation Details:
The proactive failure recovery mechanism involves changes in both the Rust backend and the Python Manager:
Rust:
src/lighthouse.rs:- The
Lighthouseserver now includes afailure_channel(a Tokio broadcast channel). - When
_failure_tickdetects a timed-out replica, it broadcasts aFailureNotificationon this channel. - A new gRPC method,
subscribe_failures, is added toLighthouseService. Clients can call this to receive a stream ofFailureNotifications. - The
inject_failuremethod has been added to theLighthouseServer(Python-exposed) andLighthouse(Rust struct) to facilitate testing by manually triggering failure notifications.
- The
src/lib.rs:- A
FailureStreamclass is introduced, wrapping thetonic::Streaming<ProtoFailureNotification>. Its__next__method allows Python to iterate over failure notifications. This method usespy.allow_threadsaround a blockingruntime.block_on(fut)call to fetch the next notification, allowing the GIL to be released.
- A
Python (Manager):
torchft/manager.py:- When
proactive_recoveryis enabled (via constructor argument orTORCHFT_PROACTIVE_RECOVERY=1environment variable), theManagerspawns a separate daemon process (_failure_listener_process_main). Subprocess based subscription: This process creates aLighthouseClientand callssubscribe_failures. It then iterates over the received failure notifications.Inter-Process Communication (IPC):_ManagedPipeis used for the listener process to send errors it receives from theLighthousethrough the stream returned bysubscribe_failuresback to the mainManagerprocess. This mimics the implementation of IPC inBabyProcessGroup- Error Listening: A new thread within the main
Managerprocess continuously polls the_error_pipe. - Error Response: If an exception is received, it calls
self.report_error()and aborts the underlying process group (self._pg.abort()). - Error Reporting:
self.report_error()is now also used to flag the manager as errored when a proactive failure is detected. - Shutdown:
Manager.shutdown()is enhanced to gracefully stop the_error_processor_threadand the_failure_listener_process. - The
subscribe_timeoutparameter forsubscribe_failuresin_failure_listener_process_mainallows the listener process to be interruptible for clean shutdown.
- When
Design Rationale
I decided to use a separate process to subscribe to the failure notification because waiting on the failure stream is a blocking call. Because of the GIL, if one waits using a Python thread then it will block the main thread from functioning.
As I was implementing it, I considered three ways to implement this:
- GIL Release in Rust Stream Iteration: Decouple the Python logic from the tokio streaming logic so that the GIL can be released in
lib.rs. - Asyncio: Use
pyo3-asyncioto create an async iterator from tokio-stream. - Multiprocessing: Use a separate process to subscribe to the failure notification.
Approach 1 and 2 are more elegant and should be more efficient as they do not involve spawning a separate process. However, I am limited by my Rust langauge understanding and was unable to implement them.
Tests
I introduced the following tests:
- Rust:
src/lighthouse.rs:test_subscribe_failures_delivers_notifications: Verifies thatinject_failurecorrectly sends a notification that is received by a subscriber.test_failure_tick_single_notification_and_cleanup: Ensures_failure_tickcorrectly identifies timeouts, broadcasts notifications once, and cleans up state.
- Python:
torchft/lighthouse_test.py:test_subscribe_failures_notification: Python-level test ensuringLighthouseClient.subscribe_failuresreceives notifications triggered byLighthouseServer.inject_failure.test_inject_failure: Confirms thatserver.inject_failure()leads to a notification being received byclient.subscribe_failures().
torchft/manager_test.py:test_manager_error_handler: Tests that theManagerprocesses exceptions passed to its internal error handler.test_direct_error_pipe: Verifies that an exception sent directly via the IPC pipe is correctly picked up by theManager.test_manager_failure_e2e: An end-to-end test whereLighthouseServer.inject_failuretriggers a notification that propagates through the listener process, IPC pipe, and results in theManagercapturing the error.
Linter
I am still getting the following error after running lintrunner -a, but I couldn’t debug it:
Advice (pyre) command-failed
Failed due to JSONDecodeError:
Expecting value: line 1 column 1 (char 0)
Successfully applied all patches.
Other minor changes
Note: In order to test the code using train_ddp.py, I fixed an error introduced by commit 652a00948fbd96d20fbc0e361da6026f2bf4dbba and changed the api of DistributedSampler to use replica_group_id.