ensmallen icon indicating copy to clipboard operation
ensmallen copied to clipboard

Friendly Recipe for Reading a Graph

Open LunarEngineerSF opened this issue 1 year ago • 0 comments

Good morning!

I thought I'd contribute, as I found this helpful; I wanted to read a large graph from previous existing datasets of nodes and edges; those datasets were kind of large, so I wrote something to help simplify the process and I thought it might make some useful documentation for how Ensmallen builds graphs.

import pyarrow.dataset as ds
from ensmallen import GraphBuilder, Graph

def read_graph_from_datasets(
    node_dataset: ds.Dataset,
    edge_dataset: ds.Dataset,
) -> Tuple[ds.Dataset, ds.Dataset]:
    """Ingest a dataset into an Ensmallen graph.

    Parameters
    ----------
    node_data: ds.Dataset
        This is a dataset of individual label/value pairs.
    edge_dataset: ds.Dataset
        This is a dataset of edges linking label/value pairs.
    
    Returns
    -------
    graph: Graph
    """
    builder = GraphBuilder()
    builder.set_directed(False)
    builder.set_name("NEATO_KEENO_NAMEO")
    # Add the nodes.
    for batch in node_dataset.to_batches():
        for batch_row in batch.select(['node_id', 'feature']).to_pylist():
            builder.add_node(
                name=batch_row['node_id'],
                node_type=[batch_row['feature']],  # Must be list of label
            )
    # Add the edges.
    for batch in edge_dataset.to_batches():
        batch = pa.Table.from_batches([batch])
        batch = batch.append_column(
            'src_col',
            pc.binary_join_element_wise(
                batch.column('left_feature'),
                batch.column('left_value'),
                '_'
            )
        )
        batch = batch.append_column(
            'dst_col',
            pc.binary_join_element_wise(
                batch.column('right_feature'),
                batch.column('right_value'),
                '_'
            )
        )
        batch = batch.append_column(
            'edge_type',
            pc.binary_join_element_wise(
                batch.column('left_feature'),
                batch.column('right_feature'),
                '_'
            )
        )
        for batch_row in batch.select(
            ['src_col', 'dst_col', 'edge_type']
        ).to_pylist():
            builder.add_edge(
                src=batch_row['src_col'],
                dst=batch_row['dst_col'],
                edge_type=batch_row['edge_type']
            )
    # Build the graph
    graph = builder.build()
    return graph

LunarEngineerSF avatar Jan 11 '24 14:01 LunarEngineerSF