KTOTrainer Hangs even before the first step for Multi-GPU:
Hello! is #1342 fixed? I'm experiencing similar issue, yet even before the first step completes. (i'm using accelerate + deepseed) i.e.
It hangs here:
from get_batch_loss_metrics() in KTOTrainer:
print("this runs!")
print(f"CHOSEN LOGPS SHAPE: {policy_chosen_logps.shape}\n")
mean_chosen_reward = self.accelerator.gather(chosen_rewards.detach()).nanmean().nan_to_num(0)
mean_rejected_reward = self.accelerator.gather(rejected_rewards.detach()).nanmean().nan_to_num(0)
mean_margin = mean_chosen_reward - mean_rejected_reward
mean_logps_chosen = self.accelerator.gather(policy_chosen_logps.detach()).nanmean().nan_to_num(0)
mean_logps_rejected = self.accelerator.gather(policy_rejected_logps.detach()).nanmean().nan_to_num(0)
print("this doesn't")
I'm using 2xA100 (also tried with 4 and 8), and I'm using batch size of 8. Interestingly, I get:
CHOSEN LOGPS SHAPE: torch.Size([1]) CHOSEN LOGPS SHAPE: torch.Size([5])
Could this be a reason why it hangs? and how can I fix it? Thanks in advnace!
Some suggestions:
- Are you using the latest code? e.g., do you have the following above your print statement? If not, I would update it
# lists can't be empty -- if they are, then accelerate.gather will hang
if policy_chosen_logps.shape[0] == 0:
policy_chosen_logps = torch.Tensor([torch.nan]).to(self.accelerator.device)
- I'm using accelerate 0.27.2 and deepspeed 0.13.3. Which versions are you using?
Thanks for the reply! Yes, I have that code above the print statement. I was using accelerate 0.27.2 and deepspeed 0.13.4. I guess I'll downgrade it to 0.13.3 and try again. And also follow some suggestions on #1342 .
Thanks everyone! @hbin0701 can you try on TRL main? we recently added many fixes for KTO
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Yup it works nicely :) One thing that's concerning is that for some reason, this takes considerably much more time than DPOTrainer on the same data. Any ideas why?
@hbin0701 KTO is twice as data-efficient as DPO, since n pairwise preferences for DPO -> 2n examples for KTO. So for the same dataset, KTO should take roughly twice as long as DPO to train, because you're making twice as many updates.
This also means that if the batch size is the same, you're using less memory with KTO and can thus align larger models and/or longer sequences.
Thanks for the comment! closing this now :)