gpt-neox
gpt-neox copied to clipboard
Migrate tensor parallelism code to use OSLO
Is your feature request related to a problem? Please describe. Would be good to remove the megatron tensor parallelism code from NeoX, and OSLO currently has support for this, and a slightly nicer interface.
Describe the solution you'd like
Steps:
- [ ] Rewrite all current modules as plain pytorch implementations, removing the
mpudependency from any internal code as much as possible. (so, anything that's currently anmpu.[Column|Row]ParallelLinearormpu.VocabParallelEmbeddingshould be replaced with its plain pytorch equivalent (nn.Linear/nn.Embeddingrespectively). - [ ] Write a mapping for neox modules, which oslo uses to handle parallelization.
- [ ] Ensure backwards compatibility
I will actively support this work.
The main problem is that currently the model is loaded on the CPU and then moved to the GPU. OSLO was originally designed for transformers, and there was no way to pass downloaded checkpoints directly to the GPU in the transformers. (At least when I'm developing, so I didn't care about this) But we need to implement something like deepspeed.ZeroInit internally so that it's allocated to the GPU from scratch. I will try this right from tomorrow.
@hyunwoongko actually in neox we also load onto the CPU and then move to the GPU, so i'm not sure this is a problem
The main problem is that currently the model is loaded on the CPU and then moved to the GPU. OSLO was originally designed for transformers, and there was no way to pass downloaded checkpoints directly to the GPU in the transformers. (At least when I'm developing, so I didn't care about this) But we need to implement something like deepspeed.ZeroInit internally so that it's allocated to the GPU from scratch. I will try this right from tomorrow.
this is actually something we have a work-around for. I don't know if Transformers ever got around to merging it though.
@sdtblck please check my branch. https://github.com/EleutherAI/gpt-neox/tree/kevin_new I am restructuring our code based on plain torch.
@sdtblck Did you check my branch?
@hyunwoongko -- Would you like to restart this effort?