state_dict_factory: llama checkpoint - support SWIGLU
DeepSpeed supports loading a checkpoint for inference with different DP/TP/PP. This requires to split/merge parameters based on their TP attributes. Currently, this is done by using model specific parameter names. This is not a good practice and should be modified.
This commit handles the required changes to support MDS LLaMA model. There are 2 changes:
- Support for lm_head.weight
- Support for mlp.h_to_4h.weight for SWIGLU
SWIGLU requires different handling, however there is no meta data available that identifies mlp.h_to_4h.weight as SWIGLU. Therefore, for now we use a hack to detect it.
@nelyahu Do we have any update?
@nelyahu Do we have any update?
ping @nelyahu
@nelyahu - closing this as stale for now, happy to come back to it, just re-open or tag us