d2l-en icon indicating copy to clipboard operation
d2l-en copied to clipboard

[MXNet] Speedup SSD Scratch Implementation

Open AnirudhDagar opened this issue 3 years ago • 2 comments

The scratch version of SSD is currently much slower than its counterpart PyTorch scratch implementation.

AnirudhDagar avatar Dec 28 '20 02:12 AnirudhDagar

Can you fix it?

astonzhang avatar Dec 28 '20 19:12 astonzhang

I benchmarked the training loop and tried to find the bottleneck. Interestingly just by using the npx.multibox_target method instead of the scratch implementation of d2l.multibox_target resulted in the desired speedup. This indicates that the custom multibox_target function implemented from scratch for mxnet is the bottleneck for speed. Although the exact same implementation using custom scratch d2l.multibox_target in PyTorch is just as fast.

Scratch Implementation using d2l.mutlibox_prior slow in SSD Training

.
.
.
            anchors, cls_preds, bbox_preds = net(X)
            # Label the category and offset of each anchor box
            bbox_labels, bbox_masks, cls_labels = d2l.multibox_target(anchors,
                                                                      Y)
            # Calculate the loss function using the predicted and labeled
            # category and offset values
.
.
.
print(f'{len(train_iter._dataset) / timer.stop():.1f} examples/sec on '
      f'{str(device)}')

>>> 2818.4 examples/sec on gpu

Scratch Implementation, but using npx.multibox_target is much faster in SSD Training

.
.
.
            anchors, cls_preds, bbox_preds = net(X)
            # Label the category and offset of each anchor box
            bbox_labels, bbox_masks, cls_labels = npx.multibox_target(
                anchors, Y, cls_preds.transpose(0, 2, 1))
            # Calculate the loss function using the predicted and labeled
            # category and offset values
.
.
.
print(f'{len(train_iter._dataset) / timer.stop():.1f} examples/sec on '
      f'{str(device)}')

>>> 5244.6 examples/sec on gpu

These are timeit results for the function multibox_target in pytorch vs mxnet:

%%timeit
#@tab pytorch
labels = multibox_target(anchors.unsqueeze(dim=0),
                         ground_truth.unsqueeze(dim=0))

>>> 651 µs ± 867 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%%timeit
labels = multibox_target(np.expand_dims(anchors, axis=0),
                         np.expand_dims(ground_truth, axis=0))

>>> 21.5 ms ± 24.3 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

This requires further investigation because even though the scratch function multibox_target implements the same logic in different frameworks, mxnet is orders of magnitude slower than pytorch which is not expected. Either there is some room for speedup in mxnet implementation of the function or pytorch is just faster and in that case we can't really do much about this issue.

AnirudhDagar avatar Dec 29 '20 05:12 AnirudhDagar