verifiers
verifiers copied to clipboard
Add design patterns / architecture overview to docs
This is a more architecturally focused addition to the docs. Did it mainly for my own understanding but think it might be useful to others!
You could add a section about the flow:
GRPO Repeated Sampling Flow Implementation
This document explains how this GRPO implementation generates multiple different completions for each prompt when num_generations > 1.
Overview
When num_generations > 1, the same prompt gets sent multiple times to the environment, and stochastic sampling creates different completions. The system uses a RepeatSampler to create repeated indices, then shuffles the results before training.
Complete Flow
1. RepeatSampler Creates the Sampling Pattern
The RepeatSampler class is configured in _get_train_sampler() with:
return RepeatSampler(
data_source=self.train_dataset,
mini_repeat_count=self.num_generations, # Each prompt index repeated num_generations times
batch_size=self.generation_batch_size // self.num_generations,
repeat_count=self.num_iterations * self.gradient_accumulation_steps,
shuffle=self.shuffle_dataset,
seed=self.args.seed,
)
Key insight: mini_repeat_count=self.num_generations means each prompt gets repeated num_generations times consecutively. The shuffle=self.shuffle_dataset parameter shuffles the dataset indices before repeating them.
2. Sampling Pattern Example
From the helpful comment in the code, here's what happens with num_generations=2:
| Accum step 0 |
| GPU 0 | GPU 1 |
global_step step <-───> num_generations=2
<-───────> per_device_train_batch_size=3
grad_accum ▲ ▲ 0 0 0 0 1 1 2 2 <- Generate for prompts 0,1,2 (each repeated twice)
=2 ▼ | 0 1 3 3 4 4 5 5 <- Generate for prompts 3,4,5 (each repeated twice)
|
| 1 2 6 6 7 7 8 8 <- Generate for prompts 6,7,8 (each repeated twice)
grad_accum=4▼ 1 3 9 9 10 10 11 11 <- Generate for prompts 9,10,11 (each repeated twice)
3. DataLoader Returns Same Prompts Multiple Times
When the dataloader uses the repeated indices, it fetches the same prompt multiple times:
# For num_generations=3, the batch might look like:
[
{'prompt': "What is 2+2?", 'answer': "4"},
{'prompt': "What is 2+2?", 'answer': "4"}, # Same prompt again
{'prompt': "What is 2+2?", 'answer': "4"}, # Same prompt again
{'prompt': "Solve x+1=5", 'answer': "x=4"},
{'prompt': "Solve x+1=5", 'answer': "x=4"}, # Same prompt again
{'prompt': "Solve x+1=5", 'answer': "x=4"} # Same prompt again
]
4. Submission to AsyncBatchGenerator
The all_prompts list (which contains repeated prompts) gets submitted to the AsyncBatchGenerator:
request = BatchRequest(
batch_id=batch_id,
env_inputs={'prompt': all_prompts, 'answer': all_answers, 'task': all_tasks, 'info': all_infos},
# ... other parameters
)
self.async_generator.submit_batch(request)
5. Environment Receives Multiple Identical Prompts
The environment receives this list of prompts where the same prompt appears multiple times:
# env_inputs passed to AsyncBatchGenerator:
{
'prompt': ["What is 2+2?", "What is 2+2?", "What is 2+2?", "Solve x+1=5", "Solve x+1=5", "Solve x+1=5"],
'answer': ["4", "4", "4", "x=4", "x=4", "x=4"]
}
6. Environment Processes Each Prompt Individually
The key insight is in the run_rollouts() method. It receives the list of prompts (including duplicates) and processes each one individually:
rollout_tasks = [
self._run_single(semaphore, client, model, prompt, answer, task, info, sampling_args, **kwargs)
for prompt, answer, task, info in zip(prompts, answers, tasks, infos) # Each prompt processed separately
]
7. Independent API Calls with Stochastic Sampling
For each prompt in the list (including duplicates), the environment calls its rollout() method, which eventually calls get_model_response():
response = client.chat.completions.create(
model=model,
messages=prompt,
**sanitized_args # Contains temperature, top_p, etc.
)
The sampling_args from the GRPO trainer include parameters like:
temperature > 0(enables randomness)top_p(nucleus sampling)top_k(top-k sampling)- etc.
8. Reward Computation and Advantage Calculation
In _compute_advantages(), the rewards are processed in groups:
def _compute_advantages(self, rewards: torch.Tensor) -> torch.Tensor:
# Reshape rewards to group by prompt: (num_prompts, num_generations)
mean_grouped = rewards.view(-1, self.num_generations).mean(dim=1)
std_grouped = rewards.view(-1, self.num_generations).std(dim=1)
# Expand back to original shape for normalization
mean_grouped = mean_grouped.repeat_interleave(self.num_generations, dim=0)
std_grouped = std_grouped.repeat_interleave(self.num_generations, dim=0)
# Compute advantages (rewards - baseline)
advantages = rewards - mean_grouped
if self.scale_rewards:
advantages = advantages / (std_grouped + 1e-4)
return advantages
9. Shuffling Before Training
After collecting all completions and computing advantages, the data is shuffled before being split for gradient accumulation:
# Concatenate all data for shuffling
full_batch = {
"prompt_ids": prompt_ids,
"prompt_mask": prompt_mask,
"completion_ids": completion_ids,
"completion_mask": completion_mask,
"old_per_token_logps": None,
"advantages": advantages,
}
# Shuffle and split for gradient accumulation
full_batch = shuffle_tensor_dict(full_batch)
self._buffered_inputs = split_tensor_dict(full_batch, self.gradient_accumulation_steps)
This shuffling ensures that completions from the same prompt are mixed across different gradient accumulation steps, improving training stability.
Complete Example Flow
Let's trace a concrete example with num_generations=3:
Step 1: RepeatSampler Creates Repeated Indices
# Original dataset: ["What is 2+2?", "Solve x+1=5"]
# RepeatSampler yields: [0, 0, 0, 1, 1, 1]
Step 2: DataLoader Returns Repeated Prompts
# Batch from dataloader:
[
{'prompt': "What is 2+2?", 'answer': "4"},
{'prompt': "What is 2+2?", 'answer': "4"}, # Same prompt
{'prompt': "What is 2+2?", 'answer': "4"}, # Same prompt
{'prompt': "Solve x+1=5", 'answer': "x=4"},
{'prompt': "Solve x+1=5", 'answer': "x=4"}, # Same prompt
{'prompt': "Solve x+1=5", 'answer': "x=4"} # Same prompt
]
Step 3: Environment Receives Repeated Prompts
# env_inputs passed to AsyncBatchGenerator:
{
'prompt': ["What is 2+2?", "What is 2+2?", "What is 2+2?", "Solve x+1=5", "Solve x+1=5", "Solve x+1=5"],
'answer': ["4", "4", "4", "x=4", "x=4", "x=4"]
}
Step 4: Independent API Calls with Stochastic Sampling
The environment calls rollout() → get_model_response() for each prompt separately:
# Call 1: "What is 2+2?" → client.chat.completions.create(...) → "2+2=4"
# Call 2: "What is 2+2?" → client.chat.completions.create(...) → "Let me calculate: 2+2 equals 4"
# Call 3: "What is 2+2?" → client.chat.completions.create(...) → "The answer is 4"
# Call 4: "Solve x+1=5" → client.chat.completions.create(...) → "x+1=5, so x=4"
# Call 5: "Solve x+1=5" → client.chat.completions.create(...) → "Subtract 1: x=5-1=4"
# Call 6: "Solve x+1=5" → client.chat.completions.create(...) → "x=4"
Step 5: Advantages Computed Across Groups
# Rewards: [0.9, 0.8, 0.7, 0.6, 0.9, 0.5]
# Grouped by prompt: [[0.9, 0.8, 0.7], [0.6, 0.9, 0.5]]
# Mean per group: [0.8, 0.67]
# Advantages: [0.1, 0.0, -0.1, -0.07, 0.23, -0.17]
Step 6: Shuffling Before Training
# Before shuffling (grouped by prompt):
# prompt_ids: [prompt0_gen0, prompt0_gen1, prompt0_gen2, prompt1_gen0, prompt1_gen1, prompt1_gen2]
# advantages: [0.1, 0.0, -0.1, -0.07, 0.23, -0.17]
# After shuffle_tensor_dict():
# prompt_ids: [prompt1_gen1, prompt0_gen0, prompt1_gen2, prompt0_gen2, prompt1_gen0, prompt0_gen1]
# advantages: [0.23, 0.1, -0.17, -0.1, -0.07, 0.0]
# Split into gradient_accumulation_steps=2:
# Step 0: [prompt1_gen1, prompt0_gen0, prompt1_gen2]
# Step 1: [prompt0_gen2, prompt1_gen0, prompt0_gen1]
Implementation Details
-
No Special API Parameter: The system doesn't use
n > 1in the API call. Instead, it sends the same prompt multiple times as separate requests. -
Stochastic Sampling Required: Without
temperature > 0or other stochastic parameters, all repeated prompts would generate identical completions. -
Independent API Calls: Each prompt (including duplicates) gets processed as a completely separate API call.
-
Two-Level Shuffling:
RepeatSamplershuffles dataset indices before repeating themshuffle_tensor_dict()shuffles the final batch before splitting for gradient accumulation
-
Async Processing: The
AsyncBatchGeneratorallows all API calls to happen concurrently despite being independent requests.
You could add a section about the flow:
GRPO Repeated Sampling Flow Implementation
This document explains how this GRPO implementation generates multiple different completions for each prompt when
num_generations > 1.Overview
When
num_generations > 1, the same prompt gets sent multiple times to the environment, and stochastic sampling creates different completions. The system uses aRepeatSamplerto create repeated indices, then shuffles the results before training.Complete Flow
1. RepeatSampler Creates the Sampling Pattern
The
RepeatSamplerclass is configured in_get_train_sampler()with:return RepeatSampler( data_source=self.train_dataset, mini_repeat_count=self.num_generations, # Each prompt index repeated num_generations times batch_size=self.generation_batch_size // self.num_generations, repeat_count=self.num_iterations * self.gradient_accumulation_steps, shuffle=self.shuffle_dataset, seed=self.args.seed, )Key insight:
mini_repeat_count=self.num_generationsmeans each prompt gets repeatednum_generationstimes consecutively. Theshuffle=self.shuffle_datasetparameter shuffles the dataset indices before repeating them.2. Sampling Pattern Example
From the helpful comment in the code, here's what happens with
num_generations=2:| Accum step 0 | | GPU 0 | GPU 1 | global_step step <-───> num_generations=2 <-───────> per_device_train_batch_size=3 grad_accum ▲ ▲ 0 0 0 0 1 1 2 2 <- Generate for prompts 0,1,2 (each repeated twice) =2 ▼ | 0 1 3 3 4 4 5 5 <- Generate for prompts 3,4,5 (each repeated twice) | | 1 2 6 6 7 7 8 8 <- Generate for prompts 6,7,8 (each repeated twice) grad_accum=4▼ 1 3 9 9 10 10 11 11 <- Generate for prompts 9,10,11 (each repeated twice)3. DataLoader Returns Same Prompts Multiple Times
When the dataloader uses the repeated indices, it fetches the same prompt multiple times:
# For num_generations=3, the batch might look like: [ {'prompt': "What is 2+2?", 'answer': "4"}, {'prompt': "What is 2+2?", 'answer': "4"}, # Same prompt again {'prompt': "What is 2+2?", 'answer': "4"}, # Same prompt again {'prompt': "Solve x+1=5", 'answer': "x=4"}, {'prompt': "Solve x+1=5", 'answer': "x=4"}, # Same prompt again {'prompt': "Solve x+1=5", 'answer': "x=4"} # Same prompt again ]4. Submission to AsyncBatchGenerator
The
all_promptslist (which contains repeated prompts) gets submitted to the AsyncBatchGenerator:request = BatchRequest( batch_id=batch_id, env_inputs={'prompt': all_prompts, 'answer': all_answers, 'task': all_tasks, 'info': all_infos}, # ... other parameters ) self.async_generator.submit_batch(request)5. Environment Receives Multiple Identical Prompts
The environment receives this list of prompts where the same prompt appears multiple times:
# env_inputs passed to AsyncBatchGenerator: { 'prompt': ["What is 2+2?", "What is 2+2?", "What is 2+2?", "Solve x+1=5", "Solve x+1=5", "Solve x+1=5"], 'answer': ["4", "4", "4", "x=4", "x=4", "x=4"] }6. Environment Processes Each Prompt Individually
The key insight is in the
run_rollouts()method. It receives the list of prompts (including duplicates) and processes each one individually:rollout_tasks = [ self._run_single(semaphore, client, model, prompt, answer, task, info, sampling_args, **kwargs) for prompt, answer, task, info in zip(prompts, answers, tasks, infos) # Each prompt processed separately ]7. Independent API Calls with Stochastic Sampling
For each prompt in the list (including duplicates), the environment calls its
rollout()method, which eventually callsget_model_response():response = client.chat.completions.create( model=model, messages=prompt, **sanitized_args # Contains temperature, top_p, etc. )The
sampling_argsfrom the GRPO trainer include parameters like:
temperature > 0(enables randomness)top_p(nucleus sampling)top_k(top-k sampling)- etc.
8. Reward Computation and Advantage Calculation
In
_compute_advantages(), the rewards are processed in groups:def _compute_advantages(self, rewards: torch.Tensor) -> torch.Tensor: # Reshape rewards to group by prompt: (num_prompts, num_generations) mean_grouped = rewards.view(-1, self.num_generations).mean(dim=1) std_grouped = rewards.view(-1, self.num_generations).std(dim=1) # Expand back to original shape for normalization mean_grouped = mean_grouped.repeat_interleave(self.num_generations, dim=0) std_grouped = std_grouped.repeat_interleave(self.num_generations, dim=0) # Compute advantages (rewards - baseline) advantages = rewards - mean_grouped if self.scale_rewards: advantages = advantages / (std_grouped + 1e-4) return advantages9. Shuffling Before Training
After collecting all completions and computing advantages, the data is shuffled before being split for gradient accumulation:
# Concatenate all data for shuffling full_batch = { "prompt_ids": prompt_ids, "prompt_mask": prompt_mask, "completion_ids": completion_ids, "completion_mask": completion_mask, "old_per_token_logps": None, "advantages": advantages, } # Shuffle and split for gradient accumulation full_batch = shuffle_tensor_dict(full_batch) self._buffered_inputs = split_tensor_dict(full_batch, self.gradient_accumulation_steps)This shuffling ensures that completions from the same prompt are mixed across different gradient accumulation steps, improving training stability.
Complete Example Flow
Let's trace a concrete example with
num_generations=3:Step 1: RepeatSampler Creates Repeated Indices
# Original dataset: ["What is 2+2?", "Solve x+1=5"] # RepeatSampler yields: [0, 0, 0, 1, 1, 1]Step 2: DataLoader Returns Repeated Prompts
# Batch from dataloader: [ {'prompt': "What is 2+2?", 'answer': "4"}, {'prompt': "What is 2+2?", 'answer': "4"}, # Same prompt {'prompt': "What is 2+2?", 'answer': "4"}, # Same prompt {'prompt': "Solve x+1=5", 'answer': "x=4"}, {'prompt': "Solve x+1=5", 'answer': "x=4"}, # Same prompt {'prompt': "Solve x+1=5", 'answer': "x=4"} # Same prompt ]Step 3: Environment Receives Repeated Prompts
# env_inputs passed to AsyncBatchGenerator: { 'prompt': ["What is 2+2?", "What is 2+2?", "What is 2+2?", "Solve x+1=5", "Solve x+1=5", "Solve x+1=5"], 'answer': ["4", "4", "4", "x=4", "x=4", "x=4"] }Step 4: Independent API Calls with Stochastic Sampling
The environment calls
rollout()→get_model_response()for each prompt separately:# Call 1: "What is 2+2?" → client.chat.completions.create(...) → "2+2=4" # Call 2: "What is 2+2?" → client.chat.completions.create(...) → "Let me calculate: 2+2 equals 4" # Call 3: "What is 2+2?" → client.chat.completions.create(...) → "The answer is 4" # Call 4: "Solve x+1=5" → client.chat.completions.create(...) → "x+1=5, so x=4" # Call 5: "Solve x+1=5" → client.chat.completions.create(...) → "Subtract 1: x=5-1=4" # Call 6: "Solve x+1=5" → client.chat.completions.create(...) → "x=4"Step 5: Advantages Computed Across Groups
# Rewards: [0.9, 0.8, 0.7, 0.6, 0.9, 0.5] # Grouped by prompt: [[0.9, 0.8, 0.7], [0.6, 0.9, 0.5]] # Mean per group: [0.8, 0.67] # Advantages: [0.1, 0.0, -0.1, -0.07, 0.23, -0.17]Step 6: Shuffling Before Training
# Before shuffling (grouped by prompt): # prompt_ids: [prompt0_gen0, prompt0_gen1, prompt0_gen2, prompt1_gen0, prompt1_gen1, prompt1_gen2] # advantages: [0.1, 0.0, -0.1, -0.07, 0.23, -0.17] # After shuffle_tensor_dict(): # prompt_ids: [prompt1_gen1, prompt0_gen0, prompt1_gen2, prompt0_gen2, prompt1_gen0, prompt0_gen1] # advantages: [0.23, 0.1, -0.17, -0.1, -0.07, 0.0] # Split into gradient_accumulation_steps=2: # Step 0: [prompt1_gen1, prompt0_gen0, prompt1_gen2] # Step 1: [prompt0_gen2, prompt1_gen0, prompt0_gen1]Implementation Details
No Special API Parameter: The system doesn't use
n > 1in the API call. Instead, it sends the same prompt multiple times as separate requests.Stochastic Sampling Required: Without
temperature > 0or other stochastic parameters, all repeated prompts would generate identical completions.Independent API Calls: Each prompt (including duplicates) gets processed as a completely separate API call.
Two-Level Shuffling:
RepeatSamplershuffles dataset indices before repeating themshuffle_tensor_dict()shuffles the final batch before splitting for gradient accumulationAsync Processing: The
AsyncBatchGeneratorallows all API calls to happen concurrently despite being independent requests.
Good contribution. I've incorporated it with some light edits.