verl icon indicating copy to clipboard operation
verl copied to clipboard

Out of Memory running Qwen3-30B-A3B with 32k sequence length on 4 nodes (32 GPUs)

Open jiayi37u opened this issue 5 months ago • 6 comments

Description

I’m trying to train GRPO on the Qwen3-30B-A3B model with a sequence length of 32,000 tokens on 4 nodes (total 32 GPUs with with 80 GB of VRAM), but I keep hitting CUDA out-of-memory errors. I’d like advice on how best to configure distributed training to avoid OOM.

Environment

Model: Qwen3-30B-A3B

Sequence length: 32,000

Nodes: 4 × 8× NVIDIA A100 GPUs (80 GB)

Current config:

actor_rollout_ref.actor.megatron.grad_offload=True \
actor_rollout_ref.actor.megatron.optimizer_offload=True \
actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=1 \
actor_rollout_ref.actor.megatron.context_parallel_size=1 \
actor_rollout_ref.actor.megatron.tensor_model_parallel_size=4 \
actor_rollout_ref.actor.megatron.expert_model_parallel_size=32 \
actor_rollout_ref.actor.megatron.expert_tensor_parallel_size=1 \

jiayi37u avatar Aug 06 '25 03:08 jiayi37u

Seems the params cannot be offloaded correctlly. What's the global batch size you used btw?

A1waysBeenHere avatar Aug 06 '25 06:08 A1waysBeenHere

Could you show where the OOM happens?

vermouth1992 avatar Aug 06 '25 10:08 vermouth1992

Could you show where the OOM happens?

I met same OOM issue. But maybe it's different from this issue. The problem occurs in step 2, likely because the offloading of parameters and gradients in the update_actor in step1 is not taking effect as expected. I suggest checking the following code location (https://github.com/volcengine/verl/blob/main/verl/workers/megatron_workers.py#L619):

        if self._is_offload_param:
           log_gpu_memory_usage("Before offload actor params and grad during update_actor", logger=logger)
            offload_megatron_model_to_cpu(self.actor_module)
            log_gpu_memory_usage("After offload actor params and grad during update_actor", logger=logger)

It appears that the offloading may not be properly releasing GPU memory, leading to excessive memory consumption during training.

Additionally, please ensure that the sequence length is sufficiently long (e.g., 32k), as the issue is more likely to reproduce under such high-memory-pressure scenarios.

I tried inspecting the memory fragmentation, and it appears that a GPU memory segment is being simultaneously held by both load_megatron_model_to_gpu and load_megatron_optimizer, which prevents empty_cache from effectively freeing up memory.

jiaqiw09 avatar Aug 06 '25 12:08 jiaqiw09

Hi @jiaqiw09, I wonder if you ever fixed this issue?

MarkYangjiayi avatar Nov 19 '25 05:11 MarkYangjiayi

Hi @jiaqiw09, I wonder if you ever fixed this issue?

You can try setting export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True, which helps reduce memory fragmentation.

However, I usually run my scripts on NPUs, and a similar environment variable conflicts with vLLM’s sleep mode. Still, it’s definitely useful during training. If there’s no conflict on CUDA, I think it could help.

Alternatively, you can check whether CUDA provides an API for this. If it does, enable it through the API rather than an environment variable.

jiaqiw09 avatar Nov 19 '25 15:11 jiaqiw09

@jiaqiw09 Thanks jiaqi, I later managed to solve it by increasing TP 1->4, but I'll also try your method.

MarkYangjiayi avatar Nov 20 '25 00:11 MarkYangjiayi