GraphMixer
GraphMixer copied to clipboard
dst节点的邻居采样可能存在问题
作者您好!感谢您的出色工作~
我发现在为root_nodes进行邻居节点采样时,dst节点的采样可能存在问题,以下是问题的概述与修复方式: 在每次采样时,root_nodes会以batch_size * (2+sample_num)的形式送入采样器,后者会对每个root_node返回其历史邻居。但是,看上去,root_nodes的前batch_size个节点一定是该batch中的src,第batch_size到2*batch_size一定是该batch中的dst。 因此,当采样器遍历前batch_size个root_node(src)时,采样器中的指针会根据时间变化,这导致在采样dst节点时,倒数neighbor个历史邻居很可能不满足采样要求,导致无法采集到任何邻居。
我想询问是否应该在dst采样前执行一次sample.reset()?
示例代码如下,位于construct_graph.py中:
def get_mini_batch(sampler, root_nodes, ts, num_hops, extra_neg_samples): # neg_samples is not used
"""
Call function fetch_subgraph()
Return: Subgraph of each node.
"""
all_graphs = []
train_ptr = len(root_nodes) // (extra_neg_samples + 2)
for i, z in enumerate(zip(root_nodes, ts)):
if i == train_ptr:
sampler.reset()
root_node, root_time = z
all_graphs.append(fetch_subgraph(sampler, root_node, root_time, num_hops))
return all_graphs
期待您的回复~ @CongWeilin