composer icon indicating copy to clipboard operation
composer copied to clipboard

torch_xla multi-device support

Open dskhudia opened this issue 3 years ago • 5 comments

Adds support for multi-device training using torch_xla.

The resnet9 on cifar10 trains fine on 2 GPUs using torch_xla.

Some of the issues I ran into while adding this support:

  • with torch_xla on GPUs, dist.is_initialized()returns false
    • Everything is handled by ENV vars
  • composer.utils.dist.initialize_dist is called twice in YAHP flow.
    • Which make torch_xla thing we have 2x the number of GPUS
  • There are some env vars that we need to setup before forking
  • xm.* functions cannot be called until torch_xla env. is init properly.
    • We call them before init in constructing device object.
  • broadcast doesn't exist in current release of torch_xla
    • Implemented with all_reduce.

Test:

composer -n 2 examples/run_composer_trainer.py -f composer/yamls/models/resnet9_cifar10.yaml --train_dataset.cifar10.datadir /tmp/cifar --train_dataset.cifar10.download true --val_dataset.cifar10.datadir /tmp/cifar --val_dataset.cifar10.download true --max_duration 2ep --train_subset_num_batches 2 --device 'tpu'

dskhudia avatar Sep 22 '22 18:09 dskhudia

@hanlint : device_tpu should really be renamed to device_xla.

dskhudia avatar Sep 22 '22 18:09 dskhudia

composer.utils.dist.initialize_dist is called twice in YAHP flow.

is this function not idempotent? why does calling it more than once matter? (asking bc we'll hit this soon if people start calling it outside trainer for entrypoints)

mvpatel2000 avatar Sep 22 '22 18:09 mvpatel2000

composer.utils.dist.initialize_dist is called twice in YAHP flow.

is this function not idempotent? why does calling it more than once matter? (asking bc we'll hit this soon if people start calling it outside trainer for entrypoints)

The way it's written, calling it second time doesn't have any effect. However, for torch_xla (torch.distributed.is_initialized) always returns False for GPUs so it was going through init again and torch_xla was assuming 2x the number of devices available. It's no longer a issue as I have added a check for torch_xla as well for repeated calls.

dskhudia avatar Sep 22 '22 21:09 dskhudia

Looks good, but general comment is to break out the xla specific functionality into its own method (e.g. xla_all_reduce, xla_broadcast) for for modularity.

In most cases except broadcast it will be thin wrapper. Did you mean I should do it for every function or just where we have some extra logic?

dskhudia avatar Sep 23 '22 16:09 dskhudia

In most cases except broadcast it will be thin wrapper. Did you mean I should do it for every function or just where we have some extra logic?

Yeah, I think it would be cleaner if all functions were like:

def broadcast():
    if xla:
        _xla_broadcast(..)
    else:
         _dist_broadcast(..)  # or dist.broadcast if it really is just a one-liner

it's more verbose, but easier to read. The helper functions can be private, so no docstrings needed.

hanlint avatar Sep 25 '22 16:09 hanlint

Closing out, we'll re-visit in the future.

bandish-shah avatar Nov 03 '22 18:11 bandish-shah