tabnet icon indicating copy to clipboard operation
tabnet copied to clipboard

Adding Support for Apple Silicon Mac GPU

Open NMZ0429 opened this issue 1 year ago • 1 comments

Feature request

  • I want to make this library compatible with Apple's GPU but it needs two lines of code to be modified.

What is the expected behavior?

  • Currently, running a training on Apple's GPU almost works by setting the device_name to "mps". Yet, at the end of the training when TabModel.explain method is called, it raises an error.

  • Specifically, if I started the training with the following initialization of the model, the line 354 of TabModel.explain raises TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.

    regressor = TabNetRegressor(
        optimizer_fn=torch.optim.Adam,
        optimizer_params=dict(lr=2e-2),
        device_name="mps",
        mask_type="entmax", 
    )
    

    Then the error comes from the following line. This is because the data is in float64 while Apple's GPU only supports float32. https://github.com/dreamquark-ai/tabnet/blob/2c0c4ebd2bb1cb639ea94ab4b11823bc49265588/pytorch_tabnet/abstract_model.py#L353-L356

What is motivation or use case for adding/changing the behavior?

  • I believe utilizing GPU on training benefits users of Apple computers.

How should this be implemented in your opinion?

  • I confirmed that adding the two lines below to the method solves the issue.
        for batch_nb, data in enumerate(dataloader):
+            if self.device == torch.device("mps"):
+                data = data.to(torch.float32)
            data = data.to(self.device).float()

            M_explain, masks = self.network.forward_masks(data)

Are you willing to work on this yourself? yes

NMZ0429 avatar Dec 01 '24 15:12 NMZ0429

Hello @NMZ0429,

Thanks for this proposal, feel free to open a PR!

Optimox avatar Dec 14 '24 09:12 Optimox