graphium
graphium copied to clipboard
Torchmetrics usage improvements with classes instead of functionals
Changelogs
- [x] Relaxing the constraints on Torchmetrics version
- [x] Changing the code to use the
update
andcompute
to avoid memory issues with large validation set - [x] Major changes and cleanup to the
predictor_summaries.py
to simply handle metrics updates and compute + new unit-tests - [x] Fixing all unit-tests, including those testing the training and finetuning of the models
- [x] Adding more unit-tests to cover the changes in
MetricWrapper
- [x] Adding more unit-tests to cover
SingleTaskSummary
andMultiTaskSummary
- [x] Adding more unittests in
test_training.py
andtest_finetuning.py
to make sure that the full pipeline runs - [x] Moving to a
ProgressBar
callback to fix the progress bar logging - [x] Running on a real task and ensuring that we recover the same performance, and that all metrics are reported correctly on CPU
- [x] Running again on single GPU
- [ ] Running again with DDP multi-gpus, and that they give the expected values (all predictions are synced before computing the metrics)
- [ ] Update all the config files to be compatible with new dataloader, and some changes in torchmetrics (example, f1score and accuracy now require the
task
parameter.
Checklist:
- [x] _Was this PR discussed in an issue? Issue #466
- [x] Add tests to cover the fixed bug(s) or the new introduced feature(s) (if appropriate).
- [x] Update the API documentation is a new function is added, or an existing one is deleted.
- [x] Write concise and explanatory changelogs above.
- [x] If possible, assign one of the following labels to the PR:
feature
,fix
ortest
(or ask a maintainer to do it for you).
discussion related to that PR