burn
burn copied to clipboard
Refactor node to Enum-based design for type-safe IR
Motivation
The ONNX-IR layer serves as a clean intermediate representation consumed by burn-import and other downstream tools. The IR should be:
- Type-safe: No runtime downcasting or type erasure
- Self-contained: Fully processed and ready for consumption without re-interpretation
- Explicit: Node types and their configurations are clear at compile-time
- Efficient: Zero-cost abstractions with optimal memory layout
Currently, the Node struct stores configuration via trait objects (Box<dyn NodeConfig>), which requires runtime downcasting and introduces unnecessary indirection. While we're removing the stored config as a short-term fix (see #XXXX), an enum-based design would provide a more robust long-term solution.
Current Design
struct Node {
pub name: String,
pub node_type: NodeType, // Redundant with config type
pub inputs: Vec<TensorOrScalar>,
pub outputs: Vec<Argument>,
// (config field being removed in current PR)
}
Issues:
- node_type field is redundant - stores type information separately
- Generic operations require matching on node_type discriminant
- No compile-time guarantee about config/node type correspondence
Proposed Design
enum Node {
ArgMax {
name: String,
inputs: Vec<TensorOrScalar>,
outputs: Vec<Argument>,
config: ArgMaxConfig,
},
Conv {
name: String,
inputs: Vec<TensorOrScalar>,
outputs: Vec<Argument>,
config: ConvConfig,
},
Add {
name: String,
inputs: Vec<TensorOrScalar>,
outputs: Vec<Argument>,
// No config needed
},
// ... one variant per supported op
}
impl Node {
// Helper methods for common fields
pub fn name(&self) -> &str {
match self {
Node::ArgMax { name, .. } |
Node::Conv { name, .. } |
Node::Add { name, .. } => name,
}
}
pub fn inputs(&self) -> &[TensorOrScalar] {
match self {
Node::ArgMax { inputs, .. } |
Node::Conv { inputs, .. } |
Node::Add { inputs, .. } => inputs,
}
}
// Similar for outputs
}
Benefits
- Eliminates Type Redundancy
- The enum variant discriminant IS the node type
- No separate node_type field needed
- Single source of truth for what operation this is
- Compile-Time Type Safety
- No trait objects (Box<dyn NodeConfig>)
- No runtime downcasting with as_any() and downcast_ref()
- Config type is guaranteed to match operation type
- Impossible to have mismatched config/operation pairs
- Zero Runtime Overhead
- No vtable indirection
- No heap allocation for configs
- Better cache locality (all data inline)
- Optimal memory layout
- Exhaustive Matching
- Compiler forces handling all operation types
- Can't accidentally forget to implement a case
- Refactoring is safer (adding/removing ops triggers compiler errors)
- Cleaner IR Semantics
- Downstream consumers (burn-import) get fully typed, processed nodes
- No need to know about processors or extraction logic
- IR is truly "intermediate" - fully resolved and ready to consume
There is also a drawback when all enum variants have the same fields: there is a lot of duplication in that case, so things like name, inputs, and outputs can be defined globally.
There is also a drawback when all enum variants have the same fields: there is a lot of duplication in that case, so things like name, inputs, and outputs can be defined globally.
I minimized definitions via a single macro. It worked out to be clean in the end. In burn-import, I don't even use accessors.
OK. Refactoring is done and all tests are passing.
Here is an example of ONNX-IR with typed config:
OnnxGraph {
nodes: [
Conv2d {
name: "conv2d1",
inputs: [
Argument {
name: "input",
ty: Tensor(
TensorType {
dtype: F32,
rank: 4,
static_shape: Some(
[
2,
4,
10,
15,
],
),
},
),
value_source: Dynamic,
},
Argument {
name: "",
ty: Tensor(
TensorType {
dtype: F32,
rank: 4,
static_shape: Some(
[
6,
2,
3,
5,
],
),
},
),
value_source: Static(
0,
),
},
Argument {
name: "",
ty: Tensor(
TensorType {
dtype: F32,
rank: 1,
static_shape: Some(
[
6,
],
),
},
),
value_source: Static(
1,
),
},
],
outputs: [
Argument {
name: "conv2d1_out1",
ty: Tensor(
TensorType {
dtype: F32,
rank: 4,
static_shape: None,
},
),
value_source: Dynamic,
},
],
config: Conv2dConfig {
channels: [
4,
6,
],
kernel_size: [
3,
5,
],
stride: [
2,
1,
],
padding: Explicit(
4,
2,
),
dilation: [
3,
1,
],
groups: 2,
bias: true,
},
},
],
inputs: [
Argument {
name: "input",
ty: Tensor(
TensorType {
dtype: F32,
rank: 4,
static_shape: Some(
[
2,
4,
10,
15,
],
),
},
),
value_source: Dynamic,
},
],
outputs: [
Argument {
name: "conv2d1_out1",
ty: Tensor(
TensorType {
dtype: F32,
rank: 4,
static_shape: None,
},
),
value_source: Dynamic,
},
],
}
Code generated:
#[derive(Module, Debug)]
pub struct Model<B: Backend> {
conv2d1: Conv2d<B>,
phantom: core::marker::PhantomData<B>,
device: burn::module::Ignored<B::Device>,
}
impl<B: Backend> Default for Model<B> {
fn default() -> Self {
Self::from_file("./out/conv2d", &Default::default())
}
}
impl<B: Backend> Model<B> {
pub fn from_file(file: &str, device: &B::Device) -> Self {
let record = burn::record::PrettyJsonFileRecorder::<FullPrecisionSettings>::new()
.load(file.into(), device)
.expect("Record file to exist.");
Self::new(device).load_record(record)
}
}
impl<B: Backend> Model<B> {
#[allow(unused_variables)]
pub fn new(device: &B::Device) -> Self {
let conv2d1 = Conv2dConfig::new([4, 6], [3, 5])
.with_stride([2, 1])
.with_padding(PaddingConfig2d::Explicit(4, 2))
.with_dilation([3, 1])
.with_groups(2)
.with_bias(true)
.init(device);
Self {
conv2d1,
phantom: core::marker::PhantomData,
device: burn::module::Ignored(device.clone()),
}
}
#[allow(clippy::let_and_return, clippy::approx_constant)]
pub fn forward(&self, input: Tensor<B, 4>) -> Tensor<B, 4> {
let conv2d1_out1 = self.conv2d1.forward(input);
conv2d1_out1
}
}