add: example notebooks to train and predict
- created examples folder
- added quick start colab notebook that trains&validates 2d refuge data
created examples folder
- added quick start colab notebook that trains&validates 2d refuge data
i tried colab notebook with a100, l4, and t4. notebook mysteriously crashed. surprisingly, on kaggle, notebook did run with minor fix /content -> /kaggle/working but got this error
/kaggle/working/Medical-SAM2/sam2_train/modeling/sam/transformer.py:22: UserWarning: Flash Attention is disabled as it requires a GPU with Ampere (8.0) CUDA capability.
OLD_GPU, USE_FLASH_ATTN, MATH_KERNEL_ON = get_sdpa_settings()
INFO:root:Namespace(net='sam2', encoder='vit_b', exp_name='REFUGE_MedSAM2', vis=True, train_vis=False, prompt='bbox', prompt_freq=2, pretrain=None, val_freq=1, gpu=True, gpu_device=0, image_size=1024, out_size=1024, distributed='none', dataset='REFUGE', sam_ckpt='/kaggle/working/checkpoints/sam2_hiera_tiny.pt', sam_config='sam2_hiera_t', video_length=2, b=2, lr=0.0001, weights=0, multimask_output=1, memory_bank_size=16, data_path='/kaggle/working/data/REFUGE', path_helper={'prefix': 'logs/REFUGE_MedSAM2_2024_09_10_02_23_32', 'ckpt_path': 'logs/REFUGE_MedSAM2_2024_09_10_02_23_32/Model', 'log_path': 'logs/REFUGE_MedSAM2_2024_09_10_02_23_32/Log', 'sample_path': 'logs/REFUGE_MedSAM2_2024_09_10_02_23_32/Samples'})
Namespace(net='sam2', encoder='vit_b', exp_name='REFUGE_MedSAM2', vis=True, train_vis=False, prompt='bbox', prompt_freq=2, pretrain=None, val_freq=1, gpu=True, gpu_device=0, image_size=1024, out_size=1024, distributed='none', dataset='REFUGE', sam_ckpt='/kaggle/working/checkpoints/sam2_hiera_tiny.pt', sam_config='sam2_hiera_t', video_length=2, b=2, lr=0.0001, weights=0, multimask_output=1, memory_bank_size=16, data_path='/kaggle/working/data/REFUGE', path_helper={'prefix': 'logs/REFUGE_MedSAM2_2024_09_10_02_23_32', 'ckpt_path': 'logs/REFUGE_MedSAM2_2024_09_10_02_23_32/Model', 'log_path': 'logs/REFUGE_MedSAM2_2024_09_10_02_23_32/Log', 'sample_path': 'logs/REFUGE_MedSAM2_2024_09_10_02_23_32/Samples'})
Traceback (most recent call last):
File "/kaggle/working/Medical-SAM2/train_2d.py", line 124, in <module>
main()
File "/kaggle/working/Medical-SAM2/train_2d.py", line 97, in main
tol, (eiou, edice) = function.validation_sam(args, nice_test_loader, epoch, net, writer)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/kaggle/working/Medical-SAM2/func_2d/function.py", line 335, in validation_sam
vision_feats_temp = vision_feats[-1].permute(1, 0, 2).view(B, -1, 64, 64)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.
ERROR conda.cli.main_run:execute(125): `conda run bash -c python train_2d.py -net sam2 -exp_name REFUGE_MedSAM2 -vis 1 -sam_ckpt /kaggle/working/checkpoints/sam2_hiera_tiny.pt -sam_config sam2_hiera_t -image_size 1024 -out_size 1024 -b 2 -val_freq 1 -dataset REFUGE -data_path /kaggle/working/data/REFUGE` failed. (See above for error)
anyway thanks. i will try notebook way without using conda.
Actually I removed the part where I added reshape func, since my PR was merged :/ Kaggle one interesting... I guess the problem related to GPU architecture differences? The note from my local notebook: `#rewrite this file because it breaks
changed view to reshape in vision_feats_temp var on line 104
#vision_feats_temp = vision_feats[-1].permute(1, 0, 2).reshape(B, -1, 64, 64)`
So for colab, can you add this cell right before training command and try again? @ibinti
%%writefile /content/Medical-SAM2/func_2d/function.py
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
import cfg
from conf import settings
from func_2d.utils import *
import pandas as pd
args = cfg.parse_args()
GPUdevice = torch.device('cuda', args.gpu_device)
pos_weight = torch.ones([1]).cuda(device=GPUdevice)*2
criterion_G = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight)
mask_type = torch.float32
torch.backends.cudnn.benchmark = True
def train_sam(args, net: nn.Module, optimizer, train_loader, epoch, writer):
# use bfloat16 for the entire notebook
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
if torch.cuda.get_device_properties(0).major >= 8:
# turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
# train mode
net.train()
optimizer.zero_grad()
# init
epoch_loss = 0
memory_bank_list = []
lossfunc = criterion_G
feat_sizes = [(256, 256), (128, 128), (64, 64)]
with tqdm(total=len(train_loader), desc=f'Epoch {epoch}', unit='img') as pbar:
for ind, pack in enumerate(train_loader):
to_cat_memory = []
to_cat_memory_pos = []
to_cat_image_embed = []
# input image and gt masks
imgs = pack['image'].to(dtype = mask_type, device = GPUdevice)
masks = pack['mask'].to(dtype = mask_type, device = GPUdevice)
name = pack['image_meta_dict']['filename_or_obj']
# click prompt: unsqueeze to indicate only one click, add more click across this dimension
if 'pt' in pack:
pt_temp = pack['pt'].to(device = GPUdevice)
pt = pt_temp.unsqueeze(1)
point_labels_temp = pack['p_label'].to(device = GPUdevice)
point_labels = point_labels_temp.unsqueeze(1)
coords_torch = torch.as_tensor(pt, dtype=torch.float, device=GPUdevice)
labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=GPUdevice)
else:
coords_torch = None
labels_torch = None
'''Train image encoder'''
backbone_out = net.forward_image(imgs)
_, vision_feats, vision_pos_embeds, _ = net._prepare_backbone_features(backbone_out)
# dimension hint for your future use
# vision_feats: list: length = 3
# vision_feats[0]: torch.Size([65536, batch, 32])
# vision_feats[1]: torch.Size([16384, batch, 64])
# vision_feats[2]: torch.Size([4096, batch, 256])
# vision_pos_embeds[0]: torch.Size([65536, batch, 256])
# vision_pos_embeds[1]: torch.Size([16384, batch, 256])
# vision_pos_embeds[2]: torch.Size([4096, batch, 256])
'''Train memory attention to condition on meomory bank'''
B = vision_feats[-1].size(1) # batch size
if len(memory_bank_list) == 0:
vision_feats[-1] = vision_feats[-1] + torch.nn.Parameter(torch.zeros(1, B, net.hidden_dim)).to(device="cuda")
vision_pos_embeds[-1] = vision_pos_embeds[-1] + torch.nn.Parameter(torch.zeros(1, B, net.hidden_dim)).to(device="cuda")
else:
for element in memory_bank_list:
to_cat_memory.append((element[0]).cuda(non_blocking=True).flatten(2).permute(2, 0, 1)) # maskmem_features
to_cat_memory_pos.append((element[1]).cuda(non_blocking=True).flatten(2).permute(2, 0, 1)) # maskmem_pos_enc
to_cat_image_embed.append((element[3]).cuda(non_blocking=True)) # image_embed
memory_stack_ori = torch.stack(to_cat_memory, dim=0)
memory_pos_stack_ori = torch.stack(to_cat_memory_pos, dim=0)
image_embed_stack_ori = torch.stack(to_cat_image_embed, dim=0)
vision_feats_temp = vision_feats[-1].permute(1, 0, 2).reshape(B, -1, 64, 64)
vision_feats_temp = vision_feats_temp.reshape(B, -1)
image_embed_stack_ori = F.normalize(image_embed_stack_ori, p=2, dim=1)
vision_feats_temp = F.normalize(vision_feats_temp, p=2, dim=1)
similarity_scores = torch.mm(image_embed_stack_ori, vision_feats_temp.t()).t()
similarity_scores = F.softmax(similarity_scores, dim=1)
sampled_indices = torch.multinomial(similarity_scores, num_samples=B, replacement=True).squeeze(1) # Shape [batch_size, 16]
memory_stack_ori_new = (memory_stack_ori[sampled_indices].squeeze(3).permute(1, 2, 0, 3))
memory = memory_stack_ori_new.reshape(-1, memory_stack_ori_new.size(2), memory_stack_ori_new.size(3))
memory_pos_stack_new = (memory_pos_stack_ori[sampled_indices].squeeze(3).permute(1, 2, 0, 3))
memory_pos = memory_pos_stack_new.reshape(-1, memory_stack_ori_new.size(2), memory_stack_ori_new.size(3))
vision_feats[-1] = net.memory_attention(
curr=[vision_feats[-1]],
curr_pos=[vision_pos_embeds[-1]],
memory=memory,
memory_pos=memory_pos,
num_obj_ptr_tokens=0
)
feats = [feat.permute(1, 2, 0).view(B, -1, *feat_size)
for feat, feat_size in zip(vision_feats[::-1], feat_sizes[::-1])][::-1]
image_embed = feats[-1]
high_res_feats = feats[:-1]
# feats[0]: torch.Size([batch, 32, 256, 256]) #high_res_feats part1
# feats[1]: torch.Size([batch, 64, 128, 128]) #high_res_feats part2
# feats[2]: torch.Size([batch, 256, 64, 64]) #image_embed
'''prompt encoder'''
with torch.no_grad():
if (ind%5) == 0:
points=(coords_torch, labels_torch) # input shape: ((batch, n, 2), (batch, n))
flag = True
else:
points=None
flag = False
se, de = net.sam_prompt_encoder(
points=points, #(coords_torch, labels_torch)
boxes=None,
masks=None,
batch_size=B,
)
# dimension hint for your future use
# se: torch.Size([batch, n+1, 256])
# de: torch.Size([batch, 256, 64, 64])
'''train mask decoder'''
low_res_multimasks, iou_predictions, sam_output_tokens, object_score_logits = net.sam_mask_decoder(
image_embeddings=image_embed,
image_pe=net.sam_prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=se,
dense_prompt_embeddings=de,
multimask_output=False, # args.multimask_output if you want multiple masks
repeat_image=False, # the image is already batched
high_res_features = high_res_feats
)
# dimension hint for your future use
# low_res_multimasks: torch.Size([batch, multimask_output, 256, 256])
# iou_predictions.shape:torch.Size([batch, multimask_output])
# sam_output_tokens.shape:torch.Size([batch, multimask_output, 256])
# object_score_logits.shape:torch.Size([batch, 1])
# resize prediction
pred = F.interpolate(low_res_multimasks,size=(args.out_size,args.out_size))
high_res_multimasks = F.interpolate(low_res_multimasks, size=(args.image_size, args.image_size),
mode="bilinear", align_corners=False)
'''memory encoder'''
# new caluculated memory features
maskmem_features, maskmem_pos_enc = net._encode_new_memory(
current_vision_feats=vision_feats,
feat_sizes=feat_sizes,
pred_masks_high_res=high_res_multimasks,
is_mask_from_pts=flag)
# dimension hint for your future use
# maskmem_features: torch.Size([batch, 64, 64, 64])
# maskmem_pos_enc: [torch.Size([batch, 64, 64, 64])]
maskmem_features = maskmem_features.to(torch.bfloat16)
maskmem_features = maskmem_features.to(device=GPUdevice, non_blocking=True)
maskmem_pos_enc = maskmem_pos_enc[0].to(torch.bfloat16)
maskmem_pos_enc = maskmem_pos_enc.to(device=GPUdevice, non_blocking=True)
# add single maskmem_features, maskmem_pos_enc, iou
if len(memory_bank_list) < args.memory_bank_size:
for batch in range(maskmem_features.size(0)):
memory_bank_list.append([(maskmem_features[batch].unsqueeze(0)).detach(),
(maskmem_pos_enc[batch].unsqueeze(0)).detach(),
iou_predictions[batch, 0],
image_embed[batch].reshape(-1).detach()])
else:
for batch in range(maskmem_features.size(0)):
# current simlarity matrix in existing memory bank
memory_bank_maskmem_features_flatten = [element[0].reshape(-1) for element in memory_bank_list]
memory_bank_maskmem_features_flatten = torch.stack(memory_bank_maskmem_features_flatten)
# normalise
memory_bank_maskmem_features_norm = F.normalize(memory_bank_maskmem_features_flatten, p=2, dim=1)
current_similarity_matrix = torch.mm(memory_bank_maskmem_features_norm,
memory_bank_maskmem_features_norm.t())
# replace diagonal (diagnoal always simiarity = 1)
current_similarity_matrix_no_diag = current_similarity_matrix.clone()
diag_indices = torch.arange(current_similarity_matrix_no_diag.size(0))
current_similarity_matrix_no_diag[diag_indices, diag_indices] = float('-inf')
# first find the minimum similarity from memory feature and the maximum similarity from memory bank
single_key_norm = F.normalize(maskmem_features[batch].reshape(-1), p=2, dim=0).unsqueeze(1)
similarity_scores = torch.mm(memory_bank_maskmem_features_norm, single_key_norm).squeeze()
min_similarity_index = torch.argmin(similarity_scores)
max_similarity_index = torch.argmax(current_similarity_matrix_no_diag[min_similarity_index])
# replace with less similar object
if similarity_scores[min_similarity_index] < current_similarity_matrix_no_diag[min_similarity_index][max_similarity_index]:
# soft iou, not stricly greater than current iou
if iou_predictions[batch, 0] > memory_bank_list[max_similarity_index][2] - 0.1:
memory_bank_list.pop(max_similarity_index)
memory_bank_list.append([(maskmem_features[batch].unsqueeze(0)).detach(),
(maskmem_pos_enc[batch].unsqueeze(0)).detach(),
iou_predictions[batch, 0],
image_embed[batch].reshape(-1).detach()])
# backpropagation
loss = lossfunc(pred, masks)
pbar.set_postfix(**{'loss (batch)': loss.item()})
epoch_loss += loss.item()
loss.backward()
optimizer.step()
optimizer.zero_grad()
pbar.update()
return epoch_loss/len(train_loader)
def validation_sam(args, val_loader, epoch, net: nn.Module, clean_dir=True):
# use bfloat16 for the entire notebook
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
if torch.cuda.get_device_properties(0).major >= 8:
# turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
# eval mode
net.eval()
n_val = len(val_loader)
threshold = (0.1, 0.3, 0.5, 0.7, 0.9)
GPUdevice = torch.device('cuda:' + str(args.gpu_device))
# init
lossfunc = criterion_G
memory_bank_list = []
feat_sizes = [(256, 256), (128, 128), (64, 64)]
total_loss = 0
total_eiou = 0
total_dice = 0
with tqdm(total=n_val, desc='Validation round', unit='batch', leave=False) as pbar:
for ind, pack in enumerate(val_loader):
to_cat_memory = []
to_cat_memory_pos = []
to_cat_image_embed = []
name = pack['image_meta_dict']['filename_or_obj']
imgs = pack['image'].to(dtype = torch.float32, device = GPUdevice)
masks = pack['mask'].to(dtype = torch.float32, device = GPUdevice)
if 'pt' in pack:
pt_temp = pack['pt'].to(device = GPUdevice)
pt = pt_temp.unsqueeze(1)
point_labels_temp = pack['p_label'].to(device = GPUdevice)
point_labels = point_labels_temp.unsqueeze(1)
coords_torch = torch.as_tensor(pt, dtype=torch.float, device=GPUdevice)
labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=GPUdevice)
else:
coords_torch = None
labels_torch = None
'''test'''
with torch.no_grad():
""" image encoder """
backbone_out = net.forward_image(imgs)
_, vision_feats, vision_pos_embeds, _ = net._prepare_backbone_features(backbone_out)
B = vision_feats[-1].size(1)
""" memory condition """
if len(memory_bank_list) == 0:
vision_feats[-1] = vision_feats[-1] + torch.nn.Parameter(torch.zeros(1, B, net.hidden_dim)).to(device="cuda")
vision_pos_embeds[-1] = vision_pos_embeds[-1] + torch.nn.Parameter(torch.zeros(1, B, net.hidden_dim)).to(device="cuda")
else:
for element in memory_bank_list:
maskmem_features = element[0]
maskmem_pos_enc = element[1]
to_cat_memory.append(maskmem_features.cuda(non_blocking=True).flatten(2).permute(2, 0, 1))
to_cat_memory_pos.append(maskmem_pos_enc.cuda(non_blocking=True).flatten(2).permute(2, 0, 1))
to_cat_image_embed.append((element[3]).cuda(non_blocking=True)) # image_embed
memory_stack_ori = torch.stack(to_cat_memory, dim=0)
memory_pos_stack_ori = torch.stack(to_cat_memory_pos, dim=0)
image_embed_stack_ori = torch.stack(to_cat_image_embed, dim=0)
vision_feats_temp = vision_feats[-1].permute(1, 0, 2).reshape(B, -1, 64, 64)
vision_feats_temp = vision_feats_temp.reshape(B, -1)
image_embed_stack_ori = F.normalize(image_embed_stack_ori, p=2, dim=1)
vision_feats_temp = F.normalize(vision_feats_temp, p=2, dim=1)
similarity_scores = torch.mm(image_embed_stack_ori, vision_feats_temp.t()).t()
similarity_scores = F.softmax(similarity_scores, dim=1)
sampled_indices = torch.multinomial(similarity_scores, num_samples=B, replacement=True).squeeze(1) # Shape [batch_size, 16]
memory_stack_ori_new = (memory_stack_ori[sampled_indices].squeeze(3).permute(1, 2, 0, 3))
memory = memory_stack_ori_new.reshape(-1, memory_stack_ori_new.size(2), memory_stack_ori_new.size(3))
memory_pos_stack_new = (memory_pos_stack_ori[sampled_indices].squeeze(3).permute(1, 2, 0, 3))
memory_pos = memory_pos_stack_new.reshape(-1, memory_stack_ori_new.size(2), memory_stack_ori_new.size(3))
vision_feats[-1] = net.memory_attention(
curr=[vision_feats[-1]],
curr_pos=[vision_pos_embeds[-1]],
memory=memory,
memory_pos=memory_pos,
num_obj_ptr_tokens=0
)
feats = [feat.permute(1, 2, 0).view(B, -1, *feat_size)
for feat, feat_size in zip(vision_feats[::-1], feat_sizes[::-1])][::-1]
image_embed = feats[-1]
high_res_feats = feats[:-1]
""" prompt encoder """
if (ind%5) == 0:
flag = True
points = (coords_torch, labels_torch)
else:
flag = False
points = None
se, de = net.sam_prompt_encoder(
points=points,
boxes=None,
masks=None,
batch_size=B,
)
low_res_multimasks, iou_predictions, sam_output_tokens, object_score_logits = net.sam_mask_decoder(
image_embeddings=image_embed,
image_pe=net.sam_prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=se,
dense_prompt_embeddings=de,
multimask_output=False,
repeat_image=False,
high_res_features = high_res_feats
)
# prediction
pred = F.interpolate(low_res_multimasks,size=(args.out_size,args.out_size))
high_res_multimasks = F.interpolate(low_res_multimasks, size=(args.image_size, args.image_size),
mode="bilinear", align_corners=False)
""" memory encoder """
maskmem_features, maskmem_pos_enc = net._encode_new_memory(
current_vision_feats=vision_feats,
feat_sizes=feat_sizes,
pred_masks_high_res=high_res_multimasks,
is_mask_from_pts=flag)
maskmem_features = maskmem_features.to(torch.bfloat16)
maskmem_features = maskmem_features.to(device=GPUdevice, non_blocking=True)
maskmem_pos_enc = maskmem_pos_enc[0].to(torch.bfloat16)
maskmem_pos_enc = maskmem_pos_enc.to(device=GPUdevice, non_blocking=True)
""" memory bank """
if len(memory_bank_list) < 16:
for batch in range(maskmem_features.size(0)):
memory_bank_list.append([(maskmem_features[batch].unsqueeze(0)),
(maskmem_pos_enc[batch].unsqueeze(0)),
iou_predictions[batch, 0],
image_embed[batch].reshape(-1).detach()])
else:
for batch in range(maskmem_features.size(0)):
memory_bank_maskmem_features_flatten = [element[0].reshape(-1) for element in memory_bank_list]
memory_bank_maskmem_features_flatten = torch.stack(memory_bank_maskmem_features_flatten)
memory_bank_maskmem_features_norm = F.normalize(memory_bank_maskmem_features_flatten, p=2, dim=1)
current_similarity_matrix = torch.mm(memory_bank_maskmem_features_norm,
memory_bank_maskmem_features_norm.t())
current_similarity_matrix_no_diag = current_similarity_matrix.clone()
diag_indices = torch.arange(current_similarity_matrix_no_diag.size(0))
current_similarity_matrix_no_diag[diag_indices, diag_indices] = float('-inf')
single_key_norm = F.normalize(maskmem_features[batch].reshape(-1), p=2, dim=0).unsqueeze(1)
similarity_scores = torch.mm(memory_bank_maskmem_features_norm, single_key_norm).squeeze()
min_similarity_index = torch.argmin(similarity_scores)
max_similarity_index = torch.argmax(current_similarity_matrix_no_diag[min_similarity_index])
if similarity_scores[min_similarity_index] < current_similarity_matrix_no_diag[min_similarity_index][max_similarity_index]:
if iou_predictions[batch, 0] > memory_bank_list[max_similarity_index][2] - 0.1:
memory_bank_list.pop(max_similarity_index)
memory_bank_list.append([(maskmem_features[batch].unsqueeze(0)),
(maskmem_pos_enc[batch].unsqueeze(0)),
iou_predictions[batch, 0],
image_embed[batch].reshape(-1).detach()])
# binary mask and calculate loss, iou, dice
total_loss += lossfunc(pred, masks)
pred = (pred> 0.5).float()
temp = eval_seg(pred, masks, threshold)
total_eiou += temp[0]
total_dice += temp[1]
'''vis images'''
if ind % args.vis == 0:
namecat = 'Test'
for na in name:
img_name = na
namecat = namecat + img_name + '+'
vis_image(imgs,pred, masks, os.path.join(args.path_helper['sample_path'], namecat+'epoch+' +str(epoch) + '.jpg'), reverse=False, points=None)
pbar.update()
return total_loss/ n_val , tuple([total_eiou/n_val, total_dice/n_val])
So for colab, can you add this cell right before training command and try again?
i did not add the cell you suggest and run because i have already forked the repo and so can modify my fork. i commented out the line #335 of function.py and added .reshape() like you showed:
# vision_feats_temp = vision_feats[-1].permute(1, 0, 2).view(B, -1, 64, 64)
vision_feats_temp = vision_feats[-1].permute(1, 0, 2).reshape(B, -1, 64, 64)
this fix made both kaggle p100 and colab t4 happy. here are training log outputs from both.
kaggle p100
INFO:root:Total score: 0.73396235704422, IOU: 0.062239742812431074, DICE: 0.09917049956518037 || @ epoch 0.
Total score: 0.73396235704422, IOU: 0.062239742812431074, DICE: 0.09917049956518037 || @ epoch 0.
Epoch 0: 100%|███████████| 200/200 [03:23<00:00, 1.02s/img, loss (batch)=0.119]
INFO:root:Train loss: 0.2960984502360225 || @ epoch 0.
Train loss: 0.2960984502360225 || @ epoch 0.
time_for_training 203.05405259132385
INFO:root:Total score: 0.18758922815322876, IOU: 0.6464821135604341, DICE: 0.7748540666699409 || @ epoch 0.
Total score: 0.18758922815322876, IOU: 0.6464821135604341, DICE: 0.7748540666699409 || @ epoch 0.
colab t4
INFO:root:Total score: 0.6537907123565674, IOU: 0.03657788351116989, DICE: 0.05533080002906395 || @ epoch 0.
Total score: 0.6537907123565674, IOU: 0.03657788351116989, DICE: 0.05533080002906395 || @ epoch 0.
Epoch 0: 100% 200/200 [04:01<00:00, 1.21s/img, loss (batch)=0.105]
INFO:root:Train loss: 0.2971741591766477 || @ epoch 0.
Train loss: 0.2971741591766477 || @ epoch 0.
time_for_training 241.99381804466248
INFO:root:Total score: 0.19345636665821075, IOU: 0.6315837719022945, DICE: 0.7623434653878212 || @ epoch 0.
Total score: 0.19345636665821075, IOU: 0.6315837719022945, DICE: 0.7623434653878212 || @ epoch 0.
thanks!
this is to make the story complete with 3d train. if cuda extension is built on the colab or kaggle, train_3d.py also runs without a problem. use l4 on colab to provide enough gpu memory. t4 and p100 will run out of memory after couple of steps training.
one does build cuda extension like this.
!python setup.py build_ext --inplace
where setup.py is not in the Medical-SAM2 repo, i copied one from the upstream meta segment-anything-2 repo. add this cell right before running train_3d.py. the only minor modification is to change path to _C.so, like "sam2" => "sam2_train" in the two places for srcs and ext_modules lines. this is due to the path difference between Medical-SAM2 repo and segment-anything-2 repo.
%%writefile /content/Medical-SAM2/setup.py
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import os
from setuptools import find_packages, setup
# Package metadata
NAME = "SAM 2"
VERSION = "1.0"
DESCRIPTION = "SAM 2: Segment Anything in Images and Videos"
URL = "https://github.com/facebookresearch/segment-anything-2"
AUTHOR = "Meta AI"
AUTHOR_EMAIL = "[email protected]"
LICENSE = "Apache 2.0"
# Read the contents of README file
with open("README.md", "r", encoding="utf-8") as f:
LONG_DESCRIPTION = f.read()
# Required dependencies
REQUIRED_PACKAGES = [
"torch>=2.3.1",
"torchvision>=0.18.1",
"numpy>=1.24.4",
"tqdm>=4.66.1",
"hydra-core>=1.3.2",
"iopath>=0.1.10",
"pillow>=9.4.0",
]
EXTRA_PACKAGES = {
"demo": ["matplotlib>=3.9.1", "jupyter>=1.0.0", "opencv-python>=4.7.0"],
"dev": ["black==24.2.0", "usort==1.0.2", "ufmt==2.0.0b2"],
}
# By default, we also build the SAM 2 CUDA extension.
# You may turn off CUDA build with `export SAM2_BUILD_CUDA=0`.
BUILD_CUDA = os.getenv("SAM2_BUILD_CUDA", "1") == "1"
# By default, we allow SAM 2 installation to proceed even with build errors.
# You may force stopping on errors with `export SAM2_BUILD_ALLOW_ERRORS=0`.
BUILD_ALLOW_ERRORS = os.getenv("SAM2_BUILD_ALLOW_ERRORS", "1") == "1"
# Catch and skip errors during extension building and print a warning message
# (note that this message only shows up under verbose build mode
# "pip install -v -e ." or "python setup.py build_ext -v")
CUDA_ERROR_MSG = (
"{}\n\n"
"Failed to build the SAM 2 CUDA extension due to the error above. "
"You can still use SAM 2 and it's OK to ignore the error above, although some "
"post-processing functionality may be limited (which doesn't affect the results in most cases; "
"(see https://github.com/facebookresearch/segment-anything-2/blob/main/INSTALL.md).\n"
)
def get_extensions():
if not BUILD_CUDA:
return []
try:
from torch.utils.cpp_extension import CUDAExtension
srcs = ["sam2_train/csrc/connected_components.cu"]
compile_args = {
"cxx": [],
"nvcc": [
"-DCUDA_HAS_FP16=1",
"-D__CUDA_NO_HALF_OPERATORS__",
"-D__CUDA_NO_HALF_CONVERSIONS__",
"-D__CUDA_NO_HALF2_OPERATORS__",
],
}
ext_modules = [CUDAExtension("sam2_train._C", srcs, extra_compile_args=compile_args)]
except Exception as e:
if BUILD_ALLOW_ERRORS:
print(CUDA_ERROR_MSG.format(e))
ext_modules = []
else:
raise e
return ext_modules
try:
from torch.utils.cpp_extension import BuildExtension
class BuildExtensionIgnoreErrors(BuildExtension):
def finalize_options(self):
try:
super().finalize_options()
except Exception as e:
print(CUDA_ERROR_MSG.format(e))
self.extensions = []
def build_extensions(self):
try:
super().build_extensions()
except Exception as e:
print(CUDA_ERROR_MSG.format(e))
self.extensions = []
def get_ext_filename(self, ext_name):
try:
return super().get_ext_filename(ext_name)
except Exception as e:
print(CUDA_ERROR_MSG.format(e))
self.extensions = []
return "_C.so"
cmdclass = {
"build_ext": (
BuildExtensionIgnoreErrors.with_options(no_python_abi_suffix=True)
if BUILD_ALLOW_ERRORS
else BuildExtension.with_options(no_python_abi_suffix=True)
)
}
except Exception as e:
cmdclass = {}
if BUILD_ALLOW_ERRORS:
print(CUDA_ERROR_MSG.format(e))
else:
raise e
# Setup configuration
setup(
name=NAME,
version=VERSION,
description=DESCRIPTION,
long_description=LONG_DESCRIPTION,
long_description_content_type="text/markdown",
url=URL,
author=AUTHOR,
author_email=AUTHOR_EMAIL,
license=LICENSE,
packages=find_packages(exclude="notebooks"),
package_data={"": ["*.yaml"]}, # SAM 2 configuration files
include_package_data=True,
install_requires=REQUIRED_PACKAGES,
extras_require=EXTRA_PACKAGES,
python_requires=">=3.10.0",
ext_modules=get_extensions(),
cmdclass=cmdclass,
)
this is the log output from train_3d.py on colab l4:
Epoch 0: 0% 0/24 [00:00<?, ?img/s]/usr/lib/python3.10/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.
self.pid = os.fork()
Epoch 0: 100% 24/24 [00:40<00:00, 1.68s/img, loss (batch)=0.00301]
INFO:root:Train loss: 0.051431525489487206, 0.007973176463565324, 0.09488987206714228 || @ epoch 0.
Train loss: 0.051431525489487206, 0.007973176463565324, 0.09488987206714228 || @ epoch 0.
time_for_training 40.224876165390015
INFO:root:Total score: 0.16452749073505402, IOU: 0.8724716256719525, DICE: 0.9101958117103393 || @ epoch 0.
Total score: 0.16452749073505402, IOU: 0.8724716256719525, DICE: 0.9101958117103393 || @ epoch 0.
it would be nice if setup.py is included in the Medical-SAM2, and provide an instruction for anyone having conflicting cuda extension issue on their system.
thanks!
Hi @ibinti, thank you so much! I was stuck at this problem, however, couldn't find any appropriate time to focus on this. This helped me! But... I have a problem with training 3d on Colab/Kaggle. I got this error, after resolving a bunch of them:
x = F.scaled_dot_product_attention(
AttributeError: module 'torch.nn.functional' has no attribute 'scaled_dot_product_attention'. Did you mean: '_scaled_dot_product_attention'?
Then I changed in sam2_train/modeling/backbones/hieradet.py the F._scaled_dot_product_attention since I got this error:
x = F._scaled_dot_product_attention(
File "/opt/conda/envs/medsam2/lib/python3.10/site-packages/torch/nn/functional.py", line 4848, in _scaled_dot_product_attention
B, Nt, E = q.shape
ValueError: too many values to unpack (expected 3)
Do you have any idea?
hello @rabiaedayilmaz , i did not use conda, as i've found it less beneficial on platforms like colab and kaggle compared to local machines. instead, i removed conda dependencies and installed the required packages directly using pip. this simplified the setup and made the code more straightforward. if you're interested, i can share a working colab notebook that demonstrates this approach. but i need to make some modification to the dataset path as i used my dataset on kaggle. let me know.
Hi @ibinti , I see. After removing conda, it works properly. Thanks!
hello @rabiaedayilmaz , i did not use conda, as i've found it less beneficial on platforms like colab and kaggle compared to local machines. instead, i removed conda dependencies and installed the required packages directly using pip. this simplified the setup and made the code more straightforward. if you're interested, i can share a working colab notebook that demonstrates this approach. but i need to make some modification to the dataset path as i used my dataset on kaggle. let me know.
I would love to have a notebook capable of 3D training on my own dataset. If you could provide yours to start with, that would be great.
Thank you.