Pretrained model on BTCV doesn't reproduce the mean dice score
Hi there,
The provided model weights for BTCV (swinunetr-base) can't reproduce the same mean dice score on validation set. I only get a mean dice score around 0.16~0.2 which is far less than the given 0.8.
Basically I used the google colab codes as following:
with torch.no_grad():
dice_list_case = []
for i, batch in enumerate(val_loader):
val_inputs, val_labels = (batch["image"].cuda(), batch["label"].cuda())
# original_affine = batch['label_meta_dict']['affine'][0].numpy()
_, _, h, w, d = val_labels.shape
target_shape = (h, w, d)
# img_name = batch['image_meta_dict']['filename_or_obj'][0].split('/')[-1]
# print("Inference on case {}".format(img_name))
val_outputs = sliding_window_inference(val_inputs, (96, 96, 96), 4, model, overlap=0.5, mode="gaussian")
val_outputs = torch.softmax(val_outputs, 1).cpu().numpy()
val_outputs = np.argmax(val_outputs, axis=1).astype(np.uint8)[0]
val_labels = val_labels.cpu().numpy()[0, 0, :, :, :]
val_outputs = resample_3d(val_outputs, target_shape)
dice_list_sub = []
for i in range(1, 14):
organ_Dice = dice(val_outputs == i, val_labels == i)
dice_list_sub.append(organ_Dice)
mean_dice = np.mean(dice_list_sub)
print("Mean Organ Dice: {}".format(mean_dice))
dice_list_case.append(mean_dice)
# nib.save(nib.Nifti1Image(val_outputs.astype(np.uint8), original_affine),
# os.path.join(output_directory, img_name))
print("Overall Mean Dice: {}".format(np.mean(dice_list_case)))
The model has been loaded from the pretrained weights you provided as below and data transformation and data loader are set exactly the same as provided:
| Name | Dice (overlap=0.7) | Dice (overlap=0.5) | Feature Size | # params (M) | Self-Supervised Pre-trained | Download |
|---|---|---|---|---|---|---|
| Swin UNETR/Base | 82.25 | 81.86 | 48 | 62.1 | Yes | model |
| Swin UNETR/Small | 79.79 | 79.34 | 24 | 15.7 | No | model |
| Swin UNETR/Tiny | 72.05 | 70.35 | 12 | 4.0 | No | model |
I wonder if I actually missed anything here, I appreciate for your feedback! Thanks.
HI @ZEKAICHEN , thanks for raising the issue. We've double checked and re-run the test.py using the code. If the used code is from https://github.com/Project-MONAI/research-contributions/tree/main/SwinUNETR/BTCV and use the Swin UNETR/Base model downloaded. It should give us the Dice score as below using overlap0.5:
Inference on case img0035.nii.gz
Mean Organ Dice: 0.7715836852979835
Inference on case img0036.nii.gz
Mean Organ Dice: 0.8377579306350628
Inference on case img0037.nii.gz
Mean Organ Dice: 0.8386162560902106
Inference on case img0038.nii.gz
Mean Organ Dice: 0.7809781930534572
Inference on case img0039.nii.gz
Mean Organ Dice: 0.8375578949580794
Inference on case img0040.nii.gz
Mean Organ Dice: 0.8275152177091785
Overall Mean Dice: 0.815668196290662
Could you provide more detailed of your implementation of the testing, we can help dig deep to the problem. Thanks!