dgl icon indicating copy to clipboard operation
dgl copied to clipboard

[Enhancement] dgl.batch is slow due to FFI

Open peizhou001 opened this issue 2 years ago • 2 comments

According to our test, dgl.batch is a bottleneck in graph classification. One main reason is it frequently call number_of_edges and number_of_nodes which further call underlying C++ function via FFI. Unfortunately, FFI has taken up 97% time of these 2 functions. To tackle the problem, we thought out 2 ways:

  1. Replace FFI with other method.
  2. Duplicate attributes to python level object.

peizhou001 avatar Mar 20 '23 06:03 peizhou001

For option 1, the longer term path will be first replace the internal of DGLGraph with the new SparseMatrix object and then redirect dgl.batch to the corresponding sparse operations. This will make all APIs go through PyTorch native FFI.

jermainewang avatar Mar 21 '23 13:03 jermainewang

We have observed the same bottleneck in batched graph classification with dgl 2.1.0+cu118, especially when running multiple models on the same machine (each model uses a standalone GPU). The CPU overhead is extremely large and GPU has to wait for the Dataloder, which takes over 30% of the total time.

Profiling stats:

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
    23                                           @profile
    24                                           def collate(samples):
    25                                               # input samples
    26      2462  110399208.0  44841.3      0.0      graphs, graph_label, node_label_list, sgs = map(list, zip(*samples))
    27      2462        1e+11    5e+07     49.8      batched_graph = dgl.batch(graphs)
    28      2462        1e+11    5e+07     49.6      batched_sg = dgl.batch(sgs)
    29                                               # return the batched data
    30      2462 1116662987.0 453559.3      0.5      return batched_graph, torch.tensor(graph_label),  torch.cat(node_label_list), batched_sg

In the training loop:

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
   538                                               @profile
   539                                               def train(self):
...
   545       358        3e+10    9e+07     30.2              for batched_data in self.dataloader:
   546       354    9136435.0  25809.1      0.0                  batched_graph, _, batched_node_labels, _ = batched_data
   547                                                           # put to GPU
   548       354  331649634.0 936863.4      0.3                  batched_graph = batched_graph.to(self.train_config['device'])
   549       354   21899359.0  61862.6      0.0                  batched_node_labels = batched_node_labels.to(self.train_config['device'])
   550                                                           # train mode
   551       354   69193048.0 195460.6      0.1                  self.model.train()
   552                                           
   553       354 1475964364.0    4e+06      1.5                  node_logits = self.model(batched_graph)
...

Do we have any update or workaround on this issue?

itewqq avatar May 10 '24 10:05 itewqq