HALOs
HALOs copied to clipboard
Comments in KTO Trainer `forward()`
Hi there,
I'm reading through the forward() function in KTO Trainer, and in the function signature it states that if read in correctly, the sizes of chosen and rejected logps should be batch_size/2
. However, this doesn't make sense to me because this sounds like a limitation for Paired preference training rather than the unpaired training method of kto.
Here's comment from lines 875-877 of trainers.py
:
chosen_logps: log probabilities of chosen examples (should be batch size / 2 if data was read in correctly)
rejected_logps: log probabilities of rejected examples (should be batch size / 2 if data was read in correctly)
KL_logps: log probabilities of the unmatched y'|x (used to estimate the KL divergence between policy and reference; should be batch size)
Please let me know if this makes sense, Im happy to open a PR.
You're correct! The comment is from when i was trying to debug the code during development and is outdated. Feel free to open a PR and i'll merge it in. Thanks!