dgl
dgl copied to clipboard
dgl.to_homogeneous does not support DGLBlock MFGs
🐛 Bug
I am attempting to learn on a heterogeneous DGLGraph
created by dgl.heterograph
, with node features. I am using a edge-wise sampler created by dgl.dataloading.NeighborSampler
, dgl.dataloading.negative_sampler
, and dgl.dataloading.as_edge_prediction_sampler
, which I then pass to dgl.dataloading.DataLoader
to create a DataLoader iterable. For example, please see below:
# define node samplers
sampler = dgl.dataloading.MultiLayerFullNeighborSampler(3)
neg_sampler = dgl.dataloading.negative_sampler.Uniform(5)
# convert to edge sampler
sampler = dgl.dataloading.as_edge_prediction_sampler(
sampler,
exclude = "reverse_types", # exclude reverse edges
reverse_etypes = reverse_edge_dict, # define reverse edge types
negative_sampler = neg_sampler)
# define training dataloader
train_dataloader = dgl.dataloading.DataLoader(
train_graph, train_eids, sampler,
batch_size = 128, # for example
shuffle = True,
drop_last = False,
num_workers = 0)
Iterating over the DataLoader produces DGLBlock
message flow graph (MFG) objects.
for input_nodes, pos_graph, neg_graph, blocks in train_dataloader:
# training loop code here
break
Within each batch, I need to convert my heterograph to a homogeneous graph (e.g., to pass to the HGTConv
convolutional layer). However, when I attempt to convert the first DGLBlock
to a homogeneous graph and concatenate the node features, I receive the below error.
dgl.to_homogeneous(blocks[0], ndata = ['node_index'])
> DGLError: Expect number of features to match number of nodes (len(u)). Got 81810 and 113708 instead.
Discussion with @jermainewang about this error:
DGLBlock
is designed to represent a bipartite computation graph. We reuse the heterogeneous graph data structure because one can view it as a graph with two groups of nodes (i.e., source and destination nodes). However,dgl.to_homogeneous
will try to merge the source and destination nodes into one set of nodes (which is the definition ofto_homogeneous
), but that will break the integrity of the data structure as a bipartite computation graph... I will talk to the team to see if we can quickly patchdgl.to_homogeneous
to support MFGs too.
Creating this issue to track progress on patching dgl.to_homogeneous
to support MFGs.
cc: @jermainewang @mufeili, many thanks for your help!
@ayushnoori Check out this implementation. It may not handle all the corner cases but the general idea should apply. cc @czkkkkkk
def block_to_homogeneous(block, source_ndata, source_ndata_name):
"""Convert a DGLBlock to a homogeneous graph.
This function performs a two-step process to convert a given block into a
homogeneous graph. It first transforms the block into a heterogeneous graph,
treating its source nodes as heterogeneous nodes. Then, this heterogeneous
graph is further converted to a homogeneous graph. The function outputs
includes the homogeneous graph, the mapping from node type to node type ID,
node type count offsets, and the number of destination nodes for each node
type. Notably, the `source_ndata` of the original block is retained and
stored in the `source_ndata_name` attribute of the homogeneous graph.
Parameters
----------
block : dgl.Block
The block to convert.
source_ndata : dict[str, Tensor]
The source node features of the block.
source_ndata_name : str
The name of the source node features in the homogeneous graph.
Returns
-------
DGL.Graph
The converted homogeneous graph.
dict[str, int]
The mapping from node type to node type ID.
list[int]
The node type count offsets.
dict[str, int]
The number of destination nodes for each node type.
"""
num_dst_nodes_dict = {}
num_src_nodes_dict = {}
for ntype in block.dsttypes:
num_dst_nodes_dict[ntype] = block.number_of_dst_nodes(ntype)
for ntype in block.srctypes:
num_src_nodes_dict[ntype] = block.number_of_src_nodes(ntype)
hetero_edges = {}
for srctpye, etype, dsttype in block.canonical_etypes:
src, dst = block.all_edges(etype=etype, order="eid")
hetero_edges[(srctpye, etype, dsttype)] = (src, dst)
hetero_g = dgl.heterograph(
hetero_edges,
num_nodes_dict=num_src_nodes_dict,
idtype=block.idtype,
device=block.device,
)
ntype_to_id = {ntype: hetero_g.get_ntype_id(ntype) for ntype in hetero_g.ntypes}
hetero_g.ndata[source_ndata_name] = source_ndata
homo_g, ntype_counts, _ = dgl.to_homogeneous(
hetero_g, ndata=[source_ndata_name], return_count=True
)
ntype_offset = np.insert(np.cumsum(ntype_counts), 0, 0)
return homo_g, ntype_to_id, ntype_offset, num_dst_nodes_dict
class HGTGNN(nn.Module):
def __init__(
self,
in_size: int,
hid_size: int,
num_heads: int,
num_ntypes: int,
num_etypes: int,
num_layers: int,
):
super().__init__()
self.in_size = in_size
self.out_size = hid_size
self.conv_layers = nn.ModuleList()
self.conv_layers.append(
HGTConv(
in_size,
hid_size // num_heads,
num_heads,
num_ntypes,
num_etypes,
)
)
for _ in range(num_layers - 1):
self.conv_layers.append(
HGTConv(
hid_size,
hid_size // num_heads,
num_heads,
num_ntypes,
num_etypes,
)
)
def forward(
self,
mfgs,
X_node_dict: Dict[str, torch.Tensor]
) -> Dict[str, torch.Tensor]:
assert len(mfgs) == len(self.conv_layers)
for ntype, X in X_node_dict.items():
assert X.shape[1] == self.in_size
H_node_dict = X_node_dict
for i, conv in enumerate(self.conv_layers):
# There are several steps to conduct HGT on a DGLBlock.
# 1. Convert the DGLBlock to a homogeneous graph. The
# homogeneous graph is based on the source nodes of the DGLBlock.
# 2. Run HGTConv on the homogeneous graph to compute homogeneous
# features.
# 3. Extract the heterogeneous features of the destination nodes
# from the homogeneous features.
homo_ndata_name = "x"
(
homo_g,
ntype_to_id,
ntype_offset,
num_dst_nodes_dict,
) = block_to_homogeneous(mfgs[i], H_node_dict, homo_ndata_name)
# Run hgtconv on the homogeneous graph.
homo_features = conv(
homo_g,
homo_g.ndata[homo_ndata_name],
homo_g.ndata[dgl.NTYPE],
homo_g.edata[dgl.ETYPE],
)
# Convert the output features back to a dict.
dst_features = {}
for ntype in mfgs[i].dsttypes:
ntype_id = ntype_to_id[ntype]
feature = homo_features[
ntype_offset[ntype_id] : ntype_offset[ntype_id + 1]
]
dst_features[ntype] = feature[: num_dst_nodes_dict[ntype]]
H_node_dict = dst_features
return H_node_dict
Hi @ayushnoori as DGL Graphbolt release, I am closing this as it, please try our new dataloading package DGL Graphbolt, let us know if GB does not solve your problem.
Thanks, @frozenbugs. Is there a guide available to facilitate transitioning over to GraphBolt from the older DGL version – and if so, could you please point me to it?
This is the official guide: https://docs.dgl.ai/en/latest/stochastic_training/index.html
And this is the minibatch training user guide: https://docs.dgl.ai/en/latest/guide/minibatch.html