torchtune
torchtune copied to clipboard
[FR] Sample Packing with correct attention mask
Sample packing with correct attention mask (where the model can't attend to other examples in the batch) and ideally correct RoPE offset would be extremely beneficial. In SFT, examples tend to be highly correlated, so there's an opportunity for the model to cheat when training.
It helps (significantly) in training speed when you are training on examples with diversse lengths, with large maximum seq length.
Known existing implementations:
- https://github.com/MeetKai/functionary/tree/main/functionary/train/packing
- https://github.com/OpenAccess-AI-Collective/axolotl/blob/main/src/axolotl/monkeypatch/multipack.py
This kind of mask is supported by FA:
- https://github.com/Dao-AILab/flash-attention/issues/432#issuecomment-1698610752
One thing to consider is how the examples should be packed -- e.g. naive greedy packing, vs. some more elaborate bin packing algorithm. I think a naive greedy approach would bring a lot of benefit.
Thanks for opening this feature request. Indeed, this very thing is being worked on in #875. I am currently investigating how to make the sample masking work with flash attention (we currently use SDPA which does not support arbitrary masks, so may have to use Tri Dao's implementation as you pointed out). If you have thoughts on this would love your feedback on the PR.
ideally correct RoPE offset would be extremely beneficial
Do you mind elaborating on this?
What I meant by the RoPE comment -- and maybe this is already handled automatically -- is that if we just concatenate examples as with naive packing e.g. in HF transformers, the token's positional embeddings will not represent the actual position in the example, but rather a position in the concatenated examples.
Completed by #875