nnUNet icon indicating copy to clipboard operation
nnUNet copied to clipboard

ImportError: cannot import name 'GradScaler' from 'torch'

Open aymuos15 opened this issue 9 months ago • 3 comments

Traceback (most recent call last):
  File "/usr/local/bin/nnUNetv2_train", line 5, in <module>
   from nnunetv2.run.run_training import run_training_entry
  File "/usr/local/lib/python3.10/dist-packages/nnunetv2/run/run_training.py", line 13, i
n <module>
   from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer
  File "/usr/local/lib/python3.10/dist-packages/nnunetv2/training/nnUNetTrainer/nnUNetTra
iner.py", line 43, in <module>
   from torch import GradScaler
ImportError: cannot import name 'GradScaler' from 'torch' (/usr/local/lib/python3.10/dist
-packages/torch/__init__.py)

My Dockerfile has this: nvcr.io/nvidia/pytorch:24.01-py3 The only other thing I install is the library itself.

Any suggestions to fix?

aymuos15 avatar Mar 17 '25 17:03 aymuos15

set the gradscaler param to None for now. Would like a better suggestion tho if possible.

aymuos15 avatar Mar 17 '25 19:03 aymuos15

Hi @aymuos15 , what's your torch version? Can you try upgrading it to the latest stable release and try again?

seziegler avatar Mar 24 '25 15:03 seziegler

Hello @seziegler ! I came across the same issue, using the default pyproject.toml of the repo. The minimum pytorch version is set to v2.1.2 while the GradScaler class import in nnUnetTrainer is available in the core pytorch library from v2.3.0 (cf. torch/init.py)

It would be nice to have an update of the pyproject.toml for users that forget to install pytorch first. Do I need to open an issue for this ?

rcremese avatar Jun 02 '25 12:06 rcremese

Note that git blaming the import line, we can see that the syntax changed with https://github.com/MIC-DKFZ/nnUNet/commit/74799d51be224c6a3ea37e71c94ffcc069b5d690, which was first included in nnunetv2==2.6.0.

This means that if you want to use an older version of torch (<2.3), you can't use newer versions of nnunetv2. This is inconvenient, because torch==2.2.2 was the last version to support Intel Macs, meaning that nnunetv2==2.6.0 broke compatibility with a lot of older machines (likely to be in use by medical professionals without the funds to upgrade to M1 Macs). Even though training is costly, inference is still possible (albeit slow) on these kinds of machines, so downstream nnunetv2 users who deploy models to users with older machines get bitten here.

Perhaps it would have been possible to do something like:

try:
   from torch import GradScaler           # torch >= 2.3
except ImportError:
   from torch.cuda.amp import GradScaler  # torch < 2.3

That way nnunetv2 would have stayed backwards-compatible with torch for Intel Macs.

joshuacwnewton avatar Jun 26 '25 18:06 joshuacwnewton