mmengine
mmengine copied to clipboard
[Feature] Support calculating loss during validation
Background
Since early stopping requires validation loss as a possible metric, mmengine currently does not support calculating and parsing validation loss as a metric.
However, due to the inconsistency of model implementations and the fact that calculating validation loss is not a common requirement, the process of calculating validation loss should not be initiated by mmengine, but rather, initiated by the model and returned by mmengine with a convention to be parsed and returned as a metric.
Thus this PR aims to implement this return-and-resolve convention without introducing breaking change.
Design
In order not to introduce breaking change, we chose to return the loss computed by the model at val_step
(model.forward
with mode='predict'
or predict
) wrapped by BaseDataElement
and append after the val step result.
Therefore, mmengine needs to get the last item of the result of val_step
in ValLoop
and determine whether it is validation loss or not. If it is validation loss, it will perform the relevant computation and return it at the end of the ValLoop
, and then compute other metrics based on the items other than the validation loss, e.g., the accuracy, etc. If it is not a val loss, it will not be processed.
Adaptation
Custom Model
Take https://github.com/open-mmlab/mmengine/blob/02f80e8bdd38f6713e04a872304861b02157905a/examples/distributed_training.py#L14-#L25 as an example.
class MMResNet50(BaseModel):
def __init__(self):
super().__init__()
self.resnet = torchvision.models.resnet50()
def forward(self, imgs, labels, mode):
x = self.resnet(imgs)
if mode == 'loss':
return {'loss': F.cross_entropy(x, labels)}
elif mode == 'predict':
- return x, labels
+ val_loss = {'loss': F.cross_entropy(x, labels)}
+ return x, labels, BaseDataElement(loss=val_loss)
MMPreTrain
Take https://github.com/open-mmlab/mmpretrain/blob/17a886cb5825cd8c26df4e65f7112d404b99fe12/mmpretrain/models/classifiers/image.py#L248-L249 as an example.
def predict(self,
inputs: torch.Tensor,
data_samples: Optional[List[DataSample]] = None,
**kwargs) -> List[DataSample]:
"""Predict results from a batch of inputs.
Args:
inputs (torch.Tensor): The input tensor with shape
(N, C, ...) in general.
data_samples (List[DataSample], optional): The annotation
data of every samples. Defaults to None.
**kwargs: Other keyword arguments accepted by the ``predict``
method of :attr:`head`.
"""
feats = self.extract_feat(inputs)
- return self.head.predict(feats, data_samples, **kwargs)
+ preds = self.head.predict(feats, data_samples, **kwargs)
+ loss = self.head.loss(feats, data_samples)
+ loss_sample = DataSample(loss=loss)
+ preds.append(loss_sample)
+ return preds
MMPose
Calculating loss in this way maybe not correct.
Take https://github.com/open-mmlab/mmpose/blob/5a3be9451bdfdad2053a90dc1199e3ff1ea1a409/mmpose/models/pose_estimators/topdown.py#L99-#L120 as an example.
def predict(self, inputs: Tensor, data_samples: SampleList) -> SampleList:
"""Predict results from a batch of inputs and data samples with post-
processing.
Args:
inputs (Tensor): Inputs with shape (N, C, H, W)
data_samples (List[:obj:`PoseDataSample`]): The batch
data samples
Returns:
list[:obj:`PoseDataSample`]: The pose estimation results of the
input images. The return value is `PoseDataSample` instances with
``pred_instances`` and ``pred_fields``(optional) field , and
``pred_instances`` usually contains the following keys:
- keypoints (Tensor): predicted keypoint coordinates in shape
(num_instances, K, D) where K is the keypoint number and D
is the keypoint dimension
- keypoint_scores (Tensor): predicted keypoint scores in shape
(num_instances, K)
"""
assert self.with_head, (
'The model must have head to perform prediction.')
if self.test_cfg.get('flip_test', False):
_feats = self.extract_feat(inputs)
_feats_flip = self.extract_feat(inputs.flip(-1))
feats = [_feats, _feats_flip]
+ loss = self.head.loss(_feats, data_samples, train_cfg=self.train_cfg)
else:
feats = self.extract_feat(inputs)
+ loss = self.head.loss(feats, data_samples, train_cfg=self.train_cfg)
preds = self.head.predict(feats, data_samples, test_cfg=self.test_cfg)
if isinstance(preds, tuple):
batch_pred_instances, batch_pred_fields = preds
else:
batch_pred_instances = preds
batch_pred_fields = None
results = self.add_pred_to_datasample(batch_pred_instances,
batch_pred_fields, data_samples)
+ results.append(loss_sample)
return results
In addition, you should add dict(type='GenerateTarget', encoder=codec)
to val_pipeline
similar to train_pipeline
.
Hope to merge val loss into mmengine as soon as possible, which is a very useful feature