bug which may result out-of-index error on cross_entropy function
Thanks for your error report and we appreciate it a lot.
Checklist
- I have searched related issues but cannot get the expected help.
- The bug has not been fixed in the latest version.
Describe the bug
on mmseg's model -> loss->cross_entropy_loss.py, line 66:
label_weights = torch.stack([class_weight[cls] for cls in label ]).to(device=class_weight.device)
there is no check if the cls in label will break the range of class_weight.
Reproduction
-
What command or script did you run?
`python train.py pidnet-s_2xb6-120k_1024x1024-cityscapes.py
2. Did you make any modifications on the code or config? Did you understand what you have modified?
3. What dataset did you use?
cityscapes
**System environment:
sys.platform: linux
Python: 3.10.0 (default, May 11 2024, 13:44:05) [GCC 7.5.0]
CUDA available: False
MUSA available: False
numpy_random_seed: 304
GCC: gcc (Ubuntu 7.5.0-3ubuntu1~18.04) 7.5.0
PyTorch: 1.12.1+cu102
PyTorch compiling details: PyTorch built with:
- GCC 7.3
- C++ Version: 201402
- Intel(R) Math Kernel Library Version 2020.0.0 Product Build 20191122 for Intel(R) 64 architecture applications
- Intel(R) MKL-DNN v2.6.0 (Git Hash 52b5f107dd9cf10910aaa19cb47f3abf9b349815)
- OpenMP 201511 (a.k.a. OpenMP 4.5)**
1. Please run `python mmseg/utils/collect_env.py` to collect necessary environment information and paste it here.
2. You may add addition that may be helpful for locating the problem, such as
- How you installed PyTorch \[e.g., pip, conda, source\]
- Other environment variables that may be related (such as `$PATH`, `$LD_LIBRARY_PATH`, `$PYTHONPATH`, etc.)
**Error traceback**
If applicable, paste the error trackback here.
```none
A placeholder for trackback.

** if (avg_factor is None) and reduction == 'mean': if class_weight is None: if avg_non_ignore: avg_factor = label.numel() - (label == ignore_index).sum().item() else: avg_factor = label.numel()
else:
# the average factor should take the class weights into account
mask = label == ignore_index
masked_label = label * ~mask
label_weights = class_weight[masked_label]
label_weights[mask] = 0
# label_weights = torch.stack([class_weight[cls] for cls in label
# ]).to(device=class_weight.device)
if avg_non_ignore:
label_weights[label == ignore_index] = 0
avg_factor = label_weights.sum()**
If you have already identified the reason, you can provide the information here. If you are willing to create a PR to fix it, please also leave a comment here and that would be much appreciated!
Had the same issue. Your fix works for me. Thank you