torchrec
torchrec copied to clipboard
Unified Authoring of Overlapped Optimizers, and per parameter optimizer settings
Context and Problem:
Optimizer overlap/fusion is a technique where upon calling backwards(), the optimizer step upon each parameter is applied as soon as its gradient is calculated. The benefits of this are twofold:
- Model training efficiency: We don’t need to move gradients and relook them up, they can be applied as soon as we get it.
- Memory efficiency: gradients can be released through the computation graph, so they don’t need to all be held in memory simultaneously for opt.step()
These two properties are critical for a distributed, production setting. As such, it has taken different forms across many different parallelization domains such TorchRec, DDP, FSDP
A new use case we now have is to compose these different parallelization strategies, both utilizing, in spirit, some kind of optimizer overlap (implementation may be different, e.g. TorchRec uses fused cuda kernels, but DDP/FSDP may explicitly do it, but overlap gradient bucket reductions with previous bucket parameter updates). However, one problem that we face is that there are differences in how model authors express this functionality across different domains
We propose a unified way of registering optimizers over model parameters() across distributed settings, as well as keeping behavioral consistency between the single-host and parallelized variants.
The main expected behavior that after backwards() is called the parameter is updated and its gradient is set to None (and released from memory).
For TorchRec
In TorchRec's case, we currently have a limitation on our expressibility of which parameters can have which optimizers. In a normal unsharded world you could do
opt1 = SGD(ebc.embedding_bags["table_0"].parameters())
opt2 = Adam(ebc.embedding_bags["table_1"].parameters())
But ShardedEmbeddingBagCollection, or even FusedEmbeddingBagCollection only allow you to specify one optimizer on a per module basis.
Instead, we will now read in which parameters should use which optimizer settings via the metadata attached via attach_overlapped_optimizer. The implementation details will be that in addition to grouping on sharding_type/pooling_type/data_type/etc, we will also group on "fused_params" to create separate lookup kernels for different parameter groups (if they have different optimizers).
Some future implications of this are that we will be able to get rid of fused_params as an input to EBCSharder, and fully get rid of compute_kernel.
Proposed pseudocode
model = SparseNN(
sparse_arch=EmbeddingBagCollection,
over_arch=Linear(10,10)
)
apply_overlapped_optimizer(torch.optim.SGD, model.sparse_arch["table_0"].parameters(), lr=.02)
apply_overlapped_optimizer(torch.optim.Adam, model.sparse_arch["table_1"].parameters(), lr=.04)
# chose this to be similar to torch.optim.SGD(model.sparse_arch.parameters(), lr)
apply_overlapped_optimizer(torch.optim.Adagrad, model.over_arch.parameters(), lr=.006)
model.sparse_arch = shard_embedding_modules(model.sparse_arch, [EmbeddingBagCollectionSharder()])
model.over_arch = DistributedDataParallel(model.over_arch)
optimizers = get_optimizers(model)
torch.save(optimizer.state_dict())
def apply_overlapped_optimizer(
optimizer_class: Type[torch.optim.Optimizer],
params: Iterator[nn.Parameter],
optimizer_kwargs: Dict[str, Any],
) -> List[torch.optim.Optimizer]:
...
for param in params:
# attach optimizer property/metadata
# this is param level metadata that FSDP/DDP/DMP can use to achieve behavioral consistency
param._optimizer = optimizer_class
param._optimizer_kwargs = optimizer_kwargs
param._overlapped_optimizer = optimizer_class([param], **optimizer_kwargs)
...
def optimizer_hook(*_unused):
param._overlapped_optimizer.step()
param.grad = None
# this hook can be thrown away on the sharded/parallelized module
param._acc_grad.register_hook(optimizer_hook)
Note that even in an unsharded world, you could get memory benefits from early releasing gradients.
cc @mrshenli @rohan-varma @colin2328 @divchenko @dstaay-fb @xing-liu @zhaojuanmao @wangkuiyi
- 'overlapped' is quite an odd name imho. Something more concrete is better, e.g. 'in autograd optimizer' .
- What happens if I don't specify overlapped optimizers for some parameters within ebc?
What happens if I don't specify overlapped optimizers for some parameters within ebc?
it will allocate dense kernel for that table, and you get the grads (as potentially a ShardedTensor) back. This isn't the case right now, but this will be part of the new API
behavior wise unsharded and sharded model will be identical*
IMO we should call it "apply_optimizer_in_backward". Fused/non fused is an implementation detail, and whether its done in torch.autograd or requires comms (e.g PT-D) can also be flexible
Landed as part of composability and per parameter optimizers