xla
xla copied to clipboard
use XLA patched linear in FSDP (fix #3811 and #3718) and expose options on padding in all-gather and pinning memory
This PR applies a patch to nn.Linear
(torch.nn.functional.linear
) in XLA FSDP so that the nn.Linear
's backward pass will use its weight parameter rather than an intermediate result. This resolves the issue in https://github.com/pytorch/xla/issues/3811 and https://github.com/pytorch/xla/issues/3718.
It is accomplished via ~a context manager xla_patched_linear
around the forward pass~ a patch to the nn.Linear submodules's forward method in an FSDP-wrapped model (to be thread-safe in PJRT) that explicitly defines the backward behavior of torch.nn.functional.linear
via XLAPatchedLinear
.
Besides, it also has the following (backward-compatible) minor changes:
- ~expose
use_padding_in_all_gather
in FSDP__init__
to allow turning off the padding to a multiple of 128 (for https://github.com/pytorch/xla/issues/3510#issuecomment-1101739677) since this sometimes does not work on a few compilers;~ exposeshard_param_on_dim_0
(defaultFalse
). Whenshard_param_on_dim_0
is setTrue
, then shard the parameter tensors only along their first dimension (dim 0) without flattening them. This can be a workaround for those compilers that may have trouble handling flattened parameters. This option has no effect ifflatten_parameters
isTrue
. - expose
pin_layout_in_collective_ops
in FSDP__init__
so that one can specify whether to pin layout in all_reduce, all_gather, and reduce_scatter in the FSDP class. - remove the rendezvous in init that could be undesirable in some situations.
The MNIST and ImageNet examples are updated with the new optional flag --shard_param_on_dim_0
and --pin_layout_in_collective_ops
, which were tested on v3-8 TPU VM and working well.
The following two tests on MNIST and ImageNet are added to the FSDP test cases in https://github.com/pytorch/xla/pull/3431#issuecomment-1119737644
- [x] Test MNIST nested FSDP w/o padding in all-gather on v3-8
- [x] Test ImageNet ResNet-50 nested FSDP w/o padding in all-gather
- [x] Test ViT 10-billion model on v3-128 w/o gradient checkpointing
[OK] Test MNIST nested FSDP w/o padding in all-gather on v3-8
python3 -u ~/xla_fsdp_dev/test/test_train_mp_mnist_fsdp_with_ckpt.py \
--batch_size 16 --drop_last --num_epochs 2 \
--use_nested_fsdp --shard_param_on_dim_0 --pin_layout_in_collective_ops
Results: matching expected accuracy for 2 training epochs
found 8 checkpoint files in /tmp/mnist-fsdp/final_ckpt_rank-*-of-*.pth
saved consolidated model to /tmp/mnist-fsdp/final_ckpt_consolidated.pth
Checkpoint consolidated, Accuracy=98.92 (note: it can be slightly different from the final training accuracy due to non-sync BatchNorm2d in the model)
Max Accuracy: 98.88%
[pending] Test ImageNet ResNet-50 nested FSDP w/o padding in all-gather
python3 -u ~/xla_fsdp_dev/test/test_train_mp_imagenet_fsdp.py \
--datadir /datasets/imagenet-1k --drop_last \
--model resnet50 --test_set_batch_size 64 --eval_interval 10 \
--lr 0.4 --batch_size 128 --num_warmup_epochs 5 --lr_scheduler_divide_every_n_epochs 30 --lr_scheduler_divisor 10 --num_epochs 100 \
--use_nested_fsdp --shard_param_on_dim_0 --pin_layout_in_collective_ops
Results: matching expected accuracy for batch size 128
Max Accuracy: 75.97%
cc: @hjm-aws
Tests showed that it worked OK under the examples above. However, a context manager that patches torch.nn.functional.linear
is not thread-safe under PJRT (i.e. one thread might exit the context scope and reset torch.nn.functional.linear
while the other is still doing its forward pass).
Let me try an alternative to directly patch at the nn.Linear
module level.
@hjm-aws I added a new commit that introduces the option shard_param_on_dim_0
(default False
). When shard_param_on_dim_0
is set True
, then shard the parameter tensors only along their first dimension (dim 0) without flattening them. This can be a workaround for those compilers that may have trouble handling flattened parameters. This option has no effect if flatten_parameters
is True
. Please take a look and see if this is what's needed in your compiler.
However, overall I don't think this is the ideal solution (only sharding the parameters along its first dimension without flattening them). While this could apply to most nn.Linear
, nn.Embedding
, nn.Conv2d
layers, it is very likely that this won't efficiently address many cases (e.g. a user-specified module that has its own parameter defined via nn.Parameters
). So I think ultimately the solution should be to address the compiler issue.
I will try to take a look today.