burn
burn copied to clipboard
Node to Enum-based design for type-safe IR
This PR refactors the ONNX-IR Node from a struct-based design to an enum-based design, providing compile-time type safety and eliminating runtime overhead. This addresses the architectural issues described in #3988.
Checklist
- [x] Confirmed that
cargo run-checkscommand has been executed. - [x] Made sure the book is up to date with changes in this PR.
Related Issues/PRs
Fixes #3988
Changes
What Changed
Before: Node was a struct with a node_type discriminant field and trait object-based configuration (Box<dyn NodeConfig>), requiring runtime downcasting and creating type redundancy.
After: Node is now an enum where each variant represents a specific operation type with its associated configuration inline.
Key Improvements
- Type Safety: Eliminates runtime downcasting by making node types explicit at compile-time
- Zero Runtime Overhead: Removes vtable indirection and heap allocations for configurations
- Single Source of Truth: The enum variant discriminant serves as the node type—no redundant
node_typefield - Exhaustive Matching: Compiler enforces handling all operation types, making refactoring safer
- Cleaner IR Semantics: Downstream consumers (burn-import) receive fully typed, processed nodes ready for consumption
Architecture
Each operation now has its own enum variant containing:
- Common fields:
name,inputs,outputs - Operation-specific config (where needed)
Helper methods provide ergonomic access to common fields across all variants while maintaining full type safety.
Testing
All existing tests are passing
OK. Refactoring is done and all tests are passing.
Here is an example of ONNX-IR with typed config:
OnnxGraph {
nodes: [
Conv2d {
name: "conv2d1",
inputs: [
Argument {
name: "input",
ty: Tensor(
TensorType {
dtype: F32,
rank: 4,
static_shape: Some(
[
2,
4,
10,
15,
],
),
},
),
value_source: Dynamic,
},
Argument {
name: "",
ty: Tensor(
TensorType {
dtype: F32,
rank: 4,
static_shape: Some(
[
6,
2,
3,
5,
],
),
},
),
value_source: Static(
0,
),
},
Argument {
name: "",
ty: Tensor(
TensorType {
dtype: F32,
rank: 1,
static_shape: Some(
[
6,
],
),
},
),
value_source: Static(
1,
),
},
],
outputs: [
Argument {
name: "conv2d1_out1",
ty: Tensor(
TensorType {
dtype: F32,
rank: 4,
static_shape: None,
},
),
value_source: Dynamic,
},
],
config: Conv2dConfig {
channels: [
4,
6,
],
kernel_size: [
3,
5,
],
stride: [
2,
1,
],
padding: Explicit(
4,
2,
),
dilation: [
3,
1,
],
groups: 2,
bias: true,
},
},
],
inputs: [
Argument {
name: "input",
ty: Tensor(
TensorType {
dtype: F32,
rank: 4,
static_shape: Some(
[
2,
4,
10,
15,
],
),
},
),
value_source: Dynamic,
},
],
outputs: [
Argument {
name: "conv2d1_out1",
ty: Tensor(
TensorType {
dtype: F32,
rank: 4,
static_shape: None,
},
),
value_source: Dynamic,
},
],
}
Code generated:
#[derive(Module, Debug)]
pub struct Model<B: Backend> {
conv2d1: Conv2d<B>,
phantom: core::marker::PhantomData<B>,
device: burn::module::Ignored<B::Device>,
}
impl<B: Backend> Default for Model<B> {
fn default() -> Self {
Self::from_file("./out/conv2d", &Default::default())
}
}
impl<B: Backend> Model<B> {
pub fn from_file(file: &str, device: &B::Device) -> Self {
let record = burn::record::PrettyJsonFileRecorder::<FullPrecisionSettings>::new()
.load(file.into(), device)
.expect("Record file to exist.");
Self::new(device).load_record(record)
}
}
impl<B: Backend> Model<B> {
#[allow(unused_variables)]
pub fn new(device: &B::Device) -> Self {
let conv2d1 = Conv2dConfig::new([4, 6], [3, 5])
.with_stride([2, 1])
.with_padding(PaddingConfig2d::Explicit(4, 2))
.with_dilation([3, 1])
.with_groups(2)
.with_bias(true)
.init(device);
Self {
conv2d1,
phantom: core::marker::PhantomData,
device: burn::module::Ignored(device.clone()),
}
}
#[allow(clippy::let_and_return, clippy::approx_constant)]
pub fn forward(&self, input: Tensor<B, 4>) -> Tensor<B, 4> {
let conv2d1_out1 = self.conv2d1.forward(input);
conv2d1_out1
}
}
Codecov Report
:x: Patch coverage is 88.59987% with 355 lines in your changes missing coverage. Please review.
:white_check_mark: Project coverage is 66.99%. Comparing base (cc0f22a) to head (cbec4c6).
:warning: Report is 1 commits behind head on main.
Additional details and impacted files
@@ Coverage Diff @@
## main #4019 +/- ##
==========================================
- Coverage 67.03% 66.99% -0.04%
==========================================
Files 1216 1221 +5
Lines 149805 150429 +624
==========================================
+ Hits 100427 100786 +359
- Misses 49378 49643 +265
:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.
:rocket: New features to boost your workflow:
- :snowflake: Test Analytics: Detect flaky tests, report on failures, and find test suite problems.
Slight change of plans. We will have structure inside Node enums.
Current Structure
// In src/ir/node.rs - Node enum with struct variants
pub enum Node {
ArgMax {
name: String,
inputs: Vec<Argument>,
outputs: Vec<Argument>,
config: ArgMaxConfig,
},
Cast {
name: String,
inputs: Vec<Argument>,
outputs: Vec<Argument>,
config: CastConfig,
},
// ... etc
}
Desired Structure
// In src/node/argmax.rs - NEW public struct
pub struct ArgMaxNode {
pub name: String,
pub inputs: Vec<Argument>,
pub outputs: Vec<Argument>,
pub config: ArgMaxConfig,
}
// In src/node/cast.rs - NEW public struct
pub struct CastNode {
pub name: String,
pub inputs: Vec<Argument>,
pub outputs: Vec<Argument>,
pub config: CastConfig,
}
// In src/ir/node.rs - Node enum with tuple variants wrapping the structs
pub enum Node {
ArgMax(ArgMaxNode),
Cast(CastNode),
// ... etc
}
Key Changes
- Create dedicated structs (ArgMaxNode, CastNode, etc.) for each operation
- Define these structs in their operation files (src/node/argmax.rs, src/node/cast.rs)
- Make these structs public so they can be exported via pub use node::*;
- Change Node enum from struct variants to tuple variants that wrap these structs
- Enable trait implementations on the specific node types (e.g., impl Display for ArgMaxNode)
Benefits
- Can implement operation-specific traits on node types
- Better code organization (each node type lives with its processor/config)
- More flexible API (users can work with ArgMaxNode directly if needed)