Add `CMPNN` model
Paper Summary
- Introduces a new model, CMPNN, for Communicative Representation Learning on Attributed Molecular Graphs
- CMPNN improves molecular embedding by:
- Following the edge-based message passing in DMPNN
- Introducing node-edge message communication modules such as Inner Product Kernel, Gated Graph Kernel, and Multilayer Perception
- Updating both bond and atom embeddings during training
- Including a message booster to enrich the message generation process
Motivation
- Adding an alternative to DMPNN (Chemprop uses this GNN architecture) that improves DMPNN by adding:
- Node-edge message communication modules
- Message booster
- Updating bond embeddings
- Adding a GNN model that updates edge embeddings during training
Benchmark Results
I have benchmarked CMPNN on a subset of datasets from the TDC ADMET Benchmark Group. I have used LitGNN to perform this benchmark and results can be found in a W&B report.
| Task | Classification (AUROC ↑) | Regression (MAE ↓) | |||
|---|---|---|---|---|---|
| Dataset | BBB_Martins | AMES | Solubility_AqSolDB | Lipophilicity_AstraZeneca | LD50_Zhu |
| AttentiveFP | 0.855 ± 0.011 | 0.814 ± 0.008 | 0.776 ± 0.008 | 0.572 ± 0.007 | 0.678 ± 0.012 |
| Chemprop | 0.821 ± 0.112 | 0.842 ± 0.014 | 0.829 ± 0.022 | 0.470 ± 0.009 | 0.606 ± 0.024 |
| Chemprop-RDKit* | 0.869 ± 0.027 | 0.850 ± 0.004 | 0.761 ± 0.025 | 0.467 ± 0.006 | 0.625 ± 0.022 |
| CMPNN | 0.89 ± 0.016 CMPNN-GRU |
0.843 ± 0.009 CMPNN-MLP |
0.796 ± 0.038 CMPNN-GRU |
0.515 ± 0.008 CMPNN-MLP |
0.631 ± 0.021 CMPNN-Additive |
Table: Prediction results of CMPNN on five chemical graph datasets. The datasets were used from the TDC ADMET Benchmark group that provides train_val/test scaffold splits. The model was trained and tested for each task for five times, and reported the mean and standard deviation of AUROC or MAE values. *Chemprop-RDKIT utilizes a hybrid approach where it combines the learned molecule embeddings with 200 global molecule features (descriptors).
Implementation Details
[!Note] Here is a fork of the original code with some cleanups, addition of poetry for dependency management etc.
Below are the places where improvements have been made:
- Node-edge message communication modules
- In the paper, the authors mention different communicators such as Inner Product Kernel, Gated Graph Kernel and Multilayer Perception. However, within their code, they don't use such different communicators. Instead, they only use an additive communicator (not mentioned in the paper; I came up with the name 'additive').
- I have implemented 4 communicators proposed in the paper and their code:
- Additive
- Inner product
- GRU
- MLP
- The user can choose different communicator according to their dataset.
- These communication modules are applied during the convolution layers.
- Final communication module
- In the paper, the authors mention applying the same communication module to the message from incoming bonds [m(v)], current atom's representation [hK(v)] and atom's initial representation [x(v)].
- In the section 3.3 of the paper where they describe the different node-edge message communication modules, all the communication modules operate on m(v) and hK-1(v).
- In their code, they use MLP communication module.
- I have kept this part the same and hardcoded the MLP communication module.
Checklist
[!Note] For
CMPNNPyG implementation, I have usedAttentiveFPas a template.
- [x] Add
torch_geometric/nn/models/cmpnn.py - [x] Add
CMPNNtotorch_geometric/nn/models/__init__.py - [x] Add
test/nn/models/test_cmpnn.py - [x] Add
message_boostermode to thetorch_geometric/nn/aggr/multi.py:MultiAggregationclass - [x] Add a test for
message_boostermode intest/nn/aggr/test_multi.py - [x] Add
examples/cmpnn.py - [x] Add CMPNN to the 'Implemented GNN Models' section of the README
- [x] Update CHANGELOG
Thank you! Please let me know if any changes are required.
Thank you for your efforts in implementing CMPNN. Based on my tests, this pull request correctly implements the key modules described in the original paper. One question I have is whether it would be more appropriate to separate BatchGRU from CMPNN, considering that BatchGRU serves the role of a readout function.
Thank you @AzureLeon1 for taking the time to review the PR. Regarding the BatchGRU, I have few thoughts -
- If you look at the forward method of the AttentiveFP, it utilizes a PyTorch
GRUCellwhileCMPNNimplementsBatchGRUusing the PyTorchGRU. - Another option would be to move it under
nn.pool.
I feel BatchGRU is specific to CMPNN, but if you (or the maintainers of PyG) see a utility of having it in a separate file, I'm happy to make the necessary changes.