burn icon indicating copy to clipboard operation
burn copied to clipboard

Node to Enum-based design for type-safe IR

Open antimora opened this issue 3 weeks ago • 2 comments

This PR refactors the ONNX-IR Node from a struct-based design to an enum-based design, providing compile-time type safety and eliminating runtime overhead. This addresses the architectural issues described in #3988.

Checklist

  • [x] Confirmed that cargo run-checks command has been executed.
  • [x] Made sure the book is up to date with changes in this PR.

Related Issues/PRs

Fixes #3988

Changes

What Changed

Before: Node was a struct with a node_type discriminant field and trait object-based configuration (Box<dyn NodeConfig>), requiring runtime downcasting and creating type redundancy.

After: Node is now an enum where each variant represents a specific operation type with its associated configuration inline.

Key Improvements

  1. Type Safety: Eliminates runtime downcasting by making node types explicit at compile-time
  2. Zero Runtime Overhead: Removes vtable indirection and heap allocations for configurations
  3. Single Source of Truth: The enum variant discriminant serves as the node type—no redundant node_type field
  4. Exhaustive Matching: Compiler enforces handling all operation types, making refactoring safer
  5. Cleaner IR Semantics: Downstream consumers (burn-import) receive fully typed, processed nodes ready for consumption

Architecture

Each operation now has its own enum variant containing:

  • Common fields: name, inputs, outputs
  • Operation-specific config (where needed)

Helper methods provide ergonomic access to common fields across all variants while maintaining full type safety.

Testing

All existing tests are passing

antimora avatar Nov 13 '25 21:11 antimora

OK. Refactoring is done and all tests are passing.

Here is an example of ONNX-IR with typed config:

OnnxGraph {
    nodes: [
        Conv2d {
            name: "conv2d1",
            inputs: [
                Argument {
                    name: "input",
                    ty: Tensor(
                        TensorType {
                            dtype: F32,
                            rank: 4,
                            static_shape: Some(
                                [
                                    2,
                                    4,
                                    10,
                                    15,
                                ],
                            ),
                        },
                    ),
                    value_source: Dynamic,
                },
                Argument {
                    name: "",
                    ty: Tensor(
                        TensorType {
                            dtype: F32,
                            rank: 4,
                            static_shape: Some(
                                [
                                    6,
                                    2,
                                    3,
                                    5,
                                ],
                            ),
                        },
                    ),
                    value_source: Static(
                        0,
                    ),
                },
                Argument {
                    name: "",
                    ty: Tensor(
                        TensorType {
                            dtype: F32,
                            rank: 1,
                            static_shape: Some(
                                [
                                    6,
                                ],
                            ),
                        },
                    ),
                    value_source: Static(
                        1,
                    ),
                },
            ],
            outputs: [
                Argument {
                    name: "conv2d1_out1",
                    ty: Tensor(
                        TensorType {
                            dtype: F32,
                            rank: 4,
                            static_shape: None,
                        },
                    ),
                    value_source: Dynamic,
                },
            ],
            config: Conv2dConfig {
                channels: [
                    4,
                    6,
                ],
                kernel_size: [
                    3,
                    5,
                ],
                stride: [
                    2,
                    1,
                ],
                padding: Explicit(
                    4,
                    2,
                ),
                dilation: [
                    3,
                    1,
                ],
                groups: 2,
                bias: true,
            },
        },
    ],
    inputs: [
        Argument {
            name: "input",
            ty: Tensor(
                TensorType {
                    dtype: F32,
                    rank: 4,
                    static_shape: Some(
                        [
                            2,
                            4,
                            10,
                            15,
                        ],
                    ),
                },
            ),
            value_source: Dynamic,
        },
    ],
    outputs: [
        Argument {
            name: "conv2d1_out1",
            ty: Tensor(
                TensorType {
                    dtype: F32,
                    rank: 4,
                    static_shape: None,
                },
            ),
            value_source: Dynamic,
        },
    ],
}

Code generated:

#[derive(Module, Debug)]
pub struct Model<B: Backend> {
    conv2d1: Conv2d<B>,
    phantom: core::marker::PhantomData<B>,
    device: burn::module::Ignored<B::Device>,
}


impl<B: Backend> Default for Model<B> {
    fn default() -> Self {
        Self::from_file("./out/conv2d", &Default::default())
    }
}

impl<B: Backend> Model<B> {
    pub fn from_file(file: &str, device: &B::Device) -> Self {
        let record = burn::record::PrettyJsonFileRecorder::<FullPrecisionSettings>::new()
            .load(file.into(), device)
            .expect("Record file to exist.");
        Self::new(device).load_record(record)
    }
}

impl<B: Backend> Model<B> {
    #[allow(unused_variables)]
    pub fn new(device: &B::Device) -> Self {
        let conv2d1 = Conv2dConfig::new([4, 6], [3, 5])
            .with_stride([2, 1])
            .with_padding(PaddingConfig2d::Explicit(4, 2))
            .with_dilation([3, 1])
            .with_groups(2)
            .with_bias(true)
            .init(device);
        Self {
            conv2d1,
            phantom: core::marker::PhantomData,
            device: burn::module::Ignored(device.clone()),
        }
    }

    #[allow(clippy::let_and_return, clippy::approx_constant)]
    pub fn forward(&self, input: Tensor<B, 4>) -> Tensor<B, 4> {
        let conv2d1_out1 = self.conv2d1.forward(input);
        conv2d1_out1
    }
}

antimora avatar Nov 14 '25 01:11 antimora

Codecov Report

:x: Patch coverage is 88.59987% with 355 lines in your changes missing coverage. Please review. :white_check_mark: Project coverage is 66.99%. Comparing base (cc0f22a) to head (cbec4c6). :warning: Report is 1 commits behind head on main.

Files with missing lines Patch % Lines
crates/onnx-ir/src/node/unsupported.rs 0.00% 63 Missing :warning:
crates/onnx-ir/src/ir/graph.rs 72.88% 16 Missing :warning:
crates/onnx-ir/src/node/identity.rs 72.72% 15 Missing :warning:
crates/onnx-ir/src/processor.rs 62.50% 15 Missing :warning:
crates/onnx-ir/src/node/global_avg_pool.rs 76.00% 12 Missing :warning:
crates/onnx-ir/src/node/attention.rs 69.44% 11 Missing :warning:
crates/onnx-ir/src/node/elementwise.rs 82.81% 11 Missing :warning:
crates/onnx-ir/src/node/reshape.rs 82.60% 8 Missing :warning:
crates/onnx-ir/src/ir/argument.rs 46.15% 7 Missing :warning:
crates/onnx-ir/tests/infrastructure.rs 76.66% 7 Missing :warning:
... and 137 more
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #4019      +/-   ##
==========================================
- Coverage   67.03%   66.99%   -0.04%     
==========================================
  Files        1216     1221       +5     
  Lines      149805   150429     +624     
==========================================
+ Hits       100427   100786     +359     
- Misses      49378    49643     +265     

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.

:rocket: New features to boost your workflow:
  • :snowflake: Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

codecov[bot] avatar Nov 14 '25 02:11 codecov[bot]

Slight change of plans. We will have structure inside Node enums.

Current Structure

// In src/ir/node.rs - Node enum with struct variants
pub enum Node {
    ArgMax {
        name: String,
        inputs: Vec<Argument>,
        outputs: Vec<Argument>,
        config: ArgMaxConfig,
    },
    Cast {
        name: String,
        inputs: Vec<Argument>,
        outputs: Vec<Argument>,
        config: CastConfig,
    },
    // ... etc
}

Desired Structure

// In src/node/argmax.rs - NEW public struct
pub struct ArgMaxNode {
    pub name: String,
    pub inputs: Vec<Argument>,
    pub outputs: Vec<Argument>,
    pub config: ArgMaxConfig,
}

// In src/node/cast.rs - NEW public struct
pub struct CastNode {
    pub name: String,
    pub inputs: Vec<Argument>,
    pub outputs: Vec<Argument>,
    pub config: CastConfig,
}

// In src/ir/node.rs - Node enum with tuple variants wrapping the structs
pub enum Node {
    ArgMax(ArgMaxNode),
    Cast(CastNode),
    // ... etc
}

Key Changes

  1. Create dedicated structs (ArgMaxNode, CastNode, etc.) for each operation
  2. Define these structs in their operation files (src/node/argmax.rs, src/node/cast.rs)
  3. Make these structs public so they can be exported via pub use node::*;
  4. Change Node enum from struct variants to tuple variants that wrap these structs
  5. Enable trait implementations on the specific node types (e.g., impl Display for ArgMaxNode)

Benefits

  • Can implement operation-specific traits on node types
  • Better code organization (each node type lives with its processor/config)
  • More flexible API (users can work with ArgMaxNode directly if needed)

antimora avatar Nov 17 '25 04:11 antimora