skorch icon indicating copy to clipboard operation
skorch copied to clipboard

Allow passing uninitialized Accelerator when using AccelerateMixin

Open BenjaminBossan opened this issue 2 years ago • 0 comments

Right now, when using AccelerateMixin, the instantiated Accelerator instance needs to be passed:

accelerator = Accelerator(...)
net = AcceleratedNet(
    MyModule,
    accelerator=accelerator,
)
net.fit(X, y)

This new feature should allow to pass the Accelerator class directly, with possible extra arguments using the __ notation:

net = AcceleratedNet(
    MyModule,
    accelerator=Accelerator,
    accelerator__mixed_precision='fp16',
)
net.fit(X, y)

The change should be implemented in a backwards compatible manner, so that passing the instantiated Accelerator keeps on working.

The reason for this change is not so much that we want to, for instance, grid search on the Accelerator parameters, which makes little sense. Mainly, it would be for consistency, so that the Accelerator can be instantiated and configured in the same way as other components like the module and optimizer.

BenjaminBossan avatar Apr 28 '23 15:04 BenjaminBossan