feat: enable GRPO training with logprobs from offline trajectory data
Summary
This PR enables proper GRPO training with importance sampling when using offline trajectory data (e.g., from vLLM traces). It includes three complementary fixes:
1. Extract logprobs from dict messages
Problem: ART's tokenizer only extracted logprobs from OpenAI Choice objects, but offline trajectory data often stores logprobs in plain Python dicts. This caused all dict message logprobs to be set to NaN, making the importance ratio = 1.0 always (effectively REINFORCE instead of GRPO).
Solution: Modified tokenize.py to also extract logprobs from dict messages that have the format {"logprobs": {"content": [{"logprob": -0.5}, ...]}}.
2. Strip logprobs before RULER scoring
Problem: When trajectories contain verbose logprobs data, sending them to the RULER judge causes context length errors.
Solution: Strip logprobs from trajectories before sending to RULER using strip_logprobs().
3. Preserve _internal_config.engine_args
Problem: When using TrainableModel._internal_config.engine_args to configure vLLM engine settings (like max_logprobs), the configuration was silently lost when using the SkyPilot backend.
Solution: Add a model_validator(mode="wrap") to preserve _internal_config during Pydantic deserialization.
Impact
| Aspect | Before | After |
|---|---|---|
| Importance ratio | 1.0 always (for dict messages) | π_new / π_old |
| PPO clipping | Never activates | Activates when ratio outside [0.8, 1.2] |
| Algorithm | REINFORCE | GRPO with importance sampling |
Test plan
- [x] Verified
max_logprobssetting works with SkyPilot backend - [x] Ran
./scripts/run_checks.sh- all checks pass - [ ] Test with training that uses offline trajectory data with logprobs