pyg_autoscale
pyg_autoscale copied to clipboard
SubgraphLoader for heterogeneous graph
Hi,
I am trying to apply pyg_autoscale to heterogeneous graph and have to modify the compute_subgraph method in SubgraphLoader class. I was wondering would you like to elaborate on what offset
, count
are and what is relabel_fn
doing?
My current understanding is that compute_subgraph
is basically taking the sub-graph spanned by n_id. Is this understanding accurate?
Many thanks!
def compute_subgraph(self, batches: List[Tuple[int, Tensor]]) -> SubData:
batch_ids, n_ids = zip(*batches)
n_id = torch.cat(n_ids, dim=0)
batch_id = torch.tensor(batch_ids)
# We collect the in-mini-batch size (`batch_size`), the offset of each
# partition in the mini-batch (`offset`), and the number of nodes in
# each partition (`count`)
batch_size = n_id.numel()
offset = self.ptr[batch_id]
count = self.ptr[batch_id.add_(1)].sub_(offset)
rowptr, col, value = self.data.adj_t.csr()
rowptr, col, value, n_id = relabel_fn(rowptr, col, value, n_id,
self.bipartite)
adj_t = SparseTensor(rowptr=rowptr, col=col, value=value,
sparse_sizes=(rowptr.numel() - 1, n_id.numel()),
is_sorted=True)
data = self.data.__class__(adj_t=adj_t)
for k, v in self.data:
if isinstance(v, Tensor) and v.size(0) == self.data.num_nodes:
data[k] = v.index_select(0, n_id)
return SubData(data, batch_size, n_id, offset, count)
Yes, that is correct. Importantly, batches
denotes a list of contiguous node indices grouped that we want to group into one single mini-batch/subgraph, for example: [[0, 1, 2], [5, 6, 7], [10, 11, 12, 13]]
for which offset
would be [0, 5, 10]
and count
would be [3, 3, 4]
. relabel_fn
then computes the induced subgraph of these chunks of nodes, and relabels their node indices to [0, ..., 9]
.