mmdetection3d
mmdetection3d copied to clipboard
Can not disable fp16.
Hello, I' building a model which contains a pre-trained teacher model as
def __init__(xxx, teacher_model_cfg=None):
...
if teacher_model_cfg:
self.teacher_model = builder.build_model(teacher_model_cfg)
if self.teacher_model.init_cfg:
initialize(self.teacher_model, self.teacher_model.init_cfg)
for p in self.teacher_model.parameters():
p.requires_grad = False
And I'd like to use fp16 for training and add this in the cfg file
fp16 = dict(loss_scale='dynamic')
However, when I call the teacher model as
with torch.no_grad():
teacher_x = self.teacher_model.forward_teacher(points)
I meet the error in mmdet3d/ops/spconv/ops.py:
Where the dtype of the filters is float32 but the input is float16.
And I think the teacher model (CenterPoint) might not support fp16, so I add force_fp32 to the teacher model's method as
@force_fp32()
def forward_teacher(self,
points=None,
img=None,
img_metas=None):
voxels, num_points, coors = self.voxelize(points)
voxel_features = self.pts_voxel_encoder(voxels, num_points, coors)
batch_size = coors[-1, 0] + 1
x = self.pts_middle_encoder(voxel_features, coors, batch_size)
return x
But I still meet the the error. And I tried several things but just can not disable the fp16 in the teacher model. So I'd like to ask how can I disable the fp 16 in the teacher model, or how can I solve this error and use the fp 16 in the teacher model. Thank you.