DeepSpeed icon indicating copy to clipboard operation
DeepSpeed copied to clipboard

[BUG] Receiving CUDA error: invalid argument using pytorch 2.7 with deepspeed 0.16.4 with Cuda 12.8

Open rpgmaker opened this issue 9 months ago • 58 comments

Describe the bug I am currently not able to run deepspeed latest version (0.16.4) with cuda 12.8 using pytorch 2.7. I am receiving the following error stack: GPU: 3090 TI FE

[rank0]: RuntimeError: CUDA error: invalid argument
[rank0]: CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
[rank0]: For debugging consider passing CUDA_LAUNCH_BLOCKING=1
[rank0]: Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

To Reproduce

Model Name: deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B

Use deepspeed config:

{
    "train_batch_size": 1,
    "gradient_accumulation_steps": 1,
    "optimizer": {
        "type": "AdamW",
        "params": {
            "lr": 2e-5
        }
    },
    "zero_optimization": {
        "stage": 2,
        "offload_optimizer": {
            "device": "cpu",
            "pin_memory": true
        },
        "allgather_partitions": true,
        "allgather_bucket_size": 2e8,
        "overlap_comm": true,
        "reduce_scatter": true,
        "reduce_bucket_size": 2e8,
        "contiguous_gradients": true
    },
    "fp16": {
        "enabled": true
    }
}

Expected behavior I expect it to be able to run and allow me to train the model without getting CUDA error: invalid argument

ds_report output

[2025-03-18 20:48:02,997] [INFO] [real_accelerator.py:222:get_accelerator] Setting ds_accelerator to cuda (auto detect)
--------------------------------------------------
DeepSpeed C++/CUDA extension op report
--------------------------------------------------
NOTE: Ops not installed will be just-in-time (JIT) compiled at
      runtime if needed. Op compatibility means that your system
      meet the required dependencies to JIT install the op.
--------------------------------------------------
JIT compiled ops requires ninja
ninja .................. [OKAY]
--------------------------------------------------
op name ................ installed .. compatible
--------------------------------------------------
async_io ............... [NO] ....... [OKAY]
fused_adam ............. [NO] ....... [OKAY]
cpu_adam ............... [NO] ....... [OKAY]
cpu_adagrad ............ [NO] ....... [OKAY]
cpu_lion ............... [NO] ....... [OKAY]
 [WARNING]  Please specify the CUTLASS repo directory as environment variable $CUTLASS_PATH
evoformer_attn ......... [NO] ....... [NO]
 [WARNING]  FP Quantizer is using an untested triton version (3.2.0), only 2.3.(0, 1) and 3.0.0 are known to be compatible with these kernels
fp_quantizer ........... [NO] ....... [NO]
fused_lamb ............. [NO] ....... [OKAY]
fused_lion ............. [NO] ....... [OKAY]
gds .................... [NO] ....... [OKAY]
transformer_inference .. [NO] ....... [OKAY]
inference_core_ops ..... [NO] ....... [OKAY]
cutlass_ops ............ [NO] ....... [OKAY]
quantizer .............. [NO] ....... [OKAY]
ragged_device_ops ...... [NO] ....... [OKAY]
ragged_ops ............. [NO] ....... [OKAY]
random_ltd ............. [NO] ....... [OKAY]
 [WARNING]  sparse_attn requires a torch version >= 1.5 and < 2.0 but detected 2.7
 [WARNING]  using untested triton version (3.2.0), only 1.0.0 is known to be compatible
sparse_attn ............ [NO] ....... [NO]
spatial_inference ...... [NO] ....... [OKAY]
transformer ............ [NO] ....... [OKAY]
stochastic_transformer . [NO] ....... [OKAY]
--------------------------------------------------
DeepSpeed general environment info:
torch install path ............... ['/xxxx/lib/python3.12/site-packages/torch']
torch version .................... 2.7.0.dev20250309+cu128
deepspeed install path ........... ['/xxxx/lib/python3.12/site-packages/deepspeed']
deepspeed info ................... 0.16.4, unknown, unknown
torch cuda version ............... 12.8
torch hip version ................ None
nvcc version ..................... 12.8
deepspeed wheel compiled w. ...... torch 0.0, cuda 0.0
shared memory (/dev/shm) size .... 61.66 GB

Screenshots N/A

System info (please complete the following information):

  • OS: Ubuntu 24.10
  • GPU count and types: 3090 TI FE 1
  • Interconnects : N/A
  • Python version: 3.12.7

Launcher context Are you launching your experiment with the deepspeed launcher, MPI, or something else? No

Docker context Are you using a specific docker image that you can share? No

Additional context N/A

rpgmaker avatar Mar 19 '25 03:03 rpgmaker

Hi @rpgmaker , can you shared the full error message you get to help us identify the cause? Thanks!

hwchen2017 avatar Mar 19 '25 21:03 hwchen2017

Sure, here is the full stacktrace

[2025-03-19 14:54:53,885] [INFO] [stage_1_and_2.py:152:__init__] Round robin gradient partitioning: False
[rank0]: Traceback (most recent call last):
[rank0]:   File "/xxxxxx/main.py", line 41, in <module>
[rank0]:     run()
[rank0]:   File "/xxxxxx/main.py", line 36, in run
[rank0]:     setup_runner_qwen_1_5B()
[rank0]:   File "/xxxxxx/main.py", line 15, in setup_runner_qwen_1_5B
[rank0]:     runner.run(epoch=50)
[rank0]:   File "/xxxxxx/runner.py", line 47, in run
[rank0]:     self.trainer.train(epoch)
[rank0]:   File "/xxxxxx/trainer/deepspeed/deepspeed_trainer.py", line 39, in train
[rank0]:     self.__setup()
[rank0]:   File "/xxxxxx/trainer/deepspeed/deepspeed_trainer.py", line 73, in __setup
[rank0]:     model, optimizer, _, _ = deepspeed.initialize(
[rank0]:                              ^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "xxxxxx/lib/python3.12/site-packages/deepspeed/__init__.py", line 193, in initialize
[rank0]:     engine = DeepSpeedEngine(args=args,
[rank0]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "xxxxxx/lib/python3.12/site-packages/deepspeed/runtime/engine.py", line 317, in __init__
[rank0]:     self._configure_optimizer(optimizer, model_parameters)
[rank0]:   File "xxxxxx/lib/python3.12/site-packages/deepspeed/runtime/engine.py", line 1385, in _configure_optimizer
[rank0]:     self.optimizer = self._configure_zero_optimizer(basic_optimizer)
[rank0]:                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "xxxxxx/lib/python3.12/site-packages/deepspeed/runtime/engine.py", line 1643, in _configure_zero_optimizer
[rank0]:     optimizer = DeepSpeedZeroOptimizer(
[rank0]:                 ^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "xxxxxx/lib/python3.12/site-packages/deepspeed/runtime/zero/stage_1_and_2.py", line 407, in __init__
[rank0]:     weights_partition = get_accelerator().pin_memory(weights_partition)
[rank0]:                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "xxxxxx/lib/python3.12/site-packages/deepspeed/accelerator/cuda_accelerator.py", line 293, in pin_memory
[rank0]:     return tensor.pin_memory()
[rank0]:            ^^^^^^^^^^^^^^^^^^^
[rank0]: RuntimeError: CUDA error: invalid argument
[rank0]: CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
[rank0]: For debugging consider passing CUDA_LAUNCH_BLOCKING=1
[rank0]: Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

[rank0]:[W319 14:54:59.544769873 ProcessGroupNCCL.cpp:1476] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())

rpgmaker avatar Mar 19 '25 21:03 rpgmaker

Can you rerun the script after setting the environment variable CUDA_LAUNCH_BLOCKING=1 by

export CUDA_LAUNCH_BLOCKING=1

and then share the full stack trace? Thanks!

hwchen2017 avatar Mar 20 '25 01:03 hwchen2017

Sure here is the new stacktrace with the env enabled.

[2025-03-19 18:52:01,562] [INFO] [real_accelerator.py:222:get_accelerator] Setting ds_accelerator to cuda (auto detect)
Sliding Window Attention is enabled but not implemented for `sdpa`; unexpected results may be encountered.
[2025-03-19 18:52:07,164] [INFO] [logging.py:128:log_dist] [Rank -1] DeepSpeed info: version=0.16.4, git-hash=unknown, git-branch=unknown
[2025-03-19 18:52:07,165] [INFO] [comm.py:658:init_distributed] cdb=None
[2025-03-19 18:52:07,165] [INFO] [comm.py:673:init_distributed] Not using the DeepSpeed or dist launchers, attempting to detect MPI environment...
[2025-03-19 18:52:07,223] [INFO] [comm.py:728:mpi_discovery] Discovered MPI settings of world_rank=0, local_rank=0, world_size=1, master_addr=10.0.0.246, master_port=29500
[2025-03-19 18:52:07,223] [INFO] [comm.py:689:init_distributed] Initializing TorchBackend in DeepSpeed with backend nccl
[2025-03-19 18:52:07,243] [INFO] [config.py:734:__init__] Config mesh_device None world_size = 1
[2025-03-19 18:52:07,344] [INFO] [logging.py:128:log_dist] [Rank 0] DeepSpeed Flops Profiler Enabled: False
Using /home/xxxxxx/.cache/torch_extensions/py312_cu128 as PyTorch extensions root...
Emitting ninja build file /home/xxxxxx/.cache/torch_extensions/py312_cu128/cpu_adam/build.ninja...
Building extension module cpu_adam...
Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)
ninja: no work to do.
Loading extension module cpu_adam...
Time to load cpu_adam op: 2.1535685062408447 seconds
Adam Optimizer #0 is created with AVX2 arithmetic capability.
Config: alpha=0.000020, betas=(0.900000, 0.999000), weight_decay=0.000000, adam_w=1
[2025-03-19 18:52:10,568] [INFO] [logging.py:128:log_dist] [Rank 0] Using DeepSpeed Optimizer param name adamw as basic optimizer
[2025-03-19 18:52:10,568] [INFO] [logging.py:128:log_dist] [Rank 0] Removing param_group that has no 'params' in the basic Optimizer
[2025-03-19 18:52:10,574] [INFO] [logging.py:128:log_dist] [Rank 0] DeepSpeed Basic Optimizer = DeepSpeedCPUAdam
[2025-03-19 18:52:10,574] [INFO] [utils.py:59:is_zero_supported_optimizer] Checking ZeRO support for optimizer=DeepSpeedCPUAdam type=<class 'deepspeed.ops.adam.cpu_adam.DeepSpeedCPUAdam'>
[2025-03-19 18:52:10,574] [INFO] [logging.py:128:log_dist] [Rank 0] Creating torch.float16 ZeRO stage 2 optimizer
[2025-03-19 18:52:10,574] [INFO] [stage_1_and_2.py:149:__init__] Reduce bucket size 200000000
[2025-03-19 18:52:10,574] [INFO] [stage_1_and_2.py:150:__init__] Allgather bucket size 200000000
[2025-03-19 18:52:10,574] [INFO] [stage_1_and_2.py:151:__init__] CPU Offload: True
[2025-03-19 18:52:10,574] [INFO] [stage_1_and_2.py:152:__init__] Round robin gradient partitioning: False
[rank0]: Traceback (most recent call last):
[rank0]:   File "/xxxxxx/main.py", line 41, in <module>
[rank0]:     run()
[rank0]:   File "/xxxxxx/main.py", line 36, in run
[rank0]:     setup_runner_qwen_1_5B()
[rank0]:   File "/xxxxxx/main.py", line 15, in setup_runner_qwen_1_5B
[rank0]:     runner.run(epoch=50)
[rank0]:   File "/xxxxxx/runner.py", line 47, in run
[rank0]:     self.trainer.train(epoch)
[rank0]:   File "/xxxxxx/trainer/deepspeed/deepspeed_trainer.py", line 39, in train
[rank0]:     self.__setup()
[rank0]:   File "/xxxxxx/trainer/deepspeed/deepspeed_trainer.py", line 73, in __setup
[rank0]:     model, optimizer, _, _ = deepspeed.initialize(
[rank0]:                              ^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/xxxxxx/lib/python3.12/site-packages/deepspeed/__init__.py", line 193, in initialize
[rank0]:     engine = DeepSpeedEngine(args=args,
[rank0]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/xxxxxx/lib/python3.12/site-packages/deepspeed/runtime/engine.py", line 317, in __init__
[rank0]:     self._configure_optimizer(optimizer, model_parameters)
[rank0]:   File "/xxxxxx/lib/python3.12/site-packages/deepspeed/runtime/engine.py", line 1385, in _configure_optimizer
[rank0]:     self.optimizer = self._configure_zero_optimizer(basic_optimizer)
[rank0]:                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/xxxxxx/lib/python3.12/site-packages/deepspeed/runtime/engine.py", line 1643, in _configure_zero_optimizer
[rank0]:     optimizer = DeepSpeedZeroOptimizer(
[rank0]:                 ^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/xxxxxx/lib/python3.12/site-packages/deepspeed/runtime/zero/stage_1_and_2.py", line 407, in __init__
[rank0]:     weights_partition = get_accelerator().pin_memory(weights_partition)
[rank0]:                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/xxxxxx/lib/python3.12/site-packages/deepspeed/accelerator/cuda_accelerator.py", line 293, in pin_memory
[rank0]:     return tensor.pin_memory()
[rank0]:            ^^^^^^^^^^^^^^^^^^^
[rank0]: RuntimeError: CUDA error: invalid argument
[rank0]: Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

[rank0]:[W319 18:52:15.078933979 ProcessGroupNCCL.cpp:1476] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())

rpgmaker avatar Mar 20 '25 01:03 rpgmaker

It appears that pin_memory() leads to this error. Can you try pin_memory: false to see if it can solve this problem? Memory pinning can incur high memory consumption and cause memory allocation error. See here: #3481

hwchen2017 avatar Mar 22 '25 00:03 hwchen2017

I tried to use the pin_memory: false and still end up with same issue.

config:

{
    "train_batch_size": 1,
    "gradient_accumulation_steps": 1,
    "optimizer": {
        "type": "AdamW",
        "params": {
            "lr": 2e-5
        }
    },
    "zero_optimization": {
        "stage": 2,
        "offload_optimizer": {
            "device": "cpu",
            "pin_memory": false
        },
        "allgather_partitions": true,
        "allgather_bucket_size": 2e8,
        "overlap_comm": true,
        "reduce_scatter": true,
        "reduce_bucket_size": 2e8,
        "contiguous_gradients": true
    },
    "fp16": {
        "enabled": true
    }
}

error trace:

[2025-03-21 17:51:00,755] [INFO] [real_accelerator.py:222:get_accelerator] Setting ds_accelerator to cuda (auto detect)
Sliding Window Attention is enabled but not implemented for `sdpa`; unexpected results may be encountered.
[2025-03-21 17:51:06,216] [INFO] [logging.py:128:log_dist] [Rank -1] DeepSpeed info: version=0.16.4, git-hash=unknown, git-branch=unknown
[2025-03-21 17:51:06,216] [INFO] [comm.py:658:init_distributed] cdb=None
[2025-03-21 17:51:06,216] [INFO] [comm.py:673:init_distributed] Not using the DeepSpeed or dist launchers, attempting to detect MPI environment...
[2025-03-21 17:51:06,275] [INFO] [comm.py:728:mpi_discovery] Discovered MPI settings of world_rank=0, local_rank=0, world_size=1, master_addr=10.0.0.246, master_port=29500
[2025-03-21 17:51:06,275] [INFO] [comm.py:689:init_distributed] Initializing TorchBackend in DeepSpeed with backend nccl
[2025-03-21 17:51:06,295] [INFO] [config.py:734:__init__] Config mesh_device None world_size = 1
[2025-03-21 17:51:06,403] [INFO] [logging.py:128:log_dist] [Rank 0] DeepSpeed Flops Profiler Enabled: False
Using /home/xxxxxx/.cache/torch_extensions/py312_cu128 as PyTorch extensions root...
Emitting ninja build file /home/xxxxxx/.cache/torch_extensions/py312_cu128/cpu_adam/build.ninja...
Building extension module cpu_adam...
Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)
ninja: no work to do.
Loading extension module cpu_adam...
Time to load cpu_adam op: 2.176302909851074 seconds
Adam Optimizer #0 is created with AVX2 arithmetic capability.
Config: alpha=0.000020, betas=(0.900000, 0.999000), weight_decay=0.000000, adam_w=1
[2025-03-21 17:51:09,684] [INFO] [logging.py:128:log_dist] [Rank 0] Using DeepSpeed Optimizer param name adamw as basic optimizer
[2025-03-21 17:51:09,684] [INFO] [logging.py:128:log_dist] [Rank 0] Removing param_group that has no 'params' in the basic Optimizer
[2025-03-21 17:51:09,689] [INFO] [logging.py:128:log_dist] [Rank 0] DeepSpeed Basic Optimizer = DeepSpeedCPUAdam
[2025-03-21 17:51:09,689] [INFO] [utils.py:59:is_zero_supported_optimizer] Checking ZeRO support for optimizer=DeepSpeedCPUAdam type=<class 'deepspeed.ops.adam.cpu_adam.DeepSpeedCPUAdam'>
[2025-03-21 17:51:09,689] [INFO] [logging.py:128:log_dist] [Rank 0] Creating torch.float16 ZeRO stage 2 optimizer
[2025-03-21 17:51:09,689] [INFO] [stage_1_and_2.py:149:__init__] Reduce bucket size 200000000
[2025-03-21 17:51:09,689] [INFO] [stage_1_and_2.py:150:__init__] Allgather bucket size 200000000
[2025-03-21 17:51:09,689] [INFO] [stage_1_and_2.py:151:__init__] CPU Offload: True
[2025-03-21 17:51:09,689] [INFO] [stage_1_and_2.py:152:__init__] Round robin gradient partitioning: False
[rank0]: Traceback (most recent call last):
[rank0]:   File "/xxxxxx/main.py", line 41, in <module>
[rank0]:     run()
[rank0]:   File "/xxxxxx/main.py", line 36, in run
[rank0]:     setup_runner_qwen_1_5B()
[rank0]:   File "/xxxxxx/main.py", line 15, in setup_runner_qwen_1_5B
[rank0]:     runner.run(epoch=50)
[rank0]:   File "/xxxxxx/runner.py", line 47, in run
[rank0]:     self.trainer.train(epoch)
[rank0]:   File "/xxxxxx/trainer/deepspeed/deepspeed_trainer.py", line 39, in train
[rank0]:     self.__setup()
[rank0]:   File "/xxxxxx/trainer/deepspeed/deepspeed_trainer.py", line 73, in __setup
[rank0]:     model, optimizer, _, _ = deepspeed.initialize(
[rank0]:                              ^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/xxxxxx/lib/python3.12/site-packages/deepspeed/__init__.py", line 193, in initialize
[rank0]:     engine = DeepSpeedEngine(args=args,
[rank0]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/xxxxxx/lib/python3.12/site-packages/deepspeed/runtime/engine.py", line 317, in __init__
[rank0]:     self._configure_optimizer(optimizer, model_parameters)
[rank0]:   File "/home/xxxxxx/lib/python3.12/site-packages/deepspeed/runtime/engine.py", line 1385, in _configure_optimizer
[rank0]:     self.optimizer = self._configure_zero_optimizer(basic_optimizer)
[rank0]:                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/xxxxxx/lib/python3.12/site-packages/deepspeed/runtime/engine.py", line 1643, in _configure_zero_optimizer
[rank0]:     optimizer = DeepSpeedZeroOptimizer(
[rank0]:                 ^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/xxxxxx/lib/python3.12/site-packages/deepspeed/runtime/zero/stage_1_and_2.py", line 407, in __init__
[rank0]:     weights_partition = get_accelerator().pin_memory(weights_partition)
[rank0]:                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/xxxxxx/lib/python3.12/site-packages/deepspeed/accelerator/cuda_accelerator.py", line 293, in pin_memory
[rank0]:     return tensor.pin_memory()
[rank0]:            ^^^^^^^^^^^^^^^^^^^
[rank0]: RuntimeError: CUDA error: invalid argument
[rank0]: Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

[rank0]:[W321 17:51:14.766289001 ProcessGroupNCCL.cpp:1476] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())

rpgmaker avatar Mar 22 '25 00:03 rpgmaker

Hi @tjruwase , can you have a look at it? Related PR: #4131

hwchen2017 avatar Mar 22 '25 01:03 hwchen2017

@rpgmaker, can you please share your full repro steps, so I can try on my side. I see that torch has changed pin_memory() behavior.

Also, can you report the output of the following

  1. python -c "import torch; from deepspeed.accelerator import get_accelerator; get_accelerator().pin_memory(torch.empty(1024, device='cuda'))"
  2. python -c "import torch; from deepspeed.accelerator import get_accelerator; get_accelerator().pin_memory(torch.empty(1024, device='cpu'))"

tjruwase avatar Mar 22 '25 18:03 tjruwase

Sure, let me work on it and i will provide the steps with examples. Thanks

rpgmaker avatar Mar 22 '25 19:03 rpgmaker

Below is the runs for both the CPU and the GPU runs for torch, It seem the cpu version did not fail

python -c "import torch; from deepspeed.accelerator import get_accelerator; get_accelerator().pin_memory(torch.empty(1024, device='cuda'))"

output:
[2025-03-22 12:52:09,315] [INFO] [real_accelerator.py:222:get_accelerator] Setting ds_accelerator to cuda (auto detect)
Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/home/xxxxxxx/lib/python3.12/site-packages/deepspeed/accelerator/cuda_accelerator.py", line 293, in pin_memory
    return tensor.pin_memory()
           ^^^^^^^^^^^^^^^^^^^
RuntimeError: cannot pin 'torch.cuda.FloatTensor' only dense CPU tensors can be pinned
python -c "import torch; from deepspeed.accelerator import get_accelerator; get_accelerator().pin_memory(torch.empty(1024, device='cpu'))" 
[2025-03-22 12:53:15,468] [INFO] [real_accelerator.py:222:get_accelerator] Setting ds_accelerator to cuda (auto detect)

for full repo, use the following steps:

  • Create python env
  • Make sure to have Cuda 12.8 installed
  • Run pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128
  • install the following requirement
deepspeed
datasets
transformers
sentence_transformers
mpi4py
  • download the following py code example to run the training - https://gist.github.com/rpgmaker/5e35561682c3e6cdd4b802a9127cfb9d
  • run it

rpgmaker avatar Mar 22 '25 20:03 rpgmaker

@rpgmaker, thanks for sharing the results of the 'cpu' and 'cuda' tensor pinning. That is what I expected.

Also, thanks for sharing repro steps. Unfortunately, I am unable to repro the error on my A6000. Below is my log showing that successful ZeRO initialization, which is where you are seeing the error

[2025-03-22 20:13:40,481] [INFO] [logging.py:107:log_dist] [Rank 0] Using DeepSpeed Optimizer param name adamw as basic optimizer
[2025-03-22 20:13:40,482] [INFO] [logging.py:107:log_dist] [Rank 0] Removing param_group that has no 'params' in the basic Optimizer
[2025-03-22 20:13:40,500] [INFO] [logging.py:107:log_dist] [Rank 0] DeepSpeed Basic Optimizer = DeepSpeedCPUAdam
[2025-03-22 20:13:40,500] [INFO] [utils.py:59:is_zero_supported_optimizer] Checking ZeRO support for optimizer=DeepSpeedCPUAdam type=<class 'deepspeed.ops.adam.cpu_adam.DeepSpeedCPUAdam'>
[2025-03-22 20:13:40,500] [INFO] [logging.py:107:log_dist] [Rank 0] Creating torch.bfloat16 ZeRO stage 2 optimizer
[2025-03-22 20:13:40,501] [INFO] [stage_1_and_2.py:149:__init__] Reduce bucket size 200000000
[2025-03-22 20:13:40,501] [INFO] [stage_1_and_2.py:150:__init__] Allgather bucket size 200000000
[2025-03-22 20:13:40,501] [INFO] [stage_1_and_2.py:151:__init__] CPU Offload: True
[2025-03-22 20:13:40,501] [INFO] [stage_1_and_2.py:152:__init__] Round robin gradient partitioning: False
[2025-03-22 20:13:51,913] [INFO] [utils.py:781:see_memory_usage] Before initializing optimizer states
[2025-03-22 20:13:51,914] [INFO] [utils.py:782:see_memory_usage] MA 5.49 GB         Max_MA 5.49 GB         CA 7.09 GB         Max_CA 7 GB 
[2025-03-22 20:13:51,914] [INFO] [utils.py:789:see_memory_usage] CPU Virtual Memory:  used = 20.69 GB, percent = 8.2%
[2025-03-22 20:13:52,315] [INFO] [utils.py:781:see_memory_usage] After initializing optimizer states
[2025-03-22 20:13:52,315] [INFO] [utils.py:782:see_memory_usage] MA 5.49 GB         Max_MA 5.49 GB         CA 7.09 GB         Max_CA 7 GB 
[2025-03-22 20:13:52,315] [INFO] [utils.py:789:see_memory_usage] CPU Virtual Memory:  used = 27.07 GB, percent = 10.8%
[2025-03-22 20:13:52,316] [INFO] [stage_1_and_2.py:556:__init__] optimizer state initialized
[2025-03-22 20:13:52,432] [INFO] [utils.py:781:see_memory_usage] After initializing ZeRO optimizer
[2025-03-22 20:13:52,433] [INFO] [utils.py:782:see_memory_usage] MA 5.49 GB         Max_MA 5.49 GB         CA 7.09 GB         Max_CA 7 GB 

I tweaked the ds_config to use bf16 instead of fp16 to avoid overflows and here is the output of the run.

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
The attention mask is not set and cannot be inferred from input because pad token is same as eos token.As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Prompt: Pjhjsasdaranceasy?
Response:  Mountain World.<|endoftext|>
---
Prompt: Who painted the Mona Lisa?
Response: 
---
Prompt: Explain the theory of relativity.
Response:  The theory of relativity is a fundamental concept in physics that describes how matter and energy interact. It encompasses both special relativity and general relativity. Special relativity explains how objects move at high speeds relative to each other, with time dilation and length contraction. General relativity describes gravity as the curvature of spacetime caused by mass and energy. Together, they provide a comprehensive framework for understanding the universe.

tjruwase avatar Mar 23 '25 00:03 tjruwase

I am unsure the best way to proceed. Below is my ds_report with some slight differences to your original post.

DeepSpeed general environment info:
torch install path ............... ['/py_venv/torch_2_cuda/lib/python3.12/site-packages/torch']
torch version .................... 2.8.0.dev20250322+cu128
deepspeed install path ........... ['/py_venv/torch_2_cuda/lib/python3.12/site-packages/deepspeed']
deepspeed info ................... 0.16.4, unknown, unknown
torch cuda version ............... 12.8
torch hip version ................ None
nvcc version ..................... 12.8
deepspeed wheel compiled w. ...... torch 2.8, cuda 12.8

tjruwase avatar Mar 23 '25 00:03 tjruwase

File "/home/xxxxxxx/lib/python3.12/site-packages/deepspeed/accelerator/cuda_accelerator.py", line 293, in pin_memory return tensor.pin_memory()

@rpgmaker, can you try changing the above code to return tensor.cpu().pin_memory()

tjruwase avatar Mar 23 '25 00:03 tjruwase

I tried using the changes with .cpu().pin_memory() and ended up with same errors

[rank0]:   File "/home/xxxxxx/lib/python3.12/site-packages/deepspeed/runtime/zero/stage_1_and_2.py", line 407, in __init__
[rank0]:     weights_partition = get_accelerator().pin_memory(weights_partition)
[rank0]:                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/xxxxxx/lib/python3.12/site-packages/deepspeed/accelerator/cuda_accelerator.py", line 293, in pin_memory
[rank0]:     return tensor.cpu().pin_memory()
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: RuntimeError: **CUDA error: invalid argument**
[rank0]: CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
[rank0]: For debugging consider passing CUDA_LAUNCH_BLOCKING=1
[rank0]: Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

rpgmaker avatar Mar 23 '25 01:03 rpgmaker

That is very strange. Can you try printing tensor.dtype, tensor.shape and tensor.device before the pin_memory() call?

tjruwase avatar Mar 23 '25 02:03 tjruwase

Here is the result below.

tensor.dtype:  torch.float32
tensor.shape:  torch.Size([1776255488])
tensor.device:  cpu

code changes:

def pin_memory(self, tensor, align_bytes=1):
        print("tensor.dtype: ", tensor.dtype)
        print("tensor.shape: ", tensor.shape)
        print("tensor.device: ", tensor.device)
        return tensor.cpu().pin_memory()

ds_config:

ds_config = {
        "train_batch_size": batch_size,
        "gradient_accumulation_steps": 1,
        "optimizer": {
            "type": "AdamW",
            "params": {
                "lr": learning_rate
            }
        },
        "bf16": {
            "enabled": torch.cuda.is_available()
        },
        "fp16": {
            "enabled": False
        },
        "zero_optimization": {
            "stage": 2, # Use stage 2 for memory efficiency
            "offload_optimizer": {
                "device": "cpu",
                "pin_memory": False
            },
            "allgather_partitions": True,
            "allgather_bucket_size": 2e8,
            "overlap_comm": True,
            "reduce_scatter": True,
            "reduce_bucket_size": 2e8,
            "contiguous_gradients": True
        }
    }

rpgmaker avatar Mar 23 '25 02:03 rpgmaker

tensor.dtype: torch.float32 tensor.shape: torch.Size([1776255488]) tensor.device: cpu

@rpgmaker, thanks. Those are expected results and does not explain the error. I don't have a 3090 TI FE to try this on. Can you retry the following now adjusted for tensor size?

python -c "import torch; from deepspeed.accelerator import get_accelerator; get_accelerator().pin_memory(torch.empty(1776255488, device='cpu'))"

tjruwase avatar Mar 23 '25 18:03 tjruwase

This is the result. it seem the allocation/pinning the memory is the issue:

python -c "import torch; from deepspeed.accelerator import get_accelerator; get_accelerator().pin_memory(torch.empty(1776255488, device='cpu'))"
[2025-03-23 11:54:42,621] [INFO] [real_accelerator.py:222:get_accelerator] Setting ds_accelerator to cuda (auto detect)
tensor.dtype:  torch.float32
tensor.shape:  torch.Size([1776255488])
tensor.device:  cpu
Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/home/xxxxxxxx/lib/python3.12/site-packages/deepspeed/accelerator/cuda_accelerator.py", line 296, in pin_memory
    return tensor.cpu().pin_memory()
           ^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: CUDA error: invalid argument
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

rpgmaker avatar Mar 23 '25 18:03 rpgmaker

@rpgmaker, that is interesting. Maybe the tensor size is the cause, since we already saw that 1024 size worked. Can you test a sweep of tensor sizes starting from 1024, and doubling to see the failure point?

tjruwase avatar Mar 23 '25 19:03 tjruwase

"offload_optimizer": { "device": "cpu", "pin_memory": False

One workaround is to control this particular pinning by pin_memory in the ds_config. So to unblock you, please try changing the following condition to if self.cpu_offload_pin_memory:

https://github.com/deepspeedai/DeepSpeed/blob/1ca83a6bb9f3fffdb98c94093ab48605294241ae/deepspeed/runtime/zero/stage_1_and_2.py#L406-L407

tjruwase avatar Mar 23 '25 19:03 tjruwase

Thanks, will give it a try and let you know the results.

rpgmaker avatar Mar 23 '25 19:03 rpgmaker

"offload_optimizer": { "device": "cpu", "pin_memory": False

One workaround is to control this particular pinning by pin_memory in the ds_config. So to unblock you, please try changing the following condition to if self.cpu_offload_pin_memory:

DeepSpeed/deepspeed/runtime/zero/stage_1_and_2.py

Lines 406 to 407 in 1ca83a6

if self.cpu_offload: weights_partition = get_accelerator().pin_memory(weights_partition)

Using this option actually cause my vscode to crash.

rpgmaker avatar Mar 23 '25 20:03 rpgmaker

@tjruwase I run the memory test to see how much i can allocate and it failed here

tensor.dtype:  torch.float32
tensor.shape:  torch.Size([268493824])
tensor.device:  cpu
Traceback (most recent call last):
  File "/home/xxxxxxxx/test_tensor.py", line 8, in <module>
    get_accelerator().pin_memory(torch.empty(start * increase, device='cpu'))
  File "/home/xxxxxxxx/lib/python3.12/site-packages/deepspeed/accelerator/cuda_accelerator.py", line 296, in pin_memory
    return tensor.cpu().pin_memory()
           ^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: CUDA error: invalid argument
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

rpgmaker avatar Mar 23 '25 22:03 rpgmaker

Here is the code snippet for the test

import torch; 
from deepspeed.accelerator import get_accelerator; 

start = 1024
increase = 1

while True:
    get_accelerator().pin_memory(torch.empty(start * increase, device='cpu'))
    increase += 100

rpgmaker avatar Mar 23 '25 22:03 rpgmaker

if self.cpu_offload: weights_partition = get_accelerator().pin_memory(weights_partition)

Using this option actually cause my vscode to crash.

That is odd. What happens if you comment out the code?

tjruwase avatar Mar 24 '25 00:03 tjruwase

Here is the code snippet for the test

import torch; 
from deepspeed.accelerator import get_accelerator; 

start = 1024
increase = 1

while True:
    get_accelerator().pin_memory(torch.empty(start * increase, device='cpu'))
    increase += 100

Can you further reduce this to?

import torch; 

start = 1024
increase = 1

while True:
    torch.empty(start * increase, device='cpu').pin_memory()
    increase += 100

tjruwase avatar Mar 24 '25 00:03 tjruwase

Switching the code to just torch.empty produces the following result:

Traceback (most recent call last):
  File "/home/xxxxxxx/test_tensor.py", line 7, in <module>
    torch.empty(start * increase, device='cpu')
RuntimeError: [enforce fail at alloc_cpu.cpp:119] err == 0. DefaultCPUAllocator: can't allocate memory: you tried to allocate 141011357696 bytes. Error code 12 (Cannot allocate memory)

rpgmaker avatar Mar 24 '25 00:03 rpgmaker

@tjruwase with regards to commenting out the code. it started printing the epochs and was training but then crashed at the end of it. when trying to run the test phase

Image

rpgmaker avatar Mar 24 '25 01:03 rpgmaker

torch.empty(start * increase, device='cpu')

RuntimeError: [enforce fail at alloc_cpu.cpp:119] err == 0. DefaultCPUAllocator: can't allocate memory: you tried to allocate 141011357696 bytes. Error code 12 (Cannot allocate memory)

How DRAM do you have? 141011357696 bytes is ~141GB. So that might explain the error.

tjruwase avatar Mar 24 '25 01:03 tjruwase

File "/home/xxxxxxx/test_tensor.py", line 7, in torch.empty(start * increase, device='cpu')

Did you remove the pin_memory() in your test?

tjruwase avatar Mar 24 '25 01:03 tjruwase