tugraph-analytics icon indicating copy to clipboard operation
tugraph-analytics copied to clipboard

Support subgraph sampling for GNN

Open zekisong opened this issue 8 months ago • 6 comments

is there any plan in future to provide subgraph sampling support, our needs are as follows: 1.sample N neighbors for each node 2.subgraph sampling depth of K 3.subgraph contains vertices or edges properties

zekisong avatar Apr 16 '25 06:04 zekisong

Hello.

  1. we are currently planning to support using iterators to fetch edges, so that you can decide how many edges to load in each iteration in your algorithm udf. We will provide you with an example in subsequent versions.
  2. In udf, such as kHop, you can decide the format of the output subgraph, such as outputting the current vertex and all the edges corresponding to the vertex in a specific format. I don't know if this meets your needs.

QiZhang1997 avatar Apr 16 '25 07:04 QiZhang1997

Thank you, this completely meets our needs.

zekisong avatar Apr 16 '25 09:04 zekisong

We will support as soon as possible. Thank you~

Loognqiang avatar Apr 20 '25 02:04 Loognqiang

@Loognqiang Hello, I have designed a GNN sampling scheme. You can take a look.

1.1 Core Capabilities Already Available

GeaFlow already provides solid infrastructure:

Edge Iterator Mechanism: The system has implemented EdgeScanIterator, supporting iterative edge access and filtering

Edge Limit Support: The IEdgeLimit interface supports limiting the number of incoming and outgoing edges per vertex.

Filter Pushdown: StatePushDown supports pushing filter conditions and limits down to the storage layer.

Vertex-Grouped Edge Iteration: EdgeListScanIterator has implemented edge iteration grouped by source vertex ID .

Algorithm Runtime Context: AlgorithmRuntimeContext provides iterator interfaces for loading edges.

K-hop Algorithm Example: The existing K-hop implementation demonstrates multi-layer neighbor traversal patterns.

1.2 Capabilities Requiring Extension

Current architecture limitations:

  1. LimitFilter only supports sequential limiting, not random sampling
  2. No dedicated layered sampling API for GNN scenarios
  3. Lack of flexible sampling strategy configuration (random sampling, importance sampling, etc.)

II. Core Design Plan

2.1 New GNN Sampling API Design

2.1.1 Sampling Configuration Interface

/**
 * GNN subgraph sampling configuration
 */
public interface IGNNSamplingConfig extends Serializable {
    
    // Sample size configuration per layer
    List<Integer> getLayerSampleSizes();
    
    // Sampling depth (K)
    int getSamplingDepth();
    
    // Sampling strategy
    SamplingStrategy getStrategy();
    
    // Edge direction
    EdgeDirection getDirection();
    
    // Include vertex properties
    boolean includeVertexProperties();
    
    // Include edge properties
    boolean includeEdgeProperties();
}

2.1.2 Sampling Strategy Enum

public enum SamplingStrategy {
    RANDOM,           // Random sampling
    TOP_K,            // Top K
    WEIGHTED_RANDOM,  // Weighted random sampling
    RESERVOIR         // Reservoir sampling
}

2.1.3 Extend AlgorithmRuntimeContext

Add new methods to the AlgorithmRuntimeContext interface:

/**
 * Load edge iterator with sampling configuration
 */
CloseableIterator<RowEdge> loadEdgesWithSampling(
    EdgeDirection direction, 
    IGNNSamplingConfig samplingConfig
);

/**
 * Load K-layer neighbor subgraph
 */
GNNSubgraph loadKLayerSubgraph(
    IGNNSamplingConfig samplingConfig
);

2.2 Random Sampling Filter Implementation

2.2.1 RandomSamplingEdgeLimit

Extend the existing IEdgeLimit interface:

public class RandomSamplingEdgeLimit implements IEdgeLimit {
    private final long sampleSize;
    private final long totalEdgeCount;
    private final Random random;
    private final EdgeDirection direction;
    
    @Override
    public long inEdgeLimit() {
        return direction == EdgeDirection.IN ? sampleSize : Long.MAX_VALUE;
    }
    
    @Override
    public long outEdgeLimit() {
        return direction == EdgeDirection.OUT ? sampleSize : Long.MAX_VALUE;
    }
    
    @Override
    public LimitType limitType() {
        return LimitType.COMPOSED;
    }
}

2.2.2 RandomSamplingFilter

Implement random sampling filter based on reservoir algorithm:

public class RandomSamplingFilter extends BaseGraphFilter {
    private final IGraphFilter baseFilter;
    private final int sampleSize;
    private final Random random;
    private List<IEdge> reservoir;
    private int edgeCount;
    
    @Override
    public boolean filterEdge(IEdge edge) {
        if (!baseFilter.filterEdge(edge)) {
            return false;
        }
        
        // Reservoir sampling algorithm
        if (edgeCount < sampleSize) {
            reservoir.add(edge);
            edgeCount++;
            return true;
        } else {
            int replaceIndex = random.nextInt(edgeCount + 1);
            if (replaceIndex < sampleSize) {
                reservoir.set(replaceIndex, edge);
            }
            edgeCount++;
            return false; // Return via reservoir later
        }
    }
}

2.3 Layered Sampling Iterator

2.3.1 LayeredSamplingIterator

public class LayeredSamplingIterator<K, EV> 
    implements CloseableIterator<LayeredEdges<K, EV>> {
    
    private final GraphState<K, ?, EV> graphState;
    private final IGNNSamplingConfig config;
    private final Queue<LayerTask<K>> taskQueue;
    private final Map<K, Integer> visitedVertices; // Track visit layer
    
    private int currentLayer = 0;
    
    @Override
    public boolean hasNext() {
        return !taskQueue.isEmpty() && currentLayer < config.getSamplingDepth();
    }
    
    @Override
    public LayeredEdges<K, EV> next() {
        List<IEdge<K, EV>> edges = new ArrayList<>();
        int layerSampleSize = config.getLayerSampleSizes().get(currentLayer);
        
        while (!taskQueue.isEmpty()) {
            LayerTask<K> task = taskQueue.poll();
            
            // Configure sampling using StatePushDown
            StatePushDown pushDown = StatePushDown.of()
                .withEdgeLimit(new RandomSamplingEdgeLimit(
                    layerSampleSize, 
                    config.getDirection()
                ));
                
            CloseableIterator<IEdge<K, EV>> edgeIter = 
                graphState.staticGraph().getEdgeIterator(
                    Collections.singletonList(task.vertexId), 
                    pushDown
                );
                
            // Sample edges and add target vertices to next layer queue
            while (edgeIter.hasNext()) {
                IEdge<K, EV> edge = edgeIter.next();
                edges.add(edge);
                
                K targetId = edge.getTargetId();
                if (!visitedVertices.containsKey(targetId)) {
                    taskQueue.offer(new LayerTask<>(targetId, currentLayer + 1));
                    visitedVertices.put(targetId, currentLayer + 1);
                }
            }
        }
        
        currentLayer++;
        return new LayeredEdges<>(currentLayer - 1, edges);
    }
}

2.4 GNN Subgraph Data Structure

2.4.1 GNNSubgraph

public class GNNSubgraph<K, VV, EV> implements Serializable {
    
    private final K rootVertexId;
    private final Map<Integer, Set<K>> layeredVertices; // Layer -> vertex set
    private final Map<Integer, List<IEdge<K, EV>>> layeredEdges; // Layer -> edge list
    private final Map<K, IVertex<K, VV>> vertexMap; // Vertex properties
    
    public int getDepth() {
        return layeredVertices.size();
    }
    
    public Set<K> getVerticesAtLayer(int layer) {
        return layeredVertices.getOrDefault(layer, Collections.emptySet());
    }
    
    public List<IEdge<K, EV>> getEdgesAtLayer(int layer) {
        return layeredEdges.getOrDefault(layer, Collections.emptyList());
    }
    
    public IVertex<K, VV> getVertex(K vertexId) {
        return vertexMap.get(vertexId);
    }
}

III. Usage Examples

3.1 Basic GNN Sampling UDF

@Description(name = "gnn_sampling", description = "GNN subgraph sampling")
public class GNNSamplingAlgorithm implements AlgorithmUserFunction<Object, GNNMessage> {
    
    private AlgorithmRuntimeContext<Object, GNNMessage> context;
    private IGNNSamplingConfig samplingConfig;
    
    @Override
    public void init(AlgorithmRuntimeContext<Object, GNNMessage> context, Object[] params) {
        this.context = context;
        
        // Configuration: sample 10 neighbors per layer, depth of 2
        this.samplingConfig = GNNSamplingConfig.builder()
            .layerSampleSizes(Arrays.asList(10, 10))
            .samplingDepth(2)
            .strategy(SamplingStrategy.RANDOM)
            .direction(EdgeDirection.OUT)
            .includeVertexProperties(true)
            .includeEdgeProperties(true)
            .build();
    }
    
    @Override
    public void process(RowVertex vertex, Optional<Row> updatedValues, 
                       Iterator<GNNMessage> messages) {
        if (context.getCurrentIterationId() == 1L) {
            // First iteration: load K-layer subgraph
            GNNSubgraph subgraph = context.loadKLayerSubgraph(samplingConfig);
            
            // Process subgraph, send to GNN training component
            context.take(ObjectRow.create(
                vertex.getId(), 
                subgraph.serialize()
            ));
        }
    }
    
    @Override
    public void finish(RowVertex vertex, Optional<Row> newValue) {
        // Cleanup resources
    }
    
    @Override
    public StructType getOutputType(GraphSchema graphSchema) {
        return new StructType(
            new TableField("vertex_id", graphSchema.getIdType(), false),
            new TableField("subgraph", StringType.INSTANCE, false)
        );
    }
}

3.2 Progressive Sampling Using Iterator

@Override
public void process(RowVertex vertex, Optional<Row> updatedValues, 
                   Iterator<GNNMessage> messages) {
    
    // Load layer by layer using iterator to control memory usage
    for (int layer = 0; layer < samplingConfig.getSamplingDepth(); layer++) {
        
        // Current layer sample size
        int sampleSize = samplingConfig.getLayerSampleSizes().get(layer);
        
        // Load edges using iterator
        CloseableIterator<RowEdge> edgeIter = context.loadEdgesWithSampling(
            EdgeDirection.OUT,
            samplingConfig.withLayerSampleSize(sampleSize)
        );
        
        List<RowEdge> sampledEdges = new ArrayList<>();
        int count = 0;
        
        // Control load count in UDF
        while (edgeIter.hasNext() && count < sampleSize) {
            sampledEdges.add(edgeIter.next());
            count++;
        }
        
        // Process current layer sampling results
        processLayerEdges(layer, sampledEdges);
        
        edgeIter.close();
    }
}

kaori-seasons avatar Nov 10 '25 06:11 kaori-seasons

@kaori-seasons My suggestion is we could implement this as a usual sampling UDF, but not only for GNN。

On the other hand, Geaflow has supported the edge iterator, and based on it, users can easily implement their own UDFs to sample. PR

qingwen220 avatar Nov 10 '25 08:11 qingwen220

@qingwen220 Regarding the UDF you mentioned, I have written a demo use case, which you can take a look at when you have time.

@Description(name = "neighbor_sampling", description = "K-hop neighbor sampling with configurable strategies")  
public class NeighborSamplingUDF implements AlgorithmUserFunction<Object, SamplingMessage> {  
      
    private AlgorithmRuntimeContext<Object, SamplingMessage> context;  
    private int[] layerSizes;  // e.g., [10, 5] for 2-hop with 10 and 5 samples  
    private int depth;  
    private EdgeDirection direction;  
      
    @Override  
    public void init(AlgorithmRuntimeContext<Object, SamplingMessage> context, Object[] params) {  
        this.context = context;  
        this.layerSizes = (int[]) params[0];  
        this.depth = layerSizes.length;  
        this.direction = (EdgeDirection) params[1];  
    }  
      
    @Override  
    public void process(RowVertex vertex, Optional<Row> updatedValues,   
                       Iterator<SamplingMessage> messages) {  
        // Use existing edge iterator with StatePushDown  
        CloseableIterator<RowEdge> edgeIter = context.loadEdgesIterator(direction);  
          
        // Implement sampling logic using iterator  
        List<RowEdge> sampledEdges = reservoirSample(edgeIter, layerSizes[0]);  
          
        // Output or send to next layer  
        context.take(createOutput(vertex, sampledEdges));  
    }  
}

Usage

-- Create the sampling function  
CREATE FUNCTION neighbor_sample AS 'com.example.NeighborSamplingUDF';  
  
-- Use in graph query  
MATCH (v:Person)  
CALL neighbor_sample(v, [10, 5], 'OUT')   
YIELD vertex_id, neighbors  
RETURN vertex_id, neighbors;

kaori-seasons avatar Nov 10 '25 09:11 kaori-seasons