jaxchem
jaxchem copied to clipboard
SparseGCN performance issue
SparseGCN has a serious performance issue.
Training time/epoch of the Tox21 example is almost 30 times than PadGCN.
Result on my local PC (CPU)
Log about SparseGCN
$ python gcn_sparse_pattern_example.py
Iter 0/50 (241.6028 s) valid loss: 0.1522 valid roc_auc score: 0.6647
Iter 1/50 (208.9176 s) valid loss: 0.1649 valid roc_auc score: 0.6955
Iter 2/50 (225.3157 s) valid loss: 0.1516 valid roc_auc score: 0.7013
Log about PadGCN
$ python gcn_pad_pattern_example.py
Iter 0/50 (18.0109 s) valid loss: 0.1648 valid roc_auc score: 0.5425
Iter 1/50 (7.1491 s) valid loss: 0.1645 valid roc_auc score: 0.5315
Iter 2/50 (6.5680 s) valid loss: 0.1512 valid roc_auc score: 0.5597
The reason of this performance issue is related to google/jax#2242.
The SparseGCN uses jax.ops.index_add, but a large Python "for" loop leads to a serious performance issue when involving jax.ops.index_add.
According to issue comments, I have to rewrite training loop using lax.scan or lax.fori_loop in order to resolve this issue. I think if the training loop is rewritten using lax.scan or lax.fori_loop, it will improve the performance about not only SparseGCN but also PadGCN. Therefore, it is really important to resolve this issue.
However, lax.scan or lax.fori_loop were effected by functional programing style and it is difficult to treat them. So, it is difficult to rewrite training loop and I'm struggling this issue. I explain what is blocking my work.
1. lax.scan or lax.fori_loop don't accept a side effect
DeepChem's DiskDataset provides the iterbatches. We can use this method to write training loop like below.
for epoch in range(num_epochs):
for batch in train_dataset.iterbatches(batch_size=batch_size):
params, predict = forward(batch, params)
But, lax.scan or lax.fori_loop don't accept a side effect (like iterator/generator). So, I try to implement like below, but it didin't work. I made the issue related to this topic, please confirm https://github.com/google/jax/issues/3567
train_iterator = train_dataset.iterbatches(batch_size=batch_size)
def run_epoch(init_params):
def body_fun(idx, params):
# this iterator doesn't work... batch value is always same in a loop
batch = next(train_iterator)
params, predict = forward(batch, params)
return params
return lax.fori_loop(0, train_num_batches, body_fun, init_params)
for epoch in range(num_epochs):
params = run_epoch(params)
2. All values in the body_fun of lax.scan or lax.fori_loop don't accept changing the shape
All values in lax.scan or lax.fori_loop, like return value, arguments and so on, don't accept changing the shape. (See the documentation) This is a hard limitation of lax.scan or lax.fori_loop. (To be honest, there is also some additional limitation.... like https://github.com/google/jax/issues/2962 )
One of the pain points is that it is difficult to treat accumulation operations (like adding a value to the list each loop). I explained some example!
# NG
val = []
for i in range(10):
val.append(i)
# OK
val = np.zeros(10)
for i in range(10):
val[i] = i
This point may be a problem if the number of metrics which we want to collect is increasing. Sometimes, we need a creative implementation.(See : https://github.com/google/jax/issues/1708)
Another pain point is that the sparse pattern mini-batch is incompatible with this limitation.
In the case of the sparse pattern modeling, mini-batch data is changing a shape each batch like below.
This is the example of PyTorch Geometric. (x is a node feature)
>>> from torch_geometric.datasets import TUDataset
>>> from torch_geometric.data import DataLoader
>>> dataset = TUDataset(root='/tmp/ENZYMES', name='ENZYMES', use_node_attr=True)
>>> loader = DataLoader(dataset, batch_size=32, shuffle=True)
>>> for batch in loader:
... batch
...
Batch(batch=[1137], edge_index=[2, 4368], x=[1137, 21], y=[32])
Batch(batch=[1144], edge_index=[2, 4408], x=[1144, 21], y=[32])
Batch(batch=[1191], edge_index=[2, 4444], x=[1191, 21], y=[32])
Batch(batch=[1087], edge_index=[2, 4288], x=[1087, 21], y=[32])
Batch(batch=[644], edge_index=[2, 2540], x=[644, 21], y=[24])
The sparse pattern modeling constructs one big graph each batch by unifying all graphs, so each mini-batch data has a different shape.
# batch_size = 3
graph 1 : node_feat (5, 100) edge_idx (2, 4)
graph 2 : node_feat (9, 100) edge_idx (2, 6)
graph 3 : node_feat (5, 100) edge_idx (2, 7)
-> mini-batch graph : node(19, 100) edge_idx(2, 17)
This is a serious problem about implementing the sparse pattern model. Now, I'm thinking how to resolve this shape issue. The one solution is padding mini-batch graph data like below.
mini-batch graph : node(19, 100) edge_idx(2, 17) -> node(19, 100) edge_idx(2, 17)
mini-batch graph : node(15, 100) edge_idx(2, 13) -> node(19, 100) edge_idx(2, 17)
mini-batch graph : node(12, 100) edge_idx(2, 11) -> node(19, 100) edge_idx(2, 17)
However, we should care about padding values because the values have possibilities to affect the node aggregation algorithm of the sparse pattern.
3. It is difficult to debug the body_fun in lax.scan or lax.fori_loop
It is difficult to debug the body_fun like adding print function in lax.scan or lax.fori_loop. This point is also discussed in this issue https://github.com/google/jax/issues/999, but the issue is still open...
train_iterator = train_dataset.iterbatches(batch_size=batch_size)
def run_epoch(init_params):
def body_fun(idx, params):
# this iterator doesn't work... batch value is always same in a loop
batch = next(train_iterator)
params, predict = forward(batch, params)
# any values were printed....
print(predict)
return params
return lax.fori_loop(0, train_num_batches, body_fun, init_params)
for epoch in range(num_epochs):
params = run_epoch(params)