mmpose
mmpose copied to clipboard
[Bug] (suggested fix) `mmpose.models.pose_estimators.topdown.TopdownPoseEstimator` is unable to be symbolically traced because of untraceable `add_pred_to_datasample()` and `loss()`
Prerequisite
- [X] I have searched Issues and Discussions but cannot get the expected help.
- [X] The bug has not been fixed in the latest version(https://github.com/open-mmlab/mmpose).
Environment
computer not available at the time
Using: torch 2.0.0+cu118 torchvision: 0.15.0+cu118 mmengine: 0.10.3 mmrazor: 1.0.0 mmpose: 1.3.1
Reproduces the problem - code sample
Using mmrazor to quantize this model, I stumbled upon an error when the symbolic_trace for the fx graph was being made.
Applied fixes for torch 2.0.0 incompatibility suggested in mmrazor #632 and a fix for nn.Parameters inside TopdownPoseEstimator not being traced in mmrazor #633
from mmrazor.models.task_modules.tracer.fx.custom_tracer import CustomTracer
from mmpose.models.pose_estimators.topdown import TopdownPoseEstimator
from mmengine.config import Config
cfg = Config.fromfile('/mmpose/configs/body_2d_keypoint/rtmpose/coco/rtmpose-t_8xb256-420e_coco-256x192.py')
rtmpose = TopdownPoseEstimator(
backbone=cfg.model.backbone,
neck=cfg.model.neck,
head=cfg.model.head,
train_cfg=cfg.train_cfg,
data_preprocessor=cfg.model.data_preprocessor,
)
tracer = CustomTracer(
skipped_methods=[
'mmpose.models.heads.RTMCCHead.loss',
'mmpose.models.heads.RTMCCHead.predict',
]
)
traced_graph = tracer.trace(rtmpose)
Reproduces the problem - error message
Traceback (most recent call last):
File "..../site-packages/mmrazor/models/task_modules/tracer/fx/custom_tracer.py", line 421, in trace
'output', (self.create_arg(fn(*args)), ), {},
File "..../site-packages/mmpose/models/pose_estimators/base.py", line 161, in forward
return self.predict(inputs, data_samples)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "..../site-packages/mmpose/models/pose_estimators/topdown.py", line 117, in predict
results = self.add_pred_to_datasample(batch_pred_instances,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "..../site-packages/mmpose/models/pose_estimators/topdown.py", line 138, in predict
assert len(batch_pred_instances) == len(batch_data_samples)
^^^^^^^^^^^^^^^^^^^^^^^^^
File "..../site-packages/torch/fx/proxy.py", line 420, in _len_
raise RuntimeError("'len' is not supported in symbolic tracing by default. If you want "
RuntimeError: 'len' is not supported by default. If you want this call to be recorded, please call 'torch.fx.wrap('len') at module scope
Additional information
for loss() I suggest the following patch:
@@ -68,8 +68,8 @@ def loss(self, inputs: Tensor, data_samples: SampleList) -> dict:
feats = self.extract_feat(inputs)
- losses = dict()
-
if self.with_head:
- losses.update(
self.head.loss(feats, data_samples, train_cfg=self.train_cfg))
+ losses = {self.head.loss(feats, data_samples, train_cfg=self.train_cfg)}
+ else:
+ losses = {}
return losses
for add_pred_to_datasample() I suggest the following:
@@ -138,48 +138,62 @@ def add_pred_to_datasample(self, batch_pred_instances: InstanceList,
- assert len(batch_pred_instances) == len(batch_data_samples)
- if batch_pred_fields is None:
- batch_pred_fields = []
output_keypoint_indices = self.test_cfg.get('output_keypoint_indices',
None)
-
- for pred_instances, pred_fields, data_sample in zip_longest(
- batch_pred_instances, batch_pred_fields, batch_data_samples):
-
- gt_instances = data_sample.gt_instances
-
- # convert keypoint coordinates from input space to image space
- input_center = data_sample.metainfo['input_center']
- input_scale = data_sample.metainfo['input_scale']
- input_size = data_sample.metainfo['input_size']
-
- pred_instances.keypoints[..., :2] = \
- pred_instances.keypoints[..., :2] / input_size * input_scale \
- + input_center - 0.5 * input_scale
- if 'keypoints_visible' not in pred_instances:
- pred_instances.keypoints_visible = \
- pred_instances.keypoaint_scores
-
- if output_keypoint_indices is not None:
- # select output keypoints with given indices
- num_keypoints = pred_instances.keypoints.shape[1]
- for key, value in pred_instances.all_items():
- if key.startswith('keypoint'):
- pred_instances.set_field(
- value[:, output_keypoint_indices], key)
-
- # add bbox information into pred_instances
- pred_instances.bboxes = gt_instances.bboxes
- pred_instances.bbox_scores = gt_instances.bbox_scores
-
- data_sample.pred_instances = pred_instances
-
- if pred_fields is not None:
- if output_keypoint_indices is not None:
- # select output heatmap channels with keypoint indices
- # when the number of heatmap channel matches num_keypoints
- for key, value in pred_fields.all_items():
- if value.shape[0] != num_keypoints:
- continue
- pred_fields.set_field(value[output_keypoint_indices],
- key)
- data_sample.pred_fields = pred_fields
+ batch_data_samples = _add_pred_to_datasample(
+ output_keypoint_indices,
+ batch_pred_instances,
+ batch_pred_fields,
+ batch_data_samples
+ )
return batch_data_samples
+
+
+ @torch.fx.wrap
+ def _add_pred_to_datasample(
+ output_keypoint_indices,
+ batch_pred_instances: InstanceList,
+ batch_pred_fields: Optional[PixelDataList],
+ batch_data_samples: SampleList) -> SampleList:
+ assert len(batch_pred_instances) == len(batch_data_samples)
+ if batch_pred_fields is None:
+ batch_pred_fields = []
+
+ for pred_instances, pred_fields, data_sample in zip_longest(
+ batch_pred_instances, batch_pred_fields, batch_data_samples):
+
+ gt_instances = data_sample.gt_instances
+
+ # convert keypoint coordinates from input space to image space
+ input_center = data_sample.metainfo['input_center']
+ input_scale = data_sample.metainfo['input_scale']
+ input_size = data_sample.metainfo['input_size']
+
+ pred_instances.keypoints[..., :2] = \
+ pred_instances.keypoints[..., :2] / input_size * input_scale \
+ + input_center - 0.5 * input_scale
+ if 'keypoints_visible' not in pred_instances:
+ pred_instances.keypoints_visible = \
+ pred_instances.keypoaint_scores
+
+ if output_keypoint_indices is not None:
+ # select output keypoints with given indices
+ num_keypoints = pred_instances.keypoints.shape[1]
+ for key, value in pred_instances.all_items():
+ if key.startswith('keypoint'):
+ pred_instances.set_field(
+ value[:, output_keypoint_indices], key)
+
+ # add bbox information into pred_instances
+ pred_instances.bboxes = gt_instances.bboxes
+ pred_instances.bbox_scores = gt_instances.bbox_scores
+
+ data_sample.pred_instances = pred_instances
+
+ if pred_fields is not None:
+ if output_keypoint_indices is not None:
+ # select output heatmap channels with keypoint indices
+ # when the number of heatmap channel matches num_keypoints
+ for key, value in pred_fields.all_items():
+ if value.shape[0] != num_keypoints:
+ continue
+ pred_fields.set_field(value[output_keypoint_indices],
+ key)
+ data_sample.pred_fields = pred_fields
+ return batch_data_samples
This solves the issue with fx tracing, although there's still other issues I have yet to solve.
Added reproducing code and full fix suggestion