torchinfo icon indicating copy to clipboard operation
torchinfo copied to clipboard

M2 Mac: Runtime error in training of model after call to torchinfo.summary()

Open DrMicrobit opened this issue 6 months ago • 0 comments

Describe the bug For torchinfo 1.8 on a Mac with M2 chip, the following code resulted in a runtime error:

device = torch.accelerator.current_accelerator()
model = nn.Sequential(nn.Flatten(), nn.Linear(3072, 10)).to(device)
summary(model, input_size=(batch_size, 3, 32, 32))
...
out = model(data)

with the error message

RuntimeError: Tensor for argument weight is on cpu but expected on mps

The same code ran fine on Linux with a Nvidia card.

Expected behavior No runtime error.

Desktop

  • OS: Apple Sequoia 15.5
  • CPU: M2 chip ("Apple silicon")
  • torchinfo version: 1.8

Quick workarounds

  1. (preferred) add "device=" to the call of torchinfo.summary(). E.g. summary(model, input_size=(batch_size, 3, 32, 32), device=device)
  2. (also works) push the model back to the device (model = model.to(device)) after a call to torchinfo.summary()

Cause of bug In torchinfo.py, the function get_device() seems to be focused on recognising only CUDA as accelerator, whereas other platforms may have different accelerators. E.g., M-chip Macs have "mps". This apparently leads to torchinfo pushing the model to the "cpu" when "device=" was not given in the call to summary(), which then leads to a runtime error during model training (or evaluation) when the data is on the accelerator and the model (or parts of it) are on the CPU.

Bug fix I have create a PR that should fix the bug for any accelerator recognised by PyTorch.

DrMicrobit avatar Jul 09 '25 12:07 DrMicrobit