[REQUEST] Allow fallback to torch.distributed.all_reduce for multi-node inference.
Is your feature request related to a problem? Please describe. At present, DeepSpeed’s inference communication backend defaults to using the SHM-based operation torch.ops.deepspeed.inference_all_reduce_ for performing all_reduce operations when the shared memory (SHM) op is available. This default behavior, while effective for intra-node communication (e.g., dual-socket CPU machines), becomes problematic in multi-node inference scenarios.
The SHM implementation is currently tailored with x86-specific intrinsics, and as such, it’s incompatible with alternative CPU architectures like Arm. Moreover, based on my exploration, there’s no fallback mechanism to explicitly select a different communication backend (like torch.distributed.all_reduce) even when SHM is present. This limits flexibility for multi-node or non-x86 deployments where SHM may not work or be optimal.
Describe the solution you'd like
We propose modifying the inference_all_reduce method in deepspeed/comm/torch.py to introduce a runtime-aware mechanism that conditionally chooses between the SHM and PyTorch’s default distributed backends based on node locality.
A sample implementation might look like this:
def inference_all_reduce(self, tensor, op, group=None):
all_local_ranks = True
if group != None:
world_size = torch.distributed.get_world_size(group=group)
if (world_size > int(os.getenv("LOCAL_SIZE"))):
all_local_ranks = False
if not hasattr(torch.ops, 'deepspeed') or not hasattr(torch.ops.deepspeed, 'inference_all_reduce_') or not all_local_ranks:
op = self._reduce_op(op)
return torch.distributed.all_reduce(tensor=tensor, op=op, group=group, async_op=False)
else:
return torch.ops.deepspeed.inference_all_reduce_(tensor)
Describe alternatives you've considered - but not tested
- Add a runtime flag or env var (e.g., DS_SHM_COMM_ALL_REDUCE_OFF) to disable SHM-based collectives and fallback to torch.distributed ops.
- Introduce a parameter to deepspeed.init_inference() to control SHM usage, useful for hybrid systems (e.g., dual-socket or mixed-arch setups).
- Skip SHM build in setup.py for unsupported architectures like arm to avoid x86-intrinsic issues.
Additional Context We are working on enabling multi-node inference for large language models on CPU-only clusters, including support for Arm processors. The current behavior of using SHM collectives creates runtime issues. Adding flexibility in collective selection or building would improve support.
Hi @phalani-paladugu thanks for the suggestion. Agree that fallback with multinode should be added to support multi-node inference. For single node SHM, I notice that there are RISCV implementation, is it possible to do the same for ARM architecture?
@delock Yes, SHM support for Arm is currently a work in progress.