GRPO with reward model. CUDA out of memory. How to fix? Thank you very much.
train_grpo.py:
import argparse
import os
from typing import Callable, Dict, List, Optional
import torch
from datasets import Dataset, load_dataset
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
AutoModelForSequenceClassification,
pipeline,
set_seed,
)
from trl import GRPOConfig, GRPOTrainer
class CombinedReward:
"""Combine multiple reward sources with weights.
Each reward function follows signature:
reward_fn(completions: List[str], prompts: List[str], **kwargs) -> List[float]
"""
def __init__(
self,
reward_fns: List[Callable[[List[str], List[str]], List[float]]],
weights: Optional[List[float]] = None,
) -> None:
if not reward_fns:
raise ValueError("reward_fns must not be empty")
self.reward_fns = reward_fns
self.weights = weights or [1.0] * len(reward_fns)
if len(self.weights) != len(self.reward_fns):
raise ValueError("weights length must match reward_fns length")
def __call__(self, completions: List[str], prompts: List[str], **kwargs) -> List[float]:
if not completions:
return []
all_scores: List[List[float]] = []
for reward_fn in self.reward_fns:
scores = reward_fn(completions, prompts, **kwargs)
if len(scores) != len(completions):
raise ValueError("All reward functions must return scores for each completion")
all_scores.append(scores)
# weighted sum
totals: List[float] = [0.0] * len(completions)
for w, scores in zip(self.weights, all_scores):
for i, s in enumerate(scores):
totals[i] += w * float(s)
return totals
def build_reward_model_fn(
reward_model_name: str,
device: Optional[str] = None,
normalize: bool = True,
) -> Callable[[List[str], List[str]], List[float]]:
"""Create a reward function using a sequence classification model.
Returns a function that outputs a scalar reward per completion.
"""
rm_tokenizer = AutoTokenizer.from_pretrained(reward_model_name, use_fast=True)
# ensure padding token exists for batched inference
if rm_tokenizer.pad_token is None:
candidate = rm_tokenizer.eos_token or rm_tokenizer.sep_token or rm_tokenizer.cls_token or rm_tokenizer.unk_token
if candidate is not None:
rm_tokenizer.pad_token = candidate
else:
rm_tokenizer.add_special_tokens({"pad_token": "[PAD]"})
rm_model = AutoModelForSequenceClassification.from_pretrained(reward_model_name, torch_dtype=torch.float16,
device_map="auto")
if getattr(rm_model.config, "pad_token_id", None) is None and rm_tokenizer.pad_token_id is not None:
rm_model.config.pad_token_id = rm_tokenizer.pad_token_id
# use a pipeline for batching and device placement
pipe_device = 0 if (device == "cuda" or (device is None and torch.cuda.is_available())) else -1
rm_pipe = pipeline(
task="text-classification",
model=rm_model,
tokenizer=rm_tokenizer,
# device=pipe_device,
truncation=True,
top_k=None,
function_to_apply="none", # use raw logits so we can map scores directly
return_all_scores=True,
)
def reward_fn(completions: List[str], prompts: List[str], **kwargs) -> List[float]:
del prompts # unused here
outputs = rm_pipe(completions, batch_size=kwargs.get("batch_size", 2))
scores: List[float] = []
for out in outputs:
# If binary classifier, use logit of positive class; otherwise sum weighted by label index
if len(out) == 1:
scores.append(float(out[0]["score"]))
else:
# prefer last class as "more positive"
scores.append(float(out[-1]["score"]))
if not normalize:
return scores
# z-norm for stability (per-batch)
t = torch.tensor(scores, dtype=torch.float32)
std = float(t.std().clamp(min=1e-6))
mean = float(t.mean())
normed = ((t - mean) / std).tolist()
return [float(x) for x in normed]
return reward_fn
def build_keyword_reward_fn(keywords: List[str], case_sensitive: bool = False, bonus: float = 1.0) -> Callable[[List[str], List[str]], List[float]]:
ks = keywords if case_sensitive else [k.lower() for k in keywords]
def reward_fn(completions: List[str], prompts: List[str], **kwargs) -> List[float]:
del prompts
scores: List[float] = []
for text in completions:
t = text if case_sensitive else text.lower()
count = sum(1 for k in ks if k in t)
scores.append(bonus * float(count))
return scores
return reward_fn
def build_length_reward_fn(target_min: int, target_max: int, scale: float = 1.0) -> Callable[[List[str], List[str]], List[float]]:
def reward_fn(completions: List[str], prompts: List[str], **kwargs) -> List[float]:
del prompts
scores: List[float] = []
for text in completions:
n = len(text.split())
if n < target_min:
# linearly penalize up to target_min
gap = float(target_min - n)
scores.append(-gap * scale)
elif n > target_max:
gap = float(n - target_max)
scores.append(-gap * scale)
else:
scores.append(scale) # inside band gets a small positive
return scores
return reward_fn
def get_default_dataset(num_examples: int = 50) -> Dataset:
prompts = [
"Summarize the benefits of unit testing.",
"Explain the difference between synchronous and asynchronous programming.",
"List three tips for writing readable code.",
"What is overfitting and how to prevent it?",
"Describe a use-case for message queues.",
]
prompts = (prompts * ((num_examples + len(prompts) - 1) // len(prompts)))[:num_examples]
return Dataset.from_dict({"prompt": prompts})
def load_prompts_dataset(dataset_name_or_path: Optional[str], split: str = "train", prompt_column: str = "prompt") -> Dataset:
if not dataset_name_or_path:
return get_default_dataset()
if os.path.exists(dataset_name_or_path):
# try json/jsonl/csv or arrow
ds = load_dataset("json", data_files=dataset_name_or_path, split=split)
else:
ds = load_dataset(dataset_name_or_path, split=split)
# ensure prompt column
if prompt_column != "prompt":
ds = ds.rename_column(prompt_column, "prompt")
return ds.select_columns(["prompt"]) if "prompt" in ds.column_names else get_default_dataset()
def main() -> None:
parser = argparse.ArgumentParser(description="GRPO training with combined rewards (TRL)")
parser.add_argument("--model_name", type=str, default="/path/to/Qwen3-32B")
parser.add_argument("--reward_model_name", type=str, default="/path/to/Qwen3-32B")
parser.add_argument("--dataset", type=str, default="/path/to/sample_prompts.jsonl", help="HF dataset name or local json/jsonl path")
parser.add_argument("--prompt_column", type=str, default="prompt")
parser.add_argument("--output_dir", type=str, default="./outputs-grpo")
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--learning_rate", type=float, default=5e-6)
parser.add_argument("--per_device_train_batch_size", type=int, default=2)
parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
parser.add_argument("--num_train_epochs", type=int, default=1)
parser.add_argument("--max_prompt_length", type=int, default=256)
parser.add_argument("--max_completion_length", type=int, default=128)
parser.add_argument("--num_generations", type=int, default=2, help="Samples per prompt for GRPO")
parser.add_argument("--bf16", action="store_true")
parser.add_argument("--use_liger_loss", action="store_true")
parser.add_argument("--keyword", type=str, nargs="*", default=["clear", "concise"])
parser.add_argument("--keyword_weight", type=float, default=0.3)
parser.add_argument("--length_min", type=int, default=30)
parser.add_argument("--length_max", type=int, default=200)
parser.add_argument("--length_weight", type=float, default=0.2)
parser.add_argument("--rm_weight", type=float, default=1.0)
parser.add_argument("--push_to_hub", action="store_true")
parser.add_argument("--max_steps", type=int, default=10, help="Override total training steps when dataset has no length")
args = parser.parse_args()
set_seed(args.seed)
# Load policy model and tokenizer
tokenizer = AutoTokenizer.from_pretrained(args.model_name, use_fast=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
args.model_name,
# torch_dtype="auto",
# device_map="auto",
)
# Build reward functions
rm_fn = build_reward_model_fn(args.reward_model_name)
kw_fn = build_keyword_reward_fn(args.keyword)
len_fn = build_length_reward_fn(args.length_min, args.length_max)
combined = CombinedReward(
reward_fns=[rm_fn, kw_fn, len_fn],
weights=[args.rm_weight, args.keyword_weight, args.length_weight],
)
# Adapter for TRL: GRPO expects either `reward_func` or `reward_funcs` depending on version; we provide one callable
def reward_adapter(completions: List[str], prompts: List[str], **kwargs) -> List[float]:
return combined(completions, prompts, **kwargs)
# Data
train_ds = load_prompts_dataset(args.dataset, split="train", prompt_column=args.prompt_column)
grpo_cfg = GRPOConfig(
output_dir=args.output_dir,
learning_rate=args.learning_rate,
per_device_train_batch_size=args.per_device_train_batch_size,
gradient_accumulation_steps=args.gradient_accumulation_steps,
num_train_epochs=args.num_train_epochs,
logging_steps=10,
save_strategy="steps",
save_steps=100,
eval_strategy="no",
remove_unused_columns=False,
fp16=True,
max_prompt_length=args.max_prompt_length,
max_completion_length=args.max_completion_length,
num_generations=args.num_generations,
report_to=["none"],
use_liger_loss=args.use_liger_loss,
push_to_hub=args.push_to_hub,
max_steps=args.max_steps
)
trainer = GRPOTrainer(
model=model,
processing_class=tokenizer,
args=grpo_cfg,
train_dataset=train_ds,
reward_funcs=reward_adapter,
)
trainer.train()
trainer.save_model(args.output_dir)
tokenizer.save_pretrained(args.output_dir)
if __name__ == "__main__":
main()
sample_prompts.jsonl:
{"prompt": "Summarize the benefits of continuous integration in software development."}
{"prompt": "Explain event-driven architecture and when to use it."}
{"prompt": "List practical strategies to reduce latency in web applications."}
{"prompt": "Compare synchronous vs asynchronous programming with concise examples."}
{"prompt": "What is the CAP theorem? Provide real-world implications."}
{"prompt": "Describe a robust logging strategy for a microservices system."}
{"prompt": "How to design idempotent APIs? Provide key considerations."}
{"prompt": "Outline steps to secure a REST API used by mobile apps."}
{"prompt": "Explain blue-green deployment and its trade-offs."}
{"prompt": "Give best practices for database schema versioning and migrations."}
{"prompt": "What is message deduplication and why does it matter in queues?"}
{"prompt": "Propose a caching strategy for a news homepage with frequent updates."}
{"prompt": "How to prevent prompt injection in LLM applications?"}
{"prompt": "Design a feature flag system for safe rollout and rollback."}
{"prompt": "Explain ML model monitoring: drift, data quality, performance."}
{"prompt": "How to test distributed systems effectively?"}
{"prompt": "Write guidelines for handling PII data in analytics pipelines."}
{"prompt": "Describe rate limiting algorithms (token bucket vs leaky bucket)."}
{"prompt": "Summarize strategies to manage secrets in cloud environments."}
{"prompt": "Explain circuit breakers and their role in resilient services."}
~/.cache/huggingface/accelerate/default_config.yaml:
compute_environment: LOCAL_MACHINE
debug: false
deepspeed_config:
gradient_accumulation_steps: 8
offload_optimizer_device: cpu
offload_param_device: cpu
zero3_init_flag: false
zero3_save_16bit_model: false
zero_stage: 3
distributed_type: DEEPSPEED
downcast_bf16: 'no'
enable_cpu_affinity: false
machine_rank: 0
main_training_function: main
mixed_precision: fp16
num_machines: 1
num_processes: 8
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
accelerate launch train_grpo.py
set both device_map="auto", python3 train_grpo.py can train, but GPU utilization is low.
Try set smaller gradient_accumulation_steps in default_config.yaml. In GRPOTrainer cases, the memory of GPU will get significant larger with a larger gradient_accumulation_steps. The low use ratio of GPU is also a problem for me... maybe this is natural for GRPO.
Thank you @LWShowTime . Setting gradient_accumulation_steps: 1 does not solve the OOM.
Thank you @LWShowTime . Setting
gradient_accumulation_steps: 1does not solve the OOM.
Maybe you can add some information of your gpu devices. For GRPOTrainer, I can train QWen3 8B on 3 GPUs of 64GB memory, taht's the limit. So I think 32B might need at least 12 GPUs.
Some other things you can try:
- check the bf mix_precision is true in your training (default is true)
- use LoRA training or other peft training methods. In my cases, QWen 8B using LoRA, four cards, each card's gpu memory usage is about 24GB. @guotong1988
Thank you @LWShowTime . Setting
gradient_accumulation_steps: 1does not solve the OOM.Maybe you can add some information of your gpu devices. For GRPOTrainer, I can train QWen3 8B on 3 GPUs of 64GB memory, taht's the limit. So I think 32B might need at least 12 GPUs.
Some other things you can try:
- check the bf mix_precision is true in your training (default is true)
- use LoRA training or other peft training methods. In my cases, QWen 8B using LoRA, four cards, each card's gpu memory usage is about 24GB. @guotong1988 QWen3 8B on 3 GPUs of 64GB memory? Is it too whimsical to attempt training a 72B model on 8 x L20 (48g) without using a reward model?
Try set smaller gradient_accumulation_steps in default_config.yaml. In GRPOTrainer cases, the memory of GPU will get significant larger with a larger gradient_accumulation_steps. The low use ratio of GPU is also a problem for me... maybe this is natural for GRPO. Hello, I've also observed a similar situation.In GRPOTrainer cases, the memory of GPU will get significant larger with a larger gradient_accumulation_steps.Could you please explain why this is? It's so different from SFT. Thank you
Try set smaller gradient_accumulation_steps in default_config.yaml. In GRPOTrainer cases, the memory of GPU will get significant larger with a larger gradient_accumulation_steps. The low use ratio of GPU is also a problem for me... maybe this is natural for GRPO. Hello, I've also observed a similar situation.In GRPOTrainer cases, the memory of GPU will get significant larger with a larger gradient_accumulation_steps.Could you please explain why this is? It's so different from SFT. Thank you
@Fyeward I talked about this with GPT-5, the explanation makes sense. When you use gradient accumulation, you need to cache some extra tensors like KV cache. In this cases, this kind of tensors might be giant. Sorry but I can't give more detailed calculation becuase I'm under learning.
Try set smaller gradient_accumulation_steps in default_config.yaml. In GRPOTrainer cases, the memory of GPU will get significant larger with a larger gradient_accumulation_steps. The low use ratio of GPU is also a problem for me... maybe this is natural for GRPO. Hello, I've also observed a similar situation.In GRPOTrainer cases, the memory of GPU will get significant larger with a larger gradient_accumulation_steps.Could you please explain why this is? It's so different from SFT. Thank you
@Fyeward I talked about this with GPT-5, the explanation makes sense. When you use gradient accumulation, you need to cache some extra tensors like KV cache. In this cases, this kind of tensors might be giant. Sorry but I can't give more detailed calculation becuase I'm under learning. Anyway, thank you for your reply. I found that gradient_accumulation_steps works in grpo as the steps_per_generations parameter, which I think is the key to understanding this situation. Just like the implementation of #3283
Try set smaller gradient_accumulation_steps in default_config.yaml. In GRPOTrainer cases, the memory of GPU will get significant larger with a larger gradient_accumulation_steps. The low use ratio of GPU is also a problem for me... maybe this is natural for GRPO. Hello, I've also observed a similar situation.In GRPOTrainer cases, the memory of GPU will get significant larger with a larger gradient_accumulation_steps.Could you please explain why this is? It's so different from SFT. Thank you
@Fyeward I talked about this with GPT-5, the explanation makes sense. When you use gradient accumulation, you need to cache some extra tensors like KV cache. In this cases, this kind of tensors might be giant. Sorry but I can't give more detailed calculation becuase I'm under learning. Anyway, thank you for your reply. I found that gradient_accumulation_steps works in grpo as the steps_per_generations parameter, which I think is the key to understanding this situation. Just like the implementation of #3283
@Fyeward Good, this is truly an important improvement. But it is weird for my cases because my trl version is 0.23.0 which is already updated with this amelioration.
Try set smaller gradient_accumulation_steps in default_config.yaml. In GRPOTrainer cases, the memory of GPU will get significant larger with a larger gradient_accumulation_steps. The low use ratio of GPU is also a problem for me... maybe this is natural for GRPO. Hello, I've also observed a similar situation.In GRPOTrainer cases, the memory of GPU will get significant larger with a larger gradient_accumulation_steps.Could you please explain why this is? It's so different from SFT. Thank you
@Fyeward I talked about this with GPT-5, the explanation makes sense. When you use gradient accumulation, you need to cache some extra tensors like KV cache. In this cases, this kind of tensors might be giant. Sorry but I can't give more detailed calculation becuase I'm under learning. Anyway, thank you for your reply. I found that gradient_accumulation_steps works in grpo as the steps_per_generations parameter, which I think is the key to understanding this situation. Just like the implementation of #3283
@Fyeward Good, this is truly an important improvement. But it is weird for my cases because my trl version is 0.23.0 which is already updated with this amelioration. I don't think the purpose of this improvement is to save memory. Steps_per_generations doesn't control cache release, so your situation might actually be normal
Try set smaller gradient_accumulation_steps in default_config.yaml. In GRPOTrainer cases, the memory of GPU will get significant larger with a larger gradient_accumulation_steps. The low use ratio of GPU is also a problem for me... maybe this is natural for GRPO. Hello, I've also observed a similar situation.In GRPOTrainer cases, the memory of GPU will get significant larger with a larger gradient_accumulation_steps.Could you please explain why this is? It's so different from SFT. Thank you
@Fyeward I talked about this with GPT-5, the explanation makes sense. When you use gradient accumulation, you need to cache some extra tensors like KV cache. In this cases, this kind of tensors might be giant. Sorry but I can't give more detailed calculation becuase I'm under learning. Anyway, thank you for your reply. I found that gradient_accumulation_steps works in grpo as the steps_per_generations parameter, which I think is the key to understanding this situation. Just like the implementation of #3283
@Fyeward Good, this is truly an important improvement. But it is weird for my cases because my trl version is 0.23.0 which is already updated with this amelioration. I don't think the purpose of this improvement is to save memory. Steps_per_generations doesn't control cache release, so your situation might actually be normal
@Fyeward You're right. generate_every = self.args.steps_per_generation * self.num_iterations And self.args.steps_per_generation is self.gradient_accumulation_steps. It is normal to meet OOM when you add the accumulation_steps becaue your actual llm generate batchsize will multiply.