DeepHypergraph icon indicating copy to clipboard operation
DeepHypergraph copied to clipboard

torch.topk 函数设置 largest=False 是为了找到最小的距离,即最近的邻居。这种方法可能比转换张量为 numpy 数组并使用 scipy.spatial.cKDTree 更有效,因为它避免了 CPU 和 GPU 之间的数据传输。

Open yuanyz0825 opened this issue 1 year ago • 1 comments

@staticmethod def _e_list_from_feature_kNN(features: torch.Tensor, k: int): r"""Construct hyperedges from the feature matrix. Each hyperedge in the hypergraph is constructed by the central vertex and its :math:k-1 neighbor vertices.

Args:
    ``features`` (``torch.Tensor``): The feature matrix.
    ``k`` (``int``): The number of nearest neighbors.
"""
assert features.ndim == 2, "The feature matrix should be 2-D."
assert (
    k <= features.shape[0]
), "The number of nearest neighbors should be less than or equal to the number of vertices."

dist_matrix = torch.cdist(features, features, p=2)
_, nbr_indices = torch.topk(dist_matrix, k, largest=False)

return nbr_indices.tolist()

yuanyz0825 avatar Jul 18 '23 14:07 yuanyz0825

感谢,我这最近更新一下这一块!

yifanfeng97 avatar Jul 28 '23 06:07 yifanfeng97