dgl icon indicating copy to clipboard operation
dgl copied to clipboard

dgl.to_homogeneous does not support DGLBlock MFGs

Open ayushnoori opened this issue 1 year ago • 1 comments

🐛 Bug

I am attempting to learn on a heterogeneous DGLGraph created by dgl.heterograph, with node features. I am using a edge-wise sampler created by dgl.dataloading.NeighborSampler, dgl.dataloading.negative_sampler, and dgl.dataloading.as_edge_prediction_sampler, which I then pass to dgl.dataloading.DataLoader to create a DataLoader iterable. For example, please see below:

# define node samplers
sampler = dgl.dataloading.MultiLayerFullNeighborSampler(3)
neg_sampler = dgl.dataloading.negative_sampler.Uniform(5)

# convert to edge sampler
sampler = dgl.dataloading.as_edge_prediction_sampler(
    sampler,
    exclude = "reverse_types", # exclude reverse edges
    reverse_etypes = reverse_edge_dict, # define reverse edge types
    negative_sampler = neg_sampler)

# define training dataloader
train_dataloader = dgl.dataloading.DataLoader(
    train_graph, train_eids, sampler,
    batch_size = 128, # for example
    shuffle = True,
    drop_last = False,
    num_workers = 0)

Iterating over the DataLoader produces DGLBlock message flow graph (MFG) objects.

for input_nodes, pos_graph, neg_graph, blocks in train_dataloader:

    # training loop code here
    break

Within each batch, I need to convert my heterograph to a homogeneous graph (e.g., to pass to the HGTConv convolutional layer). However, when I attempt to convert the first DGLBlock to a homogeneous graph and concatenate the node features, I receive the below error.

dgl.to_homogeneous(blocks[0], ndata = ['node_index'])
> DGLError: Expect number of features to match number of nodes (len(u)). Got 81810 and 113708 instead.

Discussion with @jermainewang about this error:

DGLBlock is designed to represent a bipartite computation graph. We reuse the heterogeneous graph data structure because one can view it as a graph with two groups of nodes (i.e., source and destination nodes). However, dgl.to_homogeneous will try to merge the source and destination nodes into one set of nodes (which is the definition of to_homogeneous), but that will break the integrity of the data structure as a bipartite computation graph... I will talk to the team to see if we can quickly patch dgl.to_homogeneous to support MFGs too.

Creating this issue to track progress on patching dgl.to_homogeneous to support MFGs.

cc: @jermainewang @mufeili, many thanks for your help!

ayushnoori avatar Sep 06 '23 12:09 ayushnoori

@ayushnoori Check out this implementation. It may not handle all the corner cases but the general idea should apply. cc @czkkkkkk

def block_to_homogeneous(block, source_ndata, source_ndata_name):
    """Convert a DGLBlock to a homogeneous graph.

    This function performs a two-step process to convert a given block into a
    homogeneous graph. It first transforms the block into a heterogeneous graph,
    treating its source nodes as heterogeneous nodes. Then, this heterogeneous
    graph is further converted to a homogeneous graph. The function outputs
    includes the homogeneous graph, the mapping from node type to node type ID,
    node type count offsets, and the number of destination nodes for each node
    type. Notably, the `source_ndata` of the original block is retained and
    stored in the `source_ndata_name` attribute of the homogeneous graph.

    Parameters
    ----------
    block : dgl.Block
        The block to convert.
    source_ndata : dict[str, Tensor]
        The source node features of the block.
    source_ndata_name : str
        The name of the source node features in the homogeneous graph.
    
    Returns
    -------
    DGL.Graph
        The converted homogeneous graph.
    dict[str, int]
        The mapping from node type to node type ID.
    list[int]
        The node type count offsets.
    dict[str, int]
        The number of destination nodes for each node type.
    """
    num_dst_nodes_dict = {}
    num_src_nodes_dict = {}
    for ntype in block.dsttypes:
        num_dst_nodes_dict[ntype] = block.number_of_dst_nodes(ntype)
    for ntype in block.srctypes:
        num_src_nodes_dict[ntype] = block.number_of_src_nodes(ntype)

    hetero_edges = {}
    for srctpye, etype, dsttype in block.canonical_etypes:
        src, dst = block.all_edges(etype=etype, order="eid")
        hetero_edges[(srctpye, etype, dsttype)] = (src, dst)
    hetero_g = dgl.heterograph(
        hetero_edges,
        num_nodes_dict=num_src_nodes_dict,
        idtype=block.idtype,
        device=block.device,
    )
    ntype_to_id = {ntype: hetero_g.get_ntype_id(ntype) for ntype in hetero_g.ntypes}
    hetero_g.ndata[source_ndata_name] = source_ndata
    homo_g, ntype_counts, _ = dgl.to_homogeneous(
        hetero_g, ndata=[source_ndata_name], return_count=True
    )
    ntype_offset = np.insert(np.cumsum(ntype_counts), 0, 0)
    return homo_g, ntype_to_id, ntype_offset, num_dst_nodes_dict

class HGTGNN(nn.Module):
    def __init__(
        self,
        in_size: int,
        hid_size: int,
        num_heads: int,
        num_ntypes: int,
        num_etypes: int,
        num_layers: int,
    ):
        super().__init__()
        self.in_size = in_size
        self.out_size = hid_size
        self.conv_layers = nn.ModuleList()
        self.conv_layers.append(
            HGTConv(
                in_size,
                hid_size // num_heads,
                num_heads,
                num_ntypes,
                num_etypes,
            )
        )
        for _ in range(num_layers - 1):
            self.conv_layers.append(
                HGTConv(
                    hid_size,
                    hid_size // num_heads,
                    num_heads,
                    num_ntypes,
                    num_etypes,
                )
            )

    def forward(
        self,
        mfgs,
        X_node_dict: Dict[str, torch.Tensor]
    ) -> Dict[str, torch.Tensor]:
        assert len(mfgs) == len(self.conv_layers)
        for ntype, X in X_node_dict.items():
            assert X.shape[1] == self.in_size
        H_node_dict = X_node_dict
        for i, conv in enumerate(self.conv_layers):
            # There are several steps to conduct HGT on a DGLBlock.
            # 1. Convert the DGLBlock to a homogeneous graph. The
            # homogeneous graph is based on the source nodes of the DGLBlock.
            # 2. Run HGTConv on the homogeneous graph to compute homogeneous
            # features.
            # 3. Extract the heterogeneous features of the destination nodes
            # from the homogeneous features.
            homo_ndata_name = "x"
            (
                homo_g,
                ntype_to_id,
                ntype_offset,
                num_dst_nodes_dict,
            ) = block_to_homogeneous(mfgs[i], H_node_dict, homo_ndata_name)
            # Run hgtconv on the homogeneous graph.
            homo_features = conv(
                homo_g,
                homo_g.ndata[homo_ndata_name],
                homo_g.ndata[dgl.NTYPE],
                homo_g.edata[dgl.ETYPE],
            )
            # Convert the output features back to a dict.
            dst_features = {}
            for ntype in mfgs[i].dsttypes:
                ntype_id = ntype_to_id[ntype]
                feature = homo_features[
                    ntype_offset[ntype_id] : ntype_offset[ntype_id + 1]
                ]
                dst_features[ntype] = feature[: num_dst_nodes_dict[ntype]]
            H_node_dict = dst_features
        return H_node_dict

jermainewang avatar Jan 10 '24 04:01 jermainewang

Hi @ayushnoori as DGL Graphbolt release, I am closing this as it, please try our new dataloading package DGL Graphbolt, let us know if GB does not solve your problem.

frozenbugs avatar Apr 17 '24 02:04 frozenbugs

Thanks, @frozenbugs. Is there a guide available to facilitate transitioning over to GraphBolt from the older DGL version – and if so, could you please point me to it?

ayushnoori avatar Apr 17 '24 02:04 ayushnoori

This is the official guide: https://docs.dgl.ai/en/latest/stochastic_training/index.html

And this is the minibatch training user guide: https://docs.dgl.ai/en/latest/guide/minibatch.html

mfbalin avatar Apr 17 '24 02:04 mfbalin