ImportError: cannot import name 'GradScaler' from 'torch'
Traceback (most recent call last):
File "/usr/local/bin/nnUNetv2_train", line 5, in <module>
from nnunetv2.run.run_training import run_training_entry
File "/usr/local/lib/python3.10/dist-packages/nnunetv2/run/run_training.py", line 13, i
n <module>
from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer
File "/usr/local/lib/python3.10/dist-packages/nnunetv2/training/nnUNetTrainer/nnUNetTra
iner.py", line 43, in <module>
from torch import GradScaler
ImportError: cannot import name 'GradScaler' from 'torch' (/usr/local/lib/python3.10/dist
-packages/torch/__init__.py)
My Dockerfile has this: nvcr.io/nvidia/pytorch:24.01-py3
The only other thing I install is the library itself.
Any suggestions to fix?
set the gradscaler param to None for now. Would like a better suggestion tho if possible.
Hi @aymuos15 , what's your torch version? Can you try upgrading it to the latest stable release and try again?
Hello @seziegler !
I came across the same issue, using the default pyproject.toml of the repo. The minimum pytorch version is set to v2.1.2 while the GradScaler class import in nnUnetTrainer is available in the core pytorch library from v2.3.0 (cf. torch/init.py)
It would be nice to have an update of the pyproject.toml for users that forget to install pytorch first.
Do I need to open an issue for this ?
Note that git blaming the import line, we can see that the syntax changed with https://github.com/MIC-DKFZ/nnUNet/commit/74799d51be224c6a3ea37e71c94ffcc069b5d690, which was first included in nnunetv2==2.6.0.
This means that if you want to use an older version of torch (<2.3), you can't use newer versions of nnunetv2. This is inconvenient, because torch==2.2.2 was the last version to support Intel Macs, meaning that nnunetv2==2.6.0 broke compatibility with a lot of older machines (likely to be in use by medical professionals without the funds to upgrade to M1 Macs). Even though training is costly, inference is still possible (albeit slow) on these kinds of machines, so downstream nnunetv2 users who deploy models to users with older machines get bitten here.
Perhaps it would have been possible to do something like:
try:
from torch import GradScaler # torch >= 2.3
except ImportError:
from torch.cuda.amp import GradScaler # torch < 2.3
That way nnunetv2 would have stayed backwards-compatible with torch for Intel Macs.