trl icon indicating copy to clipboard operation
trl copied to clipboard

TRL orpo gives everything Nan

Open gagan3012 opened this issue 2 years ago • 6 comments

{'loss': 0.0, 'grad_norm': nan, 'learning_rate': 2.5000000000000002e-08, 'rewards/chosen': nan, 'rewards/rejected': nan, 'rewards/accuracies': 0.0, 'rewards/margins': nan, 'logps/rejected': nan, 'logps/chosen': nan, 'logits/rejected': -3.114448070526123, 'logits/chosen': -3.114448070526123, 'nll_loss': nan, 'log_odds_ratio': nan, 'log_odds_chosen': nan, 'epoch': 0.0} {'loss': 0.0, 'grad_norm': nan, 'learning_rate': 5.0000000000000004e-08, 'rewards/chosen': nan, 'rewards/rejected': nan, 'rewards/accuracies': 0.0, 'rewards/margins': nan, 'logps/rejected': nan, 'logps/chosen': nan, 'logits/rejected': nan, 'logits/chosen': nan, 'nll_loss': nan, 'log_odds_ratio': nan, 'log_odds_chosen': nan, 'epoch': 0.0} {'loss': 0.0, 'grad_norm': nan, 'learning_rate': 7.500000000000001e-08, 'rewards/chosen': nan, 'rewards/rejected': nan, 'rewards/accuracies': 0.0, 'rewards/margins': nan, 'logps/rejected': nan, 'logps/chosen': nan, 'logits/rejected': nan, 'logits/chosen': nan, 'nll_loss': nan, 'log_odds_ratio': nan, 'log_odds_chosen': nan, 'epoch': 0.0} {'loss': 0.0, 'grad_norm': nan, 'learning_rate': 1.0000000000000001e-07, 'rewards/chosen': nan, 'rewards/rejected': nan, 'rewards/accuracies': 0.0, 'rewards/margins': nan, 'logps/rejected': nan, 'logps/chosen': nan, 'logits/rejected': nan, 'logits/chosen': nan, 'nll_loss': nan, 'log_odds_ratio': nan, 'log_odds_chosen': nan, 'epoch': 0.0} {'loss': 0.0, 'grad_norm': nan, 'learning_rate': 1.2500000000000002e-07, 'rewards/chosen': nan, 'rewards/rejected': nan, 'rewards/accuracies': 0.0, 'rewards/margins': nan, 'logps/rejected': nan, 'logps/chosen': nan, 'logits/rejected': nan, 'logits/chosen': nan, 'nll_loss': nan, 'log_odds_ratio': nan, 'log_odds_chosen': nan, 'epoch': 0.0} {'loss': 0.0, 'grad_norm': nan, 'learning_rate': 1.5000000000000002e-07, 'rewards/chosen': nan, 'rewards/rejected': nan, 'rewards/accuracies': 0.0, 'rewards/margins': nan, 'logps/rejected': nan, 'logps/chosen': nan, 'logits/rejected': nan, 'logits/chosen': nan, 'nll_loss': nan, 'log_odds_ratio': nan, 'log_odds_chosen': nan, 'epoch': 0.0} {'loss': 0.0, 'grad_norm': nan, 'learning_rate': 1.7500000000000002e-07, 'rewards/chosen': nan, 'rewards/rejected': nan, 'rewards/accuracies': 0.0, 'rewards/margins': nan, 'logps/rejected': nan, 'logps/chosen': nan, 'logits/rejected': nan, 'logits/chosen': nan, 'nll_loss': nan, 'log_odds_ratio': nan, 'log_odds_chosen': nan, 'epoch': 0.0} {'loss': 0.0, 'grad_norm': nan, 'learning_rate': 2.0000000000000002e-07, 'rewards/chosen': nan, 'rewards/rejected': nan, 'rewards/accuracies': 0.0, 'rewards/margins': nan, 'logps/rejected': nan, 'logps/chosen': nan, 'logits/rejected': nan, 'logits/chosen': nan, 'nll_loss': nan, 'log_odds_ratio': nan, 'log_odds_chosen': nan, 'epoch': 0.0} {'loss': 0.0, 'grad_norm': nan, 'learning_rate': 2.2500000000000002e-07, 'rewards/chosen': nan, 'rewards/rejected': nan, 'rewards/accuracies': 0.0, 'rewards/margins': nan, 'logps/rejected': nan, 'logps/chosen': nan, 'logits/rejected': nan, 'logits/chosen': nan, 'nll_loss': nan, 'log_odds_ratio': nan, 'log_odds_chosen': nan, 'epoch': 0.0}

using the example script orpo.py i get this error

gagan3012 avatar Mar 23 '24 03:03 gagan3012

I'm experiencing same issue :( seems like the grad_norm suddenly diverges to infinity after some iterations.

hbin0701 avatar Mar 23 '24 06:03 hbin0701

@gagan3012 @hbin0701 do you see this with some specific dataset? Here is my run of orpo.py:

https://wandb.ai/krasul/huggingface/runs/rqu2awe3?nw=nwuserkrasul

using:

python examples/scripts/orpo.py \
    --model_name_or_path=gpt2 \
    --per_device_train_batch_size 4 \
    --max_steps 1000 \
    --learning_rate 1e-3 \
    --gradient_accumulation_steps 1 \
    --logging_steps 10 \
    --eval_steps 500 \
    --output_dir="orpo_anthropic_hh" \
    --optim rmsprop \
    --warmup_steps 150 \
    --report_to wandb \
    --bf16 \
    --logging_first_step \
    --no_remove_unused_columns \
    --use_peft \
    --lora_r=16 \
    --lora_alpha=16

kashif avatar Mar 23 '24 09:03 kashif

I was using mistral 0.2

gagan3012 avatar Mar 24 '24 19:03 gagan3012

Hello @gagan3012, I just saw this issue and would like to add some comments!

Although I do not know the specific environment or dataset you are using, it is generally recommended that you use a lower learning rate and beta for larger models.

For example, this code for reproducing kaist-ai/mistral-orpo-capybara-7k uses a maximum learning rate of 5e-6 and beta of 0.05. (this code is not for TRL ORPOTrainer, by the way)

accelerate launch --config_file ./src/accelerate/fsdp.yaml main.py \
    --lr 5e-6 \
    --torch_compile False \
    --beta 0.05 \
    --lr_scheduler_type inverse_sqrt \
    --warmup_steps 100 \
    --model_name mistralai/Mistral-7B-v0.1 \
    --data_name argilla/distilabel-capybara-dpo-7k-binarized \
    --num_train_epochs 3 \
    --optim adamw_bnb_8bit \
    --gradient_accumulation_steps 1 \
    --prompt_max_length 1792 \
    --response_max_length 2048 \
    --per_device_train_batch_size 8 \
    --per_device_eval_batch_size 8 \
    --num_proc 8 \
    --flash_attention_2

I am not sure which dataset you are training to, but I would start with a beta of 0.1 and a learning rate of 5e-6 for the first. I will add some general guidelines for selecting the learning rate and beta by model size/dataset style in this repo by this week!

jiwooya1000 avatar Mar 25 '24 12:03 jiwooya1000

Hello, When using the Orpo repo, i don't face this issue, but I face this issue when I use TRL, which is very puzzling

gagan3012 avatar Mar 26 '24 21:03 gagan3012

Is your prompt preparation correct?

TRL expects the "chosen" and "rejected" columns to be a) formatted (but not tokenized) and b) to EXCLUDE the prompt.

TRL also does not add any bos or eos tokens, so you need to do that in the chat_template. Further, since you'll be formatting chosen and rejected columns without the prompt, you need to ensure that the bos is NOT included there...

RonanKMcGovern avatar Mar 27 '24 10:03 RonanKMcGovern

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

github-actions[bot] avatar Apr 22 '24 15:04 github-actions[bot]

When the prompt exceeds the max_length, the log probabilities for both chosen and rejected turn to NaN. Consider filtering out cases where the prompt is longer than the max_length or max_prompt_len. The reason for trimming cases where the prompt exceeds max_prompt_len is that if the chosen or rejected segments are significantly shorter than the prompt, it may hinder effective learning.

conseq2 avatar Apr 26 '24 02:04 conseq2

I also have a similar problem, but it's different from what is mentioned above. My dataset doesn't have prompts, and all the prompts are concatenated with chosen/rejected in the dataset.

paulcx avatar Apr 26 '24 22:04 paulcx

This line: https://github.com/huggingface/trl/blob/main/trl/trainer/orpo_trainer.py#L618

There is a possibility that torch.exp(policy_chosen_logps) or torch.exp(policy_rejected_logps) will be "1". Then torch.log1p results NaN.

poutyface avatar May 07 '24 05:05 poutyface

This line: https://github.com/huggingface/trl/blob/main/trl/trainer/orpo_trainer.py#L618

There is a possibility that torch.exp(policy_chosen_logps) or torch.exp(policy_rejected_logps) will be "1". Then torch.log1p results NaN.

any solution?

paulcx avatar May 08 '24 02:05 paulcx

Adding eps=1e-5 to log1p param work fine for me

poutyface avatar May 08 '24 12:05 poutyface

Adding eps=1e-5 to log1p param work fine for me

you mean torch.log1p(x + eps)?

paulcx avatar May 08 '24 21:05 paulcx

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

github-actions[bot] avatar Jun 02 '24 15:06 github-actions[bot]

When the prompt exceeds the max_length, the log probabilities for both chosen and rejected turn to NaN. Consider filtering out cases where the prompt is longer than the max_length or max_prompt_len. The reason for trimming cases where the prompt exceeds max_prompt_len is that if the chosen or rejected segments are significantly shorter than the prompt, it may hinder effective learning.

I was facing similar issues with 'nans' and the problem went away when I filtered out of my dataset examples where the length of the prompt + chosen/reward exceeded a certain lenght

KoutchemeCharles avatar Aug 01 '24 15:08 KoutchemeCharles

When the prompt exceeds the max_length, the log probabilities for both chosen and rejected turn to NaN. Consider filtering out cases where the prompt is longer than the max_length or max_prompt_len. The reason for trimming cases where the prompt exceeds max_prompt_len is that if the chosen or rejected segments are significantly shorter than the prompt, it may hinder effective learning.

FYI, also make sure to set ORPOConfig.max_length accordingly larger than max_prompt_length, especially if you ignore the loss on prompt and only calculate it on the response part. The default value of max_length is 512.

It does print warning messages but I missed them and spent some time figuring it out. I leave this comment for future reference.

https://github.com/huggingface/trl/blob/v0.10.1/trl/trainer/orpo_trainer.py#L216

        if args.max_length is None:
            warnings.warn(
                "`max_length` is not set in the ORPOConfig's init"
                " it will default to `512` by default, but you should do it yourself in the future.",
                UserWarning,
            )
            max_length = 512
        else:
            max_length = args.max_length
        if args.max_prompt_length is None:
            warnings.warn(
                "`max_prompt_length` is not set in the ORPOConfig's init"
                " it will default to `128` by default, but you should do it yourself in the future.",
                UserWarning,
            )
            max_prompt_length = 128
        else:
            max_prompt_length = args.max_prompt_length

CoaLee avatar Dec 05 '24 07:12 CoaLee

i solve this problem by changing float16 to bfloat16

xhkxhk avatar Apr 17 '25 10:04 xhkxhk

There is a possibility that torch.exp(policy_chosen_logps) or torch.exp(policy_rejected_logps) will be "1". Then torch.log1p results NaN.

Still facing the issue, Does anyone has recommendation how to fix it ?

amit-gupta- avatar Sep 09 '25 16:09 amit-gupta-