GenericSSL icon indicating copy to clipboard operation
GenericSSL copied to clipboard

Diffusion denoising branches are useless in unlabelled data

Open hhjjaaa opened this issue 11 months ago • 6 comments

I reproduced the code you provided in the Synapse dataset and got the same result as in the paper. However, I visualized the result of unlabeled data in the diffusion branch and found that the output was all noise, and it could not generate useful pseudo-tags with the weight adjustment branch to help training, but the weight adjustment branch became pseudo-tags entirely. Snipaste_2024-12-03_19-58-26 Snipaste_2024-12-03_19-58-26 Snipaste_2024-12-04_12-59-55 Snipaste_2024-12-04_12-59-55

hhjjaaa avatar Jan 04 '25 08:01 hhjjaaa

This is my visual code def label_to_color(label, num_classes): if num_classes != 14: raise ValueError(f"Expected num_classes=13, but got {num_classes}") colors = [ (0.0, 0.0, 0.0), # Class 0 - Background - Black (1.0, 0.0, 0.0), # Class 1 - Red (0.0, 1.0, 0.0), # Class 2 - Green (0.0, 0.0, 1.0), # Class 3 - Blue (1.0, 1.0, 0.0), # Class 4 - Yellow (1.0, 0.0, 1.0), # Class 5 - Magenta (0.0, 1.0, 1.0), # Class 6 - Cyan (0.5, 0.5, 0.5), # Class 7 - Gray (0.5, 0.0, 0.0), # Class 8 - Dark Red (0.0, 0.5, 0.0), # Class 9 - Dark Green (0.0, 0.0, 0.5), # Class 10 - Dark Blue (0.5, 0.5, 0.0), # Class 11 - Olive (0.5, 0.0, 0.5), # Class 12 - Purple (0.3, 0.8, 0.5) # Class 13 - Additional Color ]

# If the number of classes exceeds the predefined colors, use random colors (though here limited to 13 classes)
if num_classes > len(colors):
    additional_colors = np.random.rand(num_classes - len(colors), 3)
    colors.extend(additional_colors.tolist())

cmap = ListedColormap(colors[:num_classes])

# Convert label to a numpy array
label = label.numpy() if isinstance(label, torch.Tensor) else label

# Apply color mapping
label_color = cmap(label)

return label_color

Add visualization in the training loop

if epoch_num % 10 == 0: # Visualize every 10 epochs model.eval() with torch.no_grad(): # Get a batch of images and labels for visualization batch = next(iter(eval_loader)) images, gts = fetch_data(batch) # images: (B, C, D, H, W), gts: (B, C, D, H, W) or (B, D, H, W)

    # Get the outputs from the model's two branches
    p_u_xi = model(images, pred_type="ddim_sample")  # Shape: (B, C, D, H, W)
    p_u_psi = model(images, pred_type="D_psi_l")  # Shape: (B, C, D, H, W)

    # Convert outputs to probability maps

    smoothing = GaussianSmoothing(config.num_cls, 3, 1)
    pred_xi = smoothing(F.gumbel_softmax(p_u_xi, dim=1))
    pred_psi = F.softmax(p_u_psi, dim=1)  # (B, C, D, H, W)

    # Get predicted classes

    pred_psi = torch.argmax(pred_psi, dim=1, keepdim=True)  # (B, 1, D, H, W)

    for i in range(min(3, images.size(0))):  # Visualize the first 3 images
        image = images[i].cpu()  # Shape: (C, D, H, W)
        gt = gts[i].cpu()  # Shape: (C, D, H, W) or (D, H, W)
        px_i = pred_xi[i].cpu()  # Shape: (1, D, H, W)
        ppsi = pred_psi[i].cpu()  # Shape: (1, D, H, W)
        org_ppsi = p_u_psi[i].cpu()

        # Print shapes for debugging
        print(
            f"Before processing - Image shape: {image.shape}, GT shape: {gt.shape}, pred_xi shape: {px_i.shape}, pred_psi shape: {ppsi.shape}")

        # Select a specific depth slice
        depth_idx = image.size(1) // 2  # Middle slice index
        channel_idx = 0  # Select the first channel

        # Process image slice
        if image.ndimension() == 4:
            # Image shape is (C, D, H, W)
            if image.size(0) > 1:
                image_slice = image[channel_idx, depth_idx, :, :]  # Select specific channel and depth
            else:
                image_slice = image.squeeze(0)[depth_idx, :, :]  # Single channel
        elif image.ndimension() == 3:
            # Image shape is (D, H, W)
            image_slice = image[depth_idx, :, :]
        else:
            raise ValueError(f"Unexpected image dimensions: {image.ndimension()}")

        # Process Ground Truth slice
        if gt.ndimension() == 4:
            if gt.size(0) > 1:
                gt_slice = gt[channel_idx, depth_idx, :, :]
            else:
                gt_slice = gt.squeeze(0)[depth_idx, :, :]
        elif gt.ndimension() == 3:
            gt_slice = gt[depth_idx, :, :]
        else:
            raise ValueError(f"Unexpected GT dimensions: {gt.ndimension()}")

        # Convert label to class indices (if labels are one-hot encoded)
        if gt_slice.ndimension() == 3:
            gt_slice = torch.argmax(gt_slice, dim=0)

        # Process predicted slice
        if px_i.ndimension() == 4:
            if px_i.size(0) > 1:
                pred_xi_slice = px_i[channel_idx, depth_idx, :, :]
            else:
                pred_xi_slice = px_i.squeeze(0)[depth_idx, :, :]
        elif px_i.ndimension() == 3:
            pred_xi_slice = px_i[depth_idx, :, :]
        else:
            raise ValueError(f"Unexpected pred_xi dimensions: {px_i.ndimension()}")

        if ppsi.ndimension() == 4:
            if ppsi.size(0) > 1:
                pred_psi_slice = ppsi[channel_idx, depth_idx, :, :]
            else:
                pred_psi_slice = ppsi.squeeze(0)[depth_idx, :, :]
        elif ppsi.ndimension() == 3:
            pred_psi_slice = ppsi[depth_idx, :, :]
        else:
            raise ValueError(f"Unexpected pred_psi dimensions: {ppsi.ndimension()}")

        if org_ppsi.ndimension() == 4:
            if org_ppsi.size(0) > 1:
                p_u_psi_slice = org_ppsi[channel_idx, depth_idx, :, :]
            else:
                p_u_psi_slice = org_ppsi.squeeze(0)[depth_idx, :, :]
        elif org_ppsi.ndimension() == 3:
            p_u_psi_slice = org_ppsi[depth_idx, :, :]
        else:
            raise ValueError(f"Unexpected pred_psi dimensions: {org_ppsi.ndimension()}")

        # Ensure predictions are also 2D
        if pred_xi_slice.ndimension() > 2:
            pred_xi_slice = torch.argmax(pred_xi_slice, dim=0)
        if pred_psi_slice.ndimension() > 2:
            pred_psi_slice = torch.argmax(pred_psi_slice, dim=0)

        print(
            f"After processing - Image slice shape: {image_slice.shape}, GT slice shape: {gt_slice.shape}, pred_xi_slice shape: {pred_xi_slice.shape}, pred_psi_slice shape: {pred_psi_slice.shape}")

        # Convert labels to color maps
        gt_color = label_to_color(gt_slice, config.num_cls)
        pred_psi_color = label_to_color(pred_psi_slice, config.num_cls)

        # Normalize the image to [0,1] and convert to numpy array

        image_slice = image_slice.numpy()  # (H, W)

        # Create a grid of images
        fig, axs = plt.subplots(1, 5, figsize=(20, 5))

        # Input grayscale image
        axs[0].imshow(image_slice, cmap='gray')
        axs[0].set_title('Input Image')
        axs[0].axis('off')

        # Ground Truth
        axs[1].imshow(gt_color)
        axs[1].set_title('Ground Truth')
        axs[1].axis('off')

        # Predicted p_u_xi
        axs[2].imshow(pred_xi_slice, cmap='gray')
        axs[2].set_title('Predicted p_u_xi')
        axs[2].axis('off')

        # Predicted p_u_psi
        axs[3].imshow(pred_psi_color)
        axs[3].set_title('Predicted p_u_psi')
        axs[3].axis('off')

        axs[4].imshow(p_u_psi_slice, cmap='gray')
        axs[4].set_title('Predicted org_p_u_psi')
        axs[4].axis('off')

        # Convert the figure to a Tensor to add to TensorBoard
        fig.canvas.draw()
        img_tensor = torch.from_numpy(np.array(fig.canvas.renderer.buffer_rgba())).permute(2, 0, 1)[:3, :, :] / 255.0
        writer.add_image(f'Validation/Predictions_{epoch_num}_{i}', img_tensor, epoch_num)
        plt.close(fig)

model.train()

hhjjaaa avatar Jan 04 '25 08:01 hhjjaaa

Unlabeled data is noise during trainingSnipaste_2024-12-03_19-58-26

hhjjaaa avatar Jan 04 '25 09:01 hhjjaaa

How did you get this pictureSnipaste_2024-12-03_19-58-26

hhjjaaa avatar Jan 04 '25 11:01 hhjjaaa

Is visualization conducted during the early training stage? If so, it would be difficult for the diffusion process to learn effective representations.

McGregorWwww avatar Jan 06 '25 02:01 McGregorWwww

The output of the diffused branches on the M&Ms dataset did learn the features (the diffused parts learned well from early to middle) and the segmentation results were as expected, but the output of the diffused branches on the Synapse dataset was all noise (all periods).

hhjjaaa avatar Jan 06 '25 04:01 hhjjaaa

Snipaste_2024-12-03_19-58-26 Snipaste_2024-12-04_12-59-55

I reproduced the code you provided in the Synapse dataset and got the same result as in the paper. However, I visualized the result of unlabeled data in the diffusion branch and found that the output was all noise, and it could not generate useful pseudo-tags with the weight adjustment branch to help training, but the weight adjustment branch became pseudo-tags entirely. Snipaste_2024-12-03_19-58-26 Snipaste_2024-12-03_19-58-26 Snipaste_2024-12-04_12-59-55 Snipaste_2024-12-04_12-59-55

hhjjaaa avatar Jan 07 '25 06:01 hhjjaaa