fms-fsdp
fms-fsdp copied to clipboard
Enable HF PretrainedModel loading for speculative model training
This PR enables HF PretrainedModel loading. To use this feature, simply set the architecture to "hf_pretrained", and the variant to a huggingface variant (model_id). This was enabled my removing the need to create special adapters, by wrapping a model in a HiddenStatesExtractor (extracts hidden states from base model). With this new wrapper, the adapters and overridden model classes that used include_embeds were not required, as well as generate could be used from fms main
Nice, this does make more sense once models are partitioned into headless/head components!
Couldn't follow the reset logic. Rest everything looks good!
Couldn't follow the reset logic. Rest everything looks good!
Resetting always occurs on prefill. Past_key_value_states=None on every prefill (stage 1 always has past_key_value_states=None, stage 2 sets past_key_value_states to None on the first call to the model). This way, we always get the latest hidden_states_output.