MoleculeSTM icon indicating copy to clipboard operation
MoleculeSTM copied to clipboard

About mol_to_graph_data_obj_simple functions

Open lhkhiem28 opened this issue 10 months ago • 2 comments

Thank you for an interesting repo.

I went through the code and I noticed that you used two different mol_to_graph_data_obj_simple functions for contrastive pre-training and property prediction fine-tuning. pre-training: https://github.com/chao1224/MoleculeSTM/blob/main/MoleculeSTM/datasets/utils.py#L44 fine-tuning: https://github.com/chao1224/MoleculeSTM/blob/main/MoleculeSTM/datasets/MoleculeNet_Graph.py#L17

Could you explain why we have to do that? While you used the same GNN architecture for pre-training and fine-tuning, does using different mol_to_graph_data_obj_simple functions affect the GNN's behavior?

Looking forward to hearing from you soon.

Thanks.

lhkhiem28 avatar Apr 19 '24 03:04 lhkhiem28

Hi @lhkhiem28,

Thank you for checking this, and please use the OGB version for the featurization.

We tested both versions, as we did in GraphMVP. I merged the wrong version for the previous code release.

chao1224 avatar Apr 19 '24 05:04 chao1224

So, the function in utils.py is correct? And is that function aligned with the checkpoint here https://huggingface.co/chao1224/MoleculeSTM/tree/main/pretrained_MoleculeSTM/SciBERT-Graph-3e-5-1-1e-4-1-InfoNCE-0.1-32-32

@chao1224 Can you confirm?

def mol_to_graph_data_obj_simple(mol):
    """ used in MoleculeNetGraphDataset() class
    Converts rdkit mol objects to graph data object in pytorch geometric
    NB: Uses simplified atom and bond features, and represent as indices
    :param mol: rdkit mol object
    :return: graph data object with the attributes: x, edge_index, edge_attr """

    # atoms
    atom_features_list = []
    for atom in mol.GetAtoms():
        atom_feature = atom_to_feature_vector(atom)
        atom_features_list.append(atom_feature)
    x = torch.tensor(np.array(atom_features_list), dtype=torch.long)

    # bonds
    if len(mol.GetBonds()) <= 0:  # mol has no bonds
        num_bond_features = 3  # bond type & direction
        edge_index = torch.empty((2, 0), dtype=torch.long)
        edge_attr = torch.empty((0, num_bond_features), dtype=torch.long)
    else:  # mol has bonds
        edges_list = []
        edge_features_list = []
        for bond in mol.GetBonds():
            i = bond.GetBeginAtomIdx()
            j = bond.GetEndAtomIdx()
            edge_feature = bond_to_feature_vector(bond)

            edges_list.append((i, j))
            edge_features_list.append(edge_feature)
            edges_list.append((j, i))
            edge_features_list.append(edge_feature)

        # data.edge_index: Graph connectivity in COO format with shape [2, num_edges]
        edge_index = torch.tensor(np.array(edges_list).T, dtype=torch.long)

        # data.edge_attr: Edge feature matrix with shape [num_edges, num_edge_features]
        edge_attr = torch.tensor(np.array(edge_features_list), dtype=torch.long)

    data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr)

    return data

lhkhiem28 avatar Apr 19 '24 05:04 lhkhiem28