torchtitan icon indicating copy to clipboard operation
torchtitan copied to clipboard

[Feature] Add Multi-Token Prediction module

Open lessw2020 opened this issue 8 months ago • 3 comments

We need to create a multi-token prediction module (same as what DeepSeek v3 used) that can plug into architectures such as DeepSeek v3 etc. and allow for multi-token prediction training with an arbitrary level of blocks. (default =4). This should be a modular architecture in that it clones the existing block and appends to the neck of the model and auto-integrates for training.

lessw2020 avatar Mar 05 '25 19:03 lessw2020

Hey! I've worked on this in https://github.com/janEbert/torchtitan/commit/b308039ddb4fe2f00dee06277a5d146337afa90c. It does not yet correctly support pipelining: the shared embedding and unembedding layers' gradients need to be manually all-reduced because the PyTorch pipelining APIs don't have explicit support for shared parameters, to the best of my knowledge. I'm also worried that FSDP2 may not handle the shared parameters correctly, this also requires testing.

I went with a limitation that only full MTP modules can be pipelined, i.e., always including an embedding, 2d×d linear projection, Transformer block, and unembedding.

If you are fine with an external contribution for this, I'd be happy for comments on the implementation and, of course, making it work with pipelining as well.

janEbert avatar Apr 02 '25 12:04 janEbert

Hi @janEbert - nice, thanks for the update and yes would definitely be happy to get an external PR from you on this. I have been busy with groupGEMM and haven't had time to start on this. We can look at the pipelining aspect as would like MTP to be enabled with it. Look forward to working with you and your PR!

lessw2020 avatar Apr 02 '25 20:04 lessw2020

Awesome, thanks! Will submit a PR in due time!

janEbert avatar Apr 03 '25 07:04 janEbert