MedSegDiff icon indicating copy to clipboard operation
MedSegDiff copied to clipboard

How to run multi-label segmentation?

Open gulubao opened this issue 2 years ago • 15 comments

I encountered some issues with multi-label segmentation, and I would like to ask for your help.

The demo ISIC dataset has a single label and the demo BRATS dataset has multiple labels but is merged in the class BRATSDataset3D.

I am interested in performing multi-label segmentation on my own dataset, and I am wondering how to set the dataset and model for this purpose.

Could you please provide a demo for multi-label segmentation?

gulubao avatar Apr 12 '23 07:04 gulubao

Did you try using v2? I see that the structure diagram of v2 seems to be multi classified, but I don't seem to find any specific modifications in the code to address this issue. I tried to modify V1, but due to limited ability, it is difficult to modify loss and multi class output, but there are no errors reported, but I still do not have multi class ability. Do you currently have any good findings or ideas regarding this issue?

theneao avatar Apr 13 '23 13:04 theneao

I am attempting to train on v2. I am not sure if the author adjusted the loss function for multi-class classification, as MSE and VB don't seem to limit the number of categories.

Here are the adjustments I made for multilabel: 1. Preprocess the mask labels. 1.1 Mark labels in the entire dataset as 0,1,..n. Reassign the mask values in the entire dataset by 0-n. 1.2 When importing into a custom DataSet, normalize mask with mask = mask / n. 1.3 Use transforms.Resize((args.image_size, args.image_size), interpolation=InterpolationMode.NEAREST) for mask resizing. 1.4 Delete torch.where(mask > 0, 1, 0). 2. In the gaussain_diffusion.training_losses_segmentation function, change res = torch.where(mask > 0, 1, 0) to res = torch.softmax(mask, dim=0).

Due to limited computational resources, I had to reduce the batch_size, so I replaced all nn.BatchNorm2d with nn.InstanceNorm2d.

I am still in the training process, and I am not sure if these adjustments will work.

I hope the author could share the method for multilabel segmentation in Fig. 2 of the paper.

gulubao avatar Apr 14 '23 06:04 gulubao

I have made similar adjustments on V1 before, but in the end, I found that the loss calculation predicted noise (which can be modified to predict x0), but the predicted results only had one channel and cannot be predicted for multiple classes.

The main reason, I think, may be that the 'out_channels' in the Unet network definition was not modified before.

But after trying to modify it, I found that the model_ output, model_ var_ Values=th. split (model_output, C, dim=1), it is not clear how to allocate the number of channels between the two in multi class scenarios, and it has been found that assigning values to "model_output" using a method similar to "model_output [:,: 0,:]" directly will result in reporting dimension errors in the next loop, although the dimension is correct after the initial run. Of course, it is also possible that my lack of proficiency in learning has caused some mistakes. You can try it yourself.

If it's convenient, you can directly contact me through my homepage email, or let me know your other contact information by email.

theneao avatar Apr 15 '23 10:04 theneao

For the separation of channel numbers, I don't know why using torch.split()is different from directly using model_output [:,: 0,:,:], but I can only specify the classification ratio using model_output, model_var_values=th.split (model_output, 4, dim=1) or model_output, model_var_values=th.split (model_output, [4,1], dim=1) (I do a 5 classification task, and I think output_channel can be set to 5 or 8). However, continuing with the operation will still result in errors. In

def_ predict_ xstart_ from_ eps(self, x_t, t, eps):
assert x_ t.shape == eps.shape

The dimension error of 'x_t' is still reported. This is a computable forward diffusion image that can be input, and I believe it can be copied to the same channel as EPS for calculation. Currently, there has been no attempt. In addition, after comparing the two versions of the code, it was found that some modifications were made to the network structure related to the V2 paper, and no changes were made to the category.

theneao avatar Apr 15 '23 17:04 theneao

I did not try to add channels to do multi-segmentation but tried to generate multiple pixel values representing different classes in the last channel of the current code.

I conducted experiments on BRATS, however, the effect was very poor. My adjustments were in the comment above.

The figure below is the 50th slice in folder slice0001. The figure that looks dark is the GT segmentation. The gray one is the model output.

slice0001_slice50_GT slice0001_slice50_output_ens

gulubao avatar Apr 17 '23 07:04 gulubao

The results are indeed very poor, and the binary classification performance is also almost poor on my dataset, with poor fine-grained performance, far lower than the common U-NET network. It is unclear what caused it. I feel like I want to give up on this project.

Recently, there have been many segmentation networks based on diffusion models for similar tasks. If you are interested, we can discuss them through private email.

I did not try to add channels to do multi-segmentation but tried to generate multiple pixel values representing different classes in the last channel of the current code.

I conducted experiments on BRATS, however, the effect was very poor. My adjustments were in the comment above.

The figure below is the 50th slice in folder slice0001. The figure that looks dark is the GT segmentation. The gray one is the model output.

slice0001_slice50_GT slice0001_slice50_output_ens

theneao avatar Apr 18 '23 03:04 theneao

@gulubao, When I was dealing with the sample code of multi-label classification, I encountered some problems, can you communicate with me?

email: [email protected]

lixiang007666 avatar Apr 20 '23 03:04 lixiang007666

I did not try to add channels to do multi-segmentation but tried to generate multiple pixel values representing different classes in the last channel of the current code.

I conducted experiments on BRATS, however, the effect was very poor. My adjustments were in the comment above.

The figure below is the 50th slice in folder slice0001. The figure that looks dark is the GT segmentation. The gray one is the model output.

slice0001_slice50_GT slice0001_slice50_output_ens

I think it may be the problem of loss function design

theneao avatar Apr 21 '23 13:04 theneao

I did not try to add channels to do multi-segmentation but tried to generate multiple pixel values representing different classes in the last channel of the current code.

I conducted experiments on BRATS, however, the effect was very poor. My adjustments were in the comment above.

The figure below is the 50th slice in folder slice0001. The figure that looks dark is the GT segmentation. The gray one is the model output.

slice0001_slice50_GT slice0001_slice50_output_ens

Have you solved your multi-tab classification yet?

jaceqin avatar May 16 '23 08:05 jaceqin

@gulubao @theneao Hi guys, I think the output sample is including image too - I mean it's giving the segmentation of brain border too - Is this the case with you guys too? - Seeing at the result i think it's the same with your outputs too. Kindly let me know and correct me if I am missing something.

I did not try to add channels to do multi-segmentation but tried to generate multiple pixel values representing different classes in the last channel of the current code.

I conducted experiments on BRATS, however, the effect was very poor. My adjustments were in the comment above.

The figure below is the 50th slice in folder slice0001. The figure that looks dark is the GT segmentation. The gray one is the model output.

slice0001_slice50_GT slice0001_slice50_output_ens

saisusmitha avatar May 22 '23 09:05 saisusmitha

any updates on this? also interested in multi class segmentation.

agentdr1 avatar Aug 30 '23 16:08 agentdr1

There are some results of multi class segmentation for brats dataset. And I don't konw how to threshold the output mask like the ground turth. 0 2 3 5

smallboy-code avatar Oct 27 '23 12:10 smallboy-code

@smallboy-code Hello, what upgrades have you made to the code?

Ray010221 avatar Dec 22 '23 11:12 Ray010221

I am attempting to train on v2. I am not sure if the author adjusted the loss function for multi-class classification, as MSE and VB don't seem to limit the number of categories.

Here are the adjustments I made for multilabel: 1. Preprocess the mask labels. 1.1 Mark labels in the entire dataset as 0,1,..n. Reassign the mask values in the entire dataset by 0-n. 1.2 When importing into a custom DataSet, normalize mask with mask = mask / n. 1.3 Use transforms.Resize((args.image_size, args.image_size), interpolation=InterpolationMode.NEAREST) for mask resizing. 1.4 Delete torch.where(mask > 0, 1, 0). 2. In the gaussain_diffusion.training_losses_segmentation function, change res = torch.where(mask > 0, 1, 0) to res = torch.softmax(mask, dim=0).

Due to limited computational resources, I had to reduce the batch_size, so I replaced all nn.BatchNorm2d with nn.InstanceNorm2d.

I am still in the training process, and I am not sure if these adjustments will work.

I hope the author could share the method for multilabel segmentation in Fig. 2 of the paper.

@gulubao Bro, how you dealing with masks? Multi-channel or one-hot? And do you modify args.in_ch?

thd2020 avatar Sep 26 '24 08:09 thd2020