skorch
skorch copied to clipboard
Allow passing uninitialized Accelerator when using AccelerateMixin
Right now, when using AccelerateMixin, the instantiated Accelerator instance needs to be passed:
accelerator = Accelerator(...)
net = AcceleratedNet(
MyModule,
accelerator=accelerator,
)
net.fit(X, y)
This new feature should allow to pass the Accelerator class directly, with possible extra arguments using the __ notation:
net = AcceleratedNet(
MyModule,
accelerator=Accelerator,
accelerator__mixed_precision='fp16',
)
net.fit(X, y)
The change should be implemented in a backwards compatible manner, so that passing the instantiated Accelerator keeps on working.
The reason for this change is not so much that we want to, for instance, grid search on the Accelerator parameters, which makes little sense. Mainly, it would be for consistency, so that the Accelerator can be instantiated and configured in the same way as other components like the module and optimizer.