heterograph.set_batch_num_edges could run automatically if batch_num_nodes is set
Since, in a graph batch(GB), there are no edges between individual graphs, once batch_num_nodes is set for GB, automatically calculate batch_num_edges by calling set_bath_num_edges with no arguments, or something like set_bath_num_edges(auto=True)
Motivation
I needed to make subgraphs of a GB, and I have to maintain batch info consistency. I figured out that, after calculating batch_num_nodes for the subgraph, the corresponding edges are all the edges that have source and destination in the same group in batch_num_nodes.
Example: graph g is a new subgraph of a GB that has batch_num_nodes = [100, 100, 100]. To get the node ids for each individual graph in g, we perform cumulative sum (CS), such as CS = [100, 200, 300]. Nodes with indices < 100 are in the first graph, 100 <= indices < 200 are in the second, and 200 <= indices < 300 are in the third.
with these indices in hand, and the certainty that are no edges between nodes of different graphs in a batch, one can simply look at the source and dest of edges to determine to which batch they belong. So batch_num_edges comes for free.
I believe this feature is a good QOL improvement as it removes one source of user error when calculating batch_num_edges by hand.
Code that I'm using
bnn = sg.batch_num_nodes()
e_tail = torch.cumsum(bnn, dim=0) - 1
e_head = torch.cat([torch.tensor([0]).to(e_tail.device), e_tail[:-1] + 1])
source, dest = sg.edges()
source = source.unsqueeze(1).tile((1, len(e_tail)))
dest = dest.unsqueeze(1).tile((1, len(e_tail)))
mask = (source >= e_head) & (source <= e_tail) & (dest >= e_head) & (dest <= e_tail)
bne = torch.count_nonzero(mask, dim=0)
sg.set_batch_num_edges(bne)