Lite-HRNet
Lite-HRNet copied to clipboard
Inference code with images
Hi,
This is a great job with a light weight model of hrnet. It would be great if you can add any code for inference with images. Or also you can try to upload sample output images to the repository. In the meantime, if you could help us on how to run inference, it would be great. thanks
I adapted this simple snippet from tools/test.py and mmpose api for inference.
Hope this helps.
import matplotlib.pyplot as plt
from mmpose.apis.inference import init_pose_model, inference_top_down_pose_model
model = init_pose_model(config_path, ckpt_path, device=device)
results, heatmaps= inference_top_down_pose_model(model, inputs['image'], img_data_with_bbox, return_heatmap=True, format='xyxy', dataset='TopDownCocoDataset')
## Visualize Results
hms =heatmaps[0]['heatmap']
result = results[0]
keypoints = ([np.array([v[0],v[1]]) for v in result['keypoints']])
#Plot image and keypoints
plt.figure()
plt.scatter(*zip(*keypoints))
plt.imshow(result['image'])
plt.show()
#Plot heatmaps in a grid
n_hms = np.shape(hms)[1]
f, axarr = plt.subplots(3, 4, figsize=(15,15))
this_col=0
for idx in range(n_hms):
this_hm = hms[0,idx,:,:]
row = idx % 4
this_ax = axarr[this_col, row]
this_ax.set_title(f'{idx}')
hm_display = this_ax.imshow(this_hm, cmap='jet', vmin=0, vmax=1)
if row == 3:
this_col += 1
cb=f.colorbar(hm_display, ax=axarr)
Note that img_data_with_bbox is a list of dicts, where dicts should contain 'bbox' key. For more info checkout the mmpose documentation
If you care about inference speed, you can fuse conv and batch norm layers in the model. See tools/test.py for the code
Hello @kuldeepbrd1 thanks for your inference code, I'm trying to load the model and export it into .onnx model but i'm facing with the following error
TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
assert (num_channels % groups == 0), ('num_channels should be '
Traceback (most recent call last):
File "/Users/prudvikamtam/Downloads/Work/PoseEstimation_Python/Lite-HRNet-conversion/pytorch2onnx.py", line 61, in <module>
main()
File "/Users/prudvikamtam/Downloads/Work/PoseEstimation_Python/Lite-HRNet-conversion/pytorch2onnx.py", line 44, in main
torch.onnx.export(
File "/Users/prudvikamtam/Downloads/Work/PoseEstimation_Python/Lite-HRNet-conversion/env/lib/python3.9/site-packages/torch/onnx/__init__.py", line 271, in export
return utils.export(model, args, f, export_params, verbose, training,
File "/Users/prudvikamtam/Downloads/Work/PoseEstimation_Python/Lite-HRNet-conversion/env/lib/python3.9/site-packages/torch/onnx/utils.py", line 88, in export
_export(model, args, f, export_params, verbose, training, input_names, output_names,
File "/Users/prudvikamtam/Downloads/Work/PoseEstimation_Python/Lite-HRNet-conversion/env/lib/python3.9/site-packages/torch/onnx/utils.py", line 694, in _export
_model_to_graph(model, args, verbose, input_names,
File "/Users/prudvikamtam/Downloads/Work/PoseEstimation_Python/Lite-HRNet-conversion/env/lib/python3.9/site-packages/torch/onnx/utils.py", line 457, in _model_to_graph
graph, params, torch_out, module = _create_jit_graph(model, args,
File "/Users/prudvikamtam/Downloads/Work/PoseEstimation_Python/Lite-HRNet-conversion/env/lib/python3.9/site-packages/torch/onnx/utils.py", line 420, in _create_jit_graph
graph, torch_out = _trace_and_get_graph_from_model(model, args)
File "/Users/prudvikamtam/Downloads/Work/PoseEstimation_Python/Lite-HRNet-conversion/env/lib/python3.9/site-packages/torch/onnx/utils.py", line 380, in _trace_and_get_graph_from_model
torch.jit._get_trace_graph(model, args, strict=False, _force_outplace=False, _return_inputs_states=True)
File "/Users/prudvikamtam/Downloads/Work/PoseEstimation_Python/Lite-HRNet-conversion/env/lib/python3.9/site-packages/torch/jit/_trace.py", line 1139, in _get_trace_graph
outs = ONNXTracedModule(f, strict, _force_outplace, return_inputs, _return_inputs_states)(*args, **kwargs)
File "/Users/prudvikamtam/Downloads/Work/PoseEstimation_Python/Lite-HRNet-conversion/env/lib/python3.9/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
result = self.forward(*input, **kwargs)
File "/Users/prudvikamtam/Downloads/Work/PoseEstimation_Python/Lite-HRNet-conversion/env/lib/python3.9/site-packages/torch/jit/_trace.py", line 125, in forward
graph, out = torch._C._create_graph_by_tracing(
File "/Users/prudvikamtam/Downloads/Work/PoseEstimation_Python/Lite-HRNet-conversion/env/lib/python3.9/site-packages/torch/jit/_trace.py", line 116, in wrapper
outs.append(self.inner(*trace_inputs))
File "/Users/prudvikamtam/Downloads/Work/PoseEstimation_Python/Lite-HRNet-conversion/env/lib/python3.9/site-packages/torch/nn/modules/module.py", line 887, in _call_impl
result = self._slow_forward(*input, **kwargs)
File "/Users/prudvikamtam/Downloads/Work/PoseEstimation_Python/Lite-HRNet-conversion/env/lib/python3.9/site-packages/torch/nn/modules/module.py", line 860, in _slow_forward
result = self.forward(*input, **kwargs)
File "/Users/prudvikamtam/Downloads/Work/PoseEstimation_Python/Lite-HRNet-conversion/env/lib/python3.9/site-packages/mmcv/runner/fp16_utils.py", line 95, in new_func
return old_func(*args, **kwargs)
File "/Users/prudvikamtam/Downloads/Work/PoseEstimation_Python/Lite-HRNet-conversion/env/lib/python3.9/site-packages/mmpose/models/detectors/top_down.py", line 136, in forward
return self.forward_train(img, target, target_weight, img_metas,
File "/Users/prudvikamtam/Downloads/Work/PoseEstimation_Python/Lite-HRNet-conversion/env/lib/python3.9/site-packages/mmpose/models/detectors/top_down.py", line 152, in forward_train
keypoint_losses = self.keypoint_head.get_loss(
File "/Users/prudvikamtam/Downloads/Work/PoseEstimation_Python/Lite-HRNet-conversion/env/lib/python3.9/site-packages/mmpose/models/keypoint_heads/top_down_simple_head.py", line 160, in get_loss
assert target.dim() == 4 and target_weight.dim() == 3
AttributeError: 'NoneType' object has no attribute 'dim'
I've used the following code to load the model
args = parse_args()
cfg = Config.fromfile(args.config)
PTH_PATH = <path to the .pth file>
pose_model = init_pose_model(cfg, PTH_PATH, device=torch.device('cpu'))
Looking at the error it seems like there's some mistake with the model not being loaded properly if i'm not wrong, I was hoping if you can take a look and verify if I've initialised it correctly. The .pth that I'm using is this and the config file is this. Thank you
I adapted this simple snippet from
tools/test.pyand mmpose api for inference. Hope this helps.import matplotlib.pyplot as plt from mmpose.apis.inference import init_pose_model, inference_top_down_pose_model model = init_pose_model(config_path, ckpt_path, device=device) results, heatmaps= inference_top_down_pose_model(model, inputs['image'], img_data_with_bbox, return_heatmap=True, format='xyxy', dataset='TopDownCocoDataset') ## Visualize Results hms =heatmaps[0]['heatmap'] result = results[0] keypoints = ([np.array([v[0],v[1]]) for v in result['keypoints']]) #Plot image and keypoints plt.figure() plt.scatter(*zip(*keypoints)) plt.imshow(result['image']) plt.show() #Plot heatmaps in a grid n_hms = np.shape(hms)[1] f, axarr = plt.subplots(3, 4, figsize=(15,15)) this_col=0 for idx in range(n_hms): this_hm = hms[0,idx,:,:] row = idx % 4 this_ax = axarr[this_col, row] this_ax.set_title(f'{idx}') hm_display = this_ax.imshow(this_hm, cmap='jet', vmin=0, vmax=1) if row == 3: this_col += 1 cb=f.colorbar(hm_display, ax=axarr)Note that
img_data_with_bboxis a list of dicts, where dicts should contain 'bbox' key. For more info checkout the mmpose documentationIf you care about inference speed, you can fuse conv and batch norm layers in the model. See
tools/test.pyfor the code
first,thank you for your Answer,could you give some examples about the 'img_data_with_bbox'?is this a detector model about person?
@kuldeepbrd1 first,thank you for your Answer,could you give some examples about the 'img_data_with_bbox'?is this a detector model about person?