alignn icon indicating copy to clipboard operation
alignn copied to clipboard

load graph data directly from disk to reduce memory requirements for large datasets

Open bdecost opened this issue 3 years ago • 0 comments

e.g. the MEGNet dataset is rather large; so are OQMD and AFLOW

proposal: store graph data in hdf5 using dataset index as main dataset keys Alternate key: use the structure identifier, e.g. "JVASP-1234" or "MP-5678"

import h5py
import pandas as pd

identifier = "jid"
df = pd.DataFrame(jdata("dft_3d"))

with h5py.File("dft_3d.hdf5", "w") as f:
        for idx, row in df.iterrows():
            # store graph data in HDF5 group keyed with structure id
            # e.g. "JVASP-1234"
            identifier = row[identifier]
            group = f.create_group(identifier)
            ndata = group.create_group("ndata")
            edata = group.create_group("edata")
        
            graph = build_dgl_graph(row["structure"])
            
            # store edge list representation
            u, v = graph.edges()
            group["u"] = u
            group["v"] = v
        
            # store node data in a supgroup "ndata"
            # e.g. f["JVASP-1234/ndata/atomic_number"]
            # ndata["atomic_number"] = graph.ndata["atomic_number"]
            for key, node_feature in graph.ndata.items():
                ndata[key] = node_feature
        
            # store node data in a supgroup "ndata"
            # e.g. f["JVASP-1234/edata/r"]
            # edata["r"] = graph.edata["r"]
            for key, node_feature in graph.edata.items():
                edata[key] = node_feature

Then dataloading can look like

class StructureDataset():
    def __getitem__(self, idx):

        # https://discuss.pytorch.org/t/dataloader-when-num-worker-0-there-is-bug/25643/16
        if self.dataset is None:
            self.dataset = h5py.File("dft_3d.hdf5", "r")

        # look up structure id, e.g.  "JVASP-1234"
        key = self.identifiers[idx]
        group = self.dataset[key]
        ndata = group["ndata"]
        edata = group["edata"]

        # load graph from edge list
        g = dgl.graph(group["u"], group["v"])

        for key in ndata:
            g.ndata[key] = ndata[key]

        for key in edata:
            g.edata[key] = edata[key]

bdecost avatar Apr 21 '21 15:04 bdecost