pytorch_geometric icon indicating copy to clipboard operation
pytorch_geometric copied to clipboard

Cannot convert hetero GNN model into torchscript

Open stephenzwj opened this issue 3 years ago • 11 comments

🐛 Describe the bug

is there a way to use to_hetero(model, metadata(), aggr='sum') and convert the model to torchscript? I tried the following code, it does not work.

import torch
import torch_geometric.transforms as T
from torch_geometric.nn import SAGEConv, to_hetero
import torch.nn.functional as F
from torch import tensor
from ast import literal_eval
metadata = (['pax', 'card', 'bank'],
 [('pax', 'topup', 'card'),
  ('pax', 'topup', 'bank'),
  ('pax', 'link', 'bank'),
  ('card', 'rev_topup', 'pax'),
  ('bank', 'rev_topup', 'pax'),
  ('bank', 'rev_link', 'pax')
  ]
  )
x_dict = {'pax': tensor([[2., 0.]]),
            'card': tensor([[2., 0.]]),
            'bank': tensor([[2., 0.]])}
edge_index_dict = {('pax', 'topup', 'card'): tensor([[],[]],dtype=torch.int64),
                    ('pax', 'topup', 'bank'): tensor([[],[]],dtype=torch.int64),
                    ('pax', 'link', 'bank'): tensor([[],[]],dtype=torch.int64),
                    ('card', 'rev_topup', 'pax'): tensor([[0],
                            [0]]),
                    ('bank', 'rev_topup', 'pax'): tensor([[0],
                            [0]]),
                    ('bank', 'rev_link', 'pax'): tensor([[0],
                            [0]])
                    }
class GNN(torch.nn.Module):
    def __init__(self, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = SAGEConv((-1, -1), hidden_channels).jittable()
        self.conv2 = SAGEConv((-1, -1), out_channels).jittable()
    def forward(self, x, edge_index): 
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        return x
model = GNN(hidden_channels=64, out_channels=2)
model = to_hetero(model, metadata, aggr='sum')
model(x_dict, edge_index_dict)
model = torch.jit.script(model)

The error is

RuntimeError: 
'Tensor (inferred)' object has no attribute or method 'get'.:
  File "<eval_with_key>.3", line 5
def forward(self, x, edge_index):
    x__pax = x.get('pax')
             ~~~~~ <--- HERE
    x__card = x.get('card')
    x__bank = x.get('bank');  x = None

Environment

  • PyG version: 2.0.4
  • PyTorch version: 1.10.0
  • OS: macOS Big Sur
  • Python version: 3.7.11
  • CUDA/cuDNN version: cpu
  • How you installed PyTorch and PyG (conda, pip, source): pip
  • Any other relevant information (e.g., version of torch-scatter): torch-scatter 2.0.9 torch-sparse 0.6.13

stephenzwj avatar Mar 13 '22 03:03 stephenzwj

You are right that heterogeneous GNN models currently do not support TorchScript. This is mostly due to a limitation by TorchScript, in which Tuples can not be used as keys inside dictionaries. Here is a simple example to reproduce:

from typing import Dict, Tuple

import torch
from torch import Tensor


class GNN(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(
        self,
        x_dict: Dict[str, Tensor],
        edge_index_dict: Dict[Tuple[str, str, str], Tensor],
    ) -> Dict[str, Tensor]:
        return x_dict


model = GNN()
model = torch.jit.script(model)
RuntimeError: Cannot create dict for key type '(str, str, str)', only int, float, complex, Tensor, device and string keys are supported

We would need to wait for TorchScript to catch up, or implement some custom logic to represent edge types as single strings rather than tuples.

rusty1s avatar Mar 13 '22 08:03 rusty1s

Hi @rusty1s thank you so much for your clarification. Will it be difficult to create the custom logic to represent edge types as single strings rather than tuples by users? If not, probably I can try to do it from my end

stephenzwj avatar Mar 13 '22 09:03 stephenzwj

I'm not yet sure TBH. It's sounds pretty complex to me. I think we could first start to work on making HeteroConv jittable, which seems easier to do than to_hetero(). Given that one can represent edge types (src_node_type, rel_type, dst_node_type) as a string {src_node_type}__{rel_type}__{dst_node_type}, one could achieve this as follows:

def forward(self, x_dict: Dict[str, Tensor], edge_index_dict: Dict[str, Tensor]) -> Dict[str, Tensor]:
    for edge_type, edge_index in edge_index_dict.items():
        src_node_type, _, dst_node_type = edge_type.split('__')
        out = self.conv[edge_type]((x_dict[src_node_type], x_dict[dst_node_type]), edge_index)

rusty1s avatar Mar 13 '22 09:03 rusty1s

True, HeteroConv layer could be a quick start. Thank yous so much for your hints. Is there any plan to officially support this in pyg?

stephenzwj avatar Mar 13 '22 09:03 stephenzwj

I definitely think so. We definitely want to make some further improvements to our heterogeneous GNN pipeline. Any help is highly appreciated, so let me know if you have any success converting HeteroConv into a jittable instance.

rusty1s avatar Mar 14 '22 06:03 rusty1s

I wrote the following JittableHeteroConv, and it works on my side if using torch.jit.trace , instead of using torch.jit.script. I also managed to deploy it using torchserve.

The code is as follows:

import torch
import torch_geometric.transforms as T
from torch_geometric.datasets import OGB_MAG
from torch_geometric.nn import HeteroConv, GCNConv, SAGEConv, GATConv, Linear
import warnings
from collections import defaultdict
from typing import Dict, Optional
from torch import Tensor
from torch.nn import Module, ModuleDict
from torch_geometric.nn.conv.hgt_conv import group
from torch_geometric.typing import Adj, EdgeType, NodeType
from ast import literal_eval
from torch import tensor
class JittableHeteroConv(HeteroConv):
    def forward(
        self, 
        x_dict: Dict[str, Tensor], 
        edge_index_dict: Dict[str, Tensor],
        *args_dict,
        **kwargs_dict,
        ) -> Dict[str, Tensor]:
        r"""
        Args:
            x_dict (Dict[str, Tensor]): A dictionary holding node feature
                information for each individual node type.
            edge_index_dict (Dict[str, Tensor]): A dictionary
                holding graph connectivity information for each individual
                edge type.
            *args_dict (optional): Additional forward arguments of invididual
                :class:`torch_geometric.nn.conv.MessagePassing` layers.
            **kwargs_dict (optional): Additional forward arguments of
                individual :class:`torch_geometric.nn.conv.MessagePassing`
                layers.
                For example, if a specific GNN layer at edge type
                :obj:`edge_type` expects edge attributes :obj:`edge_attr` as a
                forward argument, then you can pass them to
                :meth:`~torch_geometric.nn.conv.HeteroConv.forward` via
                :obj:`edge_attr_dict = { edge_type: edge_attr }`.
        """
        out_dict = {}
        for edge_type, edge_index in edge_index_dict.items():
            src, rel, dst = edge_type.split('__')
            str_edge_type = '__'.join((src, rel, dst))
            if str_edge_type not in self.convs:
                continue
            args = []
            for value_dict in args_dict:
                if edge_type in value_dict:
                    args.append(value_dict[edge_type])
                elif src == dst and src in value_dict:
                    args.append(value_dict[src])
                elif src in value_dict or dst in value_dict:
                    args.append(
                        (value_dict.get(src, None), value_dict.get(dst, None)))
            kwargs = {}
            for arg, value_dict in kwargs_dict.items():
                arg = arg[:-5]  # `{*}_dict`
                if edge_type in value_dict:
                    kwargs[arg] = value_dict[edge_type]
                elif src == dst and src in value_dict:
                    kwargs[arg] = value_dict[src]
                elif src in value_dict or dst in value_dict:
                    kwargs[arg] = (value_dict.get(src, None),
                                   value_dict.get(dst, None))
            conv = self.convs[str_edge_type]
            if src == dst:
                out = conv(x_dict[src], edge_index,*args, **kwargs)
            else:
                out = conv((x_dict[src], x_dict[dst]), edge_index,*args, **kwargs)
            if dst not in out_dict.keys():
                out_dict[dst] = []
            out_dict[dst].append(out)
        for key, value in out_dict.items():
            out_dict[key] = group(value, self.aggr)
        return out_dict
metadata = (['pax', 'card', 'bank'],
 [('pax', 'topup', 'card'),
  ('pax', 'topup', 'bank'),
  ('pax', 'link', 'bank'),
  ('card', 'rev_topup', 'pax'),
  ('bank', 'rev_topup', 'pax'),
  ('bank', 'rev_link', 'pax')
  ]
  )
x_dict = {'pax': tensor([[2., 0.]]),
            'card': tensor([[2., 0.]]),
            'bank': tensor([[2., 0.]])}
edge_index_dict = {('pax', 'topup', 'card'): tensor([[0],[0]],dtype=torch.int64),
                    ('pax', 'topup', 'bank'): tensor([[0],[0]],dtype=torch.int64),
                    ('pax', 'link', 'bank'): tensor([[0],[0]],dtype=torch.int64),
                    ('card', 'rev_topup', 'pax'): tensor([[0],
                            [0]]),
                    ('bank', 'rev_topup', 'pax'): tensor([[0],
                            [0]]),
                    ('bank', 'rev_link', 'pax'): tensor([[0],
                            [0]])
                    }
str_edge_index_dict = {'__'.join(key): value for key, value in edge_index_dict.items()}
class HeteroGNN(torch.nn.Module):
    def __init__(self, hidden_channels, out_channels, num_layers):
        super().__init__()
        self.convs = torch.nn.ModuleList()
        for _ in range(num_layers):
            conv = JittableHeteroConv({
                ('pax', 'topup', 'card'): SAGEConv(-1, hidden_channels).jittable(),
                ('pax', 'topup', 'bank'): SAGEConv((-1, -1), hidden_channels).jittable(),
                ('pax', 'link', 'bank'): SAGEConv((-1, -1), hidden_channels).jittable(),
                ('card', 'rev_topup', 'pax'): SAGEConv(-1, hidden_channels).jittable(),
                ('bank', 'rev_topup', 'pax'): SAGEConv((-1, -1), hidden_channels).jittable(),
                ('bank', 'rev_link', 'pax'): SAGEConv((-1, -1), hidden_channels).jittable(),
            }, aggr='sum')
            self.convs.append(conv)
    def forward(self, x_dict, edge_index_dict):
        for conv in self.convs:
            x_dict = conv(x_dict, edge_index_dict)
            x_dict = {key: x.relu() for key, x in x_dict.items()}
        return x_dict['pax']
model = HeteroGNN(hidden_channels=64, out_channels=64,
                  num_layers=1)
out = model(x_dict, str_edge_index_dict)
model.eval()
traced_model = torch.jit.trace(model,example_inputs=(x_dict, str_edge_index_dict))
traced_model.save('tmp_models/hgnn_model.pt')

stephenzwj avatar Mar 14 '22 09:03 stephenzwj

Thank you! This looks great. I'm happy that it unblocks you for now. The main question is how we want to integrate that into the library. IMO, HeteroConv should expose its own jittable() functionality, which takes care of converting any PyG GNN layer to a jittable instance. It should also take care of supporting both triplets and single strings as edge type. Do you want to make a first draft on this one?

rusty1s avatar Mar 15 '22 15:03 rusty1s

Thank you for the comment. There are a few quick workarounds

  1. HeteroConv.jittable() can return an instance of JittableHeteroConv.
  2. any py GNN layer can but convert to jittable from in JittableHeteroConv.init()

taking care of both triplets and single string as edge type seems not easy due to the lack of support from torchscript.

stephenzwj avatar Mar 22 '22 11:03 stephenzwj

Yeah, you are right.HeteroConv.jittable() would need to make each sub-GNN jittable, as well as converting edge types to single strings. Let me know if you have interest in contributing this!

rusty1s avatar Mar 23 '22 16:03 rusty1s

Hey @rusty1s - Has there been any progress for making HeteroConv jittable?

rishi2019194 avatar Jun 13 '24 16:06 rishi2019194

It might well be that this issue is already resolved (at least HeteroConv works fine with torch.compile now).

rusty1s avatar Jun 24 '24 11:06 rusty1s