mmsegmentation icon indicating copy to clipboard operation
mmsegmentation copied to clipboard

bug which may result out-of-index error on cross_entropy function

Open KaneiGi opened this issue 1 year ago • 1 comments

Thanks for your error report and we appreciate it a lot.

Checklist

  1. I have searched related issues but cannot get the expected help.
  2. The bug has not been fixed in the latest version.

Describe the bug on mmseg's model -> loss->cross_entropy_loss.py, line 66: label_weights = torch.stack([class_weight[cls] for cls in label ]).to(device=class_weight.device)

there is no check if the cls in label will break the range of class_weight.

Reproduction

  1. What command or script did you run?

`python train.py pidnet-s_2xb6-120k_1024x1024-cityscapes.py


2. Did you make any modifications on the code or config? Did you understand what you have modified?

3. What dataset did you use?
cityscapes
**System environment:
 sys.platform: linux
 Python: 3.10.0 (default, May 11 2024, 13:44:05) [GCC 7.5.0]
 CUDA available: False
 MUSA available: False
 numpy_random_seed: 304
 GCC: gcc (Ubuntu 7.5.0-3ubuntu1~18.04) 7.5.0
 PyTorch: 1.12.1+cu102
 PyTorch compiling details: PyTorch built with:
- GCC 7.3
- C++ Version: 201402
- Intel(R) Math Kernel Library Version 2020.0.0 Product Build 20191122 for Intel(R) 64 architecture applications
- Intel(R) MKL-DNN v2.6.0 (Git Hash 52b5f107dd9cf10910aaa19cb47f3abf9b349815)
- OpenMP 201511 (a.k.a. OpenMP 4.5)**

1. Please run `python mmseg/utils/collect_env.py` to collect necessary environment information and paste it here.
2. You may add addition that may be helpful for locating the problem, such as
- How you installed PyTorch \[e.g., pip, conda, source\]
- Other environment variables that may be related (such as `$PATH`, `$LD_LIBRARY_PATH`, `$PYTHONPATH`, etc.)

**Error traceback**

If applicable, paste the error trackback here.

```none
A placeholder for trackback.
![图像](https://github.com/user-attachments/assets/d0226b05-aca3-4ab1-82db-28928f602455)

** if (avg_factor is None) and reduction == 'mean': if class_weight is None: if avg_non_ignore: avg_factor = label.numel() - (label == ignore_index).sum().item() else: avg_factor = label.numel()

    else:
        # the average factor should take the class weights into account
        mask = label == ignore_index
        masked_label = label * ~mask
        label_weights = class_weight[masked_label]
        label_weights[mask] = 0

        # label_weights = torch.stack([class_weight[cls] for cls in label
        #                              ]).to(device=class_weight.device)

        if avg_non_ignore:
            label_weights[label == ignore_index] = 0
        avg_factor = label_weights.sum()**

If you have already identified the reason, you can provide the information here. If you are willing to create a PR to fix it, please also leave a comment here and that would be much appreciated!

KaneiGi avatar Jul 12 '24 07:07 KaneiGi

Had the same issue. Your fix works for me. Thank you

AreopagX avatar Oct 30 '24 08:10 AreopagX