RGCN
RGCN copied to clipboard
The memory overflow
I use your code as a part in our experiment. But there is a problem about the test_graph. Because the code uses all train triplets to build graph during valid and test, so when I test the model, 64G memory is not enough. So I wish you to give me some suggestions. Thank you very much!
Hi, I have also encountered the problem of out of memory. It seems that when there are too many triplets, massive CPU memory is needed:
In my experiments, information of the graph is as follows:
It seems that the only way to avoid OOM is to reduce the number of triplets, then there will be less edge_type indexed in w:
I use your code as a part in our experiment. But there is a problem about the test_graph. Because the code uses all train triplets to build graph during valid and test, so when I test the model, 64G memory is not enough. So I wish you to give me some suggestions. Thank you very much!
Hi He, Could you please tell me the way you fix this overflow issue? Many thanks in advance.
One solution seems to be to modify valid()
to include batching for evaluation phase (& modify calc_mrr
accordingly to return hits also):
def valid(valid_triplets, model, test_graph, all_triplets, batch_size=1024):
with torch.no_grad():
model.eval()
mrr = 0
hits = {1: 0, 3: 0, 10: 0}
for i in range(0, len(valid_triplets), batch_size):
batch_valid_triplets = valid_triplets[i:i+batch_size]
entity_embedding = model(test_graph.entity, test_graph.edge_index, test_graph.edge_type, test_graph.edge_norm)
mrr_b, hits_bdict = calc_mrr(entity_embedding, model.relation_embedding, batch_valid_triplets, all_triplets, hits=[1, 3, 10])
mrr+=mrr_b
hits[1]+=hits_bdict[1]
hits[3]+=hits_bdict[3]
hits[10]+=hits_bdict[10]
mrr /= (len(valid_triplets) // batch_size)
hits[1] /= (len(valid_triplets) // batch_size)
hits[3] /= (len(valid_triplets) // batch_size)
hits[10] /= (len(valid_triplets) // batch_size)
print(f'MRR: {mrr}, Hits@1: {hits[1]}, Hits@3: {hits[3]}, Hits@10: {hits[10]}')
return mrr
The above however does not seem to work for FB15k-237
. Could the source of the issue be this line: https://github.com/JinheonBaek/RGCN/blob/818bf70b00d5cd178a7496a748e4f18da3bcde82/main.py#L25C41-L25C47
In case it helps, here is the memory profiling for the message
function during training & during validation.
During Training:
Line # Mem usage Increment Occurrences Line Contents =============================================================
188 1904.4 MiB 1904.4 MiB 1 @profile
189 def message(self, x_j, edge_index_j, edge_type, edge_norm):
190 """
191 """
192
193 # Call the function that might be causing the memory overflow
194 1904.4 MiB 0.0 MiB 1 w = torch.matmul(self.att, self.basis.view(self.num_bases, -1))
195
196 # If no node features are given, we implement a simple embedding
197 # loopkup based on the target node index and its edge type. 198 1904.4 MiB 0.0 MiB 1 if x_j is None:
199 w = w.view(-1, self.out_channels)
200 index = edge_type * self.in_channels + edge_index_j
201 out = torch.index_select(w, 0, index)
202 else:
203 1904.4 MiB 0.0 MiB 1 w = w.view(self.num_rel, self.in_chan, self.out_chan)
204 3047.9 MiB 1143.5 MiB 1 w = torch.index_select(w, 0, edge_type)
205 3047.9 MiB 0.0 MiB 1 out = torch.bmm(x_j.unsqueeze(1), w).squeeze(-2)
206
207 3047.9 MiB 0.0 MiB 1 if edge_norm is not None:
208 3047.9 MiB 0.0 MiB 1 out = out * edge_norm.view(-1, 1) 209
210 3047.9 MiB 0.0 MiB 1 return out
During Validation:
Line # Mem usage Increment Occurrences Line Contents
=============================================================
188 844.2 MiB 844.2 MiB 1 @profile
189 def message(self, x_j, edge_index_j, edge_type, edge_norm):
190 """
191 """
192
193 # Call the function that might be causing the memory overflow
194 844.2 MiB 0.0 MiB 1 w = torch.matmul(self.att, self.basis.view(self.num_bases, -1))
195
196 # If no node features are given, we implement a simple embedding
197 # loopkup based on the target node index and its edge type. 198 844.2 MiB 0.0 MiB 1 if x_j is None:
199 w = w.view(-1, self.out_channels)
200 index = edge_type * self.in_channels + edge_index_j
201 out = torch.index_select(w, 0, index) 202 else:
203 844.2 MiB 0.0 MiB 1 w = w.view(self.num_rel, self.in_chan, self.out_chan)
204 11635.1 MiB 10790.8 MiB 1 w = torch.index_select(w, 0, edge_type) 205 11743.1 MiB 108.0 MiB 1 out = torch.bmm(x_j.unsqueeze(1), w).squeeze(-2)
206
207 11743.1 MiB 0.0 MiB 1 if edge_norm is not None:
208 11743.2 MiB 0.2 MiB 1 out = out * edge_norm.view(-1, 1) 209
210 11743.2 MiB 0.0 MiB 1 return out
It appears that the memory overflow happens specifically during validation because the size of edge_type
is large during validation compared to training.
During Training:
Size of edge_type 30000
During Validation:
Size of edge_type 282884