Pointnet_Pointnet2_pytorch
Pointnet_Pointnet2_pytorch copied to clipboard
Semantic Segmentation Inference Take Too Long
I decided to use semantic segmentation for a SLAM project. I wrote a segmentation code that takes xyz rgb inputs, from which I could call the already provided trained models. Unfortunately I didn't understand how to use the provided test code and had to improvise a bit. The problem I'm having right now is that the segmentation is extremely slow. As you can see, I am working with CUDA GPU and even if I choose 0.5 as Voxel Size and downsample, the segmentation takes at least 50 seconds. Does anyone have any suggestions? If you have a code that can do inference with xyz rgb data, I would be very grateful if you can share it. I am calling this script from main script that can't share but it is not the problem.
import torch
import numpy as np
import importlib
import os
import sys
import open3d as o3d
import argparse
import logging
from datetime import datetime
import cv2
# Add the directory containing your model to the Python path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
NUM_CLASSES = 13
BATCH_SIZE = 32
NUM_POINT = 4096 # Adjusted to match the original code
# Define color map
COLOR_MAP = {
'ceiling': [0, 255, 0],
'floor': [0, 0, 255],
'wall': [0, 255, 255],
'beam': [255, 255, 0],
'column': [255, 0, 255],
'window': [100, 100, 255],
'door': [200, 200, 100],
'table': [170, 120, 200],
'chair': [255, 0, 0],
'sofa': [200, 100, 100],
'bookcase': [10, 200, 100],
'board': [200, 200, 200],
'clutter': [50, 50, 50]
}
LABEL_TO_NAMES = {i: name for i, name in enumerate(COLOR_MAP.keys())}
# Set up CUDA
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
def split_point_cloud_into_blocks(xyz, rgb, num_point, block_size=1.0, stride=0.5):
"""
Split the point cloud into blocks.
"""
# Determine the min and max coordinates
coord_min = torch.min(xyz, dim=0)[0]
coord_max = torch.max(xyz, dim=0)[0]
# Calculate the number of blocks in each dimension
grid_x = int(torch.ceil((coord_max[0] - coord_min[0] - block_size) / stride)) + 1
grid_y = int(torch.ceil((coord_max[1] - coord_min[1] - block_size) / stride)) + 1
blocks = []
for idx_y in range(grid_y):
for idx_x in range(grid_x):
s_x = coord_min[0] + idx_x * stride
s_y = coord_min[1] + idx_y * stride
e_x = s_x + block_size
e_y = s_y + block_size
# Find points within this block
block_mask = (xyz[:, 0] >= s_x) & (xyz[:, 0] <= e_x) & \
(xyz[:, 1] >= s_y) & (xyz[:, 1] <= e_y)
block_point_indices = torch.where(block_mask)[0]
if block_point_indices.size(0) == 0:
continue
block_xyz = xyz[block_point_indices]
block_rgb = rgb[block_point_indices]
# Adjust coordinates within the block
block_xyz_centered = block_xyz - torch.tensor([s_x + block_size / 2.0, s_y + block_size / 2.0, 0], device=device)
# Normalize XYZ coordinates
block_xyz_normalized = block_xyz / coord_max
# Normalize RGB
block_rgb_normalized = block_rgb / 255.0
# Combine features
block_points = torch.cat((block_xyz_centered, block_rgb_normalized, block_xyz_normalized), dim=1)
# Pad or sample to num_point
if block_points.size(0) >= num_point:
# Randomly sample num_point points
idx = torch.randperm(block_points.size(0))[:num_point]
block_points = block_points[idx]
block_point_indices = block_point_indices[idx]
else:
# Pad with duplicated points
idx = torch.randint(block_points.size(0), (num_point - block_points.size(0),))
block_points = torch.cat((block_points, block_points[idx]), dim=0)
block_point_indices = torch.cat((block_point_indices, block_point_indices[idx]), dim=0)
blocks.append((block_points, block_point_indices))
return blocks
def segment_point_cloud_with_voting(classifier, xyz, rgb, num_point, num_votes=3, block_size=1.0, stride=0.5):
"""
Perform segmentation with voting.
"""
num_classes = NUM_CLASSES
num_points = xyz.size(0)
vote_label_pool = torch.zeros((num_points, num_classes), dtype=torch.float32, device=device)
for vote in range(num_votes):
# Introduce randomness in point cloud for each vote
xyz_vote = xyz + torch.rand_like(xyz) * 0.02 - 0.01
# Split the point cloud into blocks
blocks = split_point_cloud_into_blocks(xyz_vote, rgb, num_point, block_size, stride)
for block_points, block_point_indices in blocks:
# Prepare data for the model
batch_points_tensor = block_points.unsqueeze(0).transpose(2, 1)
# Perform segmentation
with torch.no_grad():
seg_pred, _ = classifier(batch_points_tensor)
batch_pred_label = seg_pred.argmax(dim=2)
# Update vote label pool
one_hot = torch.nn.functional.one_hot(batch_pred_label[0], num_classes=num_classes).float()
vote_label_pool.index_add_(0, block_point_indices, one_hot)
# Get final predicted labels
final_pred_labels = vote_label_pool.argmax(dim=1)
return final_pred_labels
def segment_frame(classifier, xyz, rgb, num_point, num_votes=3, block_size=1.0, stride=0.5):
"""
Perform segmentation on a single frame of point cloud data.
Args:
classifier: The trained classifier model
xyz: Nx3 array of point coordinates
rgb: Nx3 array of point colors (0-255)
num_point: Number of points per sample
num_votes: Number of votes for segmentation
block_size: Block size for splitting point cloud
stride: Stride for splitting point cloud
Returns:
segmented_points: Nx6 array of segmented points (XYZ + RGB)
labels: Nx1 array of predicted labels
"""
# Convert inputs to PyTorch tensors on the appropriate device
xyz = torch.tensor(xyz, dtype=torch.float32, device=device)
rgb = torch.tensor(rgb, dtype=torch.float32, device=device)
# Perform segmentation
labels = segment_point_cloud_with_voting(classifier, xyz, rgb, num_point, num_votes, block_size, stride)
# Apply color map to segmented points
segmented_colors = torch.tensor([COLOR_MAP[LABEL_TO_NAMES[label.item()]] for label in labels], device=device)
# Combine XYZ coordinates with segmented colors
segmented_points = torch.cat((xyz, segmented_colors), dim=1)
return segmented_points.cpu().numpy(), labels.cpu().numpy()
# Load the classifier (this can be done once and the classifier can be passed to segment_frame)
def load_classifier(model_path, model_name):
MODEL = importlib.import_module(model_name)
classifier = MODEL.get_model(NUM_CLASSES)
state_dict = torch.load(model_path, map_location=device)['model_state_dict']
new_state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
classifier.load_state_dict(new_state_dict)
classifier = classifier.to(device)
classifier.eval()
return classifier # Remove JIT compilation