fms-fsdp icon indicating copy to clipboard operation
fms-fsdp copied to clipboard

[speculator training] Speculator training

Open daviswer opened this issue 11 months ago • 11 comments

Add support for speculator training, piggybacking off the existing training utilities.

Training script and speculator-specific utilities are inside the new speculator subfolder.

Uses distributed setup, checkpointing, and dataloaders from this repo. Adds speculator-specific fields to the training config file (to be ignored during non-speculator training). It might make more sense to pull these new fields out into a separate config subclass under speculator utilities - open to suggestions.

Uses speculator architecture from fms-extras.

Uses altered Llama-7b and generate() function from base fms, allowing the speculator to access embedding vectors, not just logits/token predictions. ~~Do not merge this until that issue can be resolved.~~

daviswer avatar Mar 01 '24 20:03 daviswer

Plan is to move the include_embeds=True versions of Llama/GPTBigCode/generate() into fms-extras. Once that is done I'll update the relevant imports here and then we can push this in

daviswer avatar Mar 20 '24 17:03 daviswer

I've pulled all the include_embeds stuff out of fms into here. We now have EmbedLLaMA and EmbedGPTBigCode subclasses that override the corresponding forward function, and an altered version of generate for use only with this script. We register the subclassed models for use with get_model in the training script.

daviswer avatar Mar 29 '24 18:03 daviswer

Code is ready for review - mypy errors are import errors, ~~it doesn't have fms-extras and~~ it doesn't like the local import of train_speculator_utils. Should I move the speculator subfolder under fms_fsdp so that I can use an absolute path for that?

daviswer avatar Mar 29 '24 18:03 daviswer

Hi! What's the status on this PR? I'd like to train a few speculator models, but I'm not sure how to get started, due to a lack of documentation...

AlpinDale avatar Aug 08 '24 22:08 AlpinDale

Hi! What's the status on this PR? I'd like to train a few speculator models, but I'm not sure how to get started, due to a lack of documentation...

Hi @AlpinDale Working on getting the documentation and code ready for this. Planning to have sometime in the next 3 weeks. Will keep you updated if we get this sometime sooner.

JRosenkranz avatar Aug 09 '24 15:08 JRosenkranz

Thanks for the reply, @JRosenkranz

I'd love to wait but I have access to a large cluster of H100s for a limited time, so I wanted to make the most out of it by training as many MLPSpeculator models as possible, on various popular models. If its doable, I'd love some basic instructions on how to get this PR running and start train runs; I can figure out the rest. Different story if the PR itself isn't ready, however 😅

AlpinDale avatar Aug 09 '24 17:08 AlpinDale

Hi, adding +1 to @AlpinDale. We are interested to experiment with MLP speculator, specifically, on latest Llama3.1 models.

Excellent work overall @JRosenkranz !

vdabravolski avatar Aug 12 '24 16:08 vdabravolski

Hi @AlpinDale @vdabravolski,

PR35 is outdated. We expect to release a stable code version in about 3 weeks.

We understand @AlpinDale's urgency and are trying to put this PR in shape so that you can use it in the interim. We hit issues running it against the main branches of foundation-model-stack and fms-extras, and are working on resolving it. If that doesn't work out we can point you to the specific branches for foundation-model-stack, fms-extras and fms-fsdp repos in the meanwhile so that you can train custom speculators, while we work on polishing them and merging them into their respective mains.

There are already a bunch of speculators available here and here, in case there is any overlap with your requirements. For example, the llama3-70b speculator works for llama3.1-70b as well as mentioned here (and so llama3-8b might also work for llama3.1-8b) .

sahilsuneja1 avatar Aug 13 '24 13:08 sahilsuneja1

@AlpinDale @vdabravolski PR35 has been updated-- it should now work with foundation-model-stack and fms-extras main branches. Added a sample training script containing example arguments to pass to the speculator training routine. Most arg names should be straightforward. For more details please refer: https://arxiv.org/pdf/2404.19124, https://pytorch.org/blog/hitchhikers-guide-speculative-decoding/ and https://github.com/foundation-model-stack/fms-fsdp/blob/main/docs/configurations.md

sahilsuneja1 avatar Aug 14 '24 15:08 sahilsuneja1

Is this expected to be merged soon?

philschmid avatar Aug 20 '24 13:08 philschmid

Is this expected to be merged soon?

@philschmid We are expecting to have speculator training merged sometime in next 2 weeks.

JRosenkranz avatar Aug 20 '24 13:08 JRosenkranz

@philschmid This has been finished and merged in #114. @philschmid The speculator training implementation is now available in main. Please let us know if you have any feedback or questions.

CC: @AlpinDale @vdabravolski

JRosenkranz avatar Sep 10 '24 14:09 JRosenkranz

Closing in favor of #114

JRosenkranz avatar Sep 10 '24 14:09 JRosenkranz