trl icon indicating copy to clipboard operation
trl copied to clipboard

GRPO with reward model. CUDA out of memory. How to fix? Thank you very much.

Open guotong1988 opened this issue 1 month ago • 11 comments

train_grpo.py:

import argparse
import os
from typing import Callable, Dict, List, Optional

import torch
from datasets import Dataset, load_dataset
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    AutoModelForSequenceClassification,
    pipeline,
    set_seed,
)
from trl import GRPOConfig, GRPOTrainer


class CombinedReward:
    """Combine multiple reward sources with weights.

    Each reward function follows signature:
        reward_fn(completions: List[str], prompts: List[str], **kwargs) -> List[float]
    """

    def __init__(
        self,
        reward_fns: List[Callable[[List[str], List[str]], List[float]]],
        weights: Optional[List[float]] = None,
    ) -> None:
        if not reward_fns:
            raise ValueError("reward_fns must not be empty")
        self.reward_fns = reward_fns
        self.weights = weights or [1.0] * len(reward_fns)
        if len(self.weights) != len(self.reward_fns):
            raise ValueError("weights length must match reward_fns length")

    def __call__(self, completions: List[str], prompts: List[str], **kwargs) -> List[float]:
        if not completions:
            return []
        all_scores: List[List[float]] = []
        for reward_fn in self.reward_fns:
            scores = reward_fn(completions, prompts, **kwargs)
            if len(scores) != len(completions):
                raise ValueError("All reward functions must return scores for each completion")
            all_scores.append(scores)
        # weighted sum
        totals: List[float] = [0.0] * len(completions)
        for w, scores in zip(self.weights, all_scores):
            for i, s in enumerate(scores):
                totals[i] += w * float(s)
        return totals


def build_reward_model_fn(
    reward_model_name: str,
    device: Optional[str] = None,
    normalize: bool = True,
) -> Callable[[List[str], List[str]], List[float]]:
    """Create a reward function using a sequence classification model.

    Returns a function that outputs a scalar reward per completion.
    """
    rm_tokenizer = AutoTokenizer.from_pretrained(reward_model_name, use_fast=True)
 
    # ensure padding token exists for batched inference
    if rm_tokenizer.pad_token is None:
        candidate = rm_tokenizer.eos_token or rm_tokenizer.sep_token or rm_tokenizer.cls_token or rm_tokenizer.unk_token
        if candidate is not None:
            rm_tokenizer.pad_token = candidate
        else:
            rm_tokenizer.add_special_tokens({"pad_token": "[PAD]"})
    rm_model = AutoModelForSequenceClassification.from_pretrained(reward_model_name, torch_dtype=torch.float16, 
                                                                  device_map="auto")
    if getattr(rm_model.config, "pad_token_id", None) is None and rm_tokenizer.pad_token_id is not None:
        rm_model.config.pad_token_id = rm_tokenizer.pad_token_id


    # use a pipeline for batching and device placement
    pipe_device = 0 if (device == "cuda" or (device is None and torch.cuda.is_available())) else -1
    rm_pipe = pipeline(
        task="text-classification",
        model=rm_model,
        tokenizer=rm_tokenizer,
      #  device=pipe_device,
        truncation=True,
        top_k=None,
        function_to_apply="none",  # use raw logits so we can map scores directly
        return_all_scores=True,
    )

    def reward_fn(completions: List[str], prompts: List[str], **kwargs) -> List[float]:
        del prompts  # unused here
        outputs = rm_pipe(completions, batch_size=kwargs.get("batch_size", 2))
        scores: List[float] = []
        for out in outputs:
            # If binary classifier, use logit of positive class; otherwise sum weighted by label index
            if len(out) == 1:
                scores.append(float(out[0]["score"]))
            else:
                # prefer last class as "more positive"
                scores.append(float(out[-1]["score"]))
        if not normalize:
            return scores
        # z-norm for stability (per-batch)
        t = torch.tensor(scores, dtype=torch.float32)
        std = float(t.std().clamp(min=1e-6))
        mean = float(t.mean())
        normed = ((t - mean) / std).tolist()
        return [float(x) for x in normed]

    return reward_fn


def build_keyword_reward_fn(keywords: List[str], case_sensitive: bool = False, bonus: float = 1.0) -> Callable[[List[str], List[str]], List[float]]:
    ks = keywords if case_sensitive else [k.lower() for k in keywords]

    def reward_fn(completions: List[str], prompts: List[str], **kwargs) -> List[float]:
        del prompts
        scores: List[float] = []
        for text in completions:
            t = text if case_sensitive else text.lower()
            count = sum(1 for k in ks if k in t)
            scores.append(bonus * float(count))
        return scores

    return reward_fn


def build_length_reward_fn(target_min: int, target_max: int, scale: float = 1.0) -> Callable[[List[str], List[str]], List[float]]:
    def reward_fn(completions: List[str], prompts: List[str], **kwargs) -> List[float]:
        del prompts
        scores: List[float] = []
        for text in completions:
            n = len(text.split())
            if n < target_min:
                # linearly penalize up to target_min
                gap = float(target_min - n)
                scores.append(-gap * scale)
            elif n > target_max:
                gap = float(n - target_max)
                scores.append(-gap * scale)
            else:
                scores.append(scale)  # inside band gets a small positive
        return scores

    return reward_fn


def get_default_dataset(num_examples: int = 50) -> Dataset:
    prompts = [
        "Summarize the benefits of unit testing.",
        "Explain the difference between synchronous and asynchronous programming.",
        "List three tips for writing readable code.",
        "What is overfitting and how to prevent it?",
        "Describe a use-case for message queues.",
    ]
    prompts = (prompts * ((num_examples + len(prompts) - 1) // len(prompts)))[:num_examples]
    return Dataset.from_dict({"prompt": prompts})


def load_prompts_dataset(dataset_name_or_path: Optional[str], split: str = "train", prompt_column: str = "prompt") -> Dataset:
    if not dataset_name_or_path:
        return get_default_dataset()
    if os.path.exists(dataset_name_or_path):
        # try json/jsonl/csv or arrow
        ds = load_dataset("json", data_files=dataset_name_or_path, split=split)
    else:
        ds = load_dataset(dataset_name_or_path, split=split)
    # ensure prompt column
    if prompt_column != "prompt":
        ds = ds.rename_column(prompt_column, "prompt")
    return ds.select_columns(["prompt"]) if "prompt" in ds.column_names else get_default_dataset()


def main() -> None:
    parser = argparse.ArgumentParser(description="GRPO training with combined rewards (TRL)")
    parser.add_argument("--model_name", type=str, default="/path/to/Qwen3-32B")
    parser.add_argument("--reward_model_name", type=str, default="/path/to/Qwen3-32B")
    parser.add_argument("--dataset", type=str, default="/path/to/sample_prompts.jsonl", help="HF dataset name or local json/jsonl path")
    parser.add_argument("--prompt_column", type=str, default="prompt")
    parser.add_argument("--output_dir", type=str, default="./outputs-grpo")
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--learning_rate", type=float, default=5e-6)
    parser.add_argument("--per_device_train_batch_size", type=int, default=2)
    parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
    parser.add_argument("--num_train_epochs", type=int, default=1)
    parser.add_argument("--max_prompt_length", type=int, default=256)
    parser.add_argument("--max_completion_length", type=int, default=128)
    parser.add_argument("--num_generations", type=int, default=2, help="Samples per prompt for GRPO")
    parser.add_argument("--bf16", action="store_true")
    parser.add_argument("--use_liger_loss", action="store_true")
    parser.add_argument("--keyword", type=str, nargs="*", default=["clear", "concise"]) 
    parser.add_argument("--keyword_weight", type=float, default=0.3)
    parser.add_argument("--length_min", type=int, default=30)
    parser.add_argument("--length_max", type=int, default=200)
    parser.add_argument("--length_weight", type=float, default=0.2)
    parser.add_argument("--rm_weight", type=float, default=1.0)
    parser.add_argument("--push_to_hub", action="store_true")
    parser.add_argument("--max_steps", type=int, default=10, help="Override total training steps when dataset has no length")
    
    args = parser.parse_args()
    set_seed(args.seed)

    # Load policy model and tokenizer
    tokenizer = AutoTokenizer.from_pretrained(args.model_name, use_fast=True)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    model = AutoModelForCausalLM.from_pretrained(
        args.model_name,
   #    torch_dtype="auto",
   #    device_map="auto",
    )

    # Build reward functions
    rm_fn = build_reward_model_fn(args.reward_model_name)
    kw_fn = build_keyword_reward_fn(args.keyword)
    len_fn = build_length_reward_fn(args.length_min, args.length_max)

    combined = CombinedReward(
        reward_fns=[rm_fn, kw_fn, len_fn],
        weights=[args.rm_weight, args.keyword_weight, args.length_weight],
    )

    # Adapter for TRL: GRPO expects either `reward_func` or `reward_funcs` depending on version; we provide one callable
    def reward_adapter(completions: List[str], prompts: List[str], **kwargs) -> List[float]:
        return combined(completions, prompts, **kwargs)

    # Data
    train_ds = load_prompts_dataset(args.dataset, split="train", prompt_column=args.prompt_column)

    grpo_cfg = GRPOConfig(
        output_dir=args.output_dir,
        learning_rate=args.learning_rate,
        per_device_train_batch_size=args.per_device_train_batch_size,
        gradient_accumulation_steps=args.gradient_accumulation_steps,
        num_train_epochs=args.num_train_epochs,
        logging_steps=10,
        save_strategy="steps",
        save_steps=100,
        eval_strategy="no",
        remove_unused_columns=False,
        fp16=True,
        max_prompt_length=args.max_prompt_length,
        max_completion_length=args.max_completion_length,
        num_generations=args.num_generations,
        report_to=["none"],
        use_liger_loss=args.use_liger_loss,
        push_to_hub=args.push_to_hub,
        max_steps=args.max_steps
    )

    trainer = GRPOTrainer(
        model=model,
        processing_class=tokenizer,
        args=grpo_cfg,
        train_dataset=train_ds,
        reward_funcs=reward_adapter,
    )

    trainer.train()
    trainer.save_model(args.output_dir)
    tokenizer.save_pretrained(args.output_dir)


if __name__ == "__main__":
    main()

sample_prompts.jsonl:

{"prompt": "Summarize the benefits of continuous integration in software development."}
{"prompt": "Explain event-driven architecture and when to use it."}
{"prompt": "List practical strategies to reduce latency in web applications."}
{"prompt": "Compare synchronous vs asynchronous programming with concise examples."}
{"prompt": "What is the CAP theorem? Provide real-world implications."}
{"prompt": "Describe a robust logging strategy for a microservices system."}
{"prompt": "How to design idempotent APIs? Provide key considerations."}
{"prompt": "Outline steps to secure a REST API used by mobile apps."}
{"prompt": "Explain blue-green deployment and its trade-offs."}
{"prompt": "Give best practices for database schema versioning and migrations."}
{"prompt": "What is message deduplication and why does it matter in queues?"}
{"prompt": "Propose a caching strategy for a news homepage with frequent updates."}
{"prompt": "How to prevent prompt injection in LLM applications?"}
{"prompt": "Design a feature flag system for safe rollout and rollback."}
{"prompt": "Explain ML model monitoring: drift, data quality, performance."}
{"prompt": "How to test distributed systems effectively?"}
{"prompt": "Write guidelines for handling PII data in analytics pipelines."}
{"prompt": "Describe rate limiting algorithms (token bucket vs leaky bucket)."}
{"prompt": "Summarize strategies to manage secrets in cloud environments."}
{"prompt": "Explain circuit breakers and their role in resilient services."}

~/.cache/huggingface/accelerate/default_config.yaml:

compute_environment: LOCAL_MACHINE
debug: false
deepspeed_config:
  gradient_accumulation_steps: 8
  offload_optimizer_device: cpu
  offload_param_device: cpu
  zero3_init_flag: false
  zero3_save_16bit_model: false
  zero_stage: 3
distributed_type: DEEPSPEED
downcast_bf16: 'no'
enable_cpu_affinity: false
machine_rank: 0
main_training_function: main
mixed_precision: fp16
num_machines: 1
num_processes: 8
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

accelerate launch train_grpo.py

guotong1988 avatar Nov 01 '25 10:11 guotong1988

set both device_map="auto", python3 train_grpo.py can train, but GPU utilization is low.

guotong1988 avatar Nov 01 '25 10:11 guotong1988

Try set smaller gradient_accumulation_steps in default_config.yaml. In GRPOTrainer cases, the memory of GPU will get significant larger with a larger gradient_accumulation_steps. The low use ratio of GPU is also a problem for me... maybe this is natural for GRPO.

LWShowTime avatar Nov 02 '25 09:11 LWShowTime

Thank you @LWShowTime . Setting gradient_accumulation_steps: 1 does not solve the OOM.

guotong1988 avatar Nov 03 '25 01:11 guotong1988

Thank you @LWShowTime . Setting gradient_accumulation_steps: 1 does not solve the OOM.

Maybe you can add some information of your gpu devices. For GRPOTrainer, I can train QWen3 8B on 3 GPUs of 64GB memory, taht's the limit. So I think 32B might need at least 12 GPUs.

Some other things you can try:

  1. check the bf mix_precision is true in your training (default is true)
  2. use LoRA training or other peft training methods. In my cases, QWen 8B using LoRA, four cards, each card's gpu memory usage is about 24GB. @guotong1988

LWShowTime avatar Nov 03 '25 15:11 LWShowTime

Thank you @LWShowTime . Setting gradient_accumulation_steps: 1 does not solve the OOM.

Maybe you can add some information of your gpu devices. For GRPOTrainer, I can train QWen3 8B on 3 GPUs of 64GB memory, taht's the limit. So I think 32B might need at least 12 GPUs.

Some other things you can try:

  1. check the bf mix_precision is true in your training (default is true)
  2. use LoRA training or other peft training methods. In my cases, QWen 8B using LoRA, four cards, each card's gpu memory usage is about 24GB. @guotong1988 QWen3 8B on 3 GPUs of 64GB memory? Is it too whimsical to attempt training a 72B model on 8 x L20 (48g) without using a reward model?

Qshuangyan avatar Nov 06 '25 23:11 Qshuangyan

Try set smaller gradient_accumulation_steps in default_config.yaml. In GRPOTrainer cases, the memory of GPU will get significant larger with a larger gradient_accumulation_steps. The low use ratio of GPU is also a problem for me... maybe this is natural for GRPO. Hello, I've also observed a similar situation.In GRPOTrainer cases, the memory of GPU will get significant larger with a larger gradient_accumulation_steps.Could you please explain why this is? It's so different from SFT. Thank you

Fyeward avatar Nov 12 '25 10:11 Fyeward

Try set smaller gradient_accumulation_steps in default_config.yaml. In GRPOTrainer cases, the memory of GPU will get significant larger with a larger gradient_accumulation_steps. The low use ratio of GPU is also a problem for me... maybe this is natural for GRPO. Hello, I've also observed a similar situation.In GRPOTrainer cases, the memory of GPU will get significant larger with a larger gradient_accumulation_steps.Could you please explain why this is? It's so different from SFT. Thank you

@Fyeward I talked about this with GPT-5, the explanation makes sense. When you use gradient accumulation, you need to cache some extra tensors like KV cache. In this cases, this kind of tensors might be giant. Sorry but I can't give more detailed calculation becuase I'm under learning.

LWShowTime avatar Nov 12 '25 12:11 LWShowTime

Try set smaller gradient_accumulation_steps in default_config.yaml. In GRPOTrainer cases, the memory of GPU will get significant larger with a larger gradient_accumulation_steps. The low use ratio of GPU is also a problem for me... maybe this is natural for GRPO. Hello, I've also observed a similar situation.In GRPOTrainer cases, the memory of GPU will get significant larger with a larger gradient_accumulation_steps.Could you please explain why this is? It's so different from SFT. Thank you

@Fyeward I talked about this with GPT-5, the explanation makes sense. When you use gradient accumulation, you need to cache some extra tensors like KV cache. In this cases, this kind of tensors might be giant. Sorry but I can't give more detailed calculation becuase I'm under learning. Anyway, thank you for your reply. I found that gradient_accumulation_steps works in grpo as the steps_per_generations parameter, which I think is the key to understanding this situation. Just like the implementation of #3283

Fyeward avatar Nov 12 '25 15:11 Fyeward

Try set smaller gradient_accumulation_steps in default_config.yaml. In GRPOTrainer cases, the memory of GPU will get significant larger with a larger gradient_accumulation_steps. The low use ratio of GPU is also a problem for me... maybe this is natural for GRPO. Hello, I've also observed a similar situation.In GRPOTrainer cases, the memory of GPU will get significant larger with a larger gradient_accumulation_steps.Could you please explain why this is? It's so different from SFT. Thank you

@Fyeward I talked about this with GPT-5, the explanation makes sense. When you use gradient accumulation, you need to cache some extra tensors like KV cache. In this cases, this kind of tensors might be giant. Sorry but I can't give more detailed calculation becuase I'm under learning. Anyway, thank you for your reply. I found that gradient_accumulation_steps works in grpo as the steps_per_generations parameter, which I think is the key to understanding this situation. Just like the implementation of #3283

@Fyeward Good, this is truly an important improvement. But it is weird for my cases because my trl version is 0.23.0 which is already updated with this amelioration.

LWShowTime avatar Nov 14 '25 02:11 LWShowTime

Try set smaller gradient_accumulation_steps in default_config.yaml. In GRPOTrainer cases, the memory of GPU will get significant larger with a larger gradient_accumulation_steps. The low use ratio of GPU is also a problem for me... maybe this is natural for GRPO. Hello, I've also observed a similar situation.In GRPOTrainer cases, the memory of GPU will get significant larger with a larger gradient_accumulation_steps.Could you please explain why this is? It's so different from SFT. Thank you

@Fyeward I talked about this with GPT-5, the explanation makes sense. When you use gradient accumulation, you need to cache some extra tensors like KV cache. In this cases, this kind of tensors might be giant. Sorry but I can't give more detailed calculation becuase I'm under learning. Anyway, thank you for your reply. I found that gradient_accumulation_steps works in grpo as the steps_per_generations parameter, which I think is the key to understanding this situation. Just like the implementation of #3283

@Fyeward Good, this is truly an important improvement. But it is weird for my cases because my trl version is 0.23.0 which is already updated with this amelioration. I don't think the purpose of this improvement is to save memory. Steps_per_generations doesn't control cache release, so your situation might actually be normal

Fyeward avatar Nov 14 '25 07:11 Fyeward

Try set smaller gradient_accumulation_steps in default_config.yaml. In GRPOTrainer cases, the memory of GPU will get significant larger with a larger gradient_accumulation_steps. The low use ratio of GPU is also a problem for me... maybe this is natural for GRPO. Hello, I've also observed a similar situation.In GRPOTrainer cases, the memory of GPU will get significant larger with a larger gradient_accumulation_steps.Could you please explain why this is? It's so different from SFT. Thank you

@Fyeward I talked about this with GPT-5, the explanation makes sense. When you use gradient accumulation, you need to cache some extra tensors like KV cache. In this cases, this kind of tensors might be giant. Sorry but I can't give more detailed calculation becuase I'm under learning. Anyway, thank you for your reply. I found that gradient_accumulation_steps works in grpo as the steps_per_generations parameter, which I think is the key to understanding this situation. Just like the implementation of #3283

@Fyeward Good, this is truly an important improvement. But it is weird for my cases because my trl version is 0.23.0 which is already updated with this amelioration. I don't think the purpose of this improvement is to save memory. Steps_per_generations doesn't control cache release, so your situation might actually be normal

@Fyeward You're right. generate_every = self.args.steps_per_generation * self.num_iterations And self.args.steps_per_generation is self.gradient_accumulation_steps. It is normal to meet OOM when you add the accumulation_steps becaue your actual llm generate batchsize will multiply.

LWShowTime avatar Nov 20 '25 12:11 LWShowTime