MoleculeSTM
MoleculeSTM copied to clipboard
About mol_to_graph_data_obj_simple functions
Thank you for an interesting repo.
I went through the code and I noticed that you used two different mol_to_graph_data_obj_simple functions for contrastive pre-training and property prediction fine-tuning. pre-training: https://github.com/chao1224/MoleculeSTM/blob/main/MoleculeSTM/datasets/utils.py#L44 fine-tuning: https://github.com/chao1224/MoleculeSTM/blob/main/MoleculeSTM/datasets/MoleculeNet_Graph.py#L17
Could you explain why we have to do that? While you used the same GNN architecture for pre-training and fine-tuning, does using different mol_to_graph_data_obj_simple functions affect the GNN's behavior?
Looking forward to hearing from you soon.
Thanks.
Hi @lhkhiem28,
Thank you for checking this, and please use the OGB version for the featurization.
We tested both versions, as we did in GraphMVP. I merged the wrong version for the previous code release.
So, the function in utils.py is correct? And is that function aligned with the checkpoint here https://huggingface.co/chao1224/MoleculeSTM/tree/main/pretrained_MoleculeSTM/SciBERT-Graph-3e-5-1-1e-4-1-InfoNCE-0.1-32-32
@chao1224 Can you confirm?
def mol_to_graph_data_obj_simple(mol):
""" used in MoleculeNetGraphDataset() class
Converts rdkit mol objects to graph data object in pytorch geometric
NB: Uses simplified atom and bond features, and represent as indices
:param mol: rdkit mol object
:return: graph data object with the attributes: x, edge_index, edge_attr """
# atoms
atom_features_list = []
for atom in mol.GetAtoms():
atom_feature = atom_to_feature_vector(atom)
atom_features_list.append(atom_feature)
x = torch.tensor(np.array(atom_features_list), dtype=torch.long)
# bonds
if len(mol.GetBonds()) <= 0: # mol has no bonds
num_bond_features = 3 # bond type & direction
edge_index = torch.empty((2, 0), dtype=torch.long)
edge_attr = torch.empty((0, num_bond_features), dtype=torch.long)
else: # mol has bonds
edges_list = []
edge_features_list = []
for bond in mol.GetBonds():
i = bond.GetBeginAtomIdx()
j = bond.GetEndAtomIdx()
edge_feature = bond_to_feature_vector(bond)
edges_list.append((i, j))
edge_features_list.append(edge_feature)
edges_list.append((j, i))
edge_features_list.append(edge_feature)
# data.edge_index: Graph connectivity in COO format with shape [2, num_edges]
edge_index = torch.tensor(np.array(edges_list).T, dtype=torch.long)
# data.edge_attr: Edge feature matrix with shape [num_edges, num_edge_features]
edge_attr = torch.tensor(np.array(edge_features_list), dtype=torch.long)
data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr)
return data