verl icon indicating copy to clipboard operation
verl copied to clipboard

CUDA Error Persists in Qwen GRPO Training Despite Setting VLLM_ATTENTION_BACKEND=XFORMERS

Open AIBionics opened this issue 1 year ago • 3 comments

When training Qwen, I encountered a CUDA error after few steps.

I have set VLLM_ATTENTION_BACKEND to XFORMERS and confirmed that the environment variable is in effect.

Before running ray start, I exported the variable, and after _init_with_resource_pool, I printed the VLLM_ATTENTION_BACKEND environment variable for each rank, and all of them showed XFORMERS.


python3 -m verl.trainer.main_ppo \
    algorithm.adv_estimator=grpo \
    data.train_files="$train_files" \
    data.val_files="$test_files" \
    data.train_batch_size=128 \
    data.val_batch_size=128 \
    data.max_prompt_length=2048 \
    data.max_response_length=8192 \
    actor_rollout_ref.model.path=/root/Qwen2.5-7B-Instruct \
    actor_rollout_ref.actor.optim.lr=1e-6 \
    actor_rollout_ref.model.use_remove_padding=True \
    actor_rollout_ref.actor.use_dynamic_bsz=True \
    actor_rollout_ref.actor.use_kl_loss=True \
    actor_rollout_ref.actor.kl_loss_coef=0.001 \
    actor_rollout_ref.actor.kl_loss_type=low_var_kl \
    actor_rollout_ref.model.enable_gradient_checkpointing=True \
    actor_rollout_ref.actor.fsdp_config.param_offload=True \
    actor_rollout_ref.actor.fsdp_config.grad_offload=True \
    actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \
    actor_rollout_ref.rollout.disable_log_stats=False \
    actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=True \
    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
    actor_rollout_ref.rollout.name=vllm \
    actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \
    actor_rollout_ref.rollout.n=16 \
    actor_rollout_ref.ref.log_prob_use_dynamic_bsz=True \
    actor_rollout_ref.ref.fsdp_config.param_offload=True \
    algorithm.kl_ctrl.kl_coef=0.001 \
    trainer.critic_warmup=0 \
    trainer.logger=['console','mlflow'] \
    trainer.project_name='kk' \
    trainer.experiment_name='qwen_7b_rl' \
    trainer.n_gpus_per_node=8 \
    trainer.nnodes=2 \
    trainer.save_freq=20 \
    trainer.test_freq=20 \
    trainer.total_epochs=15 $@ 

Env

Python== 3.11.3

bitsandbytes 0.45.2
databricks-sdk 0.43.0
datasets 2.21.0
deepspeed 0.15.0
flash-attn 2.6.1
gguf 0.10.0
latex2sympy2 1.9.1
liger_kernel 0.5.2
lightning-utilities 0.12.0
llvmlite 0.44.0
ninja 1.11.1.1
nltk 3.9.1
numpy 1.26.4
nvidia-cublas-cu12 12.1.3.1
nvidia-cuda-cupti-cu12 12.1.105
nvidia-cuda-nvrtc-cu12 12.1.105
nvidia-cuda-runtime-cu12 12.1.105
nvidia-cudnn-cu12 9.1.0.70
nvidia-cufft-cu12 11.0.2.54
nvidia-curand-cu12 10.3.2.106
nvidia-cusolver-cu12 11.4.5.107
nvidia-cusparse-cu12 12.1.0.106
nvidia-ml-py 12.560.30
nvidia-nccl-cu12 2.20.5
nvidia-nvjitlink-cu12 12.4.127
nvidia-nvtx-cu12 12.1.105
peft 0.12.0
safetensors 0.4.5
scikit-learn 1.6.1
scipy 1.14.1
tensordict 0.5.0
torch 2.4.0
torchaudio 2.5.1
torchmetrics 1.6.1
torchvision 0.19.0
transformers 4.46.1
transformers-stream-generator 0.0.5
triton 3.0.0
vllm 0.6.3
wandb 0.19.6
xformers 0.0.27.post2

Error

(WorkerDict pid=102727, ip=10.95.236.82)   with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context:  # type: ignore[attr-defined] [repeated 15x across cluster]
Traceback (most recent call last):
  File "/root/workspace/env_run/verl/verl/trainer/main_ppo.py", line 100, in main
    run_ppo(config)
  File "/root/workspace/env_run/verl/verl/trainer/main_ppo.py", line 125, in run_ppo
    ray.get(main_task.remote(config, compute_score))
  File "/root/venv/lib/python3.11/site-packages/ray/_private/auto_init_hook.py", line 21, in auto_init_wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/root/venv/lib/python3.11/site-packages/ray/_private/client_mode_hook.py", line 103, in wrapper
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/root/venv/lib/python3.11/site-packages/ray/_private/worker.py", line 2772, in get
    values, debugger_breakpoint = worker.get_objects(object_refs, timeout=timeout)
                                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/venv/lib/python3.11/site-packages/ray/_private/worker.py", line 919, in get_objects
    raise value.as_instanceof_cause()
ray.exceptions.RayTaskError(RuntimeError): ray::main_task() (pid=129449, ip=10.95.236.98)
  File "/root/workspace/env_run/verl/verl/trainer/main_ppo.py", line 211, in main_task
    trainer.fit()
  File "/root/workspace/env_run/verl/verl/trainer/ppo/ray_trainer.py", line 744, in fit
    gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch)
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/workspace/env_run/verl/verl/single_controller/ray/base.py", line 42, in func
    output = ray.get(output)
             ^^^^^^^^^^^^^^^
           ^^^^^^^^^^^^^^^^^^^
           ^^^^^^^^^^^^^^^^^^^^^
                                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ray.exceptions.RayTaskError(RuntimeError): ray::WorkerDict.actor_rollout_generate_sequences() (pid=102727, ip=10.95.236.82, actor_id=3d46c2bf190029087ce109a20a000000, repr=<verl.single_controller.ray.base.WorkerDict object at 0x7f79eb367850>)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/root/venv/lib/python3.11/site-packages/vllm/worker/model_runner.py", line 1708, in execute_model
    output: SamplerOutput = self.model.sample(
                            ^^^^^^^^^^^^^^^^^^
  File "/root/venv/lib/python3.11/site-packages/vllm/model_executor/models/qwen2.py", line 433, in sample
    next_tokens = self.sampler(logits, sampling_metadata)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/venv/lib/python3.11/site-packages/vllm/model_executor/layers/sampler.py", line 274, in forward
    maybe_deferred_sample_results, maybe_sampled_tokens_tensor = _sample(
                                                                 ^^^^^^^^
  File "/root/venv/lib/python3.11/site-packages/vllm/model_executor/layers/sampler.py", line 878, in _sample
    return _sample_with_torch(
           ^^^^^^^^^^^^^^^^^^^
  File "/root/venv/lib/python3.11/site-packages/vllm/model_executor/layers/sampler.py", line 847, in _sample_with_torch
    return get_pythonized_sample_results(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/venv/lib/python3.11/site-packages/vllm/model_executor/layers/sampler.py", line 713, in get_pythonized_sample_results
    sample_results = _random_sample(seq_groups,
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/venv/lib/python3.11/site-packages/vllm/model_executor/layers/sampler.py", line 512, in _random_sample
    random_samples = random_samples.cpu()
                     ^^^^^^^^^^^^^^^^^^^^
RuntimeError: CUDA error: an illegal memory access was encountered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


The above exception was the direct cause of the following exception:

ray::WorkerDict.actor_rollout_generate_sequences() (pid=102727, ip=10.95.236.82, actor_id=3d46c2bf190029087ce109a20a000000, repr=<verl.single_controller.ray.base.WorkerDict object at 0x7f79eb367850>)
  File "/root/workspace/env_run/verl/verl/workers/fsdp_workers.py", line 468, in generate_sequences
    output = self.rollout.generate_sequences(prompts=prompts)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/venv/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/root/workspace/env_run/verl/verl/workers/rollout/vllm_rollout/vllm_rollout.py", line 181, in generate_sequences
    output = self.inference_engine.generate(
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/venv/lib/python3.11/site-packages/vllm/utils.py", line 1063, in inner
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/root/venv/lib/python3.11/site-packages/vllm/entrypoints/llm.py", line 353, in generate
    outputs = self._run_engine(use_tqdm=use_tqdm)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/workspace/env_run/verl/verl/third_party/vllm/vllm_v_0_6_3/llm.py", line 161, in _run_engine
    outputs = super()._run_engine(use_tqdm=use_tqdm)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/venv/lib/python3.11/site-packages/vllm/entrypoints/llm.py", line 879, in _run_engine
    step_outputs = self.llm_engine.step()
                   ^^^^^^^^^^^^^^^^^^^^^^
  File "/root/venv/lib/python3.11/site-packages/vllm/engine/llm_engine.py", line 1386, in step
    outputs = self.model_executor.execute_model(
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/workspace/env_run/verl/verl/third_party/vllm/vllm_v_0_6_3/spmd_gpu_executor.py", line 163, in execute_model
    all_outputs = self.worker.execute_model(execute_model_req=execute_model_req)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/workspace/env_run/verl/verl/third_party/vllm/vllm_v_0_6_3/worker.py", line 267, in execute_model1
    return self.model_runner.execute_model(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/venv/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/root/venv/lib/python3.11/site-packages/vllm/worker/model_runner_base.py", line 146, in _wrapper
    raise type(err)(f"Error in model execution: "
RuntimeError: Error in model execution: CUDA error: an illegal memory access was encountered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


During handling of the above exception, another exception occurred:

ray::WorkerDict.actor_rollout_generate_sequences() (pid=102727, ip=10.95.236.82, actor_id=3d46c2bf190029087ce109a20a000000, repr=<verl.single_controller.ray.base.WorkerDict object at 0x7f79eb367850>)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/workspace/env_run/verl/verl/single_controller/ray/base.py", line 421, in func
    """
        
  File "/root/workspace/env_run/verl/verl/single_controller/base/decorator.py", line 404, in inner
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/root/workspace/env_run/verl/verl/workers/fsdp_workers.py", line 464, in generate_sequences
    with self.rollout_sharding_manager:
  File "/root/workspace/env_run/verl/verl/workers/sharding_manager/fsdp_vllm.py", line 105, in __exit__
    torch.cuda.empty_cache()
  File "/root/venv/lib/python3.11/site-packages/torch/cuda/memory.py", line 170, in empty_cache
    torch._C._cuda_emptyCache()
RuntimeError: CUDA error: an illegal memory access was encountered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

Set the environment variable HYDRA_FULL_ERROR=1 for a complete stack trace.
(WorkerDict pid=102724, ip=10.95.236.82) INFO 02-11 11:40:52 metrics.py:345] Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 3451.9 tokens/s, Running: 8 reqs, Swapped: 0 reqs, Pending: 0 reqs, GPU KV cache usage: 0.6%, CPU KV cache usage: 0.0%. [repeated 6x across cluster]

AIBionics avatar Feb 11 '25 03:02 AIBionics

I am also getting similar issues. But mine didn't even run few steps. Do you mind sharing your installation steps and the CUDA_Version?

Thanks

3rdAT avatar Feb 11 '25 04:02 3rdAT

I am also getting similar issues. But mine didn't even run few steps. Do you mind sharing your installation steps and the CUDA_Version?

Thanks

cuda_12.4.0_550.54.14

cd verl
pip install -r requirements.txt

AIBionics avatar Feb 11 '25 04:02 AIBionics

I also encountered the same issue. Setting VLLM_ATTENTION_BACKEND to XFORMERS doesn't work. Moreover, this bug seems to occur randomly, as there might not be this bug when running the same script a second time.

My Env(python==3.9.21):

accelerate                        1.4.0
aiohappyeyeballs                  2.4.6
aiohttp                           3.11.13
aiosignal                         1.3.2
annotated-types                   0.7.0
antlr4-python3-runtime            4.9.3
anyio                             4.8.0
asttokens                         3.0.0
async-timeout                     4.0.3
attrs                             25.1.0
beautifulsoup4                    4.13.3
cbor                              1.0.0
certifi                           2025.1.31
charset-normalizer                3.4.1
click                             8.1.8
cloudpickle                       3.1.1
codetiming                        1.4.0
comm                              0.2.2
datasets                          3.3.2
debugpy                           1.8.13
decorator                         5.2.1
dill                              0.3.8
diskcache                         5.6.3
distro                            1.9.0
docker-pycreds                    0.4.0
einops                            0.8.1
eval_type_backport                0.2.2
exceptiongroup                    1.2.2
executing                         2.1.0
faiss-cpu                         1.10.0
fastapi                           0.115.8
filelock                          3.13.1
FlagEmbedding                     1.3.4
flash_attn                        2.7.4.post1
frozenlist                        1.5.0
fsspec                            2024.6.1
gguf                              0.10.0
gitdb                             4.0.12
GitPython                         3.1.44
greenlet                          3.1.1
h11                               0.14.0
httpcore                          1.0.7
httptools                         0.6.4
httpx                             0.28.1
huggingface-hub                   0.29.1
hydra-core                        1.3.2
idna                              3.10
ijson                             3.3.0
importlib_metadata                8.6.1
inscriptis                        2.5.3
interegular                       0.3.3
ipykernel                         6.29.5
ipython                           8.18.1
ir_datasets                       0.5.9
jedi                              0.19.2
Jinja2                            3.1.4
jiter                             0.8.2
joblib                            1.4.2
jsonpatch                         1.33
jsonpointer                       3.0.0
jsonschema                        4.23.0
jsonschema-specifications         2024.10.1
jupyter_client                    8.6.3
jupyter_core                      5.7.2
langchain                         0.3.19
langchain-core                    0.3.40
langchain-text-splitters          0.3.6
langsmith                         0.3.11
lark                              1.2.2
llvmlite                          0.43.0
lm-format-enforcer                0.10.6
lxml                              5.3.1
lz4                               4.4.3
MarkupSafe                        2.1.5
matplotlib-inline                 0.1.7
mistral_common                    1.5.3
mpmath                            1.3.0
msgpack                           1.1.0
msgspec                           0.19.0
multidict                         6.1.0
multiprocess                      0.70.16
mwparserfromhell                  0.6.6
nest_asyncio                      1.6.0
networkx                          3.2.1
nltk                              3.9.1
numba                             0.60.0
numpy                             1.26.4
nvidia-cublas-cu12                12.4.2.65
nvidia-cuda-cupti-cu12            12.4.99
nvidia-cuda-nvrtc-cu12            12.4.99
nvidia-cuda-runtime-cu12          12.4.99
nvidia-cudnn-cu12                 9.1.0.70
nvidia-cufft-cu12                 11.2.0.44
nvidia-curand-cu12                10.3.5.119
nvidia-cusolver-cu12              11.6.0.99
nvidia-cusparse-cu12              12.3.0.142
nvidia-ml-py                      12.570.86
nvidia-nccl-cu12                  2.20.5
nvidia-nvjitlink-cu12             12.4.99
nvidia-nvtx-cu12                  12.4.99
omegaconf                         2.3.0
openai                            1.64.0
opencv-python-headless            4.11.0.86
orjson                            3.10.15
outlines                          0.0.46
packaging                         24.2
pandas                            2.2.3
parso                             0.8.4
partial-json-parser               0.2.1.1.post5
peft                              0.14.0
pexpect                           4.9.0
pickleshare                       0.7.5
pillow                            11.1.0
pip                               25.0
platformdirs                      4.3.6
prometheus_client                 0.21.1
prometheus-fastapi-instrumentator 7.0.2
prompt_toolkit                    3.0.50
propcache                         0.3.0
protobuf                          5.29.3
psutil                            7.0.0
ptyprocess                        0.7.0
pure_eval                         0.2.3
py-cpuinfo                        9.0.0
pyairports                        2.1.1
pyarrow                           19.0.1
pybind11                          2.13.6
pycountry                         24.6.1
pydantic                          2.10.6
pydantic_core                     2.27.2
Pygments                          2.19.1
pylatexenc                        2.10
python-dateutil                   2.9.0.post0
python-dotenv                     1.0.1
pytz                              2025.1
PyYAML                            6.0.2
pyzmq                             26.2.1
ragdata                           0.2.0
ray                               2.42.1
referencing                       0.36.2
regex                             2024.11.6
requests                          2.32.3
requests-toolbelt                 1.0.0
rpds-py                           0.23.1
safetensors                       0.5.2
scikit-learn                      1.6.1
scipy                             1.13.1
sentence-transformers             3.4.1
sentencepiece                     0.2.0
sentry-sdk                        2.22.0
setproctitle                      1.3.5
setuptools                        75.8.0
six                               1.17.0
smmap                             5.0.2
sniffio                           1.3.1
soupsieve                         2.6
SQLAlchemy                        2.0.38
stack_data                        0.6.3
starlette                         0.45.3
sympy                             1.13.1
tenacity                          9.0.0
tensordict                        0.5.0
threadpoolctl                     3.5.0
tiktoken                          0.9.0
tokenizers                        0.21.0
torch                             2.4.0+cu124
torchdata                         0.11.0
torchvision                       0.19.0
tornado                           6.4.2
tqdm                              4.67.1
traitlets                         5.14.3
transformers                      4.49.0
trec-car-tools                    2.6
triton                            3.0.0
txtai                             8.3.1
typing_extensions                 4.12.2
tzdata                            2025.1
unlzw3                            0.2.3
urllib3                           2.3.0
uvicorn                           0.34.0
uvloop                            0.21.0
verl                              0.2.0.dev0
vllm                              0.6.3
wandb                             0.19.7
warc3-wet                         0.2.5
warc3-wet-clueweb09               0.2.5
watchfiles                        1.0.4
wcwidth                           0.2.13
websockets                        15.0
wheel                             0.45.1
xformers                          0.0.27.post2
xxhash                            3.5.0
yarl                              1.18.3
zipp                              3.21.0
zlib-state                        0.1.9
zstandard                         0.23.0

0russwest0 avatar Mar 07 '25 01:03 0russwest0

I encountered exactly the same error, when I use verl 0.2.0dev0 and vllm 0.6.3. My solution was to use verl 0.3.0 and vllm 0.8.2.

OndineMrCai avatar Apr 13 '25 04:04 OndineMrCai

I have also encountered this problem. Have you resolved it? I hope you will be generous in sharing your insights @0russwest0 @AIBionics

SimonHeye avatar Sep 24 '25 17:09 SimonHeye