tgn
tgn copied to clipboard
Updating memory fails for datasets that are not bipartite
Hi,
If I am not mistaken, there seems to be a bug when using the model on a Unipartite dataset when updating the memory at the end of each batch memory_update_at_start=False
.
Running the model like this incorrectly triggers the AssertionError: Trying to update to time in the past
of the memory_updater module. This is due to lines 185-186 in tgn.py.
def compute_temporal_embeddings(self, source_nodes, destination_nodes, negative_nodes, edge_times,
edge_idxs, n_neighbors=20):
...
if self.use_memory:
if self.memory_update_at_start:
# Update memory for all nodes with messages stored in previous batches
memory, last_update = self.get_updated_memory(list(range(self.n_nodes)),
self.memory.messages)
else:
memory = self.memory.get_memory(list(range(self.n_nodes)))
last_update = self.memory.last_update
...
if self.use_memory:
if self.memory_update_at_start:
# Persist the updates to the memory only for sources and destinations (since now we have
# new messages for them)
self.update_memory(positives, self.memory.messages)
assert torch.allclose(memory[positives], self.memory.get_memory(positives), atol=1e-5), \
"Something wrong in how the memory was updated"
# Remove messages for the positives since we have already updated the memory using them
self.memory.clear_messages(positives)
unique_sources, source_id_to_messages = self.get_raw_messages(source_nodes, source_node_embedding, destination_nodes, destination_node_embedding, edge_times, edge_idxs)
unique_destinations, destination_id_to_messages = self.get_raw_messages(destination_nodes, destination_node_embedding, source_nodes, source_node_embedding, edge_times, edge_idxs)
if self.memory_update_at_start:
self.memory.store_raw_messages(unique_sources, source_id_to_messages)
self.memory.store_raw_messages(unique_destinations, destination_id_to_messages)
else:
self.update_memory(unique_sources, source_id_to_messages) <-- 185
self.update_memory(unique_destinations, destination_id_to_messages) <-- 186
...
return source_node_embedding, destination_node_embedding, negative_node_embedding
When the source_nodes
and destination_nodes
contain non-overlapping node ids this is not a problem. However, when using a unipartite graph, the same node id can be in the source_nodes
and the destination_nodes
, which causes the described issue if this node id is associated with a later timestamp on the source node side, then the target node side.
This problem can be resolved by replacing:
if self.memory_update_at_start:
self.memory.store_raw_messages(unique_sources, source_id_to_messages)
self.memory.store_raw_messages(unique_destinations, destination_id_to_messages)
else:
self.update_memory(unique_sources, source_id_to_messages)
self.update_memory(unique_destinations, destination_id_to_messages)
with:
self.memory.store_raw_messages(unique_sources, source_id_to_messages)
self.memory.store_raw_messages(unique_destinations, destination_id_to_messages)
if not self.memory_update_at_start:
unique_node_ids = np.unique(np.concatenate((unique_sources, unique_destinations)))
self.update_memory(unique_node_ids,
self.memory.messages)
self.memory.clear_messages(unique_node_ids)
Edit: Found an issue in the fix initially proposed and updated matching the pull request