XLA support
I'm working on this for my own use, so I can train in Google Cloud TPU instances. So far I've done the minimum adjustments necessary to make it work for me, but I can try to improve it if you are interested in this feature. Some of the topics remaining to be addressed are:
- State restoration in
load_resume_state. - Net and weights loading optimization. We should load the network and any pre-trained weights just once in the main process, then copy to the rest of the devices using
xmp.MpModelWrapper(). - Maybe use
torch_xla.distributed.parallel_loaderinstead ofEnlargedSampler. I made some preliminary tests and both seem to work fine, so I keptEnlargedSamplerfor simplicity. - Adapt tests / inference code. So far I only need the training loop for my purposes.
- Documentation updates.
XLA is enabled by using a new accelerator: xla option (default: cuda) in the configuration file. Setting num_gpu to auto would make the training script use all the available TPU devices. Note that XLA supports either 1 device or all of them (8, typically). It is not possible to use more than 1 and less than the number of installed TPU cores, but I haven't added a test for that: the code will crash if you use something less than the maximum.
You need to launch the training process without any parallel launcher – the code in this PR forks the process automatically, following the recommended approach I've seen in all PyTorch-XLA tutorials. So, you launch as usual:
python basicsr/train.py -opt <your_configuration.yml>
and the script will parallelize across devices nonetheless.
One limitation is that this only works to parallelize training across TPU nodes in the same computer; it does not work with distributed systems.
I tried to minimize the impact on model architectures for them to support XLA training. So far I have only adapted sr_model and the required change was minimal, but this could vary in other architectures.
Overall, I feel a bit unhappy that the parallelization works by special-casing a few places in the existing code. This is caused in part by the peculiarities of PyTorch-XLA, which clash with the previous assumptions that distributed training would require a common paradigm that no longer holds: a distributed "launcher", and the use of the torch.distributed interface. However, it is the simplest way I could think of, without having to perform a major refactor of the existing code, or completely duplicating the training script. If you can think of a better alternative, by all means please let me know.
Thanks a lot for making this project available!
This pull request introduces 5 alerts when merging 36dca17f2dc637c4b932c4c7e785092a2901ef8d into 6697f41600769d43ea201db5bc02100c095d682f - view on LGTM.com
new alerts:
- 3 for Except block handles 'BaseException'
- 2 for Unused import