pytorch-image-models
pytorch-image-models copied to clipboard
feat(train): add validation metrics and distributed support
This PR extends the validation metrics functionality (precision, recall, F1-score) to the train.py script.
Changes:
- The
validatefunction withintrain.pynow supports the--metrics-avgflag. - Implemented
torch.distributed.all_gatherto correctly collect predictions and targets from all GPUs before calculating metrics on the primary process. - The feature remains a soft dependency on
scikit-learnand is disabled by default.
This ensures that users can get these more detailed metrics during training, even in a multi-GPU environment.