tgn icon indicating copy to clipboard operation
tgn copied to clipboard

Potential problem with embeddings computing?

Open antoniofilipovic opened this issue 3 years ago • 2 comments

Hi,

let's say we are using GraphSumEmbedding as a layer in TGN.

As I understood from Your paper in the propagation part, we are supposed to concatenate embeddings from the previous layer of our neighbors and from the previous layer ourselves as nodes.

So for example, if we want to calculate an embedding on layer 2 of a node with id "a" whose neighbors are nodes with id: "b", "c", "d", we need to calculate embeddings of neighbors "b", "c", and "d" on layer 1 and we also need embedding of node "a" on layer 1.

But from code it doesn't look to me that way, could be I am wrong, but You are always sending source_node_features to the aggregate function, meaning You are always using features on layer 0 of source nodes: memory + raw_features:

 source_embedding = self.aggregate(n_layers, source_node_features,
                                        source_nodes_time_embedding,
                                        neighbor_embeddings,
                                        edge_time_embeddings,
                                        edge_features,
                                        mask)

And then later You are in GraphSumEmbedding.calculate doing

source_features = torch.cat([source_node_features,
                                 source_nodes_time_embedding.squeeze()], dim=1)

This is a problem since when you want to calculate embeddings on layer 2 You will be using neighbor_embeddings on layer 1, but source_node_features from layer 0.

Why is that so: because You are setting source_node_features at the beginning and not changing variable:

 source_node_features = self.node_features[source_nodes_torch, :]

    if self.use_memory:
      source_node_features = memory[source_nodes, :] + source_node_features

I think that part of calculating source_node_embeddings is completely missing.

This is the source code down below of GraphEmbedding:

def compute_embedding(self, memory, source_nodes, timestamps, n_layers, n_neighbors=20, time_diffs=None,
                        use_time_proj=True):
    """Recursive implementation of curr_layers temporal graph attention layers.
    src_idx_l [batch_size]: users / items input ids.
    cut_time_l [batch_size]: scalar representing the instant of the time where we want to extract the user / item representation.
    curr_layers [scalar]: number of temporal convolutional layers to stack.
    num_neighbors [scalar]: number of temporal neighbor to consider in each convolutional layer.
    """

    assert (n_layers >= 0)

    source_nodes_torch = torch.from_numpy(source_nodes).long().to(self.device)
    timestamps_torch = torch.unsqueeze(torch.from_numpy(timestamps).float().to(self.device), dim=1)

    # query node always has the start time -> time span == 0
    source_nodes_time_embedding = self.time_encoder(torch.zeros_like(
      timestamps_torch))

    source_node_features = self.node_features[source_nodes_torch, :]

    if self.use_memory:
      source_node_features = memory[source_nodes, :] + source_node_features

    if n_layers == 0:
      return source_node_features
    else:

      neighbors, edge_idxs, edge_times = self.neighbor_finder.get_temporal_neighbor(
        source_nodes,
        timestamps,
        n_neighbors=n_neighbors)

      neighbors_torch = torch.from_numpy(neighbors).long().to(self.device)

      edge_idxs = torch.from_numpy(edge_idxs).long().to(self.device)

      edge_deltas = timestamps[:, np.newaxis] - edge_times

      edge_deltas_torch = torch.from_numpy(edge_deltas).float().to(self.device)

      neighbors = neighbors.flatten()
      neighbor_embeddings = self.compute_embedding(memory,
                                                   neighbors,
                                                   np.repeat(timestamps, n_neighbors),
                                                   n_layers=n_layers - 1,
                                                   n_neighbors=n_neighbors)

      effective_n_neighbors = n_neighbors if n_neighbors > 0 else 1
      neighbor_embeddings = neighbor_embeddings.view(len(source_nodes), effective_n_neighbors, -1)
      edge_time_embeddings = self.time_encoder(edge_deltas_torch)

      edge_features = self.edge_features[edge_idxs, :]

      mask = neighbors_torch == 0

      source_embedding = self.aggregate(n_layers, source_node_features,
                                        source_nodes_time_embedding,
                                        neighbor_embeddings,
                                        edge_time_embeddings,
                                        edge_features,
                                        mask)

      return source_embedding

And code for GraphSumEmbedding

 def aggregate(self, n_layer, source_node_features, source_nodes_time_embedding,
                neighbor_embeddings,
                edge_time_embeddings, edge_features, mask):
    neighbors_features = torch.cat([neighbor_embeddings, edge_time_embeddings, edge_features],
                                   dim=2)
    neighbor_embeddings = self.linear_1[n_layer - 1](neighbors_features)
    neighbors_sum = torch.nn.functional.relu(torch.sum(neighbor_embeddings, dim=1))

    source_features = torch.cat([source_node_features,
                                 source_nodes_time_embedding.squeeze()], dim=1)
    source_embedding = torch.cat([neighbors_sum, source_features], dim=1)
    source_embedding = self.linear_2[n_layer - 1](source_embedding)

    return source_embedding

antoniofilipovic avatar Feb 19 '22 16:02 antoniofilipovic

Hi @antoniofilipovic,

You are correct. In practice, it does not make much of a difference since we're mostly using only 1 graph layer with TGN, but this is indeed incorrect and will lead to a different result if using 2 layers or more.

Would you be willing to open up a pull request for this?

Best, Emanuele

emalgorithm avatar Apr 13 '23 08:04 emalgorithm

@emalgorithm would u be willing to push an update and close this issue?

ezzeldinadel avatar Nov 02 '23 15:11 ezzeldinadel