pytorch-image-models icon indicating copy to clipboard operation
pytorch-image-models copied to clipboard

feat(train): add validation metrics and distributed support

Open ha405 opened this issue 3 months ago • 0 comments

This PR extends the validation metrics functionality (precision, recall, F1-score) to the train.py script.

Changes:

  • The validate function within train.py now supports the --metrics-avg flag.
  • Implemented torch.distributed.all_gather to correctly collect predictions and targets from all GPUs before calculating metrics on the primary process.
  • The feature remains a soft dependency on scikit-learn and is disabled by default.

This ensures that users can get these more detailed metrics during training, even in a multi-GPU environment.

ha405 avatar Aug 22 '25 06:08 ha405