BasicSR icon indicating copy to clipboard operation
BasicSR copied to clipboard

XLA support

Open pcuenca opened this issue 3 years ago • 1 comments

I'm working on this for my own use, so I can train in Google Cloud TPU instances. So far I've done the minimum adjustments necessary to make it work for me, but I can try to improve it if you are interested in this feature. Some of the topics remaining to be addressed are:

  • State restoration in load_resume_state.
  • Net and weights loading optimization. We should load the network and any pre-trained weights just once in the main process, then copy to the rest of the devices using xmp.MpModelWrapper().
  • Maybe use torch_xla.distributed.parallel_loader instead of EnlargedSampler. I made some preliminary tests and both seem to work fine, so I kept EnlargedSampler for simplicity.
  • Adapt tests / inference code. So far I only need the training loop for my purposes.
  • Documentation updates.

XLA is enabled by using a new accelerator: xla option (default: cuda) in the configuration file. Setting num_gpu to auto would make the training script use all the available TPU devices. Note that XLA supports either 1 device or all of them (8, typically). It is not possible to use more than 1 and less than the number of installed TPU cores, but I haven't added a test for that: the code will crash if you use something less than the maximum.

You need to launch the training process without any parallel launcher – the code in this PR forks the process automatically, following the recommended approach I've seen in all PyTorch-XLA tutorials. So, you launch as usual:

python basicsr/train.py -opt <your_configuration.yml>

and the script will parallelize across devices nonetheless.

One limitation is that this only works to parallelize training across TPU nodes in the same computer; it does not work with distributed systems.

I tried to minimize the impact on model architectures for them to support XLA training. So far I have only adapted sr_model and the required change was minimal, but this could vary in other architectures.

Overall, I feel a bit unhappy that the parallelization works by special-casing a few places in the existing code. This is caused in part by the peculiarities of PyTorch-XLA, which clash with the previous assumptions that distributed training would require a common paradigm that no longer holds: a distributed "launcher", and the use of the torch.distributed interface. However, it is the simplest way I could think of, without having to perform a major refactor of the existing code, or completely duplicating the training script. If you can think of a better alternative, by all means please let me know.

Thanks a lot for making this project available!

pcuenca avatar Feb 23 '22 17:02 pcuenca

This pull request introduces 5 alerts when merging 36dca17f2dc637c4b932c4c7e785092a2901ef8d into 6697f41600769d43ea201db5bc02100c095d682f - view on LGTM.com

new alerts:

  • 3 for Except block handles 'BaseException'
  • 2 for Unused import

lgtm-com[bot] avatar Feb 23 '22 18:02 lgtm-com[bot]