Support subgraph sampling for GNN
is there any plan in future to provide subgraph sampling support, our needs are as follows: 1.sample N neighbors for each node 2.subgraph sampling depth of K 3.subgraph contains vertices or edges properties
Hello.
- we are currently planning to support using iterators to fetch edges, so that you can decide how many edges to load in each iteration in your algorithm udf. We will provide you with an example in subsequent versions.
- In udf, such as kHop, you can decide the format of the output subgraph, such as outputting the current vertex and all the edges corresponding to the vertex in a specific format. I don't know if this meets your needs.
Thank you, this completely meets our needs.
We will support as soon as possible. Thank you~
@Loognqiang Hello, I have designed a GNN sampling scheme. You can take a look.
1.1 Core Capabilities Already Available
GeaFlow already provides solid infrastructure:
Edge Iterator Mechanism: The system has implemented EdgeScanIterator, supporting iterative edge access and filtering
Edge Limit Support: The IEdgeLimit interface supports limiting the number of incoming and outgoing edges per vertex.
Filter Pushdown: StatePushDown supports pushing filter conditions and limits down to the storage layer.
Vertex-Grouped Edge Iteration: EdgeListScanIterator has implemented edge iteration grouped by source vertex ID .
Algorithm Runtime Context: AlgorithmRuntimeContext provides iterator interfaces for loading edges.
K-hop Algorithm Example: The existing K-hop implementation demonstrates multi-layer neighbor traversal patterns.
1.2 Capabilities Requiring Extension
Current architecture limitations:
LimitFilteronly supports sequential limiting, not random sampling- No dedicated layered sampling API for GNN scenarios
- Lack of flexible sampling strategy configuration (random sampling, importance sampling, etc.)
II. Core Design Plan
2.1 New GNN Sampling API Design
2.1.1 Sampling Configuration Interface
/**
* GNN subgraph sampling configuration
*/
public interface IGNNSamplingConfig extends Serializable {
// Sample size configuration per layer
List<Integer> getLayerSampleSizes();
// Sampling depth (K)
int getSamplingDepth();
// Sampling strategy
SamplingStrategy getStrategy();
// Edge direction
EdgeDirection getDirection();
// Include vertex properties
boolean includeVertexProperties();
// Include edge properties
boolean includeEdgeProperties();
}
2.1.2 Sampling Strategy Enum
public enum SamplingStrategy {
RANDOM, // Random sampling
TOP_K, // Top K
WEIGHTED_RANDOM, // Weighted random sampling
RESERVOIR // Reservoir sampling
}
2.1.3 Extend AlgorithmRuntimeContext
Add new methods to the AlgorithmRuntimeContext interface:
/**
* Load edge iterator with sampling configuration
*/
CloseableIterator<RowEdge> loadEdgesWithSampling(
EdgeDirection direction,
IGNNSamplingConfig samplingConfig
);
/**
* Load K-layer neighbor subgraph
*/
GNNSubgraph loadKLayerSubgraph(
IGNNSamplingConfig samplingConfig
);
2.2 Random Sampling Filter Implementation
2.2.1 RandomSamplingEdgeLimit
Extend the existing IEdgeLimit interface:
public class RandomSamplingEdgeLimit implements IEdgeLimit {
private final long sampleSize;
private final long totalEdgeCount;
private final Random random;
private final EdgeDirection direction;
@Override
public long inEdgeLimit() {
return direction == EdgeDirection.IN ? sampleSize : Long.MAX_VALUE;
}
@Override
public long outEdgeLimit() {
return direction == EdgeDirection.OUT ? sampleSize : Long.MAX_VALUE;
}
@Override
public LimitType limitType() {
return LimitType.COMPOSED;
}
}
2.2.2 RandomSamplingFilter
Implement random sampling filter based on reservoir algorithm:
public class RandomSamplingFilter extends BaseGraphFilter {
private final IGraphFilter baseFilter;
private final int sampleSize;
private final Random random;
private List<IEdge> reservoir;
private int edgeCount;
@Override
public boolean filterEdge(IEdge edge) {
if (!baseFilter.filterEdge(edge)) {
return false;
}
// Reservoir sampling algorithm
if (edgeCount < sampleSize) {
reservoir.add(edge);
edgeCount++;
return true;
} else {
int replaceIndex = random.nextInt(edgeCount + 1);
if (replaceIndex < sampleSize) {
reservoir.set(replaceIndex, edge);
}
edgeCount++;
return false; // Return via reservoir later
}
}
}
2.3 Layered Sampling Iterator
2.3.1 LayeredSamplingIterator
public class LayeredSamplingIterator<K, EV>
implements CloseableIterator<LayeredEdges<K, EV>> {
private final GraphState<K, ?, EV> graphState;
private final IGNNSamplingConfig config;
private final Queue<LayerTask<K>> taskQueue;
private final Map<K, Integer> visitedVertices; // Track visit layer
private int currentLayer = 0;
@Override
public boolean hasNext() {
return !taskQueue.isEmpty() && currentLayer < config.getSamplingDepth();
}
@Override
public LayeredEdges<K, EV> next() {
List<IEdge<K, EV>> edges = new ArrayList<>();
int layerSampleSize = config.getLayerSampleSizes().get(currentLayer);
while (!taskQueue.isEmpty()) {
LayerTask<K> task = taskQueue.poll();
// Configure sampling using StatePushDown
StatePushDown pushDown = StatePushDown.of()
.withEdgeLimit(new RandomSamplingEdgeLimit(
layerSampleSize,
config.getDirection()
));
CloseableIterator<IEdge<K, EV>> edgeIter =
graphState.staticGraph().getEdgeIterator(
Collections.singletonList(task.vertexId),
pushDown
);
// Sample edges and add target vertices to next layer queue
while (edgeIter.hasNext()) {
IEdge<K, EV> edge = edgeIter.next();
edges.add(edge);
K targetId = edge.getTargetId();
if (!visitedVertices.containsKey(targetId)) {
taskQueue.offer(new LayerTask<>(targetId, currentLayer + 1));
visitedVertices.put(targetId, currentLayer + 1);
}
}
}
currentLayer++;
return new LayeredEdges<>(currentLayer - 1, edges);
}
}
2.4 GNN Subgraph Data Structure
2.4.1 GNNSubgraph
public class GNNSubgraph<K, VV, EV> implements Serializable {
private final K rootVertexId;
private final Map<Integer, Set<K>> layeredVertices; // Layer -> vertex set
private final Map<Integer, List<IEdge<K, EV>>> layeredEdges; // Layer -> edge list
private final Map<K, IVertex<K, VV>> vertexMap; // Vertex properties
public int getDepth() {
return layeredVertices.size();
}
public Set<K> getVerticesAtLayer(int layer) {
return layeredVertices.getOrDefault(layer, Collections.emptySet());
}
public List<IEdge<K, EV>> getEdgesAtLayer(int layer) {
return layeredEdges.getOrDefault(layer, Collections.emptyList());
}
public IVertex<K, VV> getVertex(K vertexId) {
return vertexMap.get(vertexId);
}
}
III. Usage Examples
3.1 Basic GNN Sampling UDF
@Description(name = "gnn_sampling", description = "GNN subgraph sampling")
public class GNNSamplingAlgorithm implements AlgorithmUserFunction<Object, GNNMessage> {
private AlgorithmRuntimeContext<Object, GNNMessage> context;
private IGNNSamplingConfig samplingConfig;
@Override
public void init(AlgorithmRuntimeContext<Object, GNNMessage> context, Object[] params) {
this.context = context;
// Configuration: sample 10 neighbors per layer, depth of 2
this.samplingConfig = GNNSamplingConfig.builder()
.layerSampleSizes(Arrays.asList(10, 10))
.samplingDepth(2)
.strategy(SamplingStrategy.RANDOM)
.direction(EdgeDirection.OUT)
.includeVertexProperties(true)
.includeEdgeProperties(true)
.build();
}
@Override
public void process(RowVertex vertex, Optional<Row> updatedValues,
Iterator<GNNMessage> messages) {
if (context.getCurrentIterationId() == 1L) {
// First iteration: load K-layer subgraph
GNNSubgraph subgraph = context.loadKLayerSubgraph(samplingConfig);
// Process subgraph, send to GNN training component
context.take(ObjectRow.create(
vertex.getId(),
subgraph.serialize()
));
}
}
@Override
public void finish(RowVertex vertex, Optional<Row> newValue) {
// Cleanup resources
}
@Override
public StructType getOutputType(GraphSchema graphSchema) {
return new StructType(
new TableField("vertex_id", graphSchema.getIdType(), false),
new TableField("subgraph", StringType.INSTANCE, false)
);
}
}
3.2 Progressive Sampling Using Iterator
@Override
public void process(RowVertex vertex, Optional<Row> updatedValues,
Iterator<GNNMessage> messages) {
// Load layer by layer using iterator to control memory usage
for (int layer = 0; layer < samplingConfig.getSamplingDepth(); layer++) {
// Current layer sample size
int sampleSize = samplingConfig.getLayerSampleSizes().get(layer);
// Load edges using iterator
CloseableIterator<RowEdge> edgeIter = context.loadEdgesWithSampling(
EdgeDirection.OUT,
samplingConfig.withLayerSampleSize(sampleSize)
);
List<RowEdge> sampledEdges = new ArrayList<>();
int count = 0;
// Control load count in UDF
while (edgeIter.hasNext() && count < sampleSize) {
sampledEdges.add(edgeIter.next());
count++;
}
// Process current layer sampling results
processLayerEdges(layer, sampledEdges);
edgeIter.close();
}
}
@kaori-seasons My suggestion is we could implement this as a usual sampling UDF, but not only for GNN。
On the other hand, Geaflow has supported the edge iterator, and based on it, users can easily implement their own UDFs to sample. PR
@qingwen220 Regarding the UDF you mentioned, I have written a demo use case, which you can take a look at when you have time.
@Description(name = "neighbor_sampling", description = "K-hop neighbor sampling with configurable strategies")
public class NeighborSamplingUDF implements AlgorithmUserFunction<Object, SamplingMessage> {
private AlgorithmRuntimeContext<Object, SamplingMessage> context;
private int[] layerSizes; // e.g., [10, 5] for 2-hop with 10 and 5 samples
private int depth;
private EdgeDirection direction;
@Override
public void init(AlgorithmRuntimeContext<Object, SamplingMessage> context, Object[] params) {
this.context = context;
this.layerSizes = (int[]) params[0];
this.depth = layerSizes.length;
this.direction = (EdgeDirection) params[1];
}
@Override
public void process(RowVertex vertex, Optional<Row> updatedValues,
Iterator<SamplingMessage> messages) {
// Use existing edge iterator with StatePushDown
CloseableIterator<RowEdge> edgeIter = context.loadEdgesIterator(direction);
// Implement sampling logic using iterator
List<RowEdge> sampledEdges = reservoirSample(edgeIter, layerSizes[0]);
// Output or send to next layer
context.take(createOutput(vertex, sampledEdges));
}
}
Usage
-- Create the sampling function
CREATE FUNCTION neighbor_sample AS 'com.example.NeighborSamplingUDF';
-- Use in graph query
MATCH (v:Person)
CALL neighbor_sample(v, [10, 5], 'OUT')
YIELD vertex_id, neighbors
RETURN vertex_id, neighbors;