FastSAM3D icon indicating copy to clipboard operation
FastSAM3D copied to clipboard

Difference between checkpoint FASTSAM3D and Finetuned-SAMMED3D

Open MinxuanQin opened this issue 1 year ago • 9 comments

Thank you for sharing the excellent code and checkpoints! I have run the code described in Readme.md and would like to determine whether I correctly understood them.

The current version of distillation.py and validate_student.py use an ImageEncoder with so-called "woatt" attention (window attention), not with 3D sparse flash attention. The validate_student.py file loads the tiny image encoder (first uploaded checkpoint on Github) as the image encoder; the remaining parts use the fine-tuned teacher model (the second uploaded checkpoint "Finetuned-SAMMED3D"). Does the third checkpoint, "FASTSAM3D," combine the tiny encoder and rest part together?

I think those checkpoints do not use build_sam3D_flash.py, build_sam3D_dilatedattention.py, and build_3D_decoder.py. Is it right? Does the checkpoint perform best among all encoder and decoder structure versions? Thank you!

MinxuanQin avatar Sep 10 '24 15:09 MinxuanQin

The flash attention part is just used for inference, not for distillation. You could feel free to use flash attention for your inference.

skill-diver avatar Sep 16 '24 23:09 skill-diver

Thank you for your reply! So you have distilled a lightweight image encoder with only 6 layers, where the first two layers does not contain attention layers. For the inference, there are no checkpoints with flash attention available; I can distill an image encoder with flash attention and then use it for inference. Do I understand it correctly?

MinxuanQin avatar Sep 19 '24 12:09 MinxuanQin

You are correct except one point: You could use our checkpoint to Inference, it supports flash attention.

skill-diver avatar Sep 19 '24 13:09 skill-diver

Thank you for your quick reply! I am not familiar to flash attention, so it is maybe a silly question: Based on your answer and the code I think flash attention is an efficient attention computation mechanism, does not change the network architecture. So the checkpoint supports both with and without flash attention, right?

MinxuanQin avatar Sep 19 '24 13:09 MinxuanQin

I have another question regarding to the distillation process: From utils/prepare_nnunet.py the images and labels from one dataset shall be stored under label_name/dataset_name/imagesTr and label_name/dataset_name/labelsTr, but preparelabel.py and validation.py only get directory name like label_name because they use all_dataset_paths = glob(join(args.test_data_path)) not all_dataset_paths = glob(join(args.test_data_path,'*','*')).

For the distillation process, could I use the image only once, not use data generated from utils/prepare_uunet.py, because the script generates duplicated images for each class of a single subject, and preparelabel.py does not need label?

Thank you very much for your help!

MinxuanQin avatar Sep 19 '24 14:09 MinxuanQin

Thank you for your quick reply! I am not familiar to flash attention, so it is maybe a silly question: Based on your answer and the code I think flash attention is an efficient attention computation mechanism, does not change the network architecture. So the checkpoint supports both with and without flash attention, right?

You are right.

skill-diver avatar Sep 20 '24 18:09 skill-diver

I have another question regarding to the distillation process: From utils/prepare_nnunet.py the images and labels from one dataset shall be stored under label_name/dataset_name/imagesTr and label_name/dataset_name/labelsTr, but preparelabel.py and validation.py only get directory name like label_name because they use all_dataset_paths = glob(join(args.test_data_path)) not all_dataset_paths = glob(join(args.test_data_path,'*','*')).

For the distillation process, could I use the image only once, not use data generated from utils/prepare_uunet.py, because the script generates duplicated images for each class of a single subject, and preparelabel.py does not need label?

Thank you very much for your help!

You need to use prepare_uunet.py. The model need to learn from this preprocessed images (crops, registration,etc is necessary).

skill-diver avatar Sep 20 '24 18:09 skill-diver

Got it. Thank you very much!

MinxuanQin avatar Sep 21 '24 16:09 MinxuanQin

I have a question regarding to the distillation loss. From the paper, the objective of the layer-wise progressive distillation process is described as

$$E_x (\frac{1}{k} \sum_{i=1}^{k} \Vert f_{teacher}^{(2i)} (x) - f_{student}^{(i)} (x) \Vert )$$

, where $k$ varies from 1 to 6 based on current and total training iterations. From the code distillation.py, I think the variable curlayer from the class BaseTrainer plays the role of $k$, but the loss in this case is defined as loss = self.seg_loss(output[self.curlayer], label[self.curlayer]), where only L2 norm in the current layer is computed, not from $i=1$ to $i=k$ from my point of view.

In addition, I have read that the iterations is set to 36 for the first laye-wise distillation process from the paper. I would like to know how many iterations were set for the logit-level distillation process. Thank you!

MinxuanQin avatar Sep 21 '24 18:09 MinxuanQin