DeepHypergraph
DeepHypergraph copied to clipboard
torch.topk 函数设置 largest=False 是为了找到最小的距离,即最近的邻居。这种方法可能比转换张量为 numpy 数组并使用 scipy.spatial.cKDTree 更有效,因为它避免了 CPU 和 GPU 之间的数据传输。
@staticmethod
def _e_list_from_feature_kNN(features: torch.Tensor, k: int):
r"""Construct hyperedges from the feature matrix. Each hyperedge in the hypergraph is constructed by the central vertex and its :math:k-1
neighbor vertices.
Args:
``features`` (``torch.Tensor``): The feature matrix.
``k`` (``int``): The number of nearest neighbors.
"""
assert features.ndim == 2, "The feature matrix should be 2-D."
assert (
k <= features.shape[0]
), "The number of nearest neighbors should be less than or equal to the number of vertices."
dist_matrix = torch.cdist(features, features, p=2)
_, nbr_indices = torch.topk(dist_matrix, k, largest=False)
return nbr_indices.tolist()
感谢,我这最近更新一下这一块!