tutorials icon indicating copy to clipboard operation
tutorials copied to clipboard

Training loss crush after few epochs

Open sulaimanvesal opened this issue 1 year ago • 1 comments

Hi team,

Thank you for your all supports.

It's been awhile that I am using Monai for our multi-modal image segmentation (DynUNet 3D). Recently after fixing many issues, I was able to run a model. However, there is a weird behavior during the training. The model is converging nicely however after 10-15 epochs both the loss and dice_metric crushing and goes to zero.

In the beginning, I also get the following warning which after deubgging I though it's from dice metric function. invalid value encountered in true_divide. However, I made sure that all the labels should have binary values [0,1]. My model is basically does a lesion segmentation so it's a binary.

I was wondering if I do something wrong here.

initialize network with normal
Tue Aug  9 06:33:18 2022 Epoch: 0
Final training  0/49 loss: 0.9928 time 2634.39s
y_pred should be a binarized tensor.
y should be a binarized tensor.
Final validation stats 0/49 , Dice_TC: 0.069751926 , Dice_Avg: 0.069751926 , time 491.24s
new best (0.000000 --> 0.069752).
Saving checkpoint weights_multiInput/model_fold2_multiInput.pt
Tue Aug  9 07:25:25 2022 Epoch: 1
Final training  1/49 loss: 0.9752 time 2545.19s
Final validation stats 1/49 , Dice_TC: 0.2406731 , Dice_Avg: 0.2406731 , time 474.64s
new best (0.069752 --> 0.240673).
Saving checkpoint weights_multiInput/model_fold2_multiInput.pt
Tue Aug  9 08:15:46 2022 Epoch: 2
Final training  2/49 loss: 0.9378 time 2539.51s
Final validation stats 2/49 , Dice_TC: 0.32946482 , Dice_Avg: 0.32946482 , time 463.56s
new best (0.240673 --> 0.329465).
Saving checkpoint weights_multiInput/model_fold2_multiInput.pt
Tue Aug  9 09:05:50 2022 Epoch: 3
Final training  3/49 loss: 0.9069 time 2520.52s
Final validation stats 3/49 , Dice_TC: 0.4997478 , Dice_Avg: 0.4997478 , time 473.23s
new best (0.329465 --> 0.499748).
Saving checkpoint weights_multiInput/model_fold2_multiInput.pt
Tue Aug  9 09:55:44 2022 Epoch: 4
Final training  4/49 loss: 0.9046 time 2647.55s
Final validation stats 4/49 , Dice_TC: 0.48587176 , Dice_Avg: 0.48587176 , time 444.82s
Tue Aug  9 10:47:16 2022 Epoch: 5
Final training  5/49 loss: 0.8924 time 2472.73s
Final validation stats 5/49 , Dice_TC: 0.5072948 , Dice_Avg: 0.5072948 , time 462.37s
new best (0.499748 --> 0.507295).
Saving checkpoint weights_multiInput/model_fold2_multiInput.pt
Tue Aug  9 11:36:12 2022 Epoch: 6
Final training  6/49 loss: 0.8802 time 2534.62s
Final validation stats 6/49 , Dice_TC: 0.51497203 , Dice_Avg: 0.51497203 , time 504.30s
new best (0.507295 --> 0.514972).
Saving checkpoint weights_multiInput/model_fold2_multiInput.pt
Tue Aug  9 12:26:52 2022 Epoch: 7
Final training  7/49 loss: 0.5781 time 2409.70s
Final validation stats 7/49 , Dice_TC: 2.0384609e-14 , Dice_Avg: 2.0384609e-14 , time 479.41s
Tue Aug  9 13:15:01 2022 Epoch: 8
Final training  8/49 loss: 0.2768 time 2407.49s
Final validation stats 8/49 , Dice_TC: 2.9713095e-17 , Dice_Avg: 2.9713095e-17 , time 429.29s
Tue Aug  9 14:02:18 2022 Epoch: 9
Final training  9/49 loss: 0.2783 time 2592.64s
Final validation stats 9/49 , Dice_TC: 2.2672533e-18 , Dice_Avg: 2.2672533e-18 , time 471.07s
Tue Aug  9 14:53:22 2022 Epoch: 10
Final training  10/49 loss: 0.2783 time 2667.51s
Final validation stats 10/49 , Dice_TC: 1.17504706e-20 , Dice_Avg: 1.17504706e-20 , time 502.52s
def get_loader(train_images_t2w,train_images_adc,train_images_dwi, train_segs, valid_images_t2w, valid_images_adc,valid_images_dwi,valid_segs, patch_size = [256, 256, 20]):
    
    data_dicts_train = [{'image_t2': image_name_t2, 'image_adc': image_name_adc, 'image_dwi': image_name_dwi,'label': label_name} for image_name_t2, image_name_adc, image_name_dwi, label_name in zip(train_images_t2w,train_images_adc,train_images_dwi, train_segs)]
    train_transform = Compose([
        LoadImaged(keys=['image_t2', 'image_adc','image_dwi', 'label']),
        AddChanneld(keys=['image_t2', 'image_adc','image_dwi', 'label']),
        Spacingd(keys=['image_t2', 'image_adc','image_dwi', 'label'], pixdim=(0.5, 0.5, 3.0),  mode=("bilinear",
                                                                                                    "bilinear",
                                                                                                    "bilinear", "nearest")),
        CenterSpatialCropd(keys=['image_t2', 'image_adc','image_dwi', 'label'],roi_size = patch_size), 
        SpatialPadd(keys=['image_t2', 'image_adc','image_dwi', 'label'], spatial_size=patch_size, method='end'),
        ConcatItemsd(keys=['image_t2', 'image_adc','image_dwi'], name="image"),
        RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=0),
        RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=1),
        RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=2),
        RandAdjustContrastd(keys=['image'],gamma=(0.5, 2.5),prob=0.2),
        RandAffined(
           keys=("image", "label"),
           prob=0.5,
           #rotate_range=np.pi / 12,
           translate_range=(256*0.0625, 256*0.0625),
           scale_range=(0.1, 0.1),
           mode="nearest",
           padding_mode="reflection",
        ),
        OneOf(
           [ 
               RandGridDistortiond(keys=("image", "label"), 
                                   prob=0.5, distort_limit=(-0.05, 0.05), 
                                   mode="nearest", 
                                   padding_mode="reflection"),
               RandCoarseDropoutd(
                   keys=("image", "label"),
                   holes=5,
                   max_holes=8,
                   spatial_size=(1, 1, 1),
                   max_spatial_size=(12, 12, 12),
                   fill_value=0.0,
                   prob=0.5,
               ),
           ]
        ),

        Lambdad(keys="image", func=lambda x: x / x.max()), #NormalizeIntensityd(keys=['image'], nonzero=True, channel_wise=True),
        ToTensord(keys=['image', 'label'])
    ])
    data_dicts_valid = [{'image_t2': image_name_t2, 
                         'image_adc': image_name_adc, 
                         'image_dwi': image_name_dwi,
                         'label': label_name} for image_name_t2, image_name_adc, image_name_dwi, label_name in zip(valid_images_t2w,valid_images_adc,valid_images_dwi, valid_segs)]
    valid_transform = Compose([
        LoadImaged(keys=['image_t2', 'image_adc','image_dwi', 'label']), 
        AddChanneld(keys=['image_t2', 'image_adc','image_dwi', 'label']),
        Spacingd(keys=['image_t2', 'image_adc','image_dwi', "label"], pixdim=[0.5, 0.5, 3.0], mode=("bilinear",
                                                                                                    "bilinear",
                                                                                                    "bilinear", "nearest")),

        CenterSpatialCropd(keys=['image_t2', 'image_adc','image_dwi', 'label'], roi_size = patch_size), 
        SpatialPadd(keys=['image_t2', 'image_adc','image_dwi', 'label'], spatial_size=patch_size, method='end'),
        ConcatItemsd(keys=['image_t2', 'image_adc','image_dwi'], name="image"),
        Lambdad(keys="image", func=lambda x: x / x.max()), #NormalizeIntensityd(keys=['image'], nonzero=True, channel_wise=True),
        ToTensord(keys=['image', 'label'])
    ])
    
    train_ds = data.Dataset(data=data_dicts_train, transform=train_transform)
    train_loader = data.DataLoader(
        train_ds,
        batch_size=1,
        shuffle=True, #collate_fn=list_data_collate,
        num_workers=12,
        pin_memory=True,
    )
    val_ds = data.Dataset(data=data_dicts_valid, transform=valid_transform)
    val_loader = data.DataLoader(
        val_ds,
        batch_size=1,
        shuffle=False,
        num_workers=12, #collate_fn=list_data_collate,
        pin_memory=False
    )
    return train_loader, val_loader

sulaimanvesal avatar Aug 11 '22 17:08 sulaimanvesal

Hi @sulaimanvesal, could you please provide more information like your MONAI version, your data and how you compute your Dice. And if you verify the network's data entry? To determine exactly what went wrong, you might be able to simplify your data preparation and data checking processes.

KumoLiu avatar Aug 12 '22 02:08 KumoLiu