Pointnet_Pointnet2_pytorch
Pointnet_Pointnet2_pytorch copied to clipboard
wrong in query_ball_point
https://github.com/yanx27/Pointnet_Pointnet2_pytorch/blob/eb64fe0b4c24055559cea26299cb485dcb43d8dd/models/pointnet2_utils.py#L87
group_idx = torch.arange(N, dtype=torch.long).to(device).view(1, 1, N).repeat([B, S, 1])
sqrdists = square_distance(new_xyz, xyz)
group_idx[sqrdists > radius ** 2] = N
group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample]
The element of group_idx is not the distance of points,so group_idx.sort doesn't make sense.
The code should be :
sort_dis,group_idx=sqrdists.sort(dim=-1)
group_idx[sort_dis > radius ** 2] = N
group_idx=group_idx[:, :, :nsample]
group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, nsample])
mask = group_idx == N
group_idx[mask] = group_first[mask]
you can also just use PyTorch3D's implementation of the ball query https://github.com/yanx27/Pointnet_Pointnet2_pytorch/issues/178#issuecomment-1587086798
I think you are right, the original code does not use knn within epsilon ball when there are more than nsample elements in the epsilon ball.