tabnet
tabnet copied to clipboard
Adding Support for Apple Silicon Mac GPU
Feature request
- I want to make this library compatible with Apple's GPU but it needs two lines of code to be modified.
What is the expected behavior?
-
Currently, running a training on Apple's GPU almost works by setting the
device_nameto"mps". Yet, at the end of the training whenTabModel.explainmethod is called, it raises an error. -
Specifically, if I started the training with the following initialization of the model, the line 354 of
TabModel.explainraisesTypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.regressor = TabNetRegressor( optimizer_fn=torch.optim.Adam, optimizer_params=dict(lr=2e-2), device_name="mps", mask_type="entmax", )Then the error comes from the following line. This is because the
datais in float64 while Apple's GPU only supports float32. https://github.com/dreamquark-ai/tabnet/blob/2c0c4ebd2bb1cb639ea94ab4b11823bc49265588/pytorch_tabnet/abstract_model.py#L353-L356
What is motivation or use case for adding/changing the behavior?
- I believe utilizing GPU on training benefits users of Apple computers.
How should this be implemented in your opinion?
- I confirmed that adding the two lines below to the method solves the issue.
for batch_nb, data in enumerate(dataloader):
+ if self.device == torch.device("mps"):
+ data = data.to(torch.float32)
data = data.to(self.device).float()
M_explain, masks = self.network.forward_masks(data)
Are you willing to work on this yourself? yes
Hello @NMZ0429,
Thanks for this proposal, feel free to open a PR!