CenterPoint-KITTI icon indicating copy to clipboard operation
CenterPoint-KITTI copied to clipboard

训练过程中gaussian_focal_loss函数报错

Open evil-master opened this issue 1 year ago • 1 comments

File "/home/neousys/cjg/CenterPoint-KITTI/pcdet/models/dense_heads/centerpoint_head_single.py", line 687, in gaussian_focal_loss pos_loss = (-(pred + eps).log() * (1 - pred).pow(alpha) * pos_weights).float() RuntimeError: expected backend CUDA and dtype Float but got backend CUDA and dtype Byte

您好,我的环境是cuda10.0,torch1.1,spconv1.0,报错后我尝试强制转换成flaot类型,但是not work,请问我该怎么办?

evil-master avatar Jul 21 '23 05:07 evil-master

I solve it by change int to float,maybe cause by torch 1.1. def gaussian_focal_loss(pred, gaussian_target, alpha=2.0, gamma=4.0): eps = 1e-12 pos_weights = gaussian_target.eq(1) neg_weights = (1 - gaussian_target).pow(gamma) pos_loss = -(pred.float() + eps).log() * (1 - pred.float()).pow(alpha) * pos_weights.float() neg_loss = -(1 - pred.float() + eps).log() * pred.pow(alpha) * neg_weights return pos_loss + neg_loss

evil-master avatar Aug 10 '23 09:08 evil-master