GraphGym
GraphGym copied to clipboard
Crash at end of regression runs
When I make a run with cfg.dataset.task_type = 'regression'
, the code crashes at the end of the run. The error message is:
Traceback (most recent call last):
File "main_pyg.py", line 55, in <module>
agg_runs(cfg.out_dir, cfg.metric_best)
File "~/Code/GraphGym/graphgym/utils/agg_runs.py", line 100, in agg_runs
[stats[metric] for stats in stats_list])
File "~/Code/GraphGym/graphgym/utils/agg_runs.py", line 100, in <listcomp>
[stats[metric] for stats in stats_list])
KeyError: 'accuracy'
The problem seems to be that accuracy
is not a metric logged for regression tasks. Here are the relevant lines in agg_runs.py
:
if metric_best == 'auto':
metric = 'auc' if 'auc' in stats_list[0] else 'accuracy'
Here's a fix:
if metric_best == 'auto':
if cfg.dataset.task_type == 'classification':
metric = 'auc' if 'auc' in stats_list[0] else 'accuracy'
elif cfg.dataset.task_type == 'regression':
metric = 'mse'