fms-fsdp
fms-fsdp copied to clipboard
[speculator training] Speculator training
Add support for speculator training, piggybacking off the existing training utilities.
Training script and speculator-specific utilities are inside the new speculator
subfolder.
Uses distributed setup, checkpointing, and dataloaders from this repo. Adds speculator-specific fields to the training config file (to be ignored during non-speculator training). It might make more sense to pull these new fields out into a separate config subclass under speculator utilities - open to suggestions.
Uses speculator architecture from fms-extras.
Uses altered Llama-7b and generate()
function from base fms, allowing the speculator to access embedding vectors, not just logits/token predictions. ~~Do not merge this until that issue can be resolved.~~
Plan is to move the include_embeds=True
versions of Llama/GPTBigCode/generate() into fms-extras
. Once that is done I'll update the relevant imports here and then we can push this in
I've pulled all the include_embeds
stuff out of fms
into here. We now have EmbedLLaMA
and EmbedGPTBigCode
subclasses that override the corresponding forward function, and an altered version of generate
for use only with this script. We register the subclassed models for use with get_model
in the training script.
Code is ready for review - mypy errors are import errors, ~~it doesn't have fms-extras
and~~ it doesn't like the local import of train_speculator_utils
. Should I move the speculator subfolder under fms_fsdp
so that I can use an absolute path for that?
Hi! What's the status on this PR? I'd like to train a few speculator models, but I'm not sure how to get started, due to a lack of documentation...
Hi! What's the status on this PR? I'd like to train a few speculator models, but I'm not sure how to get started, due to a lack of documentation...
Hi @AlpinDale Working on getting the documentation and code ready for this. Planning to have sometime in the next 3 weeks. Will keep you updated if we get this sometime sooner.
Thanks for the reply, @JRosenkranz
I'd love to wait but I have access to a large cluster of H100s for a limited time, so I wanted to make the most out of it by training as many MLPSpeculator models as possible, on various popular models. If its doable, I'd love some basic instructions on how to get this PR running and start train runs; I can figure out the rest. Different story if the PR itself isn't ready, however 😅
Hi, adding +1 to @AlpinDale. We are interested to experiment with MLP speculator, specifically, on latest Llama3.1 models.
Excellent work overall @JRosenkranz !
Hi @AlpinDale @vdabravolski,
PR35 is outdated. We expect to release a stable code version in about 3 weeks.
We understand @AlpinDale's urgency and are trying to put this PR in shape so that you can use it in the interim. We hit issues running it against the main branches of foundation-model-stack and fms-extras, and are working on resolving it. If that doesn't work out we can point you to the specific branches for foundation-model-stack, fms-extras and fms-fsdp repos in the meanwhile so that you can train custom speculators, while we work on polishing them and merging them into their respective mains.
There are already a bunch of speculators available here and here, in case there is any overlap with your requirements. For example, the llama3-70b speculator works for llama3.1-70b as well as mentioned here (and so llama3-8b might also work for llama3.1-8b) .
@AlpinDale @vdabravolski PR35 has been updated-- it should now work with foundation-model-stack and fms-extras main branches. Added a sample training script containing example arguments to pass to the speculator training routine. Most arg names should be straightforward. For more details please refer: https://arxiv.org/pdf/2404.19124, https://pytorch.org/blog/hitchhikers-guide-speculative-decoding/ and https://github.com/foundation-model-stack/fms-fsdp/blob/main/docs/configurations.md
Is this expected to be merged soon?
Is this expected to be merged soon?
@philschmid We are expecting to have speculator training merged sometime in next 2 weeks.
@philschmid This has been finished and merged in #114. @philschmid The speculator training implementation is now available in main. Please let us know if you have any feedback or questions.
CC: @AlpinDale @vdabravolski
Closing in favor of #114