pytorch_cluster
pytorch_cluster copied to clipboard
Get edge attributes for random walk
Hi,
I'm trying to get the edge attributes for each step in a random walk, with the use case of trying to retain edge types in a walk on a heterogeneous graph I've converted to homogeneous. Using the idea in rusty1s/pytorch_sparse#214, I have this so far:
import torch
from torch_geometric.datasets import DBLP
from torch_cluster import random_walk
def get_edge_attr(edge_index: torch.Tensor, edge_attr: torch.Tensor, query_row: torch.Tensor, query_col: torch.Tensor) -> torch.Tensor:
row, col = edge_index
row_mask = row == query_row.view(-1, 1)
col_mask = col == query_col.view(-1, 1)
mask = torch.max(torch.logical_and(row_mask, col_mask), dim=0).values
return edge_attr[mask]
data = DBLP()[0].to_homogeneous()
num_walks = 1
walk_length = 5
start = torch.arange(data.num_nodes).view(-1, 1).repeat(1, num_walks).view(-1)
rw = random_walk(data.edge_index[0], data.edge_index[1], start, walk_length, num_nodes=data.num_nodes)
print(rw)
>>> tensor([[ 0, 10514, 19555, 10973, 21385, 9952],
... [ 1, 6520, 20381, 6520, 20381, 6520]])
l, r = rw[:2].unfold(1, 2, 1).flatten().t() # sliding window of size 2 over each walk
print(get_edge_attr(data.edge_index, data.edge_type, l, r))
>>> tensor([1, 1, 4, 4, 4, 2, 2, 2])
I know this can be made more efficient with searchsorted as mentioned in the linked issue, but is this correct in general? Is there a built-in way to do this that I've missed or is this my best bet?
Thanks in advance.
I think this would be way easier if we would just return the edge indices in random_walk :(
I guess your approach works, but will be indeed inefficient.
I think this would be way easier if we would just return the edge indices in
random_walk:( I guess your approach works, but will be indeed inefficient.
Yeah, I'm certainly not a huge fan of this approach. It struggles with DBLP on a laptop (not an overpowered one, but pretty solid). For now I'm doing this awfulness:
et = data.adj_t.to_scipy('csc')
l, r = rw.unfold(-1, 2, 1).flatten(0, 1).t().cpu().numpy()
walk_etypes = torch.from_numpy(et[l.tolist(), r.tolist()]).to(rw.device).view(-1, walk_length)
Obviously not ideal compared to returning to edge indices. How difficult would that be to implement?
If you are interested, we can add support for this in pyg-lib, which should be straightforward to add. It also supports nightly builds so it should be ready to use once it lands. Let me know if you have interest in contributing!
If you are interested, we can add support for this in
pyg-lib, which should be straightforward to add. It also supports nightly builds so it should be ready to use once it lands. Let me know if you have interest in contributing!
Absolutely. Only issue is I haven't written C++ in several years and I've never written anything Torch or CUDA related with it. I can kind of follow along the CPU implementation but the CUDA one looks like it would take me a while to understand. I can do the grunt work but I'd probably need help understanding how it works and what would need to be changed.
Hey, I am also looking for such a feature (i.e., using edge indices from a random walk), and I did a quick update to the random_walk function. The edge indices are computed by the underlying C++/CUDA implementations, so we just need to return them.
Edit:
@jacobdanovitch Your example code will now simplify as follows:
import torch
from torch_geometric.datasets import DBLP
from torch_cluster import random_walk
data = DBLP("./tmp")[0].to_homogeneous()
num_walks = 1
walk_length = 5
start = torch.arange(data.num_nodes).view(-1, 1).repeat(1, num_walks).view(-1)
node_seq, edge_seq = random_walk(data.edge_index[0], data.edge_index[1], start, walk_length, num_nodes=data.num_nodes, return_edge_indices=True)
print("Node seq:", node_seq.shape)
>>> Node seq: torch.Size([26128, 6])
print("Edge seq:", edge_seq.shape)
>>> Edge seq: torch.Size([26128, 5])
visited_edge_types = data.edge_type[edge_seq]
print(visited_edge_types)
>>> tensor([[0, 1, 4, 2, 4],
[0, 1, 4, 2, 4],
[0, 2, 4, 2, 5],
...,
[5, 3, 4, 2, 4],
[5, 3, 4, 2, 0],
[5, 3, 4, 2, 5]])
print("Visited edge types:", visited_edge_types.shape)
>>> Visited edge types: torch.Size([26128, 5])
Reference: #139 Thanks @pbielak!
@jacobdanovitch Your example code will now simplify as follows:
Works perfectly, thanks so much @pbielak!
This issue had no activity for 6 months. It will be closed in 2 weeks unless there is some new activity. Is this issue already resolved?