[Enhancement] dgl.batch is slow due to FFI
According to our test, dgl.batch is a bottleneck in graph classification. One main reason is it frequently call number_of_edges and number_of_nodes which further call underlying C++ function via FFI. Unfortunately, FFI has taken up 97% time of these 2 functions. To tackle the problem, we thought out 2 ways:
- Replace FFI with other method.
- Duplicate attributes to python level object.
For option 1, the longer term path will be first replace the internal of DGLGraph with the new SparseMatrix object and then redirect dgl.batch to the corresponding sparse operations. This will make all APIs go through PyTorch native FFI.
We have observed the same bottleneck in batched graph classification with dgl 2.1.0+cu118, especially when running multiple models on the same machine (each model uses a standalone GPU). The CPU overhead is extremely large and GPU has to wait for the Dataloder, which takes over 30% of the total time.
Profiling stats:
Line # Hits Time Per Hit % Time Line Contents
==============================================================
23 @profile
24 def collate(samples):
25 # input samples
26 2462 110399208.0 44841.3 0.0 graphs, graph_label, node_label_list, sgs = map(list, zip(*samples))
27 2462 1e+11 5e+07 49.8 batched_graph = dgl.batch(graphs)
28 2462 1e+11 5e+07 49.6 batched_sg = dgl.batch(sgs)
29 # return the batched data
30 2462 1116662987.0 453559.3 0.5 return batched_graph, torch.tensor(graph_label), torch.cat(node_label_list), batched_sg
In the training loop:
Line # Hits Time Per Hit % Time Line Contents
==============================================================
538 @profile
539 def train(self):
...
545 358 3e+10 9e+07 30.2 for batched_data in self.dataloader:
546 354 9136435.0 25809.1 0.0 batched_graph, _, batched_node_labels, _ = batched_data
547 # put to GPU
548 354 331649634.0 936863.4 0.3 batched_graph = batched_graph.to(self.train_config['device'])
549 354 21899359.0 61862.6 0.0 batched_node_labels = batched_node_labels.to(self.train_config['device'])
550 # train mode
551 354 69193048.0 195460.6 0.1 self.model.train()
552
553 354 1475964364.0 4e+06 1.5 node_logits = self.model(batched_graph)
...
Do we have any update or workaround on this issue?