mmdetection3d icon indicating copy to clipboard operation
mmdetection3d copied to clipboard

The calculationg of average loss.

Open anti-destiny opened this issue 2 years ago • 2 comments

When doing distributed training, we typically have multiple processes running simultaneously. And for detection tasks, the number of objects for each process could be different. So we should be careful when calculating the average loss. For example, one can call all_reduce to get the number of objects on all processes:

num_reg = torch.tensor(float(dt_bbox.shape[0]))
if dist.is_availiable() and dist.is_initialized():
	num_reg = dist.all_reduce(num_reg.div_(n_world))  # get the number of all objects and divide it with the number of processes.
num_reg = num_reg.clamp(1.)

loss_bbox = torch.abs(dt_bbox - tgt_bbox).sum() / num_reg 

However, I can not find such implementation for detection heads in mmdetection3d (e.g., CenterPointHead).

My question is, has the team of mmdetection3d noticed this problem? Do you think the correctness of average value matters when training detectors?

anti-destiny avatar May 31 '22 09:05 anti-destiny

Since the Base3DDetector is based on BaseDetector in mmdetection, you can find the related code https://github.com/open-mmlab/mmdetection/blob/240d7a31c745578aa8c4df54c3074ce78b690c34/mmdet/models/detectors/base.py#L249 , We use all_reduce in _parse_losses function.

ZCMax avatar Jun 01 '22 06:06 ZCMax

Since the Base3DDetector is based on BaseDetector in mmdetection, you can find the related code https://github.com/open-mmlab/mmdetection/blob/240d7a31c745578aa8c4df54c3074ce78b690c34/mmdet/models/detectors/base.py#L249 , We use all_reduce in _parse_losses function.

Yes. I can see that the 'all_reduce' function is called in '_parse_losses'. However, the number of objects is not considered in your implementation.

For example, let suggest that there are two processes. For process A, avg_loss = 1.0; n = 2; And for process B, avg_loss = 2.0; n = 4; Therefore, the final average loss is (1.0 * 2 + 2.0 * 4) / (2 + 4) = 1.67. However, if you just simply calling all_reduce, it returns (1.0 + 2.0) / 2 = 1.5.

anti-destiny avatar Jun 01 '22 09:06 anti-destiny