tgn icon indicating copy to clipboard operation
tgn copied to clipboard

Updating memory fails for datasets that are not bipartite

Open daniel-gomm opened this issue 1 year ago • 1 comments

Hi,

If I am not mistaken, there seems to be a bug when using the model on a Unipartite dataset when updating the memory at the end of each batch memory_update_at_start=False.

Running the model like this incorrectly triggers the AssertionError: Trying to update to time in the past of the memory_updater module. This is due to lines 185-186 in tgn.py.


def compute_temporal_embeddings(self, source_nodes, destination_nodes, negative_nodes, edge_times,
                                  edge_idxs, n_neighbors=20):
    ...
    if self.use_memory:
      if self.memory_update_at_start:
        # Update memory for all nodes with messages stored in previous batches
        memory, last_update = self.get_updated_memory(list(range(self.n_nodes)),
                                                      self.memory.messages)
      else:
        memory = self.memory.get_memory(list(range(self.n_nodes)))
        last_update = self.memory.last_update

      ...

    if self.use_memory:
      if self.memory_update_at_start:
        # Persist the updates to the memory only for sources and destinations (since now we have
        # new messages for them)
        self.update_memory(positives, self.memory.messages)

        assert torch.allclose(memory[positives], self.memory.get_memory(positives), atol=1e-5), \
          "Something wrong in how the memory was updated"

        # Remove messages for the positives since we have already updated the memory using them
        self.memory.clear_messages(positives)

      unique_sources, source_id_to_messages = self.get_raw_messages(source_nodes, source_node_embedding, destination_nodes, destination_node_embedding, edge_times, edge_idxs)
      unique_destinations, destination_id_to_messages = self.get_raw_messages(destination_nodes, destination_node_embedding, source_nodes, source_node_embedding, edge_times, edge_idxs)
      if self.memory_update_at_start:
        self.memory.store_raw_messages(unique_sources, source_id_to_messages)
        self.memory.store_raw_messages(unique_destinations, destination_id_to_messages)
      else:
        self.update_memory(unique_sources, source_id_to_messages)                  <-- 185
        self.update_memory(unique_destinations, destination_id_to_messages)        <-- 186

     ...

    return source_node_embedding, destination_node_embedding, negative_node_embedding

When the source_nodes and destination_nodes contain non-overlapping node ids this is not a problem. However, when using a unipartite graph, the same node id can be in the source_nodes and the destination_nodes, which causes the described issue if this node id is associated with a later timestamp on the source node side, then the target node side.

This problem can be resolved by replacing:


      if self.memory_update_at_start:
        self.memory.store_raw_messages(unique_sources, source_id_to_messages)
        self.memory.store_raw_messages(unique_destinations, destination_id_to_messages)
      else:
        self.update_memory(unique_sources, source_id_to_messages)
        self.update_memory(unique_destinations, destination_id_to_messages)

with:


            self.memory.store_raw_messages(unique_sources, source_id_to_messages)
            self.memory.store_raw_messages(unique_destinations, destination_id_to_messages)

            if not self.memory_update_at_start:
                unique_node_ids = np.unique(np.concatenate((unique_sources, unique_destinations)))
                self.update_memory(unique_node_ids,
                             self.memory.messages)
                self.memory.clear_messages(unique_node_ids)

Edit: Found an issue in the fix initially proposed and updated matching the pull request

daniel-gomm avatar Nov 20 '23 18:11 daniel-gomm