pytorch-lightning icon indicating copy to clipboard operation
pytorch-lightning copied to clipboard

Enable `on_before_batch_transfer` for `DPStrategy` and `IPUAccelerator`

Open rohitgr7 opened this issue 3 years ago • 0 comments

Proposed refactor

Motivation

Currently, batch transfer hooks are disabled for IPUs and DP strategy for very valid reasons. But call to on_before_batch_transfer shouldn't be limited to whether any strategy/accelerator supports explicit transfer to batches to the device. Users should still be able to use this for batch transforms, etc...

I already created the PR: #14023

But then thought, let's discuss the design here. If the current one looks good, please gives a 👍 else here are all the proposals.

Current behavior:

fetcher(batch_to_device=strategy.batch_to_device)

def strategy.batch_to_device():
    apply_batch_transfer()

def apply_batch_transfer()
    on_before_batch_transfer()
    transfer_batch_to_device()
    on_after_batch_transfer()

Pitch

possible suggestion 1:

def strategy.batch_to_device():
    on_before_batch_transfer()  <- now this needs to be done in all the strategies where strategy.batch_to_device hook is overridden
    apply_batch_transfer()

def apply_batch_transfer()
    transfer_batch_to_device()
    on_after_batch_transfer()

but IMO, this hook call should happen inside loops, not in strategies, since this is common for all of them and has no limitations.

possible solution 2 using DataFetcher (currently implemented in PR):

fetch(pre_batch_to_device=lm.on_before_batch_transfer, batch_to_device=strategy.batch_to_device)

Additional context


If you enjoy Lightning, check out our other projects! ⚡

  • Metrics: Machine learning metrics for distributed, scalable PyTorch applications.

  • Lite: enables pure PyTorch users to scale their existing code on any kind of device while retaining full control over their own loops and optimization logic.

  • Flash: The fastest way to get a Lightning baseline! A collection of tasks for fast prototyping, baselining, fine-tuning, and solving problems with deep learning.

  • Bolts: Pretrained SOTA Deep Learning models, callbacks, and more for research and production with PyTorch Lightning and PyTorch.

  • Lightning Transformers: Flexible interface for high-performance research using SOTA Transformers leveraging PyTorch Lightning, Transformers, and Hydra.

cc @justusschock @awaelchli @rohitgr7 @tchaton @borda @carmocca @ninginthecloud

rohitgr7 avatar Aug 06 '22 19:08 rohitgr7

completed here: #14023

rohitgr7 avatar Sep 01 '22 18:09 rohitgr7