xla icon indicating copy to clipboard operation
xla copied to clipboard

use XLA patched linear in FSDP (fix #3811 and #3718) and expose options on padding in all-gather and pinning memory

Open ronghanghu opened this issue 1 year ago • 2 comments

This PR applies a patch to nn.Linear (torch.nn.functional.linear) in XLA FSDP so that the nn.Linear's backward pass will use its weight parameter rather than an intermediate result. This resolves the issue in https://github.com/pytorch/xla/issues/3811 and https://github.com/pytorch/xla/issues/3718.

It is accomplished via ~a context manager xla_patched_linear around the forward pass~ a patch to the nn.Linear submodules's forward method in an FSDP-wrapped model (to be thread-safe in PJRT) that explicitly defines the backward behavior of torch.nn.functional.linear via XLAPatchedLinear.

Besides, it also has the following (backward-compatible) minor changes:

  • ~expose use_padding_in_all_gather in FSDP __init__ to allow turning off the padding to a multiple of 128 (for https://github.com/pytorch/xla/issues/3510#issuecomment-1101739677) since this sometimes does not work on a few compilers;~ expose shard_param_on_dim_0 (default False). When shard_param_on_dim_0 is set True, then shard the parameter tensors only along their first dimension (dim 0) without flattening them. This can be a workaround for those compilers that may have trouble handling flattened parameters. This option has no effect if flatten_parameters is True.
  • expose pin_layout_in_collective_ops in FSDP __init__ so that one can specify whether to pin layout in all_reduce, all_gather, and reduce_scatter in the FSDP class.
  • remove the rendezvous in init that could be undesirable in some situations.

The MNIST and ImageNet examples are updated with the new optional flag --shard_param_on_dim_0 and --pin_layout_in_collective_ops, which were tested on v3-8 TPU VM and working well.

The following two tests on MNIST and ImageNet are added to the FSDP test cases in https://github.com/pytorch/xla/pull/3431#issuecomment-1119737644

  • [x] Test MNIST nested FSDP w/o padding in all-gather on v3-8
  • [x] Test ImageNet ResNet-50 nested FSDP w/o padding in all-gather
  • [x] Test ViT 10-billion model on v3-128 w/o gradient checkpointing

[OK] Test MNIST nested FSDP w/o padding in all-gather on v3-8

python3 -u ~/xla_fsdp_dev/test/test_train_mp_mnist_fsdp_with_ckpt.py \
  --batch_size 16 --drop_last --num_epochs 2 \
  --use_nested_fsdp --shard_param_on_dim_0 --pin_layout_in_collective_ops

Results: matching expected accuracy for 2 training epochs

found 8 checkpoint files in /tmp/mnist-fsdp/final_ckpt_rank-*-of-*.pth
saved consolidated model to /tmp/mnist-fsdp/final_ckpt_consolidated.pth
Checkpoint consolidated, Accuracy=98.92 (note: it can be slightly different from the final training accuracy due to non-sync BatchNorm2d in the model)
Max Accuracy: 98.88%

[pending] Test ImageNet ResNet-50 nested FSDP w/o padding in all-gather

python3 -u ~/xla_fsdp_dev/test/test_train_mp_imagenet_fsdp.py \
  --datadir /datasets/imagenet-1k --drop_last \
  --model resnet50 --test_set_batch_size 64 --eval_interval 10 \
  --lr 0.4 --batch_size 128 --num_warmup_epochs 5 --lr_scheduler_divide_every_n_epochs 30 --lr_scheduler_divisor 10 --num_epochs 100 \
  --use_nested_fsdp --shard_param_on_dim_0 --pin_layout_in_collective_ops

Results: matching expected accuracy for batch size 128

Max Accuracy: 75.97%

cc: @hjm-aws

ronghanghu avatar Aug 04 '22 13:08 ronghanghu

Tests showed that it worked OK under the examples above. However, a context manager that patches torch.nn.functional.linear is not thread-safe under PJRT (i.e. one thread might exit the context scope and reset torch.nn.functional.linear while the other is still doing its forward pass).

Let me try an alternative to directly patch at the nn.Linear module level.

ronghanghu avatar Aug 04 '22 15:08 ronghanghu

@hjm-aws I added a new commit that introduces the option shard_param_on_dim_0 (default False). When shard_param_on_dim_0 is set True, then shard the parameter tensors only along their first dimension (dim 0) without flattening them. This can be a workaround for those compilers that may have trouble handling flattened parameters. This option has no effect if flatten_parameters is True. Please take a look and see if this is what's needed in your compiler.

However, overall I don't think this is the ideal solution (only sharding the parameters along its first dimension without flattening them). While this could apply to most nn.Linear, nn.Embedding, nn.Conv2d layers, it is very likely that this won't efficiently address many cases (e.g. a user-specified module that has its own parameter defined via nn.Parameters). So I think ultimately the solution should be to address the compiler issue.

ronghanghu avatar Aug 11 '22 04:08 ronghanghu

I will try to take a look today.

JackCaoG avatar Aug 16 '22 17:08 JackCaoG