composer
composer copied to clipboard
Initialize Models on GPU
This might not be easily possible.
In PyTorch, nn.module
instances always seem to get initialized on the CPU, not the GPU, even though it is possible to initialize tensors directly on the GPU. GPU initialization is much faster, and we're already starting to see some measurably painful slow startups for LLMs. We should investigate whether there's a way we could initialize models faster. The ideal would probably be to provide some sort of context manager that hacks PyTorch to initialize things on GPU.
It seems like PyTorch devs keep trying to avoid adding a device context manager... not entirely sure I agree with them, but that's that :(
On the bright side, most (maybe all) Pytorch modules can be passed device=...
in their init, and then they will create their tensors on the target device directly. So if we built our own model classes with argument device=...
that propagates to all its submodules, that model could be initialized directly on the GPU.
If a model is specified on the meta device, Trainer will correctly initialize on gpu if specified