pytorch_sparse icon indicating copy to clipboard operation
pytorch_sparse copied to clipboard

It seems SparseTensor consume lot of GPU memory

Open realCrush opened this issue 3 years ago • 3 comments
trafficstars

I'm developing a graph pooling model like DiffPool, I tried to use SparseTensor for adjacent matrix adj_t and assignment matrix s (both are {0,1} matrix), and do graph pooling for nodes h (they are mini-batched in diagonal concat fashion):

h = s.t() @ h # shape: [num_communities, num_hidden]
adj_t = s.t() @ adj_t @ s # shape: [num_communities, num_communities], SparseTensor

however it seem when doing matrix manipulation for SparseTensors, it consumes too much GPU memory, for example, it could consume ~34G memory for COLLAB dataset(128 hidden dimension, 128 batch size), so I was wonder if I can do manipulation for SparseTensors in a more GPU memory-efficient way?

realCrush avatar Jul 26 '22 08:07 realCrush

The memory scales with the number of non-zero entries. DiffPool will ultimately create a dense output adjacency matrix. It also looks like you are not treating batch dimensions correctly here: the resulting adjacency matrix looks like to contain links between different graphs.

rusty1s avatar Jul 26 '22 08:07 rusty1s

The memory scales with the number of non-zero entries. DiffPool will ultimately create a dense output adjacency matrix. It also looks like you are not treating batch dimensions correctly here: the resulting adjacency matrix looks like to contain links between different graphs.

Thank you for your quick reply! Actually my model will predict a sparse assignment matrix for nodes in a batched-graph, and the adj_t in my code is mini-batched diagonally. As my algrithm is permutation-invariant, the disconnected part in a batched graph will not connect after assignment. I 've check the code you provided in PyG for DiffPool: https://github.com/pyg-team/pytorch_geometric/blob/master/examples/proteins_diff_pool.py , however it batching graphs in a [num_batch, max_nodes, num_feature] fashion and rely on DenseConv layers for message passing, but my model rely on GAT, which is not supported with DenseConv by PyG

I want to ask if my pooling operations for mini-batched graph is correct or I have to do it in a [num_batch, max_nodes, num_feature] fashion? Another question is which mini-batching way is more memory efficenet (3-dimension or diagonal concat)?

realCrush avatar Jul 26 '22 08:07 realCrush

Yeah, that‘s what I meant. Although your graph uses diagonal stacking, the multiplication will still yield an adjacency matrix where nodes in two different examples can be connected. Sadly, sparse matrix does not support three sparse dimensions (in which case you could do the reshape), so I cannot think of an efficient way to calculate this without a dense adjacency representation:(

rusty1s avatar Jul 27 '22 06:07 rusty1s

This issue had no activity for 6 months. It will be closed in 2 weeks unless there is some new activity. Is this issue already resolved?

github-actions[bot] avatar Jan 24 '23 01:01 github-actions[bot]