Medical-SAM2 icon indicating copy to clipboard operation
Medical-SAM2 copied to clipboard

Inferencing issues in 3D: Code does not produce output for non-tracked frames

Open ranjanjayapal opened this issue 1 year ago • 1 comments

I wrote a script that is designed to process images from a specified folder and generate outputs based on the model's predictions. However, it appears that only frames being tracked by the model are producing results, while other frames do not generate any output. This behavior suggests that prompts (or initial points) are not being correctly sent to the track_step function for some frames.

Expected Behavior: The script should produce outputs for all frames in the sequence, regardless of whether they are part of the tracked objects or not. It is expected that if a frame contains an object being tracked, the model will generate predictions; otherwise, it should still output something (e.g., an empty mask)

Actual Behavior: Frames being tracked by the model (specified in pt_dict) generate outputs correctly. Frames not specified in pt_dict do not produce any results, leading to an incomplete set of outputs or missing frames.

Script Code:

import os
import cv2
import shutil
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import torch
import cfg
from func_3d.utils import get_network, set_log_dir, create_logger

def read_mask_and_save_as_image(file_path):
    # Read the mask file from npy format
    mask = np.load(file_path)

    print("*" * 50)
    print(mask)
    print(mask.shape)
    print(np.unique(mask))
    
    # Ensure the mask is in uint8 format for displaying with imshow
    if mask.dtype != np.uint8:
        mask = (mask * 255).astype(np.uint8)
    
    # Save the mask as an image file
    # plt.imsave('output_image.png', mask, cmap='gray')

def load_and_run_images(args, img_path, name):
    # Load the pre-trained weights
    GPUdevice = torch.device('cuda', args.gpu_device)
    net = get_network(args, args.net, use_gpu=args.gpu, gpu_device=GPUdevice, distribution = args.distributed)
    weights = torch.load(args.sam_ckpt)
    net.load_state_dict(weights, strict=False)
    prompt = args.prompt
    prompt_freq = args.prompt_freq
    
    # '''Load masks from folder'''
    # data_seg_3d_shape = np.load(mask_path + '/0.npy').shape # shape of one mask (2D)
    # num_frame = len(os.listdir(mask_path)) # number of masks
    # data_seg_3d = np.zeros(data_seg_3d_shape + (num_frame,)) # create en empty 3D array for all masks

    '''Getting ready for inference'''
    newsize = (args.image_size, args.image_size) # resize image shape for inference with the RGB mode
    num_frame = len(os.listdir(img_path)) # number of images
    video_length = int(num_frame) # Quarter of total frames for each video
    starting_frame = 0 # Starting frame for each video
    img_tensor = torch.zeros(video_length, 3, args.image_size, args.image_size) # Create an empty tensor for images
    print("shape of empty img_tensor: ", img_tensor.shape)
    
    point_label_dict = {}
    pt_dict = {}

    for frame_index in range(starting_frame, starting_frame + video_length):
        print("Loading frame: ", os.path.join(img_path, f"{str(frame_index)}.jpg"))
        img = Image.open(os.path.join(img_path, f"{str(frame_index)}.jpg")).convert("RGB") # Open the image file
        img = img.resize(newsize) # Resize the image
        cv2.imwrite(f"data/btcv/resized_images/{name}/{str(frame_index)}.jpg", np.array(img)) # Save the image to check if it is correctly resized
        img = torch.tensor(np.array(img)).permute(2, 0, 1)
        img_tensor[frame_index - starting_frame, :, :, :] = img

        if frame_index == 0:
            diff_obj_pt_dict = {
                1: torch.Tensor([[265, 660]]), # Example point label
            }
            diff_obj_point_label_dict = {
                1: torch.Tensor([1]), # Example point label
            }
            pt_dict[0] = diff_obj_pt_dict
            point_label_dict[0] = diff_obj_point_label_dict
        elif frame_index == 28:
            diff_obj_pt_dict = {
                1: torch.Tensor([[223, 694]]), # Example point label
            }
            diff_obj_point_label_dict = {
                1: torch.Tensor([1]), # Example point label
            }
            pt_dict[28] = diff_obj_pt_dict
            point_label_dict[28] = diff_obj_point_label_dict
    image_meta_dict = {"filename_or_obj": name}
    inp_dict = {
        'image':img_tensor,
        'p_label':point_label_dict,
        'pt':pt_dict,
        'image_meta_dict':image_meta_dict,
    }
    # print("inp dict: ", inp_dict)

    # run the model
    torch.cuda.empty_cache()
    img_tensor = inp_dict["image"]
    point_labels_dict = inp_dict["p_label"]
    pt_dict = inp_dict["pt"]

    img_tensor = img_tensor.squeeze(0)
    img_tensor = img_tensor.to(dtype=torch.float32, device=GPUdevice)

    name = inp_dict["image_meta_dict"]["filename_or_obj"]
    frame_id = list(range(img_tensor.size(0)))
    inference_state = net.val_init_state(imgs_tensor=img_tensor)
    prompt_frame_id = list(range(0, len(frame_id), prompt_freq))
    # prompt_frame_id = list(pt_dict.keys())
    obj_list = [1]  # Assuming there's only one object for now
    with torch.no_grad():
        # First, add the prompt frames to the train state
        for id in prompt_frame_id:
            for ann_obj_id in obj_list:
                try:
                    if prompt == 'click':
                        print("Adding prompt frame id: ", id, " obj id: ", ann_obj_id)
                        # if id in pt_dict and ann_obj_id in pt_dict[id]:
                        points = pt_dict[id][ann_obj_id].to(device=GPUdevice)
                        labels = point_labels_dict[id][ann_obj_id].to(device=GPUdevice)
                        # else:
                        #     points = None
                        #     labels = None
                        _, _, _ = net.add_new_points(
                            inference_state=inference_state,
                            frame_idx=id,
                            obj_id=ann_obj_id,
                            points=points,
                            labels=labels,
                            clear_old_points=False,
                        )
                except KeyError:
                    _, _, _ = net.add_new_mask(
                        inference_state=inference_state,
                        frame_idx=id,
                        obj_id=ann_obj_id,
                        mask = torch.zeros(img_tensor.shape[2:]).to(device=GPUdevice),
                    )
        # Second, propagate the model forward through all frames while updating the prompts per frame
        video_segments = {}  # video_segments contains the per-frame segmentation results
        print("inference state output dict: ", inference_state["output_dict"])
        for out_frame_idx, out_obj_ids, out_mask_logits in net.propagate_in_video(inference_state, start_frame_idx=0):
            video_segments[out_frame_idx] = {
                out_obj_id: out_mask_logits[i]
                for i, out_obj_id in enumerate(out_obj_ids)
            }
        print(f"Processed {name}")
        # print(f"Video Segments: {video_segments}")

        for id in frame_id:
            for ann_obj_id in obj_list:
                pred = video_segments[id][ann_obj_id]
                pred = pred.unsqueeze(0)
                cv2.imwrite(os.path.join("output", name, f"{id}.png"), (pred[0, 0, :, :].cpu().numpy() > 0.5) * 255)

    net.reset_state(inference_state) # reset state on each run


# Example usage
img_folder = "img0003"
args = cfg.parse_args() # Assuming args is defined and contains the pretrain path
# for mask_filename in sorted(os.listdir(os.path.join("data/btcv/Test/image", mask_folder))):
#     read_mask_and_save_as_image(os.path.join("data/btcv/Test/image", mask_folder, mask_filename))
os.makedirs(os.path.join("output", img_folder), exist_ok=True)
shutil.rmtree(os.path.join("data/btcv/resized_images", img_folder), ignore_errors=True)  # Remove the directory if it exists
os.makedirs(os.path.join("data/btcv/resized_images", img_folder), exist_ok=True)
load_and_run_images(args, os.path.join("data/btcv/Test/image", img_folder), img_folder)

ranjanjayapal avatar Oct 02 '24 17:10 ranjanjayapal

hi, I encountered the same problem when trying to save the inference results. I modified the func3d/dataset/btcv.py and it works.

def __getitem__(self, index):
    point_label = 1
    newsize = (self.img_size, self.img_size)

    """Get the images"""
    name = self.name_list[index]
    img_path = os.path.join(self.data_path, self.mode, 'image', name)
    mask_path = os.path.join(self.data_path, self.mode, 'mask', name)
    # print(f"img_path {img_path}")
    # print(f"mask_path {mask_path}")
    
    data_seg_3d_shape = np.load(mask_path + '/0.npy').shape
    num_frame = len(os.listdir(mask_path))
    print(f"mask_path {mask_path} num_frame {num_frame}")
    data_seg_3d = np.zeros(data_seg_3d_shape + (num_frame,))
    for i in range(num_frame):
        data_seg_3d[..., i] = np.load(os.path.join(mask_path, f'{i}.npy'))
    # for i in range(data_seg_3d.shape[-1]):
    #     if np.sum(data_seg_3d[..., i]) > 0:
    #         data_seg_3d = data_seg_3d[..., i:]
    #         break
    starting_frame_nonzero = 0
    # for j in reversed(range(data_seg_3d.shape[-1])):
    #     if np.sum(data_seg_3d[..., j]) > 0:
    #         data_seg_3d = data_seg_3d[..., :j+1]
    #         break
    num_frame = data_seg_3d.shape[-1]
    if self.video_length is None:
        video_length = int(num_frame)
    else:
        video_length = self.video_length
    if num_frame > video_length and self.mode == 'Training':
        starting_frame = np.random.randint(0, num_frame - video_length + 1)
        # starting_frame = 0
    else:
        starting_frame = 0
    img_tensor = torch.zeros(video_length, 3, self.img_size, self.img_size)
    mask_dict = {}
    point_label_dict = {}
    pt_dict = {}
    bbox_dict = {}

    for frame_index in range(starting_frame, starting_frame + video_length):
        img = Image.open(os.path.join(img_path, f'{frame_index + starting_frame_nonzero}.jpg')).convert('RGB')
        mask = data_seg_3d[..., frame_index]
        # mask = np.rot90(mask)
        obj_list = np.unique(mask[mask > 0])
        diff_obj_mask_dict = {}
        if self.prompt == 'bbox':
            diff_obj_bbox_dict = {}
        elif self.prompt == 'click':
            diff_obj_pt_dict = {}
            diff_obj_point_label_dict = {}
        else:
            raise ValueError('Prompt not recognized')
        for obj in obj_list:
            obj_mask = mask == obj
            # if self.transform_msk:
            obj_mask = Image.fromarray(obj_mask)
            obj_mask = obj_mask.resize(newsize)
            obj_mask = torch.tensor(np.array(obj_mask)).unsqueeze(0).int()
                # obj_mask = self.transform_msk(obj_mask).int()
            diff_obj_mask_dict[obj] = obj_mask
            print(f"obj_mask.shape {obj_mask.shape}")
            if self.prompt == 'click':
                diff_obj_point_label_dict[obj], diff_obj_pt_dict[obj] = random_click(np.array(obj_mask.squeeze(0)), point_label, seed=None)
            if self.prompt == 'bbox':
                diff_obj_bbox_dict[obj] = generate_bbox(np.array(obj_mask.squeeze(0)), variation=self.variation, seed=self.seed)
        # if self.transform:
            # state = torch.get_rng_state()
            # img = self.transform(img)
            # torch.set_rng_state(state)
        img = img.resize(newsize)
        img = torch.tensor(np.array(img)).permute(2, 0, 1)

        img_tensor[frame_index - starting_frame, :, :, :] = img
        mask_dict[frame_index - starting_frame] = diff_obj_mask_dict
        if self.prompt == 'bbox':
            bbox_dict[frame_index - starting_frame] = diff_obj_bbox_dict
        elif self.prompt == 'click':
            pt_dict[frame_index - starting_frame] = diff_obj_pt_dict
            point_label_dict[frame_index - starting_frame] = diff_obj_point_label_dict


    image_meta_dict = {'filename_or_obj':name}
    if self.prompt == 'bbox':
        return {
            'image':img_tensor,
            'label': mask_dict,
            'bbox': bbox_dict,
            'image_meta_dict':image_meta_dict,
        }
    elif self.prompt == 'click':
        return {
            'image':img_tensor,
            'label': mask_dict,
            'p_label':point_label_dict,
            'pt':pt_dict,
            'image_meta_dict':image_meta_dict,
        }`

qianxingyucode avatar Apr 14 '25 03:04 qianxingyucode