torchinfo icon indicating copy to clipboard operation
torchinfo copied to clipboard

Bugfix/get device

Open DrMicrobit opened this issue 6 months ago • 0 comments

This PR resolves #371: M2 Mac: Runtime error in training of model after call to torchinfo.summary()

For torchinfo 1.8 on a Mac with M2 chip, the following code resulted in a runtime error:

device = torch.accelerator.current_accelerator()
model = nn.Sequential(nn.Flatten(), nn.Linear(3072, 10)).to(device)
summary(model, input_size=(batch_size, 3, 32, 32))
...
out = model(data)

with the error message

RuntimeError: Tensor for argument weight is on cpu but expected on mps

The same code ran fine on Linux with a Nvidia card.

Cause of bug: In torchinfo.py, the function get_device() seems to be focused on recognising only CUDA as accelerator, whereas other platforms may have different accelerators. E.g., M-chip Macs have "mps". This apparently leads to torchinfo pushing the model to the "cpu" when device= was not given in the call to summary(), which then leads to a runtime error during model training (or evaluation) when the data is on the accelerator and the model (or parts of it) are on the CPU.

Bug fix: I have create a PR that should fix the bug for any accelerator recognised by PyTorch.

New behaviour of get_device(): Unchanged:

  • If input_data is given, the device should not be changed (to allow for multi-device models, etc.)

Changed:

  • Otherwise gets device of first parameter of model and returns it,
  • otherwise returns current accelerator if it is available,
  • otherwise returns cpu.

DrMicrobit avatar Jul 09 '25 12:07 DrMicrobit