pytorch_geometric icon indicating copy to clipboard operation
pytorch_geometric copied to clipboard

to_hetero() tries to activate None

Open fierval opened this issue 2 years ago • 5 comments

🐛 Describe the bug

Converting the following simple model with to_hetero():

    def __init__(self, config: ELConfig):
        super().__init__()

        self.config = config
        self.conv_layers = 2

        self.conv1 = TransformerConv(
            256,
            256,
            heads=4,
            dropout=0.6,
            edge_dim=4
        )

        self.conv2 = TransformerConv(
            256 * 4,
            256,
            heads=1,
            dropout=0.6,
            edge_dim=4
        )

    def forward(self, x, edge_index, edge_attr):

        x = self.conv1(x, edge_index, edge_attr).relu()
        x = self.conv2(x, edge_index, edge_attr).relu()
        
        return x

With metadata: (['category', 'vertex'], [('category', 'assign', 'vertex'), ('vertex', 'groups', 'vertex'), ('vertex', 'continues', 'vertex'), ('vertex', 'keyof', 'vertex')])

Yields the following:

def forward(self, x, edge_index, edge_attr):
    x__category = x.get('category')
    x__vertex = x.get('vertex');  x = None
    edge_index__category__assign__vertex = edge_index.get(('category', 'assign', 'vertex'))
    edge_index__vertex__groups__vertex = edge_index.get(('vertex', 'groups', 'vertex'))
    edge_index__vertex__continues__vertex = edge_index.get(('vertex', 'continues', 'vertex'))
    edge_index__vertex__keyof__vertex = edge_index.get(('vertex', 'keyof', 'vertex'));  edge_index = None
    edge_attr__category__assign__vertex = edge_attr.get(('category', 'assign', 'vertex'))
    edge_attr__vertex__groups__vertex = edge_attr.get(('vertex', 'groups', 'vertex'))
    edge_attr__vertex__continues__vertex = edge_attr.get(('vertex', 'continues', 'vertex'))
    edge_attr__vertex__keyof__vertex = edge_attr.get(('vertex', 'keyof', 'vertex'));  edge_attr = None
    conv1__vertex1 = self.conv1.category__assign__vertex((x__category, x__vertex), edge_index__category__assign__vertex, edge_attr__category__assign__vertex);  x__category = None
    conv1__vertex2 = self.conv1.vertex__groups__vertex((x__vertex, x__vertex), edge_index__vertex__groups__vertex, edge_attr__vertex__groups__vertex)
    conv1__vertex3 = self.conv1.vertex__continues__vertex((x__vertex, x__vertex), edge_index__vertex__continues__vertex, edge_attr__vertex__continues__vertex)
    conv1__vertex4 = self.conv1.vertex__keyof__vertex((x__vertex, x__vertex), edge_index__vertex__keyof__vertex, edge_attr__vertex__keyof__vertex);  x__vertex = None
    conv1__vertex5 = torch.add(conv1__vertex1, conv1__vertex2);  conv1__vertex1 = conv1__vertex2 = None
    conv1__vertex6 = torch.add(conv1__vertex3, conv1__vertex4);  conv1__vertex3 = conv1__vertex4 = None
    conv1__vertex = torch.add(conv1__vertex5, conv1__vertex6);  conv1__vertex5 = conv1__vertex6 = None
    relu__category = None.relu()
    relu__vertex = conv1__vertex.relu();  conv1__vertex = None
    conv2__vertex1 = self.conv2.category__assign__vertex((relu__category, relu__vertex), edge_index__category__assign__vertex, edge_attr__category__assign__vertex);  relu__category = edge_index__category__assign__vertex = edge_attr__category__assign__vertex = None
    conv2__vertex2 = self.conv2.vertex__groups__vertex((relu__vertex, relu__vertex), edge_index__vertex__groups__vertex, edge_attr__vertex__groups__vertex);  edge_index__vertex__groups__vertex = edge_attr__vertex__groups__vertex = None
    conv2__vertex3 = self.conv2.vertex__continues__vertex((relu__vertex, relu__vertex), edge_index__vertex__continues__vertex, edge_attr__vertex__continues__vertex);  edge_index__vertex__continues__vertex = edge_attr__vertex__continues__vertex = None
    conv2__vertex4 = self.conv2.vertex__keyof__vertex((relu__vertex, relu__vertex), edge_index__vertex__keyof__vertex, edge_attr__vertex__keyof__vertex);  relu__vertex = edge_index__vertex__keyof__vertex = edge_attr__vertex__keyof__vertex = None
    conv2__vertex5 = torch.add(conv2__vertex1, conv2__vertex2);  conv2__vertex1 = conv2__vertex2 = None
    conv2__vertex6 = torch.add(conv2__vertex3, conv2__vertex4);  conv2__vertex3 = conv2__vertex4 = None
    conv2__vertex = torch.add(conv2__vertex5, conv2__vertex6);  conv2__vertex5 = conv2__vertex6 = None
    relu_1__category = None.relu()
    relu_1__vertex = conv2__vertex.relu();  conv2__vertex = None
    return {'category': relu_1__category, 'vertex': relu_1__vertex}

Then relu_category = None.relu() causes it to crash.

Environment

  • PyG version: 2.03
  • PyTorch version: 1.10
  • OS: Win/Linux
  • Python version: 3.8
  • CUDA/cuDNN version: 11.3 / 8.1
  • How you installed PyTorch and PyG (conda, pip, source): pip
  • Any other relevant information (e.g., version of torch-scatter): all

fierval avatar Apr 03 '22 05:04 fierval

The issue is that there is no edge type pointing to category. Upgrading to torch-geometric==2.0.4 should at least warn you about this.

rusty1s avatar Apr 03 '22 14:04 rusty1s

Figured as much but thought code like None.relu() was worth reporting.

fierval avatar Apr 03 '22 15:04 fierval

Yes, definitely :) I think the current workaround of warning the user is okay, but I agree that in your case it should definitely crash prior to model execution. I need to look into torch.fx to see if there is some way to check for this.

rusty1s avatar Apr 04 '22 08:04 rusty1s

So every node in the heterogeneous graph should have an edge pointing to it? How should we resolve this problem? I tried to change the graph to Undirected but it didn't work. Should we add reverse edge type to our data by ourselves?

stayones avatar Jun 27 '22 09:06 stayones

Yes, otherwise certain node types will not get properly updated during message passing. The ToUndirected transform should take care of that. Let me know if you encounter any issues wit that.

rusty1s avatar Jun 28 '22 09:06 rusty1s