Medical-SAM2 icon indicating copy to clipboard operation
Medical-SAM2 copied to clipboard

Inquiry About Training Computational Resources

Open JinyuCai124578 opened this issue 1 year ago • 7 comments

Could you please share the hardware configurations and the overall computational power you employed during the training phase? I encountered with "CUDA OUT OF MEMORY" error when using 4090. I'm wondering if A100 with 40G memory is ok for training

JinyuCai124578 avatar Oct 15 '24 08:10 JinyuCai124578

4090 is equipped with 24gb vram which should be enough. are you running 2d image training or 3d video training?

ff98li avatar Dec 11 '24 08:12 ff98li

running on HPC;

Image

Any idea as to how to go about it would be great.

AfrifaEben7 avatar Mar 24 '25 22:03 AfrifaEben7

running on HPC;

Image

Any idea as to how to go about it would be great.

Try reduing image size or video length. image size could be set in the config file (default 1024)

ff98li avatar Mar 25 '25 04:03 ff98li

running on HPC; Image Any idea as to how to go about it would be great.

Try reduing image size or video length. image size could be set in the config file (default 1024)

Thanks for the reply. I did. I reduced the image size to 369 like that of the BTCV dataset. I even decided to use only two frames per video for the length too. I have this error now.


train_3d.py -net sam2 -exp_name btcv_train_test -sam_ckpt /mmfs1/cm/shared/apps_local/medical-sam2/checkpoints/sam2_hiera_small.pt -sam_config sam2_hiera_s -image_size 1024 -val_freq 1 -prompt bbox -prompt_freq 2 -dataset btcv -data_path Exp/test_369/data -image_size 369 -out_size 369

error


Traceback (most recent call last):
  File "/mmfs1/cm/shared/apps_local/medical-sam2/train_3d.py", line 114, in <module>
    main()
  File "/mmfs1/cm/shared/apps_local/medical-sam2/train_3d.py", line 97, in main
    loss, prompt_loss, non_prompt_loss = function.train_sam(args, net, optimizer1, optimizer2, nice_train_loader, epoch)
                                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mmfs1/cm/shared/apps_local/medical-sam2/func_3d/function.py", line 85, in train_sam
    train_state = net.train_init_state(imgs_tensor=imgs_tensor)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mmfs1/cm/shared/apps_local/medical-sam2/sam2_train/sam2_video_predictor.py", line 247, in train_init_state
    self._get_image_feature(inference_state, frame_idx=0, batch_size=1)
  File "/mmfs1/cm/shared/apps_local/medical-sam2/sam2_train/sam2_video_predictor.py", line 1279, in _get_image_feature
    backbone_out = self.forward_image(image)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mmfs1/cm/shared/apps_local/medical-sam2/sam2_train/modeling/sam2_base.py", line 465, in forward_image
    backbone_out = self.image_encoder(img_batch)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mmfs1/cm/shared/apps_local/python/3.11/envs/medsam-2/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mmfs1/cm/shared/apps_local/python/3.11/envs/medsam-2/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mmfs1/cm/shared/apps_local/medical-sam2/sam2_train/modeling/backbones/image_encoder.py", line 31, in forward
    features, pos = self.neck(self.trunk(sample))
                              ^^^^^^^^^^^^^^^^^^
  File "/mmfs1/cm/shared/apps_local/python/3.11/envs/medsam-2/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mmfs1/cm/shared/apps_local/python/3.11/envs/medsam-2/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mmfs1/cm/shared/apps_local/medical-sam2/sam2_train/modeling/backbones/hieradet.py", line 284, in forward
    x = x + self._get_pos_embed(x.shape[1:3])
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mmfs1/cm/shared/apps_local/medical-sam2/sam2_train/modeling/backbones/hieradet.py", line 273, in _get_pos_embed
    pos_embed = pos_embed + window_embed.tile(
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: The size of tensor a (93) must match the size of tensor b (88) at non-singleton dimension 3

AfrifaEben7 avatar Mar 25 '25 05:03 AfrifaEben7

@AfrifaEben7 did you only modify the the image size in args without modifying the ymal config file? the training script args controls the image size coming out of the dataloader while the ymal config file controls the model architecture. you need to modify both

ff98li avatar Mar 25 '25 06:03 ff98li

@ff98li, thank you. It worked, but it seems I still have a memory issue. Ater 11 epochs, it stopped an threw the same memory issues.

Thanks anyway you have really helped.

AfrifaEben7 avatar Mar 26 '25 03:03 AfrifaEben7

Why did I also modify the imagesize in yaml, but the following problem appeared

src shape: torch.Size([8, 64, 64, 64]) dense_prompt_embeddings shape: torch.Size([8, 256, 32, 32]) Traceback (most recent call last): File "/data/wcldata/medsam/Medical-SAM2-main/train_2d_lora.py", line 169, in main() File "/data/wcldata/medsam/Medical-SAM2-main/train_2d_lora.py", line 127, in main tol, (eiou, edice) = function.validation_sam(args, test_loader, epoch, net) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/data/wcldata/medsam/Medical-SAM2-main/func_2d/function.py", line 391, in validation_sam low_res_multimasks, iou_predictions, sam_output_tokens, object_score_logits = net.sam_mask_decoder( ^^^^^^^^^^^^^^^^^^^^^ File "/data/wcldata/miniconda3/envs/medsam2/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl return self._call_impl(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/data/wcldata/miniconda3/envs/medsam2/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl return forward_call(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/data/wcldata/medsam/Medical-SAM2-main/sam2_train/modeling/sam/mask_decoder.py", line 136, in forward masks, iou_pred, mask_tokens_out, object_score_logits = self.predict_masks( ^^^^^^^^^^^^^^^^^^^ File "/data/wcldata/medsam/Medical-SAM2-main/sam2_train/modeling/sam/mask_decoder.py", line 207, in predict_masks src = src + dense_prompt_embeddings ~~~~^~~~~~~~~~~~~~~~~~~~~~~~~ RuntimeError: The size of tensor a (64) must match the size of tensor b (32) at non-singleton dimension 3

dawwnn avatar Apr 27 '25 11:04 dawwnn