Inferencing issues in 3D: Code does not produce output for non-tracked frames
I wrote a script that is designed to process images from a specified folder and generate outputs based on the model's predictions. However, it appears that only frames being tracked by the model are producing results, while other frames do not generate any output. This behavior suggests that prompts (or initial points) are not being correctly sent to the track_step function for some frames.
Expected Behavior: The script should produce outputs for all frames in the sequence, regardless of whether they are part of the tracked objects or not. It is expected that if a frame contains an object being tracked, the model will generate predictions; otherwise, it should still output something (e.g., an empty mask)
Actual Behavior: Frames being tracked by the model (specified in pt_dict) generate outputs correctly. Frames not specified in pt_dict do not produce any results, leading to an incomplete set of outputs or missing frames.
Script Code:
import os
import cv2
import shutil
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import torch
import cfg
from func_3d.utils import get_network, set_log_dir, create_logger
def read_mask_and_save_as_image(file_path):
# Read the mask file from npy format
mask = np.load(file_path)
print("*" * 50)
print(mask)
print(mask.shape)
print(np.unique(mask))
# Ensure the mask is in uint8 format for displaying with imshow
if mask.dtype != np.uint8:
mask = (mask * 255).astype(np.uint8)
# Save the mask as an image file
# plt.imsave('output_image.png', mask, cmap='gray')
def load_and_run_images(args, img_path, name):
# Load the pre-trained weights
GPUdevice = torch.device('cuda', args.gpu_device)
net = get_network(args, args.net, use_gpu=args.gpu, gpu_device=GPUdevice, distribution = args.distributed)
weights = torch.load(args.sam_ckpt)
net.load_state_dict(weights, strict=False)
prompt = args.prompt
prompt_freq = args.prompt_freq
# '''Load masks from folder'''
# data_seg_3d_shape = np.load(mask_path + '/0.npy').shape # shape of one mask (2D)
# num_frame = len(os.listdir(mask_path)) # number of masks
# data_seg_3d = np.zeros(data_seg_3d_shape + (num_frame,)) # create en empty 3D array for all masks
'''Getting ready for inference'''
newsize = (args.image_size, args.image_size) # resize image shape for inference with the RGB mode
num_frame = len(os.listdir(img_path)) # number of images
video_length = int(num_frame) # Quarter of total frames for each video
starting_frame = 0 # Starting frame for each video
img_tensor = torch.zeros(video_length, 3, args.image_size, args.image_size) # Create an empty tensor for images
print("shape of empty img_tensor: ", img_tensor.shape)
point_label_dict = {}
pt_dict = {}
for frame_index in range(starting_frame, starting_frame + video_length):
print("Loading frame: ", os.path.join(img_path, f"{str(frame_index)}.jpg"))
img = Image.open(os.path.join(img_path, f"{str(frame_index)}.jpg")).convert("RGB") # Open the image file
img = img.resize(newsize) # Resize the image
cv2.imwrite(f"data/btcv/resized_images/{name}/{str(frame_index)}.jpg", np.array(img)) # Save the image to check if it is correctly resized
img = torch.tensor(np.array(img)).permute(2, 0, 1)
img_tensor[frame_index - starting_frame, :, :, :] = img
if frame_index == 0:
diff_obj_pt_dict = {
1: torch.Tensor([[265, 660]]), # Example point label
}
diff_obj_point_label_dict = {
1: torch.Tensor([1]), # Example point label
}
pt_dict[0] = diff_obj_pt_dict
point_label_dict[0] = diff_obj_point_label_dict
elif frame_index == 28:
diff_obj_pt_dict = {
1: torch.Tensor([[223, 694]]), # Example point label
}
diff_obj_point_label_dict = {
1: torch.Tensor([1]), # Example point label
}
pt_dict[28] = diff_obj_pt_dict
point_label_dict[28] = diff_obj_point_label_dict
image_meta_dict = {"filename_or_obj": name}
inp_dict = {
'image':img_tensor,
'p_label':point_label_dict,
'pt':pt_dict,
'image_meta_dict':image_meta_dict,
}
# print("inp dict: ", inp_dict)
# run the model
torch.cuda.empty_cache()
img_tensor = inp_dict["image"]
point_labels_dict = inp_dict["p_label"]
pt_dict = inp_dict["pt"]
img_tensor = img_tensor.squeeze(0)
img_tensor = img_tensor.to(dtype=torch.float32, device=GPUdevice)
name = inp_dict["image_meta_dict"]["filename_or_obj"]
frame_id = list(range(img_tensor.size(0)))
inference_state = net.val_init_state(imgs_tensor=img_tensor)
prompt_frame_id = list(range(0, len(frame_id), prompt_freq))
# prompt_frame_id = list(pt_dict.keys())
obj_list = [1] # Assuming there's only one object for now
with torch.no_grad():
# First, add the prompt frames to the train state
for id in prompt_frame_id:
for ann_obj_id in obj_list:
try:
if prompt == 'click':
print("Adding prompt frame id: ", id, " obj id: ", ann_obj_id)
# if id in pt_dict and ann_obj_id in pt_dict[id]:
points = pt_dict[id][ann_obj_id].to(device=GPUdevice)
labels = point_labels_dict[id][ann_obj_id].to(device=GPUdevice)
# else:
# points = None
# labels = None
_, _, _ = net.add_new_points(
inference_state=inference_state,
frame_idx=id,
obj_id=ann_obj_id,
points=points,
labels=labels,
clear_old_points=False,
)
except KeyError:
_, _, _ = net.add_new_mask(
inference_state=inference_state,
frame_idx=id,
obj_id=ann_obj_id,
mask = torch.zeros(img_tensor.shape[2:]).to(device=GPUdevice),
)
# Second, propagate the model forward through all frames while updating the prompts per frame
video_segments = {} # video_segments contains the per-frame segmentation results
print("inference state output dict: ", inference_state["output_dict"])
for out_frame_idx, out_obj_ids, out_mask_logits in net.propagate_in_video(inference_state, start_frame_idx=0):
video_segments[out_frame_idx] = {
out_obj_id: out_mask_logits[i]
for i, out_obj_id in enumerate(out_obj_ids)
}
print(f"Processed {name}")
# print(f"Video Segments: {video_segments}")
for id in frame_id:
for ann_obj_id in obj_list:
pred = video_segments[id][ann_obj_id]
pred = pred.unsqueeze(0)
cv2.imwrite(os.path.join("output", name, f"{id}.png"), (pred[0, 0, :, :].cpu().numpy() > 0.5) * 255)
net.reset_state(inference_state) # reset state on each run
# Example usage
img_folder = "img0003"
args = cfg.parse_args() # Assuming args is defined and contains the pretrain path
# for mask_filename in sorted(os.listdir(os.path.join("data/btcv/Test/image", mask_folder))):
# read_mask_and_save_as_image(os.path.join("data/btcv/Test/image", mask_folder, mask_filename))
os.makedirs(os.path.join("output", img_folder), exist_ok=True)
shutil.rmtree(os.path.join("data/btcv/resized_images", img_folder), ignore_errors=True) # Remove the directory if it exists
os.makedirs(os.path.join("data/btcv/resized_images", img_folder), exist_ok=True)
load_and_run_images(args, os.path.join("data/btcv/Test/image", img_folder), img_folder)
hi, I encountered the same problem when trying to save the inference results. I modified the func3d/dataset/btcv.py and it works.
def __getitem__(self, index):
point_label = 1
newsize = (self.img_size, self.img_size)
"""Get the images"""
name = self.name_list[index]
img_path = os.path.join(self.data_path, self.mode, 'image', name)
mask_path = os.path.join(self.data_path, self.mode, 'mask', name)
# print(f"img_path {img_path}")
# print(f"mask_path {mask_path}")
data_seg_3d_shape = np.load(mask_path + '/0.npy').shape
num_frame = len(os.listdir(mask_path))
print(f"mask_path {mask_path} num_frame {num_frame}")
data_seg_3d = np.zeros(data_seg_3d_shape + (num_frame,))
for i in range(num_frame):
data_seg_3d[..., i] = np.load(os.path.join(mask_path, f'{i}.npy'))
# for i in range(data_seg_3d.shape[-1]):
# if np.sum(data_seg_3d[..., i]) > 0:
# data_seg_3d = data_seg_3d[..., i:]
# break
starting_frame_nonzero = 0
# for j in reversed(range(data_seg_3d.shape[-1])):
# if np.sum(data_seg_3d[..., j]) > 0:
# data_seg_3d = data_seg_3d[..., :j+1]
# break
num_frame = data_seg_3d.shape[-1]
if self.video_length is None:
video_length = int(num_frame)
else:
video_length = self.video_length
if num_frame > video_length and self.mode == 'Training':
starting_frame = np.random.randint(0, num_frame - video_length + 1)
# starting_frame = 0
else:
starting_frame = 0
img_tensor = torch.zeros(video_length, 3, self.img_size, self.img_size)
mask_dict = {}
point_label_dict = {}
pt_dict = {}
bbox_dict = {}
for frame_index in range(starting_frame, starting_frame + video_length):
img = Image.open(os.path.join(img_path, f'{frame_index + starting_frame_nonzero}.jpg')).convert('RGB')
mask = data_seg_3d[..., frame_index]
# mask = np.rot90(mask)
obj_list = np.unique(mask[mask > 0])
diff_obj_mask_dict = {}
if self.prompt == 'bbox':
diff_obj_bbox_dict = {}
elif self.prompt == 'click':
diff_obj_pt_dict = {}
diff_obj_point_label_dict = {}
else:
raise ValueError('Prompt not recognized')
for obj in obj_list:
obj_mask = mask == obj
# if self.transform_msk:
obj_mask = Image.fromarray(obj_mask)
obj_mask = obj_mask.resize(newsize)
obj_mask = torch.tensor(np.array(obj_mask)).unsqueeze(0).int()
# obj_mask = self.transform_msk(obj_mask).int()
diff_obj_mask_dict[obj] = obj_mask
print(f"obj_mask.shape {obj_mask.shape}")
if self.prompt == 'click':
diff_obj_point_label_dict[obj], diff_obj_pt_dict[obj] = random_click(np.array(obj_mask.squeeze(0)), point_label, seed=None)
if self.prompt == 'bbox':
diff_obj_bbox_dict[obj] = generate_bbox(np.array(obj_mask.squeeze(0)), variation=self.variation, seed=self.seed)
# if self.transform:
# state = torch.get_rng_state()
# img = self.transform(img)
# torch.set_rng_state(state)
img = img.resize(newsize)
img = torch.tensor(np.array(img)).permute(2, 0, 1)
img_tensor[frame_index - starting_frame, :, :, :] = img
mask_dict[frame_index - starting_frame] = diff_obj_mask_dict
if self.prompt == 'bbox':
bbox_dict[frame_index - starting_frame] = diff_obj_bbox_dict
elif self.prompt == 'click':
pt_dict[frame_index - starting_frame] = diff_obj_pt_dict
point_label_dict[frame_index - starting_frame] = diff_obj_point_label_dict
image_meta_dict = {'filename_or_obj':name}
if self.prompt == 'bbox':
return {
'image':img_tensor,
'label': mask_dict,
'bbox': bbox_dict,
'image_meta_dict':image_meta_dict,
}
elif self.prompt == 'click':
return {
'image':img_tensor,
'label': mask_dict,
'p_label':point_label_dict,
'pt':pt_dict,
'image_meta_dict':image_meta_dict,
}`