accelerate icon indicating copy to clipboard operation
accelerate copied to clipboard

Let `device_placement` parameter can be overridden for different type object

Open YouJiacheng opened this issue 3 years ago • 2 comments

I'd like that model(with Optimizer) and DataLoader use different device_placement value to be possible. i.e., I want to let accelerate deal with device placement for model, but not for dataloader.

Use case: I have a complicated dataloading process(e.g. use DALI), so I'd like to let my DataLoader act like a Sampler, i.e. only provide indices. If I set device_placement=True, then indices provided by DataLoader will be send to device and I need to send it back to cpu for my dataloading process. If I set device_placement=False, I need to deal with device placement for model and optimizer by myself.

It is true that I can use identity function instead of default_collate as collate_fn of my DataLoader, so that indices as list of int without to method will be ignored by send_to_device. But I think it is more explicit and elegant to make DataLoader don't deal with device placement, than to make data "incompatible" with send_to_device. In addition, this behavior is not well documented: only the doc string of prepare_data_loader says "put_on_device only works if the batches are nested list, tuples or dictionaries of tensors"(and actually works for objects with to method instead of only tensors), which is vague too - what will happen if it doesn't work: Raising an error, ignoring whole container or ignoring one element? The Args Spec of send_to_device says data should be nested list/tuple/dictionary of torch.Tensor, users might expect an error will be raise if input doesn't satisfy the Spec.

It is also true that I can manually set DataLoaderShard.device = None to achieve what I want. But that depends on the implementation, not the interface.

YouJiacheng avatar Aug 10 '22 10:08 YouJiacheng

We could indeed support device_placement=True or a list of strings (like ["model", "optimizer"]) to allow for different placements.

sgugger avatar Aug 10 '22 12:08 sgugger

Sounds good! Or maybe pass device_placement overridden values in prepare?

arg0, arg1, arg2 = accelerator.prepare(arg0, arg1, arg2, device_placement=[True, True, False])

YouJiacheng avatar Aug 10 '22 13:08 YouJiacheng

Sorry, it took me a bit of time but the PR above introduces the API you requested!

sgugger avatar Sep 23 '22 14:09 sgugger