farthest-point-sampling
farthest-point-sampling copied to clipboard
suggest a new v2 version code for faster sampling
v0 and v1 are slow when sample too much key-points, for example 4800, it cost more than hours, thus I code a new version for faster sampling. the code is below:
import numpy as np
class FPS:
def __init__(self, pcd_xyz, n_samples):
self.n_samples = n_samples
self.pcd_xyz = pcd_xyz
self.n_pts = pcd_xyz.shape[0]
self.dim = pcd_xyz.shape[1]
self.selected_pts = None
self.selected_pts_expanded = np.zeros(shape=(n_samples, 1, self.dim))
self.remaining_pts = np.copy(pcd_xyz)
self.grouping_radius = None
self.dist_pts_to_selected = None # Iteratively updated in step(). Finally re-used in group()
self.labels = None
# Random pick a start
self.start_idx = np.random.randint(low=0, high=self.n_pts - 1)
self.selected_pts_expanded[0] = self.remaining_pts[self.start_idx]
self.n_selected_pts = 1
self.dist_pts_to_selected_min = None
self.res_selected_idx=None
def get_selected_pts(self):
self.selected_pts = np.squeeze(self.selected_pts_expanded, axis=1)
return self.selected_pts
def step(self):
print(self.n_selected_pts)
if self.n_selected_pts == 1:
self.dist_pts_to_selected = self.__distance__(self.remaining_pts, self.selected_pts_expanded[:self.n_selected_pts]).T
self.dist_pts_to_selected_min = np.min(self.dist_pts_to_selected, axis=1, keepdims=True)
self.res_selected_idx = np.argmax(self.dist_pts_to_selected_min)
self.selected_pts_expanded[self.n_selected_pts] = self.remaining_pts[self.res_selected_idx]
self.n_selected_pts += 1
elif self.n_selected_pts < self.n_samples:
self.dist_pts_to_selected = self.__distance__(self.remaining_pts, np.expand_dims(np.expand_dims(self.remaining_pts[self.res_selected_idx],0),0)).T
for i in range(0,self.remaining_pts.shape[0]):
if self.dist_pts_to_selected_min[i]>self.dist_pts_to_selected[i]:
self.dist_pts_to_selected_min[i]=self.dist_pts_to_selected[i]
self.res_selected_idx = np.argmax(self.dist_pts_to_selected_min)
self.selected_pts_expanded[self.n_selected_pts] = self.remaining_pts[self.res_selected_idx]
self.n_selected_pts += 1
else:
print("Got enough number samples")
def fit(self):
for _ in range(1, self.n_samples):
self.step()
print("sampleing no.",_," point")
return self.get_selected_pts()
def group(self, radius):
self.grouping_radius = radius # the grouping radius is not actually used
dists = self.dist_pts_to_selected
# Ignore the "points"-"selected" relations if it's larger than the radius
dists = np.where(dists > radius, dists+1000000*radius, dists)
# Find the relation with the smallest distance.
# NOTE: the smallest distance may still larger than the radius.
self.labels = np.argmin(dists, axis=1)
return self.labels
@staticmethod
def __distance__(a, b):
return np.linalg.norm(a - b, ord=2, axis=2)
Hi @linharrrrrt,
Sorry for the delay and many thanks for the helpful input! Yes, the FPS implementation in this repo was not for efficiency but rather to understand and visualise the FPS. As I'm currently working on other projects in a tight schedule, I'll probably merge it later! Also, It would be much easier for me to identify the difference between v1 and v2 if you could pull a request.
Best, Zirui
Thanks for the work of @linharrrrrt. I've optimized the code using numba JIT compiler. The overall performance improves 30x~60x faster than the pure python.
The group method is removed since the self.dist_pts_to_selected no longer stores the entire distance matrix. The minimum distance update strategy plays in an extremely fast way than v0 and v1.
Here is the code for anyone that requires high performance in CPU. It takes me only 2.68s ± 1.77ms when sampling 5,000 points out of 50,000. FYI, it should take about 152s in the original v2 code mentioned above.
import numpy as np
from numba import float32, int32
from numba.experimental import jitclass
spec = [
("n_samples", int32),
("selected_pts_expanded", float32[:, :, :]),
("selected_pts_idx", int32[:]),
("pcd_xyz", float32[:, :]),
("n_selected_pts", int32),
("dist_pts_to_selected_min", float32[:]),
("res_selected_idx", int32),
]
@jitclass(spec)
class FPS:
def __init__(self, pcd_xyz, n_samples):
assert n_samples >= 1, "n_samples should be >= 1"
self.n_samples = n_samples
n_pts, dim = pcd_xyz.shape
self.pcd_xyz = pcd_xyz.astype(np.float32)
# Random pick a start
start_idx = np.random.randint(low=0, high=n_pts)
self.n_selected_pts = 1
self.dist_pts_to_selected_min = np.empty((n_pts,), dtype=np.float32)
self.res_selected_idx = -1
self.selected_pts_expanded = np.empty((n_samples, 1, dim), dtype=np.float32)
self.selected_pts_expanded[0] = self.pcd_xyz[start_idx]
self.selected_pts_idx = np.empty((n_samples,), dtype=np.int32)
self.selected_pts_idx[0] = start_idx
def step(self):
if self.n_selected_pts == 1:
dist_pts_to_selected = np.sum(
(self.pcd_xyz - self.selected_pts_expanded[:1]) ** 2, axis=2
).T # (n_pts, 1)
# write in numba way
self.dist_pts_to_selected_min = dist_pts_to_selected[:, 0]
self.res_selected_idx = np.argmax(self.dist_pts_to_selected_min)
self.selected_pts_expanded[self.n_selected_pts] = self.pcd_xyz[self.res_selected_idx]
self.selected_pts_idx[self.n_selected_pts] = self.res_selected_idx
self.n_selected_pts += 1
elif self.n_selected_pts < self.n_samples:
dist_pts_to_selected = self.distance(
self.pcd_xyz, self.pcd_xyz[self.res_selected_idx][None, None]
).T # (n_pts, 1)
dist_pts_to_selected = dist_pts_to_selected[:, 0]
self.dist_pts_to_selected_min = np.minimum(
self.dist_pts_to_selected_min, dist_pts_to_selected
)
self.res_selected_idx = np.argmax(self.dist_pts_to_selected_min)
self.selected_pts_expanded[self.n_selected_pts] = self.pcd_xyz[self.res_selected_idx]
self.selected_pts_idx[self.n_selected_pts] = self.res_selected_idx
self.n_selected_pts += 1
else:
pass
def fit(self):
"""
Returns:
selected_pts_idx: (n_samples,), 1d int array of the indices of selected points,
"""
assert (
self.n_samples <= self.pcd_xyz.shape[0]
), "n_samples should be less than the number of points"
for _ in range(1, self.n_samples):
self.step()
return self.selected_pts_idx
def distance(self, a, b):
return np.sum((a - b) ** 2, axis=2)
@leonardodalinky Thanks a lot, I really should merge this at some point...😅