accelerate
accelerate copied to clipboard
Let `device_placement` parameter can be overridden for different type object
I'd like that model(with Optimizer) and DataLoader use different device_placement value to be possible. i.e., I want to let accelerate deal with device placement for model, but not for dataloader.
Use case: I have a complicated dataloading process(e.g. use DALI), so I'd like to let my DataLoader act like a Sampler, i.e. only provide indices. If I set device_placement=True, then indices provided by DataLoader will be send to device and I need to send it back to cpu for my dataloading process. If I set device_placement=False, I need to deal with device placement for model and optimizer by myself.
It is true that I can use identity function instead of default_collate as collate_fn of my DataLoader, so that indices as list of int without to method will be ignored by send_to_device. But I think it is more explicit and elegant to make DataLoader don't deal with device placement, than to make data "incompatible" with send_to_device.
In addition, this behavior is not well documented: only the doc string of prepare_data_loader says "put_on_device only works if the batches are nested list, tuples or dictionaries of tensors"(and actually works for objects with to method instead of only tensors), which is vague too - what will happen if it doesn't work: Raising an error, ignoring whole container or ignoring one element? The Args Spec of send_to_device says data should be nested list/tuple/dictionary of torch.Tensor, users might expect an error will be raise if input doesn't satisfy the Spec.
It is also true that I can manually set DataLoaderShard.device = None to achieve what I want. But that depends on the implementation, not the interface.
We could indeed support device_placement=True or a list of strings (like ["model", "optimizer"]) to allow for different placements.
Sounds good! Or maybe pass device_placement overridden values in prepare?
arg0, arg1, arg2 = accelerator.prepare(arg0, arg1, arg2, device_placement=[True, True, False])
Sorry, it took me a bit of time but the PR above introduces the API you requested!