ignite icon indicating copy to clipboard operation
ignite copied to clipboard

Add device arg to Engine

Open vfdev-5 opened this issue 4 years ago • 3 comments

🚀 Feature

The idea is to simplify processing_function with an optional device arg to Engine:

def train_step(engine, batch):
        x, y = batch[0], batch[1]
-       if x.device != device:
-           x = x.to(device, non_blocking=True)
-           y = y.to(device, non_blocking=True)
         ...
         y = model(x)
         ...

device = "cuda"
trainer = Engine(train_step, device=device)

A possible implementation could be by adding a handler as in the example:

def to_device_handler(engine, device):
    engine.state.batch = apply_to_tensor(engine.state.batch, lambda x: x.to(device, non_blocking=True))


class Engine(Serializable):

    def __init__(self, process_function: Callable, device: Optional[Union[str, torch.device]] = None):
        ...
        if device is not None:
            self.add_event_handler(Events.GET_BATCH_COMPLETED, to_device_handler, device)

This can support the idea of https://github.com/pytorch/ignite/issues/1949

vfdev-5 avatar May 16 '21 00:05 vfdev-5

I would like to work on this issue.

01-vyom avatar Jun 21 '21 15:06 01-vyom

@01-vyom thanks for proposing a help with the issue! This issue needs to be updated as current API is not ideal and technically we do not want to specify the device as Engine's arg but inherit it from idist.device()... As it is still marked as "needs-discussion", I suggest to select another one from the list: https://github.com/pytorch/ignite/issues?q=is%3Aissue+is%3Aopen+label%3A%22help+wanted%22 What do you think ?

vfdev-5 avatar Jun 21 '21 22:06 vfdev-5

Ok, will find another issue 👍🏻

01-vyom avatar Jun 22 '21 03:06 01-vyom