Add graph batching support to `dgl.knn_graph()`
🚀 Feature
Adding graph batching support to dgl.knn_graph()
Motivation
Being able to provide a batch of coordinates/features to knn_graph() would allow users to develop much more efficient graph construction pipelines in their projects.
Hi, if your scenario is to speedup the data pipeline for preparing batched graph from coordinates/features, then GraphDataLoader is designed for that. The design follows PyTorch's DataLoader. You can define your own Dataset with the __getitem__ function returns one graph constructed from the coordinates. The dataloader will automatically batch the graphs and the pipeline can be further accelerated by multi-processing (using the num_workers argument).
Hi, @jermainewang. Thanks for your quick response.
This sounds like a good solution for use cases where we are only constructing our input DGLGraphs once, before the forward pass of our model. However, in the use case I have in mind, I actually need to reconstruct my DGLGraph batches multiple times during a single forward pass of my model. That's why I am looking to find a convenient way of constructing knn_graphs using a batch of node coordinates (e.g., where node coordinates is a Tensor of size [B, N, 3]).
Do you know of any ways to accomplish batched construction of knn_graphs in DGL without the use of a GraphDataLoader?
Is N very large in your case? If not, perhaps you could try to compute a pair-wise distance tensor (of shape [B, N, N]) using tensor operators and then use topk to find the indices of the closest neighbors of each node in the batch. You an then use the indices to construct batched graph.
cc the authors of knn_graph @hetong007 @lygztq to see if there are other good ideas.
Hi, if your input is a 3D tensor (e.g. [B, N, 3]), I think you can get a batched graph by knn_graphs. You can find an example at the end of the doc here
If your input is not padded, say you have a set of input ([N1, 3], [N2, 3], ...), you can also use the dgl.segmented_knn_graph.
Hi, @jermainewang and @lygztq. Thank you very much for your ideas here. To the best of my knowledge, the knn_graph() and segmented_knn_graph() functions both return combined graphs without batch metadata (i.e., without batch_num_nodes and batch_num_edges). This would be problematic for the DGL use case I have in mind. Do you know of any DGL-streamlined approaches to batching individual subgraphs quickly and efficiently?
@amorehead To workaround the issue, you can use set_batch_num_nodes and set_batch_num_edges to override those information manually. There is an ongoing PR to fix this issue.