mmengine icon indicating copy to clipboard operation
mmengine copied to clipboard

[Feature] Support calculating loss during validation

Open fanqiNO1 opened this issue 1 year ago • 1 comments

Background

Since early stopping requires validation loss as a possible metric, mmengine currently does not support calculating and parsing validation loss as a metric.

However, due to the inconsistency of model implementations and the fact that calculating validation loss is not a common requirement, the process of calculating validation loss should not be initiated by mmengine, but rather, initiated by the model and returned by mmengine with a convention to be parsed and returned as a metric.

Thus this PR aims to implement this return-and-resolve convention without introducing breaking change.

Design

In order not to introduce breaking change, we chose to return the loss computed by the model at val_step (model.forward with mode='predict' or predict) wrapped by BaseDataElement and append after the val step result.

Therefore, mmengine needs to get the last item of the result of val_step in ValLoop and determine whether it is validation loss or not. If it is validation loss, it will perform the relevant computation and return it at the end of the ValLoop, and then compute other metrics based on the items other than the validation loss, e.g., the accuracy, etc. If it is not a val loss, it will not be processed.

Adaptation

Custom Model

Take https://github.com/open-mmlab/mmengine/blob/02f80e8bdd38f6713e04a872304861b02157905a/examples/distributed_training.py#L14-#L25 as an example.

class MMResNet50(BaseModel):

    def __init__(self):
        super().__init__()
        self.resnet = torchvision.models.resnet50()

    def forward(self, imgs, labels, mode):
        x = self.resnet(imgs)
        if mode == 'loss':
            return {'loss': F.cross_entropy(x, labels)}
        elif mode == 'predict':
-          return x, labels
+          val_loss = {'loss': F.cross_entropy(x, labels)}
+          return x, labels, BaseDataElement(loss=val_loss)

MMPreTrain

Take https://github.com/open-mmlab/mmpretrain/blob/17a886cb5825cd8c26df4e65f7112d404b99fe12/mmpretrain/models/classifiers/image.py#L248-L249 as an example.

    def predict(self,
                inputs: torch.Tensor,
                data_samples: Optional[List[DataSample]] = None,
                **kwargs) -> List[DataSample]:
        """Predict results from a batch of inputs.

        Args:
            inputs (torch.Tensor): The input tensor with shape
                (N, C, ...) in general.
            data_samples (List[DataSample], optional): The annotation
                data of every samples. Defaults to None.
            **kwargs: Other keyword arguments accepted by the ``predict``
                method of :attr:`head`.
        """
        feats = self.extract_feat(inputs)
-       return self.head.predict(feats, data_samples, **kwargs)
+       preds = self.head.predict(feats, data_samples, **kwargs)
+       loss = self.head.loss(feats, data_samples)
+       loss_sample = DataSample(loss=loss)
+       preds.append(loss_sample)
+       return preds

MMPose

Calculating loss in this way maybe not correct.

Take https://github.com/open-mmlab/mmpose/blob/5a3be9451bdfdad2053a90dc1199e3ff1ea1a409/mmpose/models/pose_estimators/topdown.py#L99-#L120 as an example.

    def predict(self, inputs: Tensor, data_samples: SampleList) -> SampleList:
        """Predict results from a batch of inputs and data samples with post-
        processing.

        Args:
            inputs (Tensor): Inputs with shape (N, C, H, W)
            data_samples (List[:obj:`PoseDataSample`]): The batch
                data samples

        Returns:
            list[:obj:`PoseDataSample`]: The pose estimation results of the
            input images. The return value is `PoseDataSample` instances with
            ``pred_instances`` and ``pred_fields``(optional) field , and
            ``pred_instances`` usually contains the following keys:

                - keypoints (Tensor): predicted keypoint coordinates in shape
                    (num_instances, K, D) where K is the keypoint number and D
                    is the keypoint dimension
                - keypoint_scores (Tensor): predicted keypoint scores in shape
                    (num_instances, K)
        """
        assert self.with_head, (
            'The model must have head to perform prediction.')

        if self.test_cfg.get('flip_test', False):
            _feats = self.extract_feat(inputs)
            _feats_flip = self.extract_feat(inputs.flip(-1))
            feats = [_feats, _feats_flip]
+           loss = self.head.loss(_feats, data_samples, train_cfg=self.train_cfg)
        else:
            feats = self.extract_feat(inputs)
+           loss = self.head.loss(feats, data_samples, train_cfg=self.train_cfg)

        preds = self.head.predict(feats, data_samples, test_cfg=self.test_cfg)

        if isinstance(preds, tuple):
            batch_pred_instances, batch_pred_fields = preds
        else:
            batch_pred_instances = preds
            batch_pred_fields = None

        results = self.add_pred_to_datasample(batch_pred_instances,
                                              batch_pred_fields, data_samples)
+       results.append(loss_sample)
        return results

In addition, you should add dict(type='GenerateTarget', encoder=codec) to val_pipeline similar to train_pipeline.

fanqiNO1 avatar Feb 22 '24 10:02 fanqiNO1

Hope to merge val loss into mmengine as soon as possible, which is a very useful feature

MikasaLee avatar Feb 23 '24 12:02 MikasaLee