pytorch-lightning
pytorch-lightning copied to clipboard
Enable `on_before_batch_transfer` for `DPStrategy` and `IPUAccelerator`
Proposed refactor
Motivation
Currently, batch transfer hooks are disabled for IPUs and DP strategy for very valid reasons. But call to on_before_batch_transfer shouldn't be limited to whether any strategy/accelerator supports explicit transfer to batches to the device. Users should still be able to use this for batch transforms, etc...
I already created the PR: #14023
But then thought, let's discuss the design here. If the current one looks good, please gives a 👍 else here are all the proposals.
Current behavior:
fetcher(batch_to_device=strategy.batch_to_device)
def strategy.batch_to_device():
apply_batch_transfer()
def apply_batch_transfer()
on_before_batch_transfer()
transfer_batch_to_device()
on_after_batch_transfer()
Pitch
possible suggestion 1:
def strategy.batch_to_device():
on_before_batch_transfer() <- now this needs to be done in all the strategies where strategy.batch_to_device hook is overridden
apply_batch_transfer()
def apply_batch_transfer()
transfer_batch_to_device()
on_after_batch_transfer()
but IMO, this hook call should happen inside loops, not in strategies, since this is common for all of them and has no limitations.
possible solution 2 using DataFetcher (currently implemented in PR):
fetch(pre_batch_to_device=lm.on_before_batch_transfer, batch_to_device=strategy.batch_to_device)
Additional context
If you enjoy Lightning, check out our other projects! ⚡
-
Metrics: Machine learning metrics for distributed, scalable PyTorch applications.
-
Lite: enables pure PyTorch users to scale their existing code on any kind of device while retaining full control over their own loops and optimization logic.
-
Flash: The fastest way to get a Lightning baseline! A collection of tasks for fast prototyping, baselining, fine-tuning, and solving problems with deep learning.
-
Bolts: Pretrained SOTA Deep Learning models, callbacks, and more for research and production with PyTorch Lightning and PyTorch.
-
Lightning Transformers: Flexible interface for high-performance research using SOTA Transformers leveraging PyTorch Lightning, Transformers, and Hydra.
cc @justusschock @awaelchli @rohitgr7 @tchaton @borda @carmocca @ninginthecloud
completed here: #14023