Bugfix/get device
This PR resolves #371: M2 Mac: Runtime error in training of model after call to torchinfo.summary()
For torchinfo 1.8 on a Mac with M2 chip, the following code resulted in a runtime error:
device = torch.accelerator.current_accelerator()
model = nn.Sequential(nn.Flatten(), nn.Linear(3072, 10)).to(device)
summary(model, input_size=(batch_size, 3, 32, 32))
...
out = model(data)
with the error message
RuntimeError: Tensor for argument weight is on cpu but expected on mps
The same code ran fine on Linux with a Nvidia card.
Cause of bug:
In torchinfo.py, the function get_device() seems to be focused on recognising only CUDA as accelerator, whereas other platforms may have different accelerators. E.g., M-chip Macs have "mps".
This apparently leads to torchinfo pushing the model to the "cpu" when device= was not given in the call to summary(), which then leads to a runtime error during model training (or evaluation) when the data is on the accelerator and the model (or parts of it) are on the CPU.
Bug fix: I have create a PR that should fix the bug for any accelerator recognised by PyTorch.
New behaviour of get_device(): Unchanged:
- If input_data is given, the device should not be changed (to allow for multi-device models, etc.)
Changed:
- Otherwise gets device of first parameter of model and returns it,
- otherwise returns current accelerator if it is available,
- otherwise returns cpu.