YOLOX
YOLOX copied to clipboard
Performance drops significantly when switching from apex to torch.cuda.amp
Hi there,
As discussed here: apex vs torch.cuda.amp. Apex is not encouraged to use anymore.
I have implemented an version, which replaces apex with torch.cuda.amp. However, after 300 epochs, the mAP with yolox-l is only 0.29 on COCO dataset using the default hyper parameters.
I have experimented with different learning rates, which seems not helpful.
Anyone has any suggestions?
Thanks.
Hi, thx for your exps. Actually, we did several exps for torch.cuda.amp too. The main reason for the significant dropping comes from here. It is illegal to conduct BCE with torch.cuda.amp, so you have to add manual cast for this loss function.
We don't have much time for this issue at this moment, but we admit that cuda.amp will be a better choice. It would be very appreciated if you could figure it out and send us a PR.
Hi there,
As discussed here: apex vs torch.cuda.amp. Apex is not encouraged to use anymore.
I have implemented an version, which replaces apex with torch.cuda.amp. However, after 300 epochs, the mAP with yolox-l is only 0.29 on COCO dataset using the default hyper parameters.
I have experimented with different learning rates, which seems not helpful.
Anyone has any suggestions?
Thanks.
Hi, as far as I know. It is unavailable to use binary_cross_entropy in torch.cuda.amp with autocast(should be bce_with_logitic). How could you train your model?
Yes, BCE will complain. So my remedy is as following:
496 cls_preds_ = (
497 cls_preds_.float().unsqueeze(0).repeat(num_gt, 1, 1).sigmoid_()
498 * obj_preds_.unsqueeze(0).repeat(num_gt, 1, 1).sigmoid_()
499 )
500 cls_preds = torch.logit(cls_preds_.sqrt())
501 pair_wise_cls_loss = F.binary_cross_entropy_with_logits(
502 cls_preds, gt_cls_per_image, reduction="none"
503 ).sum(-1)
Just use torch.logit to convert the probability back to logits to make torch.cuda.amp happy. With this remedy, I can train the model. The only issue is that the mAP curve starts at a relatively low point.
It still has a trend to gain better performance after 300 epochs. I tried to use a larger learning rate, which leads unstable training and may result in NaN loss.
I am also experimenting with 500 epochs to see if this will help.
Thanks.
Yes, BCE will complain. So my remedy is as following:
496 cls_preds_ = ( 497 cls_preds_.float().unsqueeze(0).repeat(num_gt, 1, 1).sigmoid_() 498 * obj_preds_.unsqueeze(0).repeat(num_gt, 1, 1).sigmoid_() 499 ) 500 cls_preds = torch.logit(cls_preds_.sqrt()) 501 pair_wise_cls_loss = F.binary_cross_entropy_with_logits( 502 cls_preds, gt_cls_per_image, reduction="none" 503 ).sum(-1)
Just use torch.logit to convert the probability back to logits to make torch.cuda.amp happy. With this remedy, I can train the model. The only issue is that the mAP curve starts at a relatively low point.
It still has a trend to gain better performance after 300 epochs. I tried to use a larger learning rate, which leads unstable training and may result in NaN loss.
I am also experimenting with 500 epochs to see if this will help.
Thanks.
Hi, I am using pytorch's official amp in my implement of YOLOX and it works. You can add a subregion of torch.cuda.amp.autocast and disable amp when calculating binary_cross_entropy. You can refer: https://github.com/zhangming8/yolox-pytorch/blob/main/models/losses/yolox_loss.py#L271 https://github.com/zhangming8/yolox-pytorch/blob/main/models/yolox.py#L60
@zhangming8 I am also using with autocast(False):
to disable amp for bce loss.
Okay, I tried to disable autocast for the matching cost computation. Still, the performance is almost the same with the above curve, where logit
function is applied. @hiyyg are you able to reproduce a good result?
Okay, I tried to disable autocast for the matching cost computation. Still, the performance is almost the same with the above curve, where
logit
function is applied. @hiyyg are you able to reproduce a good result?
No, but in my exps, using pytorch's amp is always better than using apex.
Should I use with autocast(False)
here?
https://github.com/Megvii-BaseDetection/YOLOX/blob/d78fe47802d3987a915645179d87f5c1a96e646f/yolox/models/yolo_head.py#L394-L401