alignn
alignn copied to clipboard
Angle Information
Hi @bdecost,
When I wrote the following code based on your code to print the line graph, an incomprehensible part appeared.
In the code below, among (node_i, node_j) based on ij_pair and (nodej_, nodek_) based on jk_pair, I think that (node_j) and (nodej_) should have the same node number, but there are cases where they do not match as a result of executing the code below.
When I referred to your paper, I understood that the angle of the line graph is made based on node_i, node_j, and node_k. Why is the shared node_j not the same?
############################################ CIF file (mp-2500.cif) '''generated using pymatgen''' data_AlCu _symmetry_space_group_name_H-M 'P 1' _cell_length_a 6.37716407 _cell_length_b 6.37716407 _cell_length_c 6.92031335 _cell_angle_alpha 57.14549155 _cell_angle_beta 57.14549155 _cell_angle_gamma 37.46229262 _symmetry_Int_Tables_number 1 _chemical_formula_structural AlCu _chemical_formula_sum 'Al5 Cu5' _cell_volume 140.31041575 cell_formula_units_Z 5 loop _symmetry_equiv_pos_site_id symmetry_equiv_pos_as_xyz 1 'x, y, z' loop _atom_site_type_symbol _atom_site_label _atom_site_symmetry_multiplicity _atom_site_fract_x _atom_site_fract_y _atom_site_fract_z _atom_site_occupancy Al Al0 1 0.50000000 0.50000000 0.50000000 1 Al Al1 1 0.15622000 0.15622000 0.53856900 1 Al Al2 1 0.84378000 0.84378000 0.46143100 1 Al Al3 1 0.37823100 0.37823100 0.00427500 1 Al Al4 1 0.62176900 0.62176900 0.99572500 1 Cu Cu5 1 0.00000000 0.00000000 0.00000000 1 Cu Cu6 1 0.25794700 0.25794700 0.75941600 1 Cu Cu7 1 0.74205300 0.74205300 0.24058400 1 Cu Cu8 1 0.10895200 0.10895200 0.22813800 1 Cu Cu9 1 0.89104800 0.89104800 0.77186200 1
############################################# Code import os import dgl import numpy as np
import torch
from jarvis.core.atoms import Atoms from jarvis.core.graphs import Graph
from torch_geometric.data import InMemoryDataset, Data, Batch from torch_geometric.utils.convert import from_networkx
raw_path = './mp-2500.cif' crystal = Atoms.from_cif(raw_path, use_cif2cell=False) coords = crystal.cart_coords graph = Graph.atom_dgl_multigraph(crystal, cutoff=8.0, atom_features='cgcnn',max_neighbors=12, compute_line_graph=True, use_canonize=False) for i in [0,1]: '''Atom-Bond Graph''' if i==0: g = from_networkx(dgl.DGLGraph.to_networkx(graph[i], node_attrs=['atom_features'], edge_attrs=['r'])) x = torch.tensor([x.detach().numpy() for x in g.atom_features]) z = torch.tensor(crystal.atomic_numbers) pos = torch.tensor(coords, dtype=torch.float) edge_id = g.id edge_pos = torch.tensor([x.detach().numpy() for x in g.r]) edge_index = g.edge_index edge_distance = torch.tensor(np.linalg.norm(graph[i].edata['r'], axis=1)) ab_g = Data(x=x, z=z, pos=pos, edge_id=edge_id, edge_index=edge_index, edge_distance=edge_distance, edge_pos=edge_pos, idx=n) '''Line Graph''' if i==1: g = from_networkx(dgl.DGLGraph.to_networkx(graph[i], node_attrs=['r'], edge_attrs=['h'])) x = torch.tensor(np.linalg.norm(graph[i].ndata['r'], axis=1)) pos = torch.tensor([x.detach().numpy() for x in g.r]) edge_id = g.id edge_index = g.edge_index edge_angle = g.h ba_g = Data(x=x, pos=pos, edge_id=edge_id, edge_index=edge_index, edge_angle=edge_angle, idx=n) dataset = [ab_g, ba_g]
'''dataset[1] = Line Graph''' '''dataset[0] = Atom-Bond Graph''' ij_pair = dataset[1].edge_index[0] jk_pair = dataset[1].edge_index[1] node_i = dataset[0].edge_index[0][ij_pair] node_j = dataset[0].edge_index[1][ij_pair] nodej_ = dataset[0].edge_index[0][jk_pair] nodek_ = dataset[0].edge_index[1][jk_pair]
################################################################# Result node_i[0:10], node_j[0:10], nodej_[0:10], nodek_[0:10]
I'm having a bit of trouble reproducing your result -- mainly because I am not able to parse the CIF data you shared. I have tried the jarvis, ase, and pymatgen parsers. Could you maybe share a bit more about the structure, or post a jarvis or pymatgen json representation? The structure does not seem to quite match the CIF data from https://materialsproject.org/materials/mp-2500?material_ids=mp-2500
One other question (see later discussion for context) -- what is n
? It seems like it's doing something index related but I don't see it defined.
Your expectations are right though, in the directed graph, bond pairs like i -> j -> k
should have a pair of edges (i, j)
and (j, k)
in primary graph, and the line graph should have an edge connecting these two bonds, (i, j) -> (j, k)
I'm wondering if something is going wrong during the dgl -> networkx -> pyG conversion?
I'm not really familiar with pytorch geometric, but I think this is building an index array into the line graph edges, which should share edge ids in the primary graph
ij_pair = dataset[1].edge_index[0]
jk_pair = dataset[1].edge_index[1]
as far as I can tell this bit of code is equivalent to the dgl function line_graph.edges()
In [105]: lg.edges()
Out[105]:
(tensor([ 0, 0, 0, ..., 251, 251, 251]),
tensor([ 1, 25, 59, ..., 246, 248, 250]))
then I think this bit of code is looking up the corresponding node/atom ids, right?
node_i = dataset[0].edge_index[0][ij_pair] # src, bond 1
node_j = dataset[0].edge_index[1][ij_pair] # dst, bond 1
nodej_ = dataset[0].edge_index[0][jk_pair] # src, bond 2
nodek_ = dataset[0].edge_index[1][jk_pair] # dst, bond 2
and (dst, bond 1) should match (src, bond 2).
That all seems fine. One question -- what is n
? I see it's doing something index related in ba_g = Data(x=x, pos=pos, edge_id=edge_id, edge_index=edge_index, edge_angle=edge_angle, idx=n)
, but I don't see where it is defined.
Here's my minimal test example (sticking with dgl):
using the non-symmetrized CIF from MP, I would load the data like this. We might need to check if the graph building code needs to get pushed back upstream to jarvis-tools...
import dgl
import dgl.function as fn
import torch
from jarvis.core.atoms import Atoms
from alignn.graphs import Graph
crystal = Atoms.from_cif("debugging/mp-2500-dl.cif", use_cif2cell=False)
g, lg = Graph.atom_dgl_multigraph(crystal)
print(g)
Graph(num_nodes=10, num_edges=252,
ndata_schemes={'atom_features': Scheme(shape=(92,), dtype=torch.float32), 'lattice_mat': Scheme(shape=(3, 3), dtype=torch.float64), 'V': Scheme(shape=(), dtype=torch.float32)}
edata_schemes={'r': Scheme(shape=(3,), dtype=torch.float32)})
next, we can add the node id pairs to the edge data to propagate to the line graph
# add node and edge ids to propagate to line graph
src, dst = g.edges()
pair_ids = torch.stack((src,dst), dim=-1)
g.edata["pair_ids"] = pair_ids
we get a 2D array with (src, dst) pairs for each edge
In [110]: g.edata["pair_ids"]
Out[110]:
tensor([[0, 6],
[6, 0],
[0, 7],
[7, 0],
[0, 8],
[8, 0],
[0, 8],
[8, 0],
[0, 9],
[9, 0],
[0, 9],
...
We can rebuild the line graph to propagate these features (or we could just add them to the edge data since we know the node/edge ordering hasn't changed)
lg = g.line_graph(shared=True)
# get linear indices into line graph source and dest ids for triplets
u, v = lg.edges()
If we index into the line graph node data (corresponding to the bond data), the ids should match as (i, j), (j, k)
, i.e. the second column of the line graph source ids should match the first column of the dest ids
In [113]: lg.ndata["pair_ids"][u]
Out[113]:
tensor([[0, 6],
[0, 6],
[0, 6],
...,
[8, 9],
[8, 9],
[8, 9]])
In [114]: lg.ndata["pair_ids"][v]
Out[114]:
tensor([[6, 0],
[6, 1],
[6, 2],
...,
[9, 7],
[9, 4],
[9, 8]])
We can check all edges in the line graph:
j_src = lg.ndata["pair_ids"][u][:,1]
j_dst = lg.ndata["pair_ids"][v][:,0]
In [121]: (lg_src[:,1] == lg_dst[:,0]).all()
Out[121]: tensor(True)
So this seems ok to me, everything matches as expected. It would be nice to check on the actual structure you are having an issue with, but at the moment I'm thinking that would just induce a different starting graph structure, potentially. It might be a good idea to make sure node and edge indices match across the networkx and pyG graphs, that's what I would check first