pytorch_geometric
pytorch_geometric copied to clipboard
to_hetero() tries to activate None
🐛 Describe the bug
Converting the following simple model with to_hetero()
:
def __init__(self, config: ELConfig):
super().__init__()
self.config = config
self.conv_layers = 2
self.conv1 = TransformerConv(
256,
256,
heads=4,
dropout=0.6,
edge_dim=4
)
self.conv2 = TransformerConv(
256 * 4,
256,
heads=1,
dropout=0.6,
edge_dim=4
)
def forward(self, x, edge_index, edge_attr):
x = self.conv1(x, edge_index, edge_attr).relu()
x = self.conv2(x, edge_index, edge_attr).relu()
return x
With metadata: (['category', 'vertex'], [('category', 'assign', 'vertex'), ('vertex', 'groups', 'vertex'), ('vertex', 'continues', 'vertex'), ('vertex', 'keyof', 'vertex')])
Yields the following:
def forward(self, x, edge_index, edge_attr):
x__category = x.get('category')
x__vertex = x.get('vertex'); x = None
edge_index__category__assign__vertex = edge_index.get(('category', 'assign', 'vertex'))
edge_index__vertex__groups__vertex = edge_index.get(('vertex', 'groups', 'vertex'))
edge_index__vertex__continues__vertex = edge_index.get(('vertex', 'continues', 'vertex'))
edge_index__vertex__keyof__vertex = edge_index.get(('vertex', 'keyof', 'vertex')); edge_index = None
edge_attr__category__assign__vertex = edge_attr.get(('category', 'assign', 'vertex'))
edge_attr__vertex__groups__vertex = edge_attr.get(('vertex', 'groups', 'vertex'))
edge_attr__vertex__continues__vertex = edge_attr.get(('vertex', 'continues', 'vertex'))
edge_attr__vertex__keyof__vertex = edge_attr.get(('vertex', 'keyof', 'vertex')); edge_attr = None
conv1__vertex1 = self.conv1.category__assign__vertex((x__category, x__vertex), edge_index__category__assign__vertex, edge_attr__category__assign__vertex); x__category = None
conv1__vertex2 = self.conv1.vertex__groups__vertex((x__vertex, x__vertex), edge_index__vertex__groups__vertex, edge_attr__vertex__groups__vertex)
conv1__vertex3 = self.conv1.vertex__continues__vertex((x__vertex, x__vertex), edge_index__vertex__continues__vertex, edge_attr__vertex__continues__vertex)
conv1__vertex4 = self.conv1.vertex__keyof__vertex((x__vertex, x__vertex), edge_index__vertex__keyof__vertex, edge_attr__vertex__keyof__vertex); x__vertex = None
conv1__vertex5 = torch.add(conv1__vertex1, conv1__vertex2); conv1__vertex1 = conv1__vertex2 = None
conv1__vertex6 = torch.add(conv1__vertex3, conv1__vertex4); conv1__vertex3 = conv1__vertex4 = None
conv1__vertex = torch.add(conv1__vertex5, conv1__vertex6); conv1__vertex5 = conv1__vertex6 = None
relu__category = None.relu()
relu__vertex = conv1__vertex.relu(); conv1__vertex = None
conv2__vertex1 = self.conv2.category__assign__vertex((relu__category, relu__vertex), edge_index__category__assign__vertex, edge_attr__category__assign__vertex); relu__category = edge_index__category__assign__vertex = edge_attr__category__assign__vertex = None
conv2__vertex2 = self.conv2.vertex__groups__vertex((relu__vertex, relu__vertex), edge_index__vertex__groups__vertex, edge_attr__vertex__groups__vertex); edge_index__vertex__groups__vertex = edge_attr__vertex__groups__vertex = None
conv2__vertex3 = self.conv2.vertex__continues__vertex((relu__vertex, relu__vertex), edge_index__vertex__continues__vertex, edge_attr__vertex__continues__vertex); edge_index__vertex__continues__vertex = edge_attr__vertex__continues__vertex = None
conv2__vertex4 = self.conv2.vertex__keyof__vertex((relu__vertex, relu__vertex), edge_index__vertex__keyof__vertex, edge_attr__vertex__keyof__vertex); relu__vertex = edge_index__vertex__keyof__vertex = edge_attr__vertex__keyof__vertex = None
conv2__vertex5 = torch.add(conv2__vertex1, conv2__vertex2); conv2__vertex1 = conv2__vertex2 = None
conv2__vertex6 = torch.add(conv2__vertex3, conv2__vertex4); conv2__vertex3 = conv2__vertex4 = None
conv2__vertex = torch.add(conv2__vertex5, conv2__vertex6); conv2__vertex5 = conv2__vertex6 = None
relu_1__category = None.relu()
relu_1__vertex = conv2__vertex.relu(); conv2__vertex = None
return {'category': relu_1__category, 'vertex': relu_1__vertex}
Then relu_category = None.relu() causes it to crash.
Environment
- PyG version: 2.03
- PyTorch version: 1.10
- OS: Win/Linux
- Python version: 3.8
- CUDA/cuDNN version: 11.3 / 8.1
- How you installed PyTorch and PyG (
conda
,pip
, source): pip - Any other relevant information (e.g., version of
torch-scatter
): all
The issue is that there is no edge type pointing to category
. Upgrading to torch-geometric==2.0.4
should at least warn you about this.
Figured as much but thought code like None.relu()
was worth reporting.
Yes, definitely :) I think the current workaround of warning the user is okay, but I agree that in your case it should definitely crash prior to model execution. I need to look into torch.fx
to see if there is some way to check for this.
So every node in the heterogeneous graph should have an edge pointing to it? How should we resolve this problem? I tried to change the graph to Undirected but it didn't work. Should we add reverse edge type to our data by ourselves?
Yes, otherwise certain node types will not get properly updated during message passing. The ToUndirected
transform should take care of that. Let me know if you encounter any issues wit that.