llm-foundry
llm-foundry copied to clipboard
Adding Custom Embedding, Enabling us to initialize on Heterogeneous Devices
The below code:
- adds a SharedEmbedding class that let's us get rid of a
F.linearcall. This is necessary with certain wrapping structures (our HF ones), otherwise FSDP emits a strange error. - Changes how we wrap HF modules, this allows us to initialize models on heterogeneous devices e.g. on rank 0 initialize on
cpuand on all other devicesmeta. - NOTE: this change will break tied word embeddings.
@alextrott16 1. Yeah let me create a separate PR for the HF changes. We'd need to propagate it to also enc-dec. Also I think we need to pass in a flag for the mixed initializations? see here: https://github.com/mosaicml/llm-foundry/pull/298#discussion_r1223384528
- It should be fine because the module name doesn't change e.g. it's still
self.wte