trl icon indicating copy to clipboard operation
trl copied to clipboard

[KTO]: Fix nan losses and crashing job

Open claralp opened this issue 1 year ago • 23 comments

fixes #1447

  • use nanmean() instead of mean() for losses to avoid nan losses
  • remove obsolete accelerator.gather for metrics as the metrics are all collected to cpu anyway and averaged later when logging them.

claralp avatar Mar 22 '24 20:03 claralp

thanks @claralp

kashif avatar Mar 22 '24 21:03 kashif

btw. I also tested it on multiple GPUs now. It is running and both GPUs show usage, but the distribution does not seem to be effective as there is no speedup compared to 1 GPU.
However, this is the same behavior as before this PR. Might look into that later

claralp avatar Mar 24 '24 13:03 claralp

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@kashif @lewtun: is there anything more I should validate or can you approve this for now?
I am looking further into the multi GPU case atm, but would prefer to keep this small and open another PR if I find something.

claralp avatar Mar 26 '24 09:03 claralp

@claralp would you be able to test out #1476 ?

kashif avatar Mar 26 '24 09:03 kashif

@kashif for me https://github.com/huggingface/trl/pull/1476 has a few issues, (see comments there).
Except from this:

  • removing the interleaving of datasets works for me and is a useful change
  • using e.g. metrics[f"{prefix}rewards/chosen"] = self.accelerator.gather(mean_chosen_reward).nanmean().item() instead of metrics[f"{prefix}rewards/chosen"] = chosen_rewards.detach().nanmean().cpu() makes no difference for balanced batches. Both move the mean value to CPU at the end (one as tensor and one as scalar tho). However, self.accelerator.gather(mean_chosen_reward) leads to issues when mean_chosen_reward is nan. This happened if there are no 'chosen' samples in the batch

claralp avatar Mar 26 '24 14:03 claralp

thanks @claralp with regard to moving the loss to the self.args.device I thought that accelerate handles all that for us?

kashif avatar Mar 26 '24 14:03 kashif

@kashif for the loss accelerate handles that for us, yes.
But this is just for the rewards/logp that are stored as metrics to be logged. In my understanding, that's why the metrics are collected across all GPUs and sent to CPU to calculate the mean value there.
E.g. in dpo_trainer, you use metrics[f"{prefix}rewards/chosen"] = chosen_rewards.mean().cpu()

claralp avatar Mar 26 '24 14:03 claralp

@kashif, I applied the suggested changes from the code style pipeline. Should be green again now

claralp avatar Mar 28 '24 12:03 claralp

we're facing a problem that possibly will be solved by this PR, thanks for this @claralp ! we're going to test it and give the feedback.

johncordeiro avatar Mar 28 '24 19:03 johncordeiro

hey! I just tested the code in branch claralp:fix_nans, and it's still hanging during evaluation. I tested with with 8 NVIDIA H100 with the following batch_size configuration:

per_device_train_batch_size: 4 per_device_eval_batch_size: 2 gradient_accumulation_steps: 8 total_batch_size: 256

johncordeiro avatar Mar 30 '24 14:03 johncordeiro

@johncordeiro could you try the version here and see if you're sill experiencing hanging? https://github.com/kawine/trl if so, more context would be helpful

@claralp thank you for all the fixes! seems like this version is only using metrics from the main process' microbatch, by getting rid of the gathers? i have a pr in #1499 that includes most of your changes here but still supports batch-wise gathering if you want to take a look

kawine avatar Apr 01 '24 04:04 kawine

@johncordeiro could you try the version here and see if you're sill experiencing hanging? https://github.com/kawine/trl if so, more context would be helpful

@claralp thank you for all the fixes! seems like this version is only using metrics from the main process' microbatch, by getting rid of the gathers? i have a pr in #1499 that includes most of your changes here but still supports batch-wise gathering if you want to take a look

@kawine as I understand and experienced it, this version does not use only the main process' micro batch. It sends the micro batch of every GPU directly to the CPU to calculate metrics there.
Using accelerator.gather would only gather them all in the main process first before sending them to CPU. This works, but is not necessary, as the metrics are collected on CPU later anyway.

For DPO it is implemented the same without gathering: https://github.com/huggingface/trl/blob/0ee349dcd43b0f4b3169449f16751c38ac4a609f/trl/trainer/dpo_trainer.py#L1055

claralp avatar Apr 02 '24 08:04 claralp

It sends the micro batch of every GPU directly to the CPU to calculate metrics there.

thanks @claralp ! maybe it's just me, but I'm not seeing the CPU collection happen? if there is no explicit gather, then the logging only reflects the statistics from the microbatch on the main process only.

the HF documentation for Trainer, which is subclassed here, also says for log_metrics that "Under distributed environment this is done only for a process with rank 0."

kawine avatar Apr 02 '24 17:04 kawine

@johncordeiro do you have prediction_loss_only enabled in your evaluation step? The logits were not propagated, so this part was not working in the previous version.
Another reason could be an issue with the autocast context manager, which was previously only used in the evaluation step (but not the training step). Can you check if the problem now also occurs in the training step?
If both is not the reason behind the hanging issue in evaluation, can you provide the line where it hangs?

claralp avatar Apr 02 '24 20:04 claralp

@kawine it uses .tolist() for all metrics, which moves them to CPU. The log_metrics of HF Trainer is not used by KTO Trainer, it uses its own subclassed log function

claralp avatar Apr 02 '24 20:04 claralp

@claralp if i add the line metrics[f"device"] = torch.Tensor([float(str(self.args.device)[-1])]).cpu() to get_batch_loss_metrics, i can see in wandb that the value is always 0 (i.e., the main process), suggesting that only metrics from one device are tracked without the explicit gather

kawine avatar Apr 03 '24 09:04 kawine

@claralp if i add the line metrics[f"device"] = torch.Tensor([float(str(self.args.device)[-1])]).cpu() to get_batch_loss_metrics, i can see in wandb that the value is always 0 (i.e., the main process), suggesting that only metrics from one device are tracked without the explicit gather

@kawine you need to read self.accelerator.process_index to get the current index when using Accelerator.
And make sure you use the latest version, in a previous version store_metrics was really just called from the main process, which was wrong

claralp avatar Apr 03 '24 09:04 claralp

still getting the same thing with metrics[f"device"] = torch.Tensor([float(str(self.accelerator.process_index)[-1])]).cpu() and the latest version of accelerate (0.28.0)

kawine avatar Apr 03 '24 09:04 kawine

still getting the same thing with metrics[f"device"] = torch.Tensor([float(str(self.accelerator.process_index)[-1])]).cpu() and the latest version of accelerate (0.28.0)

@kawine, I just tested this with per_device_batch_size=2 on 2 GPUs. When printing the metrics in store_metrics, I get:

{'rewards/chosen': [-0.23836135864257812, -0.9353666305541992], 'rewards/rejected': [], 'kl': 1.9421863555908203, 'logps/chosen': [-33.11909484863281, -32.70828628540039], 'logps/rejected': [], 'device': 1}
{'rewards/chosen': [-0.018026351928710938], 'rewards/rejected': [-0.438812255859375], 'kl': 1.9421863555908203, 'logps/chosen': [-24.179973602294922], 'logps/rejected': [-138.9338836669922], 'device': 0}

I print them in store_metrics, because the logging to e.g. wandb prints an average of 0.5 for 2 devices instead of showing the metrics per device.
Which accelerate config do you use?

claralp avatar Apr 03 '24 13:04 claralp

@kashif or @lewtun can you confirm if the current changes work for you as well?

claralp avatar Apr 03 '24 15:04 claralp

sure @claralp let me see!

kashif avatar Apr 03 '24 15:04 kashif

@claralp i've tried this with deepspeed and the regular data parallel (examples/accelerate_configs/multi_gpu.yaml) and it still only reports stats from the main process.

if you are printing the stats in store_metrics, it will print them from each process, yes, so you will have one line with 'device': 1 and another with 'device': 0. But if you look at the chart for train/device in wandb, it will show only 0, because the two processes' stats are never gathered (at least for me) Screen Shot 2024-04-03 at 1 40 13 PM

if you look at PPOTrainer, the metrics are explicitly being gathered across all machines at each step with an all_reduce: https://github.com/huggingface/trl/blob/ab0d11d81550618546cd1d8807627a594a58f029/trl/trainer/ppo_trainer.py#L897

kawine avatar Apr 03 '24 20:04 kawine

@kawine reading your plot -> is it possible that you train on multiple different machines with one GPU each? It reads "copper-paper" and "noble-pyramid". I think thats names coming from k8s?

PhilipMay avatar Apr 08 '24 08:04 PhilipMay

closing this due to #1514

kashif avatar Apr 08 '24 15:04 kashif