fms-fsdp icon indicating copy to clipboard operation
fms-fsdp copied to clipboard

Enable HF PretrainedModel loading for speculative model training

Open JRosenkranz opened this issue 1 year ago • 3 comments

This PR enables HF PretrainedModel loading. To use this feature, simply set the architecture to "hf_pretrained", and the variant to a huggingface variant (model_id). This was enabled my removing the need to create special adapters, by wrapping a model in a HiddenStatesExtractor (extracts hidden states from base model). With this new wrapper, the adapters and overridden model classes that used include_embeds were not required, as well as generate could be used from fms main

JRosenkranz avatar Oct 18 '24 18:10 JRosenkranz

Nice, this does make more sense once models are partitioned into headless/head components!

daviswer avatar Oct 18 '24 19:10 daviswer

Couldn't follow the reset logic. Rest everything looks good!

sahilsuneja1 avatar Oct 21 '24 15:10 sahilsuneja1

Couldn't follow the reset logic. Rest everything looks good!

Resetting always occurs on prefill. Past_key_value_states=None on every prefill (stage 1 always has past_key_value_states=None, stage 2 sets past_key_value_states to None on the first call to the model). This way, we always get the latest hidden_states_output.

JRosenkranz avatar Oct 22 '24 20:10 JRosenkranz