RL
RL copied to clipboard
high logprob error with qwen30b a3b gspo
Describe the bug
Large logprob erors with qwen30b a3b with gspo
grpo:
num_prompts_per_step: 256
num_generations_per_prompt: 16
loss_fn:
reference_policy_kl_penalty: 0
ratio_clip_min: 3e-4
ratio_clip_max: 4e-4
ratio_clip_c: null
use_on_policy_kl_approximation: false
use_importance_sampling_correction: false
sequence_level_importance_ratios: true
token_level_loss: false
policy:
model_name: Qwen/Qwen3-30B-A3B-Instruct-2507
dtensor_cfg:
enabled: False
optimizer: null
scheduler: null
sequence_packing:
enabled: true
train_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.train_micro_batch_size}}
logprob_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.logprob_batch_size}}
algorithm: "modified_first_fit_decreasing"
sequence_length_round: 64
generation:
vllm_cfg:
tensor_parallel_size: 2
gpu_memory_utilization: 0.5
make_sequence_length_divisible_by: ${policy.megatron_cfg.tensor_model_parallel_size}
megatron_cfg:
enabled: true
empty_unused_memory_level: 1
activation_checkpointing: true
converter_type: "LlamaForCausalLM"
tensor_model_parallel_size: 4
expert_tensor_parallel_size: 1
expert_model_parallel_size: 8
pipeline_model_parallel_size: 1
num_layers_in_first_pipeline_stage: null
num_layers_in_last_pipeline_stage: null
context_parallel_size: 2
pipeline_dtype: ${policy.precision}
sequence_parallel: true
freeze_moe_router: true
moe_router_dtype: "fp64"
moe_router_load_balancing_type: "none"
moe_router_bias_update_rate: 0.0
moe_permute_fusion: false
apply_rope_fusion: True
defer_fp32_logits: true
optimizer:
optimizer: "adam"
lr: 2.0e-6
min_lr: ${policy.megatron_cfg.optimizer.lr}
weight_decay: 0.01
bf16: true
fp16: false
params_dtype: "float32"
adam_beta1: 0.9
adam_beta2: 0.999
adam_eps: 1e-8
sgd_momentum: 0.9
use_distributed_optimizer: true
use_precision_aware_optimizer: true
clip_grad: ${policy.max_grad_norm}
scheduler:
start_weight_decay: ${policy.megatron_cfg.optimizer.weight_decay}
end_weight_decay: ${policy.megatron_cfg.optimizer.weight_decay}
weight_decay_incr_style: "constant"
lr_decay_style: "constant"
lr_decay_iters: null
lr_warmup_iters: 0
lr_warmup_init: ${policy.megatron_cfg.optimizer.lr}
distributed_data_parallel_config:
grad_reduce_in_fp32: false
overlap_grad_reduce: true
overlap_param_gather: true
average_in_collective: true
use_custom_fsdp: false
data_parallel_sharding_strategy: "optim_grads_params"
env_vars: null
checkpointing:
checkpoint_must_save_by: "00:03:15:00"
save_period: 10
@cmunley1 do the gen_kl_error plots in your runs show problematic spikes? https://github.com/NVIDIA-NeMo/RL/blob/6984ba7d2c59e368e80be022e0d45a0f8170b977/docs/guides/grpo.md#kl-divergence-error
We've noticed that the log prob error can spike from time to time, but hasn't demonstrated any crash, so we're also in alignment with the insights from this blog https://yingru.notion.site/When-Speed-Kills-Stability-Demystifying-RL-Collapse-from-the-Training-Inference-Mismatch-271211a558b7808d8b12d403fd15edda#271211a558b78046af48c3129693f3f1