verl icon indicating copy to clipboard operation
verl copied to clipboard

[BREAKING][Async SGLang Rollout] Efficient and model-agnostic multi-turn messages tokenization and masking

Open jybsuper opened this issue 6 months ago • 3 comments

Checklist Before Starting

  • [x] Search for similar PR(s).

What does this PR do?

Implement efficient, model-agnostic multi-turn message tokenization and masking based solely on the chat template

Specific Changes

Challenges

  1. Template-specific configs
    Current rollout requires hand-crafting tokenization and loss-masking rules for every chat template, leading to verbose, error-prone if/else logic whenever a new template is added.
  2. Redundant tokenization
    On each generation turn, we re-tokenize the entire conversation history, wasting time on duplicated work.

Solution

  1. Delta-based masking
    • Rather than custom code per template, derive the assistant’s loss-mask span by comparing two tokenizations of the same history:
      prev = tokenizer.apply_chat_template(messages[:i], add_generation_prompt=True, tokenize=True)
      curr = tokenizer.apply_chat_template(messages[:i+1], add_generation_prompt=False, tokenize=True)
      # Mask only the new assistant tokens
      loss_mask[prev.shape[-1] : curr.shape[-1]] = 1
      
    • This works for any chat template—no more per-template branches.
  2. Zero-redundancy tokenization
    • Benchmarking shows apply_chat_template(tokenize=False) is far faster than with tokenize=True.
    • We use two non-tokenized chat template rendering passes (past messages vs past messages +currect message) to extract just the new message text, then tokenize only that delta.
  3. Sanity-check mode
    • To guard against edge-case template mismatches, we compare fast delta tokenization against a full re-tokenization of the history and assert they match in sanity-check mode.
    • This makes it easy to validate new templates for compatibility with the fast tokenization mode.

Usage Example

multi_turn:
  enable: True
  tokenization_mode: fast / full / sanity_check

Test

Correctness validation

  1. Validated with sanity_check mode on my own multiturn RL project(QwQ-32B + GRPO)
  2. [WIP] Validating with Verl's multiturn GSM8K example

Speed Benchmark

Simulated multi-turn rollout using snapshot data from a prior RL experiment:

Samples count: 757
Avg messages per sample: 54.28
Avg tokens per sample: 18942.74

Testing code:

n = 2
repeat = 5
elapsed = 0
for round in range(repeat):
    print(f"Start round {round}")
    for data_idx, row in tqdm(enumerate(data), total=len(data)):
        for rollout_offset in range(n):
            messages = row["messages"]
            tool_resps = []

            start = time.time()
            req = AsyncRolloutRequest(batch_data_id=data_idx, rollout_offset=rollout_offset, request_id=str(uuid4()), state=AsyncRolloutRequestStateEnum.PENDING, messages=messages[:2], tool_schemas=tool_schema, tools_kwargs={}, input_ids=None, response_ids=[], attention_mask=None, response_attention_mask=[], response_position_ids=[], response_loss_mask=[], reward_scores={}, max_prompt_len=30720, max_response_len=10240, max_model_len=40960, tokenization_mode="fast", tokenizer=tokenizer)

            for message in messages[2:]:
                if message["role"] != "tool":
                    req.add_tool_response_messages(tokenizer, [resp["content"] for resp in tool_resps])
                    tool_resps = []

                if message['role'] == "assistant":
                    req.get_prompt_ids(tokenizer)
                    req.add_assistant_message(tokenizer, message["content"], message["content_ids"], message.get("tool_calls"))
                elif message["role"] == "tool":
                    tool_resps.append(message)
            elapsed += (time.time() - start)
elapsed /= len(data) * repeat * n
print(f"Fast Tokenization avg time: {elapsed:.4f}s")

Result

Start round 0
100%|█████████████████████| 757/757 [03:34<00:00,  3.54it/s]
Start round 1
100%|█████████████████████| 757/757 [03:33<00:00,  3.54it/s]
Start round 2
100%|█████████████████████| 757/757 [03:34<00:00,  3.54it/s]
Start round 3
100%|█████████████████████| 757/757 [03:33<00:00,  3.54it/s]
Start round 4
100%|█████████████████████| 757/757 [03:33<00:00,  3.54it/s]
Fast Tokenization avg time: 0.1411s

Start round 0
100%|█████████████████████| 757/757 [11:37<00:00,  1.08it/s]
Start round 1
100%|█████████████████████| 757/757 [11:37<00:00,  1.09it/s]
Start round 2
100%|█████████████████████| 757/757 [11:38<00:00,  1.08it/s]
Start round 3
100%|█████████████████████| 757/757 [11:38<00:00,  1.08it/s]
Start round 4
100%|█████████████████████| 757/757 [11:38<00:00,  1.08it/s]
Current Tokenization avg time: 0.4610s

Conversations with more turns or tokens will get higher acceleration.

Additional Info.

  • Inference: SGLang

Checklist Before Submitting

  • [x] Read the Contribute Guide.
  • [x] Apply pre-commit checks.
  • [x] Add [BREAKING] to the PR title if it breaks any API.
  • [x] Update the documentation about your changes in the docs.
  • [x] Add CI test(s) if necessary.

jybsuper avatar May 24 '25 08:05 jybsuper

I asked serveral friends to validate this, stay tuned @jybsuper After validation, we should add documentation in verl @SwordFaith , then we can merge

zhaochenyang20 avatar May 26 '25 01:05 zhaochenyang20

Great job, big thanks to Yanbin! Excited to hear back from you about the review comments !

SwordFaith avatar May 26 '25 08:05 SwordFaith

  1. sanity check default on, and add skip_sanity_check is an option
  2. data & rollout config have chat_template option, default null means use origin chat template, otherwise should be a jinja file , pre-placed jinjas (QwQ & Qwen3)
  3. User specific chat template e.g. search r1 / retool original version , fast concat logic support, our special finetune tag.
  4. White list for known models, store chat template hash due to mismatch in train/inference / can't pass sanity check. Replace chat template by default.

SwordFaith avatar May 28 '25 17:05 SwordFaith

LGTM

SwordFaith avatar Jun 06 '25 06:06 SwordFaith

Hi @vermouth1992 this PR broke async vLLM Server. Considering this was an update for SGLang, I am quite surprised.

CC @wuxibin89

Here is the problem:

  • omegaconf.errors.ConfigAttributeError: Key 'format' is not in struct due to removal of it in this PR.
  • But vLLM Async Server depends on this parameter. And when you try to add actor_rollout_ref.rollout.multi_turn.format=hermes, it doesn't work

https://github.com/volcengine/verl/blob/49b08e95090fc1924161c640d190637337a747b9/verl/workers/rollout/vllm_rollout/vllm_async_server.py#L206

casper-hansen avatar Jun 12 '25 13:06 casper-hansen

Hi @vermouth1992 this PR broke async vLLM Server. Considering this was an update for SGLang, I am quite surprised.

CC @wuxibin89

Here is the problem:

  • omegaconf.errors.ConfigAttributeError: Key 'format' is not in struct due to removal of it in this PR.
  • But vLLM Async Server depends on this parameter. And when you try to add actor_rollout_ref.rollout.multi_turn.format=hermes, it doesn't work

https://github.com/volcengine/verl/blob/49b08e95090fc1924161c640d190637337a747b9/verl/workers/rollout/vllm_rollout/vllm_async_server.py#L206

Hi @casper-hansen, thanks for reporting this issue!

The format field was removed 2 weeks ago when this PR was created, at which point the vLLM async server wasn't utilizing this field. All vLLM tests passed during the PR review process, which is why the removal proceeded.

That said, I'm quite surprised that vLLM has been repurposing the chat template format field for tool parser functionality. As you may have noticed, the default value chatml is actually a chat template format name from SGLang's multiturn rollout - it's not even a valid tool parser identifier in vLLM's server. This creative reuse of fields intended for different purposes does create unnecessary confusion. Moving forward, it would be beneficial if vLLM could add a dedicated field for the tool parser rather than overloading existing fields with unrelated functionality. This would help prevent such integration issues.

Given the collaborative nature of VeRL and the number of moving parts, tracking cross-PR dependencies is indeed challenging, especially when PRs have long review cycles. This is exactly where proper tests become critical. I'd suggest updating the test here to include OmegaConf.set_struct(config, True) to catch any unexpected removal of config fields.

Let me know if you'd like help fixing this—happy to contribute a patch to resolve the issue properly.

jybsuper avatar Jun 12 '25 18:06 jybsuper

@jybsuper I would appreciate a patch for this so that the ChatCompletionScheduler can use similar config arguments. I do have a preference for the scheduler because of how easy it is for me to implement custom multi-turn training workflows.

casper-hansen avatar Jun 13 '25 12:06 casper-hansen

@jybsuper I would appreciate a patch for this so that the ChatCompletionScheduler can use similar config arguments. I do have a preference for the scheduler because of how easy it is for me to implement custom multi-turn training workflows.

Sounds good. I will create a PR for fix soon.

jybsuper avatar Jun 13 '25 22:06 jybsuper