ART icon indicating copy to clipboard operation
ART copied to clipboard

feat: enable GRPO training with logprobs from offline trajectory data

Open JRMeyer opened this issue 3 months ago • 0 comments

Summary

This PR enables proper GRPO training with importance sampling when using offline trajectory data (e.g., from vLLM traces). It includes three complementary fixes:

1. Extract logprobs from dict messages

Problem: ART's tokenizer only extracted logprobs from OpenAI Choice objects, but offline trajectory data often stores logprobs in plain Python dicts. This caused all dict message logprobs to be set to NaN, making the importance ratio = 1.0 always (effectively REINFORCE instead of GRPO).

Solution: Modified tokenize.py to also extract logprobs from dict messages that have the format {"logprobs": {"content": [{"logprob": -0.5}, ...]}}.

2. Strip logprobs before RULER scoring

Problem: When trajectories contain verbose logprobs data, sending them to the RULER judge causes context length errors.

Solution: Strip logprobs from trajectories before sending to RULER using strip_logprobs().

3. Preserve _internal_config.engine_args

Problem: When using TrainableModel._internal_config.engine_args to configure vLLM engine settings (like max_logprobs), the configuration was silently lost when using the SkyPilot backend.

Solution: Add a model_validator(mode="wrap") to preserve _internal_config during Pydantic deserialization.

Impact

Aspect Before After
Importance ratio 1.0 always (for dict messages) π_new / π_old
PPO clipping Never activates Activates when ratio outside [0.8, 1.2]
Algorithm REINFORCE GRPO with importance sampling

Test plan

  • [x] Verified max_logprobs setting works with SkyPilot backend
  • [x] Ran ./scripts/run_checks.sh - all checks pass
  • [ ] Test with training that uses offline trajectory data with logprobs

JRMeyer avatar Nov 29 '25 02:11 JRMeyer