dgl
dgl copied to clipboard
[GraphBolt] modify logic for `HeteroItemSet` indexing
Description
First let's take a look at the current code for indexing a HeteroItemSet (occurs in HeteroItemSet.__getitem__):
elif isinstance(index, Iterable):
if not isinstance(index, torch.Tensor):
index = torch.tensor(index)
assert torch.all((index >= 0) & (index < self._length))
key_indices = (
torch.searchsorted(self._offsets, index, right=True) - 1
)
data = {}
for key_id, key in enumerate(self._keys):
mask = (key_indices == key_id).nonzero().squeeze(1)
if len(mask) == 0:
continue
data[key] = self._itemsets[key][
index[mask] - self._offsets[key_id]
]
return data
Say the length of indices is N and the number of etypes/ntyeps is K, then the time complexity of current implementation of indexing a dictionary is O(N * K), which is mainly introduced by the line
mask = (key_indices == key_id).nonzero().squeeze(1)
If there are a lot of etypes, this line could easily become the bottleneck.
This draft PR intends to propose an alternative to current logic:
elif isinstance(index, Iterable):
if not isinstance(index, torch.Tensor):
index = torch.tensor(index)
sorted_index, indices = index.sort()
assert sorted_index[0] >= 0 and sorted_index[-1] < self._length
index_offsets = torch.searchsorted(sorted_index, self._offsets)
data = {}
for key_id, key in enumerate(self._keys):
if index_offsets[key_id] == index_offsets[key_id + 1]:
continue
current_indices, _ = indices[
index_offsets[key_id] : index_offsets[key_id + 1]
].sort()
data[key] = self._itemsets[key][
index[current_indices] - self._offsets[key_id]
]
return data
whose time complexity is O(N * logN) where the log is introduced by the sorting operation.
This will imporve the performance when there are many etypes, but might cause more time consuming when there are few etypes. A thoughtful consideration lies in striking a balance between the two approaches.
Update on June 18
Benchmark: https://docs.google.com/document/d/1Bbmp8gMekiGIYYxEMVbmXSANRZlZ_nTNbhpWul4RaKA/edit?usp=sharing
The results show that the original algorithm is faster than the new algorithm (theoretical time complexity N*logN) for almost all batch_size and num_types.
Checklist
Please feel free to remove inapplicable items for your PR.
- [ ] The PR title starts with [$CATEGORY] (such as [NN], [Model], [Doc], [Feature]])
- [ ] I've leverage the tools to beautify the python and c++ code.
- [ ] The PR is complete and small, read the Google eng practice (CL equals to PR) to understand more about small PR. In DGL, we consider PRs with less than 200 lines of core code change are small (example, test and documentation could be exempted).
- [ ] All changes have test coverage
- [ ] Code is well-documented
- [ ] To the best of my knowledge, examples are either not affected by this change, or have been fixed to be compatible with this change
- [ ] Related issue is referred in this PR
- [ ] If the PR is for a new model/paper, I've updated the example index here.
Changes
To trigger regression tests:
@dgl-bot run [instance-type] [which tests] [compare-with-branch]; For example:@dgl-bot run g4dn.4xlarge all dmlc/masteror@dgl-bot run c5.9xlarge kernel,api dmlc/master
Commit ID: 6c3a7f203cb4b94964ea2cd4e1e172190f488ad1
Build ID: 1
Status: ✅ CI test succeeded.
Report path: link
Full logs path: link
@dgl-bot
Commit ID: b4c0ea46dfefd4426afc83d6a6098b4495a653b5
Build ID: 2
Status: ✅ CI test succeeded.
Report path: link
Full logs path: link
Do you have a benchmark comparing the new approach to the old one for different K values?
Commit ID: 50da7181836ef6ecedd78cbd58e32bd9c9d5003b
Build ID: 3
Status: ❌ CI test failed in Stage [Distributed Torch CPU Unit test].
Report path: link
Full logs path: link
@mfbalin @frozenbugs See benchmark results in the description. The new implementation does not seem to be as efficient as we thought. Maybe we should keep it as is?
Commit ID: 591c122323236759ec8e2df4021308331e93cf6b
Build ID: 4
Status: ✅ CI test succeeded.
Report path: link
Full logs path: link
@mfbalin @frozenbugs See benchmark results in the description. The new implementation does not seem to be as efficient as we thought. Maybe we should keep it as is?
Let me take a look at the code to see if we missed anything. Thank you for the benchmark.