YOLOX icon indicating copy to clipboard operation
YOLOX copied to clipboard

Performance drops significantly when switching from apex to torch.cuda.amp

Open superaha opened this issue 2 years ago • 8 comments

Hi there,

As discussed here: apex vs torch.cuda.amp. Apex is not encouraged to use anymore.

I have implemented an version, which replaces apex with torch.cuda.amp. However, after 300 epochs, the mAP with yolox-l is only 0.29 on COCO dataset using the default hyper parameters.

I have experimented with different learning rates, which seems not helpful.

Anyone has any suggestions?

Thanks.

superaha avatar Aug 05 '21 00:08 superaha

Hi, thx for your exps. Actually, we did several exps for torch.cuda.amp too. The main reason for the significant dropping comes from here. It is illegal to conduct BCE with torch.cuda.amp, so you have to add manual cast for this loss function.

We don't have much time for this issue at this moment, but we admit that cuda.amp will be a better choice. It would be very appreciated if you could figure it out and send us a PR.

GOATmessi7 avatar Aug 05 '21 02:08 GOATmessi7

Hi there,

As discussed here: apex vs torch.cuda.amp. Apex is not encouraged to use anymore.

I have implemented an version, which replaces apex with torch.cuda.amp. However, after 300 epochs, the mAP with yolox-l is only 0.29 on COCO dataset using the default hyper parameters.

I have experimented with different learning rates, which seems not helpful.

Anyone has any suggestions?

Thanks.

Hi, as far as I know. It is unavailable to use binary_cross_entropy in torch.cuda.amp with autocast(should be bce_with_logitic). How could you train your model?

zhangming8 avatar Aug 05 '21 08:08 zhangming8

Yes, BCE will complain. So my remedy is as following:

496         cls_preds_ = (
497             cls_preds_.float().unsqueeze(0).repeat(num_gt, 1, 1).sigmoid_()
498             * obj_preds_.unsqueeze(0).repeat(num_gt, 1, 1).sigmoid_()
499         )
500         cls_preds = torch.logit(cls_preds_.sqrt())
501         pair_wise_cls_loss = F.binary_cross_entropy_with_logits(
502             cls_preds, gt_cls_per_image, reduction="none"
503         ).sum(-1)

Just use torch.logit to convert the probability back to logits to make torch.cuda.amp happy. With this remedy, I can train the model. The only issue is that the mAP curve starts at a relatively low point.

image

It still has a trend to gain better performance after 300 epochs. I tried to use a larger learning rate, which leads unstable training and may result in NaN loss.

I am also experimenting with 500 epochs to see if this will help.

Thanks.

superaha avatar Aug 05 '21 20:08 superaha

Yes, BCE will complain. So my remedy is as following:

496         cls_preds_ = (
497             cls_preds_.float().unsqueeze(0).repeat(num_gt, 1, 1).sigmoid_()
498             * obj_preds_.unsqueeze(0).repeat(num_gt, 1, 1).sigmoid_()
499         )
500         cls_preds = torch.logit(cls_preds_.sqrt())
501         pair_wise_cls_loss = F.binary_cross_entropy_with_logits(
502             cls_preds, gt_cls_per_image, reduction="none"
503         ).sum(-1)

Just use torch.logit to convert the probability back to logits to make torch.cuda.amp happy. With this remedy, I can train the model. The only issue is that the mAP curve starts at a relatively low point.

image

It still has a trend to gain better performance after 300 epochs. I tried to use a larger learning rate, which leads unstable training and may result in NaN loss.

I am also experimenting with 500 epochs to see if this will help.

Thanks.

Hi, I am using pytorch's official amp in my implement of YOLOX and it works. You can add a subregion of torch.cuda.amp.autocast and disable amp when calculating binary_cross_entropy. You can refer: https://github.com/zhangming8/yolox-pytorch/blob/main/models/losses/yolox_loss.py#L271 https://github.com/zhangming8/yolox-pytorch/blob/main/models/yolox.py#L60

zhangming8 avatar Aug 10 '21 14:08 zhangming8

@zhangming8 I am also using with autocast(False): to disable amp for bce loss.

hiyyg avatar Aug 11 '21 12:08 hiyyg

Okay, I tried to disable autocast for the matching cost computation. Still, the performance is almost the same with the above curve, where logit function is applied. @hiyyg are you able to reproduce a good result?

superaha avatar Aug 11 '21 17:08 superaha

Okay, I tried to disable autocast for the matching cost computation. Still, the performance is almost the same with the above curve, where logit function is applied. @hiyyg are you able to reproduce a good result?

No, but in my exps, using pytorch's amp is always better than using apex.

hiyyg avatar Aug 15 '21 03:08 hiyyg

Should I use with autocast(False) here?

https://github.com/Megvii-BaseDetection/YOLOX/blob/d78fe47802d3987a915645179d87f5c1a96e646f/yolox/models/yolo_head.py#L394-L401

developer0hye avatar Aug 12 '22 01:08 developer0hye