sensAI icon indicating copy to clipboard operation
sensAI copied to clipboard

Add support for MPS backend

Open schroedk opened this issue 10 months ago • 3 comments

The current class TorchModel has the following init:

class TorchModel(ABC, ToStringMixin):
    """
    sensAI abstraction for torch models, which supports one-line training, allows for convenient model application,
    has basic mechanisms for data scaling, and soundly handles persistence (via pickle).
    An instance wraps a torch.nn.Module, which is constructed on demand during training via the factory method
    createTorchModule.
    """
    log: logging.Logger = log.getChild(__qualname__)

    def __init__(self, cuda=True) -> None:
        self.cuda: bool = cuda
        self.module: Optional[torch.nn.Module] = None
        self.outputScaler: Optional[TensorScaler] = None
        self.inputScaler: Optional[TensorScaler] = None
        self.trainingInfo: Optional[TrainingInfo] = None
        self._gpu: Optional[int] = None
        self._normalisationCheckThreshold: Optional[int] = 5

and is responsible for putting the inputs of the torch model the corresponding device (here):

if self._is_cuda_enabled():
            torch.cuda.set_device(self._gpu)
            inputs = [t.cuda() for t in inputs]

I would like to suggest to include support for different torch backends, in particular the MPS-backend for Apple machines.

My first impression is, that this could be a breaking change, so let's discuss here.

schroedk avatar Mar 26 '24 16:03 schroedk