DeepSpeed icon indicating copy to clipboard operation
DeepSpeed copied to clipboard

[REQUEST] Auto-Tuning CPU Core Binding for DeepSpeed&ZenFlow

Open Antlera opened this issue 4 months ago • 32 comments

Description

Currently, DeepSpeed offers --bind_cores_to_rank and --bind_core_list flags to bind CPU cores, but these require explicit specification from the user. While core binding works, it is not fully automated and does not adapt dynamically to different NUMA node configurations or CPU layouts.

Problem

The current core binding functionality requires users to manually specify which cores to bind to each worker. This approach does not provide flexibility in handling complex CPU layouts, particularly when dealing with non-sequential core IDs or varying NUMA node configurations.

Proposed Discussion

We may explore the possibility of auto-tuning CPU core binding to find optimal strategies automatically. This would involve:

  • Automatic Core Binding: Automatically determine the best core binding across NUMA nodes without requiring user input.

  • Adaptation to Non-Sequential Core Layouts (due to logical cores): Account for irregularities in core numbering (e.g., non-sequential core IDs within a NUMA node) to ensure efficient binding.

  • Maximizing Core Utilization: Ensure that the binding maximizes CPU utilization and minimizes memory contention across workers.

Possible Context

Deepspeed CPU binding: https://github.com/deepspeedai/DeepSpeed/blob/master/deepspeed/utils/numa.py#L117 ZenFlow CPU binding: zenflow_optimizer_process() at #7391

Antlera avatar Aug 09 '25 04:08 Antlera

@delock @sfc-gh-truwase @tohtana, I've moved the discussion to this issue, feel free to continue the conversation here.

Antlera avatar Aug 09 '25 04:08 Antlera

I build a benchmark that test CPUAdam performance seperately. The tensor size is to simulate Qwen2.5-3B model running on multiple cards.

import torch
import deepspeed
from deepspeed.ops.op_builder import CPUAdamBuilder
import deepspeed.comm as dist
import time
import argparse

parser = argparse.ArgumentParser(description="test cpu adam performance")
parser.add_argument("--local_rank", type=int, help="number of ranks")
parser.add_argument("--size", type=int, default=1542969344*2, help="number of ranks")
args = parser.parse_args()

deepspeed.init_distributed()
if args.local_rank == None:
    world_size = 1
else:
    world_size = dist.get_world_size()
adam = CPUAdamBuilder().load()
adam.create_adam(0,2e-05,0.9,0.999,1e-08,0.01,True,True)
size = args.size//world_size
pdata = torch.ones(size)
pgrad = torch.ones(size)
exp_avg = torch.ones(size)
exp_avg_sq = torch.ones(size)
dist.barrier()
t0 = time.time()
times = []
for i in range(20):
        adam.adam_update(0,1,2e-05,0.9,0.999,1e-08,0.01,True, pdata, pgrad, exp_avg, exp_avg_sq)
        t1 = time.time()
        times.append(t1-t0)
        t0 = t1
#print (times)
print (sum(times)/len(times))

The run command should be:

DS_ACCELERATOR=cpu deepspeed --bind_cores_to_rank --num_gpus <num_cpu_workers> test_cpuadam.py

When running with above command on a 2Sx64 core machine with --num_gpus 2, the average adam_update time is ~0.6s. This can be looked as a roofline when we evaluate performance in Zero Offloading and ZenFlow. In finetune_llama case in DSE, when finetune Qwen2.5-3B model with 2 cards on the same machine, we have the following adam_update time: zero offload -- 0.805s zenflow -- I didn't find a way to print out the time yet.

So for ZO CPU adam time is longer than roofline.

delock avatar Aug 09 '25 09:08 delock

@Antlera there is a way to know which two cores are virtual core of the same physical core lscpu --extended The result would list out core ID of each CPU. Sibling cores has the same core ID.

It occurs to me that physical cores are always consequtive. Say the system have two sockets, each socket have 64 physical cores. Then CPU 0-127 belongs to different physical cores. CPU 128-255 are sibling cores of core 0-127.

Thus DeepSpeed get the number of physical cores and do CPU core bindings of these physical cores. This should be a reasonable way to bind cores to different workers and let each worker use different set of physical cores.

delock avatar Aug 09 '25 16:08 delock

If I understand correctly, zenflow optimizer runs in a seperate process. So this is the process that needs core binding. However this also makes it running in parallel with pytorch, which has a couple of active threads, i.e. submit kernel to GPU etc. So a proper solution should have seperate core binding between pytorch threads and zenflow threads. https://github.com/deepspeedai/DeepSpeed/pull/7391/files#diff-c36f7081aa96d0cca5833845684635eddfe454beb4a76c826bb3b91d1ac38401R811

delock avatar Aug 10 '25 04:08 delock

@Antlera I gave CPUAdam benchmark an update. We defined a problem statement as "Given 3B parameters on CPU memory, how to update them as fast as possible."

Given this problem statement, the benchmark code above is designed to update a tensor size of 3B/nranks (nranks = number of workers = num_gpus) data among workers in parallel. The average time spent on each worker is printed out as result.

Here is CPUAdam time with different number of ranks. (Running on Xeon 6972P * 2S) nranks = 1: 0.2s nranks = 2: 0.189s nranks = 4: 0.188s nranks = 8: 0.187s

I'll be travel for next 2 weeks. We can continue our discussion then. If we are able to shorten optimizer time through a high CPU throughput design, then hide this latency using your ZenFlow tehnique, we may be able to achieve very low offload finetuning overhead.

delock avatar Aug 11 '25 03:08 delock

@delock Thanks a lot for the quick and thoughtful feedback! This benchmark looks great — I’ll incorporate the results and also refine the ZenFlow benchmarks and CPU-binding part accordingly. I’ve also got some deadlines coming up soon, but let’s keep in touch on this topic. I’ll work on optimizing this part — it will be super helpful for our efforts.

Antlera avatar Aug 11 '25 03:08 Antlera

Hi @Antlera . Some of my thoughts on CPU affinity for DeepSpeed+ZenFlow

  1. Each ZenFlow optimizer worker needs to run on seperate set of physical CPU cores.
  2. OMP_NUM_THREADS needs to be set to the number of cores binded to each ZenFlow optimizer worker.
  3. DeepSpeed worker needs to avoid use these cores bind to ZenFlow optimizer worker.

So a possible solution could be:

  1. Still use bind_cores_to_rank to start DeepSpeed.
  2. Before spawning ZenFlow optimizer, get current cpu bind list and OMP_NUM_THREADS, reserve a couple of cores for DeepSpeed itself through set affinity, then pass the cpu list to ZenFlow optimizer.
  3. After ZenFlow optimizer worker started, set CPU affinity and OMP_NUM_THREADS according to the cpu list passed.

delock avatar Aug 21 '25 11:08 delock

I created a branch that implementes core seperation between ZenFlow workers and DeepSpeed workers. https://github.com/deepspeedai/DeepSpeed/tree/gma/zenflow_affinity

I use Qwen2.5-3B to test the performance. Finetune 50 steps then compute average step time between step 5-50.

Master (Zenflow):

  • No overlap: 1564.24ms
  • Overlap: 1381.41ms zenflow_affinity (use deepspeed core binding, leave two cores per DS worker, the rest to ZenFlow worker)
  • No overlap: 1343.63ms
  • Overlap: 1216.65ms

delock avatar Aug 21 '25 16:08 delock

@Antlera from the logging with #7506 I observed the following:

  1. In steps with update, bwd_microstep: 1695.09 is longer. Is there explaination of longer bwd_microstep?
  2. optimizer_transmit_time is 470ms in update steps, is this time reasonable?
  3. step_microstep is ~750ms in update steps, are they supposed to overlap with bwd time?

So far I'm seeing similiar step time pattern between overlap update and not overlap update. The main difference is not overlap update step took longer time. Not sure if overlap update step can be even shorter. Image

delock avatar Aug 22 '25 04:08 delock

@delock Give me some time to check this part in more detail, I’ll get back with specifics. Could you share the exact command you used to run this? I’d like to reproduce it on my side.

Antlera avatar Aug 22 '25 04:08 Antlera

Hi @Antlera here is the command and config file I used:

deepspeed --bind_cores_to_rank --num_gpus=2 finetune_llama.py --model_name Qwen/Qwen2.5-3B --output_dir output --lr 2e-5 --batch_size 8 --deepspeed_config zf_config.json --num_train_epochs 1
{
    "train_batch_size": 8,
    "bf16": { "enabled": true },
    "zero_optimization": {
      "stage": 2,
      "offload_optimizer": {
        "device": "cpu",
        "pin_memory": true
      },
      "zenflow": {
            "topk_ratio": 0.1,
            "update_interval": 4,
            "full_warm_up_rounds": 0,
            "overlap_step": true
        }
    },
    "optimizer": {
      "type": "AdamW",
      "params": {
        "lr": 2e-5,
        "betas": [0.9, 0.999],
        "eps": 1e-8,
        "weight_decay": 0.01
      }
    },
    "gradient_accumulation_steps": 1,
    "gradient_clipping": 1.0,
    "zero_allow_untested_optimizer": true,
    "wall_clock_breakdown": true
}

delock avatar Aug 22 '25 05:08 delock

Benchmark on CPU binding methods

Image

Hi @delock. Thanks for comfirming the command. Please see my benchmark for CPU core binding and overhead breakdown. The overhead is shown as the delta value versus the command you suggested with ds_core_num=2 and the command deepspeed --bind_cores_to_rank.

Setup:

  • Model: Qwen-0.5B
  • GPUs: 2 × L4
  • CPU NUMA layout:
NUMA node(s): 2
NUMA node0 CPU(s): 0-15,32-47
NUMA node1 CPU(s): 16-31,48-63

Binding strategies (suffix -cX = ds_core_num=X):

  • bind_core_list_Adeepspeed --bind_cores_to_rank --bind_core_list "0-31,32-63"
    • Evenly splits CPUs but may cross NUMA nodes; leaves no cores for background work.
  • bind_core_list_Bdeepspeed --bind_cores_to_rank --bind_core_list "0-29,30-59"
    • Leaves 4 cores for background activities.
  • bind_core_list_Cdeepspeed --bind_cores_to_rank --bind_core_list "0-15,32-47"
    • Both ranks pinned to the same NUMA node.
  • bind_core_list_Ddeepspeed --bind_cores_to_rank --bind_core_list "0-15,16-31"
    • Each rank confined to one NUMA node.
  • bind_to_rank--bind_cores_to_rank without explicit --bind_core_list.

(Darker colors in the plot indicate more cores assigned to DeepSpeed part process.)

Quick takeaways:

  • Using plain --bind_cores_to_rank (blue) is a reasonable “lazy” default.
  • However, ds_core_num=2 is not always ideal — the forward and backward pass may slow down due to poor CPU assignment. Using ~half of the available cores (e.g., blue, c-10) flips the latency delta and performs better.
  • bind_core_list_D (seperate node for each rank) shows potential for the best performance, but it requires explicitly handling non-contiguous CPU numbering. With this setup, performance can surpass the naive binding strategy even though only half of the CPUs (i.e., "0-15,16-31") are used.

Antlera avatar Aug 25 '25 05:08 Antlera

Another quick comment for the potential bugs. The default --bind_to_rank implementation using numactl can be problematic for Slurm users, since they only have access rights to a subset of cores while numactl still reports the full node view. This causes the binding-to-rank logic to potentially assign processes to cores that are not actually accessible. For this reason, I suggest adding a Python-side fallback that detects the cores the user truly has access to and then applies a “soft” binding within that range, instead of relying solely on numactl.

Antlera avatar Aug 25 '25 05:08 Antlera

Hi @delock. For your logs and questions. The transmit throughput looks a bit slow here, only around 12 GB/s. From my side I usually see ~200 ms for this stage (for example, 3B model takes ~236 ms). Whether overlap is happening can be directly checked from optimizer_receive_params_time: if it is close to zero, then you are indeed fully overlapped. The microstep timing itself just measures the current step’s clock, but the required parameters may have already been updated earlier. So even if step_microstep looks long, it doesn’t necessarily mean the optimizer is blocking the bwd step. In this case, the binding seems to have introduced some unnecessary overhead in fwd/bwd, but the overlap itself looks fine. You can cross-check the details against my benchmark for reference.

Antlera avatar Aug 25 '25 06:08 Antlera

@Antlera Thanks for this very detailed analysis! It gives good suggestion on what should be default value. Maybe make ds_core_num bigger when there are aboundant number of cores would be a good idea.

About the question of 'fallback' mode, I think its a valid argument with the slurm example. A possible solution would be:

  1. Compare affinity between ranks, to make sure they are not binded by numactl. If all ranks have same affinity, it means numactl is not in action and a local 'soft' version needs to be in action.
  2. In the soft version, needs to figure out the set of physical cores within current affinity.
  3. Then shard the physical cores in current affinity as best effort, and reserve ds_core_num per worker.

delock avatar Aug 27 '25 06:08 delock

Thanks! I see optimizer_receive_params_time is very small, so in my case the optimizer time should be almost fully overlapped. Thanks for the explaination!

Hi @delock. For your logs and questions. The transmit throughput looks a bit slow here, only around 12 GB/s. From my side I usually see ~200 ms for this stage (for example, 3B model takes ~236 ms). Whether overlap is happening can be directly checked from optimizer_receive_params_time: if it is close to zero, then you are indeed fully overlapped. The microstep timing itself just measures the current step’s clock, but the required parameters may have already been updated earlier. So even if step_microstep looks long, it doesn’t necessarily mean the optimizer is blocking the bwd step. In this case, the binding seems to have introduced some unnecessary overhead in fwd/bwd, but the overlap itself looks fine. You can cross-check the details against my benchmark for reference.

delock avatar Aug 27 '25 06:08 delock

@Antlera Thanks for this very detailed analysis! It gives good suggestion on what should be default value. Maybe make ds_core_num bigger when there are aboundant number of cores would be a good idea.

About the question of 'fallback' mode, I think its a valid argument with the slurm example. A possible solution would be:

  1. Compare affinity between ranks, to make sure they are not binded by numactl. If all ranks have same affinity, it means numactl is not in action and a local 'soft' version needs to be in action.
  2. In the soft version, needs to figure out the set of physical cores within current affinity.
  3. Then shard the physical cores in current affinity as best effort, and reserve ds_core_num per worker.

I have implemented this soft fallback into (https://github.com/deepspeedai/DeepSpeed/pull/7506).

Also the default value of pt_reserved_cores is kept as 1 because I think the optimal value might be machine dependent. 10 is obviously not a good default value if each rank has < 10 cores. Since it is a field in config file, maybe it can be covered by autotuning. Need comments from @sfc-gh-truwase .

delock avatar Aug 27 '25 10:08 delock

@Antlera Thanks for this very detailed analysis! It gives good suggestion on what should be default value. Maybe make ds_core_num bigger when there are aboundant number of cores would be a good idea. About the question of 'fallback' mode, I think its a valid argument with the slurm example. A possible solution would be:

  1. Compare affinity between ranks, to make sure they are not binded by numactl. If all ranks have same affinity, it means numactl is not in action and a local 'soft' version needs to be in action.
  2. In the soft version, needs to figure out the set of physical cores within current affinity.
  3. Then shard the physical cores in current affinity as best effort, and reserve ds_core_num per worker.

I have implemented this soft fallback into (#7506).

Also the default value of pt_reserved_cores is kept as 1 because I think the optimal value might be machine dependent. 10 is obviously not a good default value if each rank has < 10 cores. Since it is a field in config file, maybe it can be covered by autotuning. Need comments from @sfc-gh-truwase .

Yes, it is machine-dependent. From my experiments, reserving about half of the cores works best, though I’m not entirely sure of the underlying reason. Need comments from @sfc-gh-truwase .

Antlera avatar Aug 28 '25 13:08 Antlera

@delock Thanks for implementing the soft fallback in (#7506). I’ll run a quick test on it soon.

Antlera avatar Aug 28 '25 13:08 Antlera

Yes, it is machine-dependent. From my experiments, reserving about half of the cores works best, though I’m not entirely sure of the underlying reason. Need comments from @sfc-gh-truwase .

Thanks for the great discussion, guys. I have a few thoughts below. It might be worth finding time to discuss.

  1. Yes, I think auto-tuning could be useful here.
  2. If we agree on auto-tuning, then we need to decide whether to revive the DeepSpeed Auto-tuning feature for this or create a simple standalone sweep script (e.g., ds_nvme_tune) that users can run to derive the optimal pt_reserved_cores value for their environment.
  3. Although DeepSpeed Auto-tuning has the benefit of e2e automation it has not been maintained recently and may require substantial effort to revive. On the other hand, a standalone approach as used in DeepNVMe will be quicker to implement but requires manual effort to run with the discovered optimal settings.
  4. I will let you guys decide between DeepSpeed Auto-tuning or standalone script.
  5. Did we consider whether pt_reserved_cores should be float type as opposed to integer? This will make specifying fractional reservation of total cores more portable.

sfc-gh-truwase avatar Aug 28 '25 15:08 sfc-gh-truwase

@delock @sfc-gh-truwase Some thoughts on the auto-tuning feature. Personally, I’d lean toward a simple script that runs a dummy model to stress the CPU side. Since the main goal is just to detect whether forward/backward is delayed by CPU bottlenecks, this lightweight approach should be sufficient. We could also combine this with @delock’s earlier CPUAdam benchmark — for very large models, the tuning itself might otherwise become relatively costly.

Antlera avatar Aug 28 '25 21:08 Antlera

@delock Did a very quick test in the slurm setting. It looks like the current soft fallback still has issues under Slurm. For example, I requested 32 CPU cores, but the actual affinity was something like:

taskset -cp $$
pid 249296's current affinity list: 0,4-11,21-26,30-46

In this case, running without --bind_cores_to_rank leads to CPU memory allocation errors in cpu_adam during allocation.

Starting epoch 1/3
Zenflow: clearing selective optimizer states...
Zenflow: clearing selective optimizer states...
[2025-08-28 17:26:36,020] [INFO] [logging.py:107:log_dist] [Rank 0] time (ms) | selective_optimizer_update: 63.94 | selective_optimizer_process: 282.85 | selective_optimizer_sync: 0.25
[2025-08-28 17:26:36,044] [INFO] [logging.py:107:log_dist] [Rank 0] time (ms) | fwd_microstep: 279.88 | bwd_microstep: 3919.57 | bwd_inner_microstep: 111.57 | bwd_allreduce_microstep: 3802.59 | step_microstep: 0.00
Step 1, Loss: 14.3255, Time: 4201ms
[2025-08-28 17:26:37,876] [INFO] [logging.py:107:log_dist] [Rank 0] time (ms) | selective_optimizer_update: 0.00 | selective_optimizer_process: 31.39 | selective_optimizer_sync: 0.00
[2025-08-28 17:26:37,903] [INFO] [logging.py:107:log_dist] [Rank 0] time (ms) | fwd_microstep: 45.77 | bwd_microstep: 1811.39 | bwd_inner_microstep: 99.24 | bwd_allreduce_microstep: 1712.13 | step_microstep: 0.00
Step 2, Loss: 11.8107, Time: 1858ms
[2025-08-28 17:26:38,140] [INFO] [logging.py:107:log_dist] [Rank 0] time (ms) | selective_optimizer_update: 0.00 | selective_optimizer_process: 26.84 | selective_optimizer_sync: 0.00
[2025-08-28 17:26:38,167] [INFO] [logging.py:107:log_dist] [Rank 0] time (ms) | fwd_microstep: 44.31 | bwd_microstep: 219.24 | bwd_inner_microstep: 99.55 | bwd_allreduce_microstep: 119.68 | step_microstep: 0.00
Step 3, Loss: 8.7644, Time: 264ms
Setting pytorch affinity to [0], OMP_NUM_THREADS=1
Setting pytorch affinity to [31], OMP_NUM_THREADS=1
[2025-08-28 17:26:41,967] [INFO] [real_accelerator.py:260:get_accelerator] Setting ds_accelerator to cuda (auto detect)
[2025-08-28 17:26:42,015] [INFO] [real_accelerator.py:260:get_accelerator] Setting ds_accelerator to cuda (auto detect)
[2025-08-28 17:26:42,102] [INFO] [real_accelerator.py:260:get_accelerator] Setting ds_accelerator to cuda (auto detect)
[2025-08-28 17:26:42,143] [INFO] [real_accelerator.py:260:get_accelerator] Setting ds_accelerator to cuda (auto detect)
[2025-08-28 17:26:44,664] [INFO] [logging.py:107:log_dist] [Rank -1] [TorchCheckpointEngine] Initialized with serialization = False
[2025-08-28 17:26:45,167] [INFO] [logging.py:107:log_dist] [Rank -1] [TorchCheckpointEngine] Initialized with serialization = False
Using /sfs/gpfs/tardis/home/erc8gx/.cache/torch_extensions/py310_cu118 as PyTorch extensions root...
Detected CUDA files, patching ldflags
Emitting ninja build file /sfs/gpfs/tardis/home/erc8gx/.cache/torch_extensions/py310_cu118/cpu_adam/build.ninja...
/scratch/erc8gx/devtools/miniconda3/envs/ds_a100/lib/python3.10/site-packages/torch/utils/cpp_extension.py:1964: UserWarning: TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation. 
If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].
  warnings.warn(
Building extension module cpu_adam...
Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)
Setting zenflow optimizer affinity to [4, 5, 6, 7, 8, 9, 10, 11, 21, 22, 23, 24, 25, 26, 30], OMP_NUM_THREADS=15
ninja: no work to do.
Loading extension module cpu_adam...
[2025-08-28 17:26:47,683] [INFO] [logging.py:107:log_dist] [Rank 0] time (ms) | selective_optimizer_update: 0.00 | selective_optimizer_process: 12.22 | selective_optimizer_step: 58.96 | selective_optimizer_sync: 0.00
[2025-08-28 17:26:47,684] [INFO] [logging.py:107:log_dist] [Rank 0] time (ms) | fwd_microstep: 44.49 | bwd_microstep: 310.61 | bwd_inner_microstep: 98.63 | bwd_allreduce_microstep: 211.97 | step_microstep: 9160.24
[2025-08-28 17:26:47,684] [INFO] [logging.py:107:log_dist] [Rank 0] time (ms) | fwd: 414.37 | bwd: 6260.81 | bwd_inner: 408.97 | bwd_allreduce: 5846.36 | step: 9160.27
Step 4, Loss: 5.6291, Time: 9516ms
Uncaught exception in compile_worker subprocess
Traceback (most recent call last):
  File "/scratch/erc8gx/devtools/miniconda3/envs/ds_a100/lib/python3.10/site-packages/torch/_inductor/compile_worker/__main__.py", line 41, in main
    SubprocMain(args.workers, read_fd, write_fd).main()
  File "/scratch/erc8gx/devtools/miniconda3/envs/ds_a100/lib/python3.10/site-packages/torch/_inductor/compile_worker/subproc_pool.py", line 204, in __init__
    self.pool = self._new_pool(nprocs, True)
  File "/scratch/erc8gx/devtools/miniconda3/envs/ds_a100/lib/python3.10/site-packages/torch/_inductor/compile_worker/subproc_pool.py", line 215, in _new_pool
    _warm_process_pool(pool, nprocs)
  File "/scratch/erc8gx/devtools/miniconda3/envs/ds_a100/lib/python3.10/site-packages/torch/_inductor/compile_worker/subproc_pool.py", line 304, in _warm_process_pool
    pool._adjust_process_count()
  File "/scratch/erc8gx/devtools/miniconda3/envs/ds_a100/lib/python3.10/concurrent/futures/process.py", line 697, in _adjust_process_count
    self._spawn_process()
  File "/scratch/erc8gx/devtools/miniconda3/envs/ds_a100/lib/python3.10/concurrent/futures/process.py", line 714, in _spawn_process
    p.start()
  File "/scratch/erc8gx/devtools/miniconda3/envs/ds_a100/lib/python3.10/multiprocessing/process.py", line 121, in start
    self._popen = self._Popen(self)
  File "/scratch/erc8gx/devtools/miniconda3/envs/ds_a100/lib/python3.10/multiprocessing/context.py", line 281, in _Popen
    return Popen(process_obj)
  File "/scratch/erc8gx/devtools/miniconda3/envs/ds_a100/lib/python3.10/multiprocessing/popen_fork.py", line 19, in __init__
    self._launch(process_obj)
  File "/scratch/erc8gx/devtools/miniconda3/envs/ds_a100/lib/python3.10/multiprocessing/popen_fork.py", line 66, in _launch
    self.pid = os.fork()
OSError: [Errno 12] Cannot allocate memory
Process SpawnProcess-5:
Traceback (most recent call last):
  File "/scratch/erc8gx/devtools/miniconda3/envs/ds_a100/lib/python3.10/multiprocessing/process.py", line 314, in _bootstrap
    self.run()
  File "/scratch/erc8gx/devtools/miniconda3/envs/ds_a100/lib/python3.10/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/sfs/weka/scratch/erc8gx/codespace/DeepSpeed/deepspeed/runtime/zenflow/zenflow_stage_1_and_2.py", line 683, in zenflow_optimizer_process
    optimizer.step(step_id=micro_step + 1, now_state=now_state, group_info=group_info)
  File "/scratch/erc8gx/devtools/miniconda3/envs/ds_a100/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/sfs/weka/scratch/erc8gx/codespace/DeepSpeed/deepspeed/ops/adam/zenflow_cpu_adam.py", line 137, in _parallel_step
    p.stale_param.data.copy_(p.data.clone())
RuntimeError: [enforce fail at alloc_cpu.cpp:117] err == 0. DefaultCPUAllocator: can't allocate memory: you tried to allocate 988065536 bytes. Error code 12 (Cannot allocate memory)
Time to load cpu_adam op: 0.3280184268951416 seconds
ZenFlowCPUAdam initialized with overlap step.
Adam Optimizer #0 is created with AVX2 arithmetic capability.
Config: alpha=0.001000, betas=(0.900000, 0.999000), weight_decay=0.000000, adam_w=1
Using /sfs/gpfs/tardis/home/erc8gx/.cache/torch_extensions/py310_cu118 as PyTorch extensions root...
Detected CUDA files, patching ldflags
Emitting ninja build file /sfs/gpfs/tardis/home/erc8gx/.cache/torch_extensions/py310_cu118/cpu_adam/build.ninja...
/scratch/erc8gx/devtools/miniconda3/envs/ds_a100/lib/python3.10/site-packages/torch/utils/cpp_extension.py:1964: UserWarning: TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation. 
If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].
  warnings.warn(
Building extension module cpu_adam...
Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)
Setting zenflow optimizer affinity to [32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46], OMP_NUM_THREADS=15
ninja: no work to do.
Loading extension module cpu_adam...
[2025-08-28 17:26:49,020] [INFO] [logging.py:107:log_dist] [Rank 0] time (ms) | selective_optimizer_update: 0.00 | selective_optimizer_process: 27.96 | selective_optimizer_step: 0.00 | selective_optimizer_sync: 0.00
[2025-08-28 17:26:49,071] [INFO] [logging.py:107:log_dist] [Rank 0] time (ms) | fwd_microstep: 46.96 | bwd_microstep: 1338.52 | bwd_inner_microstep: 98.93 | bwd_allreduce_microstep: 1239.57 | step_microstep: 0.00
Step 5, Loss: 2.6151, Time: 1386ms
Process SpawnProcess-5:
Traceback (most recent call last):
  File "/scratch/erc8gx/devtools/miniconda3/envs/ds_a100/lib/python3.10/multiprocessing/process.py", line 314, in _bootstrap
    self.run()
  File "/scratch/erc8gx/devtools/miniconda3/envs/ds_a100/lib/python3.10/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/sfs/weka/scratch/erc8gx/codespace/DeepSpeed/deepspeed/runtime/zenflow/zenflow_stage_1_and_2.py", line 683, in zenflow_optimizer_process
    optimizer.step(step_id=micro_step + 1, now_state=now_state, group_info=group_info)
  File "/scratch/erc8gx/devtools/miniconda3/envs/ds_a100/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/sfs/weka/scratch/erc8gx/codespace/DeepSpeed/deepspeed/ops/adam/zenflow_cpu_adam.py", line 137, in _parallel_step
    p.stale_param.data.copy_(p.data.clone())
RuntimeError: [enforce fail at alloc_cpu.cpp:117] err == 0. DefaultCPUAllocator: can't allocate memory: you tried to allocate 988065536 bytes. Error code 12 (Cannot allocate memory)

However, if I run with:

deepspeed --bind_cores_to_rank --bind_core_list '4-11,21-26'

it works fine.

The root cause seems to be that the fallback logic assumes a contiguous core range, while Slurm often provides a sparse/non-contiguous affinity list. Two possible directions:

  1. At minimum, I should add documentation guidance that under Slurm, users should explicitly set --bind_cores_to_rank (and optionally --bind_core_list) to align with Slurm’s CPU allocation by using taskset -cp $$.
  2. Try the old-fashioned fallback to split based on the affinity list returned by psutil.Process().cpu_affinity(). This works fine on my slurm environment although experience some performance downgrade.

Antlera avatar Aug 28 '25 22:08 Antlera

Maybe for the fallback case it would be safer to base the core split on the CPUs visible to the current process, (e.g. num_cores = len(psutil.Process().cpu_affinity()) instead of relying on current cpu_count(logical=False). Another option could be to take the intersection of the physical cores and the process-visible cores. That way we ensure the fallback logic stays safe under Slurm or other limited environments.

Antlera avatar Aug 28 '25 23:08 Antlera

@delock @sfc-gh-truwase Some thoughts on the auto-tuning feature. Personally, I’d lean toward a simple script that runs a dummy model to stress the CPU side. Since the main goal is just to detect whether forward/backward is delayed by CPU bottlenecks, this lightweight approach should be sufficient. We could also combine this with @delock’s earlier CPUAdam benchmark — for very large models, the tuning itself might otherwise become relatively costly.

I agree with a seperate simple script. A probable huristic would be start from leaving 1 core for each ZenFlow worker, then gradually increase this number until no improvement to ZenFlow optimizer performance. In this way we might ensure there is enough cores for ZenFlow optimizer, and give the rest cores to pt_reserved_cores.

The rational of above approach is Adam optimizer is quite memory bounded, so there is one point that there is enough CPU cores to utilize the memory bandwidth. Adding more CPU cores after this point would only increase OpenMP synchronization cost. Let me know your thoughts. We might investigate huristic in a seperate work.

delock avatar Aug 29 '25 02:08 delock

@delock Did a very quick test in the slurm setting. It looks like the current soft fallback still has issues under Slurm. For example, I requested 32 CPU cores, but the actual affinity was something like:

taskset -cp $$
pid 249296's current affinity list: 0,4-11,21-26,30-46

In this case, running without --bind_cores_to_rank leads to CPU memory allocation errors in cpu_adam during allocation.

Starting epoch 1/3 Zenflow: clearing selective optimizer states... Zenflow: clearing selective optimizer states... [2025-08-28 17:26:36,020] [INFO] [logging.py:107:log_dist] [Rank 0] time (ms) | selective_optimizer_update: 63.94 | selective_optimizer_process: 282.85 | selective_optimizer_sync: 0.25 [2025-08-28 17:26:36,044] [INFO] [logging.py:107:log_dist] [Rank 0] time (ms) | fwd_microstep: 279.88 | bwd_microstep: 3919.57 | bwd_inner_microstep: 111.57 | bwd_allreduce_microstep: 3802.59 | step_microstep: 0.00 Step 1, Loss: 14.3255, Time: 4201ms [2025-08-28 17:26:37,876] [INFO] [logging.py:107:log_dist] [Rank 0] time (ms) | selective_optimizer_update: 0.00 | selective_optimizer_process: 31.39 | selective_optimizer_sync: 0.00 [2025-08-28 17:26:37,903] [INFO] [logging.py:107:log_dist] [Rank 0] time (ms) | fwd_microstep: 45.77 | bwd_microstep: 1811.39 | bwd_inner_microstep: 99.24 | bwd_allreduce_microstep: 1712.13 | step_microstep: 0.00 Step 2, Loss: 11.8107, Time: 1858ms [2025-08-28 17:26:38,140] [INFO] [logging.py:107:log_dist] [Rank 0] time (ms) | selective_optimizer_update: 0.00 | selective_optimizer_process: 26.84 | selective_optimizer_sync: 0.00 [2025-08-28 17:26:38,167] [INFO] [logging.py:107:log_dist] [Rank 0] time (ms) | fwd_microstep: 44.31 | bwd_microstep: 219.24 | bwd_inner_microstep: 99.55 | bwd_allreduce_microstep: 119.68 | step_microstep: 0.00 Step 3, Loss: 8.7644, Time: 264ms Setting pytorch affinity to [0], OMP_NUM_THREADS=1 Setting pytorch affinity to [31], OMP_NUM_THREADS=1 [2025-08-28 17:26:41,967] [INFO] [real_accelerator.py:260:get_accelerator] Setting ds_accelerator to cuda (auto detect) [2025-08-28 17:26:42,015] [INFO] [real_accelerator.py:260:get_accelerator] Setting ds_accelerator to cuda (auto detect) [2025-08-28 17:26:42,102] [INFO] [real_accelerator.py:260:get_accelerator] Setting ds_accelerator to cuda (auto detect) [2025-08-28 17:26:42,143] [INFO] [real_accelerator.py:260:get_accelerator] Setting ds_accelerator to cuda (auto detect) [2025-08-28 17:26:44,664] [INFO] [logging.py:107:log_dist] [Rank -1] [TorchCheckpointEngine] Initialized with serialization = False [2025-08-28 17:26:45,167] [INFO] [logging.py:107:log_dist] [Rank -1] [TorchCheckpointEngine] Initialized with serialization = False Using /sfs/gpfs/tardis/home/erc8gx/.cache/torch_extensions/py310_cu118 as PyTorch extensions root... Detected CUDA files, patching ldflags Emitting ninja build file /sfs/gpfs/tardis/home/erc8gx/.cache/torch_extensions/py310_cu118/cpu_adam/build.ninja... /scratch/erc8gx/devtools/miniconda3/envs/ds_a100/lib/python3.10/site-packages/torch/utils/cpp_extension.py:1964: UserWarning: TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation. If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST']. warnings.warn( Building extension module cpu_adam... Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N) Setting zenflow optimizer affinity to [4, 5, 6, 7, 8, 9, 10, 11, 21, 22, 23, 24, 25, 26, 30], OMP_NUM_THREADS=15 ninja: no work to do. Loading extension module cpu_adam... [2025-08-28 17:26:47,683] [INFO] [logging.py:107:log_dist] [Rank 0] time (ms) | selective_optimizer_update: 0.00 | selective_optimizer_process: 12.22 | selective_optimizer_step: 58.96 | selective_optimizer_sync: 0.00 [2025-08-28 17:26:47,684] [INFO] [logging.py:107:log_dist] [Rank 0] time (ms) | fwd_microstep: 44.49 | bwd_microstep: 310.61 | bwd_inner_microstep: 98.63 | bwd_allreduce_microstep: 211.97 | step_microstep: 9160.24 [2025-08-28 17:26:47,684] [INFO] [logging.py:107:log_dist] [Rank 0] time (ms) | fwd: 414.37 | bwd: 6260.81 | bwd_inner: 408.97 | bwd_allreduce: 5846.36 | step: 9160.27 Step 4, Loss: 5.6291, Time: 9516ms Uncaught exception in compile_worker subprocess Traceback (most recent call last): File "/scratch/erc8gx/devtools/miniconda3/envs/ds_a100/lib/python3.10/site-packages/torch/_inductor/compile_worker/main.py", line 41, in main SubprocMain(args.workers, read_fd, write_fd).main() File "/scratch/erc8gx/devtools/miniconda3/envs/ds_a100/lib/python3.10/site-packages/torch/_inductor/compile_worker/subproc_pool.py", line 204, in init self.pool = self._new_pool(nprocs, True) File "/scratch/erc8gx/devtools/miniconda3/envs/ds_a100/lib/python3.10/site-packages/torch/_inductor/compile_worker/subproc_pool.py", line 215, in _new_pool _warm_process_pool(pool, nprocs) File "/scratch/erc8gx/devtools/miniconda3/envs/ds_a100/lib/python3.10/site-packages/torch/_inductor/compile_worker/subproc_pool.py", line 304, in _warm_process_pool pool._adjust_process_count() File "/scratch/erc8gx/devtools/miniconda3/envs/ds_a100/lib/python3.10/concurrent/futures/process.py", line 697, in _adjust_process_count self._spawn_process() File "/scratch/erc8gx/devtools/miniconda3/envs/ds_a100/lib/python3.10/concurrent/futures/process.py", line 714, in _spawn_process p.start() File "/scratch/erc8gx/devtools/miniconda3/envs/ds_a100/lib/python3.10/multiprocessing/process.py", line 121, in start self._popen = self._Popen(self) File "/scratch/erc8gx/devtools/miniconda3/envs/ds_a100/lib/python3.10/multiprocessing/context.py", line 281, in _Popen return Popen(process_obj) File "/scratch/erc8gx/devtools/miniconda3/envs/ds_a100/lib/python3.10/multiprocessing/popen_fork.py", line 19, in init self._launch(process_obj) File "/scratch/erc8gx/devtools/miniconda3/envs/ds_a100/lib/python3.10/multiprocessing/popen_fork.py", line 66, in _launch self.pid = os.fork() OSError: [Errno 12] Cannot allocate memory Process SpawnProcess-5: Traceback (most recent call last): File "/scratch/erc8gx/devtools/miniconda3/envs/ds_a100/lib/python3.10/multiprocessing/process.py", line 314, in _bootstrap self.run() File "/scratch/erc8gx/devtools/miniconda3/envs/ds_a100/lib/python3.10/multiprocessing/process.py", line 108, in run self._target(*self._args, **self._kwargs) File "/sfs/weka/scratch/erc8gx/codespace/DeepSpeed/deepspeed/runtime/zenflow/zenflow_stage_1_and_2.py", line 683, in zenflow_optimizer_process optimizer.step(step_id=micro_step + 1, now_state=now_state, group_info=group_info) File "/scratch/erc8gx/devtools/miniconda3/envs/ds_a100/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context return func(*args, **kwargs) File "/sfs/weka/scratch/erc8gx/codespace/DeepSpeed/deepspeed/ops/adam/zenflow_cpu_adam.py", line 137, in parallel_step p.stale_param.data.copy(p.data.clone()) RuntimeError: [enforce fail at alloc_cpu.cpp:117] err == 0. DefaultCPUAllocator: can't allocate memory: you tried to allocate 988065536 bytes. Error code 12 (Cannot allocate memory) Time to load cpu_adam op: 0.3280184268951416 seconds ZenFlowCPUAdam initialized with overlap step. Adam Optimizer #0 is created with AVX2 arithmetic capability. Config: alpha=0.001000, betas=(0.900000, 0.999000), weight_decay=0.000000, adam_w=1 Using /sfs/gpfs/tardis/home/erc8gx/.cache/torch_extensions/py310_cu118 as PyTorch extensions root... Detected CUDA files, patching ldflags Emitting ninja build file /sfs/gpfs/tardis/home/erc8gx/.cache/torch_extensions/py310_cu118/cpu_adam/build.ninja... /scratch/erc8gx/devtools/miniconda3/envs/ds_a100/lib/python3.10/site-packages/torch/utils/cpp_extension.py:1964: UserWarning: TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation. If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST']. warnings.warn( Building extension module cpu_adam... Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N) Setting zenflow optimizer affinity to [32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46], OMP_NUM_THREADS=15 ninja: no work to do. Loading extension module cpu_adam... [2025-08-28 17:26:49,020] [INFO] [logging.py:107:log_dist] [Rank 0] time (ms) | selective_optimizer_update: 0.00 | selective_optimizer_process: 27.96 | selective_optimizer_step: 0.00 | selective_optimizer_sync: 0.00 [2025-08-28 17:26:49,071] [INFO] [logging.py:107:log_dist] [Rank 0] time (ms) | fwd_microstep: 46.96 | bwd_microstep: 1338.52 | bwd_inner_microstep: 98.93 | bwd_allreduce_microstep: 1239.57 | step_microstep: 0.00 Step 5, Loss: 2.6151, Time: 1386ms Process SpawnProcess-5: Traceback (most recent call last): File "/scratch/erc8gx/devtools/miniconda3/envs/ds_a100/lib/python3.10/multiprocessing/process.py", line 314, in _bootstrap self.run() File "/scratch/erc8gx/devtools/miniconda3/envs/ds_a100/lib/python3.10/multiprocessing/process.py", line 108, in run self._target(*self._args, **self._kwargs) File "/sfs/weka/scratch/erc8gx/codespace/DeepSpeed/deepspeed/runtime/zenflow/zenflow_stage_1_and_2.py", line 683, in zenflow_optimizer_process optimizer.step(step_id=micro_step + 1, now_state=now_state, group_info=group_info) File "/scratch/erc8gx/devtools/miniconda3/envs/ds_a100/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context return func(*args, **kwargs) File "/sfs/weka/scratch/erc8gx/codespace/DeepSpeed/deepspeed/ops/adam/zenflow_cpu_adam.py", line 137, in parallel_step p.stale_param.data.copy(p.data.clone()) RuntimeError: [enforce fail at alloc_cpu.cpp:117] err == 0. DefaultCPUAllocator: can't allocate memory: you tried to allocate 988065536 bytes. Error code 12 (Cannot allocate memory) However, if I run with:

deepspeed --bind_cores_to_rank --bind_core_list '4-11,21-26'

it works fine.

The root cause seems to be that the fallback logic assumes a contiguous core range, while Slurm often provides a sparse/non-contiguous affinity list. Two possible directions:

  1. At minimum, I should add documentation guidance that under Slurm, users should explicitly set --bind_cores_to_rank (and optionally --bind_core_list) to align with Slurm’s CPU allocation by using taskset -cp $$.
  2. Try the old-fashioned fallback to split based on the affinity list returned by psutil.Process().cpu_affinity(). This works fine on my slurm environment although experience some performance downgrade.

Hi @Antlera from the log we see

Setting pytorch affinity to [0], OMP_NUM_THREADS=1
Setting pytorch affinity to [31], OMP_NUM_THREADS=1
Setting zenflow optimizer affinity to [4, 5, 6, 7, 8, 9, 10, 11, 21, 22, 23, 24, 25, 26, 30], OMP_NUM_THREADS=15
Setting zenflow optimizer affinity to [32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46], OMP_NUM_THREADS=15

This means the fallback mechanism didn't assume contiguous core range. Also the code under this if statement (https://github.com/deepspeedai/DeepSpeed/pull/7506/files#diff-c36f7081aa96d0cca5833845684635eddfe454beb4a76c826bb3b91d1ac38401R811) ensures the cores being binded is interestion of current affinity and physical cores.

It is strange that this error would happen. What is the command you are using to reproduce memory allocation error?

delock avatar Aug 29 '25 02:08 delock

@delock I used deepspeed --num_gpus=$GPUS_PER_NODE --master_port $MASTER_PORT finetune_llama.py. Let me double check I am at the right branch head.

Antlera avatar Aug 29 '25 03:08 Antlera

I am currently at commit 744399e Merge branch 'master' into gma/zenflow_affinity.

@delock I used deepspeed --num_gpus=$GPUS_PER_NODE --master_port $MASTER_PORT finetune_llama.py. Let me double check I am at the right branch head.

Antlera avatar Aug 29 '25 03:08 Antlera

I tried to emulate this situation by the following command:

taskset -c 0,4-11,21-26,30-46 deepspeed --num_gpus=2 finetune_llama.py --model_name Qwen/Qwen2.5-3B --output_dir output --lr 2e-5 --batch_size 8 --deepspeed_config zf_config.json --num_train_epochs 1

With this command I didn't get memory allocation error during CPUAdam. Log shows zenflow optimizer affinity is the same as the slurm case:

Setting zenflow optimizer affinity to [32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46], OMP_NUM_THREADS=15
Setting zenflow optimizer affinity to [4, 5, 6, 7, 8, 9, 10, 11, 21, 22, 23, 24, 25, 26, 30], OMP_NUM_THREADS=15

From htop I see the busy cores are exactly the same as taskset specified, so taskset should be effective in simulating.

This indicates the memory allocation error might have other reason. Did slurm restrict memory usage? Also if you set --bind_core_list to 0,4-11,21-26,30-46 do you still see the workload passed? One thing might be worth noting is the comma in bind_core_list does not indicate how cores are seperated among workers. This parameter list all the cores avilable to all workers and deepspeed would shard them evenly among workers.

delock avatar Aug 29 '25 03:08 delock

Hi @delock, thanks for validating and confirming this. I re-ran some tests. A little weird but after all it works — in my SLURM environment, I found that by default (after the allocation of slurm job) our code always starts allocating CPU cores from 0, regardless of what SLURM has actually assigned. That’s why it fails directly in the default case.

But once I add taskset, the CPU view gets overridden immediately, and then the core allocation works correctly — even if the ranges are non-contiguous. What’s more, taskset doesn’t need to match the “real” SLURM-assigned range exactly:

For example, in my SLURM job:

grep Cpus_allowed_list /proc/$$/status
Cpus_allowed_list:      64,97-127

But if I run with:

taskset -c 80-90 deepspeed ...
# or simply
taskset deepspeed ...

the actual binding still falls back into 64,97-127 (the SLURM-assigned range), ignoring the specific range I asked for. In other words, simply invoking taskset refreshes the affinity mask, and DeepSpeed will then shard cores correctly inside the allowed set.

The current code actually works perfectly once taskset is given. So maybe the right approach for now is to add this into the guidance docs (and maybe as a runtime hint in the code): if users see affinity/memory errors under SLURM, suggest trying taskset first.

Antlera avatar Sep 02 '25 03:09 Antlera

Thanks for the detail @Antlera . Let me read slurm docs to see if I can see any clue. If not, then lets add taskset as a pratical hint.

delock avatar Sep 03 '25 03:09 delock