mmaction2 icon indicating copy to clipboard operation
mmaction2 copied to clipboard

[Feature] How to get flops of the MMRecognizer3D?

Open wukurua opened this issue 11 months ago • 2 comments

What is the problem this feature will solve?

I want to test GFlops for RGBPoseConv3D (configs\skeleton\posec3d\rgbpose_conv3d\rgbpose_conv3d.py), but the current code doesn't seem to support MMRecognizer3D . Hope to have a section for testing the Multi-modal 3D recognizer model framework in tools/analysis_tools/get_flops.py. By the way, If you know the params and GFlops of RGBPoseConv3D, or its input_shape, can anyone tell me about it?

def main():

    args = parse_args()

    if len(args.shape) == 1:
        input_shape = (1, 3, args.shape[0], args.shape[0])
    elif len(args.shape) == 2:
        input_shape = (1, 3) + tuple(args.shape)
    elif len(args.shape) == 4:
        # n, c, h, w = args.shape for 2D recognizer
        input_shape = tuple(args.shape)
    elif len(args.shape) == 5:
        # n, c, t, h, w = args.shape for 3D recognizer or
        # n, m, t, v, c = args.shape for GCN-based recognizer
        # n, v, t, m, c = args.shape for GCN-based recognizer
        input_shape = tuple(args.shape)
    else:
        raise ValueError('invalid input shape')

What is the feature?

get flops of the MMRecognizer3D

What alternatives have you considered?

No response

wukurua avatar Mar 04 '24 11:03 wukurua

the same question here...

Yangxinyee avatar May 22 '24 15:05 Yangxinyee

I tried to get the FLOPs and #param for this RGBPose model. Here it is Input shape: {'rgb': [1, 3, 8, 224, 224], 'pose': [1, 17, 32, 56, 56]} Flops: 56.98 GFLOPs Params: 36.15 M

I used the code repository https://github.com/kennymckormick/pyskl and use the get_flops script and made modification in flops_counter script file as follows:

rgb_batch = torch.ones(()).new_empty( (1, *[1,3,8,224, 224]), dtype=next(flops_model.parameters()).dtype, device=next(flops_model.parameters()).device) pose_batch = torch.ones(()).new_empty( (1, *[1,17,32,56, 56]), dtype=next(flops_model.parameters()).dtype, device=next(flops_model.parameters()).device) .... _ = flops_model(rgb_batch,pose_batch)

ahmed-nady avatar Jul 03 '24 18:07 ahmed-nady