mmdetection
mmdetection copied to clipboard
mask2former head fails during training because of amp (ValueError: cost matrix is infeasible)
Thanks for your error report and we appreciate it a lot.
Checklist
- I have searched related issues but cannot get the expected help.
- I have read the FAQ documentation but cannot get the expected help.
- The bug has not been fixed in the latest version.
Describe the bug A clear and concise description of what the bug is.
Reproduction
- What command or script did you run?
A placeholder for the command.
- Did you make any modifications on the code or config? Did you understand what you have modified?
- What dataset did you use?
Environment
- Please run
python mmdet/utils/collect_env.py
to collect necessary environment information and paste it here. - You may add addition that may be helpful for locating the problem, such as
- How you installed PyTorch [e.g., pip, conda, source]
- Other environment variables that may be related (such as
$PATH
,$LD_LIBRARY_PATH
,$PYTHONPATH
, etc.)
Error traceback If applicable, paste the error trackback here.
A placeholder for trackback.
Bug fix If you have already identified the reason, you can provide the information here. If you are willing to create a PR to fix it, please also leave a comment here and that would be much appreciated!
@AlexeySudakovB01-109 Thank you for the feedback. We will work on fixing it as soon as possible.
any update?
Looking forward to further work!
@we11as22 @hhaAndroid I am facing the same issue, when training Mask2Former with --amp
. And I have found a solution here huggingface/transformers#21644. As it proposed, this is caused by how scipy.optimize.linear_sum_assignment
handles infinite values. Replacing these with very large numbers seems to fix the issue. And I added the following code before line 125.
cost = torch.minimum(cost, torch.tensor(1e10))
cost = torch.maximum(cost, torch.tensor(-1e10))
https://github.com/open-mmlab/mmdetection/blob/cfd5d3a985b0249de009b67d04f37263e11cdf3d/mmdet/models/task_modules/assigners/hungarian_assigner.py#L123-L131
But when trying to train with amp again, I encountered the NaN
mask loss problem as follows:
Further debugging revealed issues in the calculation of loss_mask
in this line. In amp mode, multiplying the FP16 variable num_total_masks
by a large integer variable self.num_points
(it is 12544 in my configs)results in an overflowed FP16 variable, further leading to the occurrence of NaN mask loss.
By casting
num_total_masks
to the fp32 type with avg_factor=num_total_masks.float() * self.num_points
, I managed to eliminate this issue. The model can now be successfully trained in amp mode, but the accuracy is still under validation 😂 .
Hope this is helpful to you!