sensAI
sensAI copied to clipboard
Add support for MPS backend
The current class TorchModel has the following init:
class TorchModel(ABC, ToStringMixin):
"""
sensAI abstraction for torch models, which supports one-line training, allows for convenient model application,
has basic mechanisms for data scaling, and soundly handles persistence (via pickle).
An instance wraps a torch.nn.Module, which is constructed on demand during training via the factory method
createTorchModule.
"""
log: logging.Logger = log.getChild(__qualname__)
def __init__(self, cuda=True) -> None:
self.cuda: bool = cuda
self.module: Optional[torch.nn.Module] = None
self.outputScaler: Optional[TensorScaler] = None
self.inputScaler: Optional[TensorScaler] = None
self.trainingInfo: Optional[TrainingInfo] = None
self._gpu: Optional[int] = None
self._normalisationCheckThreshold: Optional[int] = 5
and is responsible for putting the inputs of the torch model the corresponding device (here):
if self._is_cuda_enabled():
torch.cuda.set_device(self._gpu)
inputs = [t.cuda() for t in inputs]
I would like to suggest to include support for different torch backends, in particular the MPS-backend for Apple machines.
My first impression is, that this could be a breaking change, so let's discuss here.