torchdrug
torchdrug copied to clipboard
[Bug] AttributeError: 'PackedMolecule' object has no attribute 'node_position'
For property prediction when we replace model to SchNet
in the tutorial and keeping other parameters as it is, we recieves below error.
Note: I used colab version of the tutorial.
AttributeError Traceback (most recent call last)
[<ipython-input-34-8206e318900d>](https://localhost:8080/#) in <module>()
2 solver = core.Engine(task, train_set, valid_set, test_set, optimizer,
3 gpus=[0], batch_size=1024)
----> 4 solver.train(num_epoch=100)
5 solver.evaluate("valid")
8 frames
[/usr/local/lib/python3.7/dist-packages/torchdrug/core/engine.py](https://localhost:8080/#) in train(self, num_epoch, batch_per_epoch)
153 batch = utils.cuda(batch, device=self.device)
154
--> 155 loss, metric = model(batch)
156 if not loss.requires_grad:
157 raise RuntimeError("Loss doesn't require grad. Did you define any loss in the task?")
[/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _call_impl(self, *input, **kwargs)
1128 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1129 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1130 return forward_call(*input, **kwargs)
1131 # Do not call functions when jit is used
1132 full_backward_hooks, non_full_backward_hooks = [], []
[/usr/local/lib/python3.7/dist-packages/torchdrug/tasks/property_prediction.py](https://localhost:8080/#) in forward(self, batch)
72 metric = {}
73
---> 74 pred = self.predict(batch, all_loss, metric)
75
76 if all([t not in batch for t in self.task]):
[/usr/local/lib/python3.7/dist-packages/torchdrug/tasks/property_prediction.py](https://localhost:8080/#) in predict(self, batch, all_loss, metric)
103 def predict(self, batch, all_loss=None, metric=None):
104 graph = batch["graph"]
--> 105 output = self.model(graph, graph.node_feature.float(), all_loss=all_loss, metric=metric)
106 pred = self.linear(output["graph_feature"])
107 return pred
[/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _call_impl(self, *input, **kwargs)
1128 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1129 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1130 return forward_call(*input, **kwargs)
1131 # Do not call functions when jit is used
1132 full_backward_hooks, non_full_backward_hooks = [], []
[/usr/local/lib/python3.7/dist-packages/torchdrug/models/schnet.py](https://localhost:8080/#) in forward(self, graph, input, all_loss, metric)
67
68 for layer in self.layers:
---> 69 hidden = layer(graph, layer_input)
70 if self.short_cut and hidden.shape == layer_input.shape:
71 hidden = hidden + layer_input
[/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _call_impl(self, *input, **kwargs)
1128 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1129 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1130 return forward_call(*input, **kwargs)
1131 # Do not call functions when jit is used
1132 full_backward_hooks, non_full_backward_hooks = [], []
[/usr/local/lib/python3.7/dist-packages/torchdrug/layers/conv.py](https://localhost:8080/#) in forward(self, graph, input)
89 update = checkpoint.checkpoint(self._message_and_aggregate, *graph.to_tensors(), input)
90 else:
---> 91 update = self.message_and_aggregate(graph, input)
92 output = self.combine(input, update)
93 return output
[/usr/local/lib/python3.7/dist-packages/torchdrug/layers/conv.py](https://localhost:8080/#) in message_and_aggregate(self, graph, input)
595 def message_and_aggregate(self, graph, input):
596 node_in, node_out = graph.edge_list.t()[:2]
--> 597 position = graph.node_position
598 rbf_weight = self.rbf_layer(self.rbf(position[node_in], position[node_out]))
599 indices = torch.stack([node_out, node_in, torch.arange(graph.num_edge, device=graph.device)])
AttributeError: 'PackedMolecule' object has no attribute 'node_position'
Hi! This bug is because SchNet
requires the 3D conformations of molecules to perform its convolution. But the datasets.Clintox
dataset doesn't provide node positions for molecules. You can switch to datasets.QM9
and set node_position=True
in the argument instead.