ColossalAI
ColossalAI copied to clipboard
[utils] lazy init.
Things to hack:
- [x] torch.Tensor factory function
- [x] eager-mode
- [x] nn.Parameter
- [x] materialize & traceable
New features
- A naive LazyTensor
- LazyInit for a model
- Layer-by-layer initialization
- AoT tracing enabled
- Shard and discard before OOM
What is Lazy?
Lazy is a behavior where the computation framework and device can take a break before the observer calls a print or any function outside of the PyTorch ecosystem. Once there is a request from the outside, PyTorch will then calculate the required values. The time when Lazyness occurs is during compilation optimization, as different devices can perform operations such as operator fusion based on their own characteristics with the help of the locally stashed computation graph information. In Colossal-AI, the entire training cycle is a time for Lazyness, as we can Lazily look at the entire computation graph before deciding how to distribute, redistribute, and compute our model.
The simplest example is model initialization: don't rush to load, first look at how the weights are distributed across multiple devices, and then load them.
What is LazyInit?
Half a year ago, Colossal-AI had the LazyInit
feature, which was intended to hijack model initialization and load the model weights piece by piece onto the device, similar to ColoInit
, and distributes them according to a reasonable ShardingSpec
before the next piece of weight fills up the machine's memory.
Naive method
LazyInit
Users familiar with Colossal-AI can proficiently use ColoInitContext
to allocate and initialize model weights, but this is not acceptable for AutoParallel.
This is because the search strategy for AutoParallel may affect the placement of model weights, and the initial model weights are all placed on a single device in the form of a Replica
. For smaller networks (i.e., ResNet50), we can tolerate loading the model first and then redistributing it. However, for ultra-large models like BLOOM
, even if a strategy like ZeRO
can be theoretically searched by the SPMDSolver
, it means that Trace
and Compile
cannot be performed if a single device cannot hold the model, and the model's Config
cannot be obtained, ending everything before it even starts. Therefore, LazyMode
is a critical step and the process of AutoParallel
is now as follows:
With this process, the initialization of ultra-large models can be well solved, and in theory, models that can be trained can actually be trained.
Details
LazyTensor
- Use
LazyTensor
instead oftorch.Tensor
.
>>> x = LazyTensor(torch.zeros, 2, 3)
>>> x += 1
>>> y = x * x
>>> y = y.cuda().half()
>>> y[0, 0] = 0
>>> y = y.materialize() # materialize the tensor
>>> print(y)
tensor([[0., 1., 1.],
[1., 1., 1.]], device='cuda:0', dtype=torch.float16)
- Generate
MetaTensor
fromLazyTensor
>>> x = LazyTensor(torch.zeros, 2, 3)
>>> x.reshape(3, 2)
>>> x = x.traceable() # generate ``MetaTensor``
>>> print(x)
MetaTensor(..., size=(3, 2), device=cpu, dtype=torch.float32)
- Use
LazyTensor
to generate shardednn.Parameter
.
>>> x = LazyTensor(torch.zeros, 2, 3)
>>> x.spec = ... # some ``ShardingSpec``
>>> x.distribute() # distribute the tensor according to the ``ShardingSpec``
LazyInit
ctx = LazyInitContext()
with ctx:
model = GPT()
with ctx.traceable(model):
gm = symbolic_trace(model, meta_args)
# choose one
ctx.materialize(model)
ctx.distribute(model)