Diffusion denoising branches are useless in unlabelled data
I reproduced the code you provided in the Synapse dataset and got the same result as in the paper. However, I visualized the result of unlabeled data in the diffusion branch and found that the output was all noise, and it could not generate useful pseudo-tags with the weight adjustment branch to help training, but the weight adjustment branch became pseudo-tags entirely.
This is my visual code def label_to_color(label, num_classes): if num_classes != 14: raise ValueError(f"Expected num_classes=13, but got {num_classes}") colors = [ (0.0, 0.0, 0.0), # Class 0 - Background - Black (1.0, 0.0, 0.0), # Class 1 - Red (0.0, 1.0, 0.0), # Class 2 - Green (0.0, 0.0, 1.0), # Class 3 - Blue (1.0, 1.0, 0.0), # Class 4 - Yellow (1.0, 0.0, 1.0), # Class 5 - Magenta (0.0, 1.0, 1.0), # Class 6 - Cyan (0.5, 0.5, 0.5), # Class 7 - Gray (0.5, 0.0, 0.0), # Class 8 - Dark Red (0.0, 0.5, 0.0), # Class 9 - Dark Green (0.0, 0.0, 0.5), # Class 10 - Dark Blue (0.5, 0.5, 0.0), # Class 11 - Olive (0.5, 0.0, 0.5), # Class 12 - Purple (0.3, 0.8, 0.5) # Class 13 - Additional Color ]
# If the number of classes exceeds the predefined colors, use random colors (though here limited to 13 classes)
if num_classes > len(colors):
additional_colors = np.random.rand(num_classes - len(colors), 3)
colors.extend(additional_colors.tolist())
cmap = ListedColormap(colors[:num_classes])
# Convert label to a numpy array
label = label.numpy() if isinstance(label, torch.Tensor) else label
# Apply color mapping
label_color = cmap(label)
return label_color
Add visualization in the training loop
if epoch_num % 10 == 0: # Visualize every 10 epochs model.eval() with torch.no_grad(): # Get a batch of images and labels for visualization batch = next(iter(eval_loader)) images, gts = fetch_data(batch) # images: (B, C, D, H, W), gts: (B, C, D, H, W) or (B, D, H, W)
# Get the outputs from the model's two branches
p_u_xi = model(images, pred_type="ddim_sample") # Shape: (B, C, D, H, W)
p_u_psi = model(images, pred_type="D_psi_l") # Shape: (B, C, D, H, W)
# Convert outputs to probability maps
smoothing = GaussianSmoothing(config.num_cls, 3, 1)
pred_xi = smoothing(F.gumbel_softmax(p_u_xi, dim=1))
pred_psi = F.softmax(p_u_psi, dim=1) # (B, C, D, H, W)
# Get predicted classes
pred_psi = torch.argmax(pred_psi, dim=1, keepdim=True) # (B, 1, D, H, W)
for i in range(min(3, images.size(0))): # Visualize the first 3 images
image = images[i].cpu() # Shape: (C, D, H, W)
gt = gts[i].cpu() # Shape: (C, D, H, W) or (D, H, W)
px_i = pred_xi[i].cpu() # Shape: (1, D, H, W)
ppsi = pred_psi[i].cpu() # Shape: (1, D, H, W)
org_ppsi = p_u_psi[i].cpu()
# Print shapes for debugging
print(
f"Before processing - Image shape: {image.shape}, GT shape: {gt.shape}, pred_xi shape: {px_i.shape}, pred_psi shape: {ppsi.shape}")
# Select a specific depth slice
depth_idx = image.size(1) // 2 # Middle slice index
channel_idx = 0 # Select the first channel
# Process image slice
if image.ndimension() == 4:
# Image shape is (C, D, H, W)
if image.size(0) > 1:
image_slice = image[channel_idx, depth_idx, :, :] # Select specific channel and depth
else:
image_slice = image.squeeze(0)[depth_idx, :, :] # Single channel
elif image.ndimension() == 3:
# Image shape is (D, H, W)
image_slice = image[depth_idx, :, :]
else:
raise ValueError(f"Unexpected image dimensions: {image.ndimension()}")
# Process Ground Truth slice
if gt.ndimension() == 4:
if gt.size(0) > 1:
gt_slice = gt[channel_idx, depth_idx, :, :]
else:
gt_slice = gt.squeeze(0)[depth_idx, :, :]
elif gt.ndimension() == 3:
gt_slice = gt[depth_idx, :, :]
else:
raise ValueError(f"Unexpected GT dimensions: {gt.ndimension()}")
# Convert label to class indices (if labels are one-hot encoded)
if gt_slice.ndimension() == 3:
gt_slice = torch.argmax(gt_slice, dim=0)
# Process predicted slice
if px_i.ndimension() == 4:
if px_i.size(0) > 1:
pred_xi_slice = px_i[channel_idx, depth_idx, :, :]
else:
pred_xi_slice = px_i.squeeze(0)[depth_idx, :, :]
elif px_i.ndimension() == 3:
pred_xi_slice = px_i[depth_idx, :, :]
else:
raise ValueError(f"Unexpected pred_xi dimensions: {px_i.ndimension()}")
if ppsi.ndimension() == 4:
if ppsi.size(0) > 1:
pred_psi_slice = ppsi[channel_idx, depth_idx, :, :]
else:
pred_psi_slice = ppsi.squeeze(0)[depth_idx, :, :]
elif ppsi.ndimension() == 3:
pred_psi_slice = ppsi[depth_idx, :, :]
else:
raise ValueError(f"Unexpected pred_psi dimensions: {ppsi.ndimension()}")
if org_ppsi.ndimension() == 4:
if org_ppsi.size(0) > 1:
p_u_psi_slice = org_ppsi[channel_idx, depth_idx, :, :]
else:
p_u_psi_slice = org_ppsi.squeeze(0)[depth_idx, :, :]
elif org_ppsi.ndimension() == 3:
p_u_psi_slice = org_ppsi[depth_idx, :, :]
else:
raise ValueError(f"Unexpected pred_psi dimensions: {org_ppsi.ndimension()}")
# Ensure predictions are also 2D
if pred_xi_slice.ndimension() > 2:
pred_xi_slice = torch.argmax(pred_xi_slice, dim=0)
if pred_psi_slice.ndimension() > 2:
pred_psi_slice = torch.argmax(pred_psi_slice, dim=0)
print(
f"After processing - Image slice shape: {image_slice.shape}, GT slice shape: {gt_slice.shape}, pred_xi_slice shape: {pred_xi_slice.shape}, pred_psi_slice shape: {pred_psi_slice.shape}")
# Convert labels to color maps
gt_color = label_to_color(gt_slice, config.num_cls)
pred_psi_color = label_to_color(pred_psi_slice, config.num_cls)
# Normalize the image to [0,1] and convert to numpy array
image_slice = image_slice.numpy() # (H, W)
# Create a grid of images
fig, axs = plt.subplots(1, 5, figsize=(20, 5))
# Input grayscale image
axs[0].imshow(image_slice, cmap='gray')
axs[0].set_title('Input Image')
axs[0].axis('off')
# Ground Truth
axs[1].imshow(gt_color)
axs[1].set_title('Ground Truth')
axs[1].axis('off')
# Predicted p_u_xi
axs[2].imshow(pred_xi_slice, cmap='gray')
axs[2].set_title('Predicted p_u_xi')
axs[2].axis('off')
# Predicted p_u_psi
axs[3].imshow(pred_psi_color)
axs[3].set_title('Predicted p_u_psi')
axs[3].axis('off')
axs[4].imshow(p_u_psi_slice, cmap='gray')
axs[4].set_title('Predicted org_p_u_psi')
axs[4].axis('off')
# Convert the figure to a Tensor to add to TensorBoard
fig.canvas.draw()
img_tensor = torch.from_numpy(np.array(fig.canvas.renderer.buffer_rgba())).permute(2, 0, 1)[:3, :, :] / 255.0
writer.add_image(f'Validation/Predictions_{epoch_num}_{i}', img_tensor, epoch_num)
plt.close(fig)
model.train()
Unlabeled data is noise during training
How did you get this picture
Is visualization conducted during the early training stage? If so, it would be difficult for the diffusion process to learn effective representations.
The output of the diffused branches on the M&Ms dataset did learn the features (the diffused parts learned well from early to middle) and the segmentation results were as expected, but the output of the diffused branches on the Synapse dataset was all noise (all periods).
I reproduced the code you provided in the Synapse dataset and got the same result as in the paper. However, I visualized the result of unlabeled data in the diffusion branch and found that the output was all noise, and it could not generate useful pseudo-tags with the weight adjustment branch to help training, but the weight adjustment branch became pseudo-tags entirely.
![]()
![]()
![]()
