torchmetrics Accuracy() fails if `get_metrics()` is called before `test_on_dataset`
Describe the bug
The torchmetrics Accuracy() class returns an error RuntimeError: You have to have determined mode. if wrapper.get_metrics() is called before wrapper.test_on_dataset, or if wrapper.test_on_dataset is not called at all.
In contrast, Baal's Accuracy() class handles this by returning 'test_accuracy': nan.
To Reproduce In this gist, a LeNet-5 model with MC Dropout is trained on the entire MNIST data for 1 epoch (no acquisitions).
The script uses Baal's Accuracy() class as standard, and adds torchmetrics Accuracy() class with the option --torchmetrics. The script evaluates on the test set as standard, and omits this with the option --no-test.
Running python baal_error_torchmetrics.py --no-test:
/home/abthuy/Documents/PhD research/active-uncertainty/src/baal_error_torchmetrics.py:45: UserWarning: You have chosen to seed training. This will turn on the CUDNN deterministic setting, which can slow down your training considerably!
set_seed(config.seed)
Use GPU: NVIDIA RTX A5000
labelling 100 observations
[1724909-MainThread] [baal.modelwrapper:train_on_dataset:83] 2023-09-08T08:28:32.092743Z [info ] Starting training dataset=100 epoch=1
[1724909-MainThread] [baal.modelwrapper:train_on_dataset:94] 2023-09-08T08:28:33.700075Z [info ] Training complete train_loss=2.309847593307495
Elapsed training time: 0:0:1
{'dataset_size': 100,
'test_accuracy': nan,
'test_loss': nan,
'train_accuracy': 0.08999999612569809,
'train_loss': 2.309847593307495}
Elapsed total time: 0:0:2
Running python baal_error_torchmetrics.py --torchmetrics:
/home/abthuy/Documents/PhD research/active-uncertainty/src/baal_error_torchmetrics.py:45: UserWarning: You have chosen to seed training. This will turn on the CUDNN deterministic setting, which can slow down your training considerably!
set_seed(config.seed)
Use GPU: NVIDIA RTX A5000
labelling 100 observations
[1724520-MainThread] [baal.modelwrapper:train_on_dataset:83] 2023-09-08T08:26:52.913297Z [info ] Starting training dataset=100 epoch=1
[1724520-MainThread] [baal.modelwrapper:train_on_dataset:94] 2023-09-08T08:26:54.735623Z [info ] Training complete train_loss=2.309847593307495
Elapsed training time: 0:0:1
[1724520-MainThread] [baal.modelwrapper:test_on_dataset:123] 2023-09-08T08:26:54.736589Z [info ] Starting evaluating dataset=10000
[1724520-MainThread] [baal.modelwrapper:test_on_dataset:133] 2023-09-08T08:26:57.792895Z [info ] Evaluation complete test_loss=2.2263691425323486
{'dataset_size': 100,
'test_accuracy': 0.19924747943878174,
'test_loss': 2.2263691425323486,
'test_torch_accuracy': 0.19900000095367432,
'train_accuracy': 0.08999999612569809,
'train_loss': 2.309847593307495,
'train_torch_accuracy': 0.09000000357627869}
Elapsed total time: 0:0:6
Running python baal_error_torchmetrics.py --torchmetrics --no-test:
/home/abthuy/Documents/PhD research/active-uncertainty/src/baal_error_torchmetrics.py:45: UserWarning: You have chosen to seed training. This will turn on the CUDNN deterministic setting, which can slow down your training considerably!
set_seed(config.seed)
Use GPU: NVIDIA RTX A5000
labelling 100 observations
[1724780-MainThread] [baal.modelwrapper:train_on_dataset:83] 2023-09-08T08:27:53.001640Z [info ] Starting training dataset=100 epoch=1
[1724780-MainThread] [baal.modelwrapper:train_on_dataset:94] 2023-09-08T08:27:54.803111Z [info ] Training complete train_loss=2.309847593307495
Elapsed training time: 0:0:1
/home/abthuy/anaconda3/envs/al_uncertainty/lib/python3.9/site-packages/torchmetrics/utilities/prints.py:36: UserWarning: The ``compute`` method of metric Accuracy was called before the ``update`` method which may lead to errors, as metric states have not yet been updated.
warnings.warn(*args, **kwargs)
Traceback (most recent call last):
File "/home/abthuy/Documents/PhD research/active-uncertainty/src/baal_error_torchmetrics.py", line 206, in <module>
main()
File "/home/abthuy/Documents/PhD research/active-uncertainty/src/baal_error_torchmetrics.py", line 115, in main
pprint(wrapper.get_metrics())
File "/home/abthuy/anaconda3/envs/al_uncertainty/lib/python3.9/site-packages/baal/metrics/mixin.py", line 71, in get_metrics
metrics = {
File "/home/abthuy/anaconda3/envs/al_uncertainty/lib/python3.9/site-packages/baal/metrics/mixin.py", line 72, in <dictcomp>
met_name: get_value(met) for met_name, met in self.metrics.items() if filter in met_name
File "/home/abthuy/anaconda3/envs/al_uncertainty/lib/python3.9/site-packages/baal/metrics/mixin.py", line 66, in get_value
val = met.compute().detach().cpu().numpy()
File "/home/abthuy/anaconda3/envs/al_uncertainty/lib/python3.9/site-packages/torchmetrics/metric.py", line 531, in wrapped_func
value = compute(*args, **kwargs)
File "/home/abthuy/anaconda3/envs/al_uncertainty/lib/python3.9/site-packages/torchmetrics/classification/accuracy.py", line 266, in compute
raise RuntimeError("You have to have determined mode.")
RuntimeError: You have to have determined mode.
Expected behavior
The torchmetrics Accuracy() class should also return 'test_torch_accuracy': nan, just like Baal's Accuracy() class.
Version:
- OS: Ubuntu 20.04
- Python: 3.9.16
- Baal version: 1.8.0
Additional context /
I found a way to circumvent this problem in my code. When I want the training metrics before running test_on_dataset, I can just do get_metrics("train") and torchmetrics won't throw an error.
Thank you for submitting this issue!
I'll take a look more deeply over the weekend, I don't have a super good idea how to fix this right now unfortunately.