burn icon indicating copy to clipboard operation
burn copied to clipboard

Refactor node to Enum-based design for type-safe IR

Open antimora opened this issue 1 month ago • 3 comments

Motivation

The ONNX-IR layer serves as a clean intermediate representation consumed by burn-import and other downstream tools. The IR should be:

  • Type-safe: No runtime downcasting or type erasure
  • Self-contained: Fully processed and ready for consumption without re-interpretation
  • Explicit: Node types and their configurations are clear at compile-time
  • Efficient: Zero-cost abstractions with optimal memory layout

Currently, the Node struct stores configuration via trait objects (Box<dyn NodeConfig>), which requires runtime downcasting and introduces unnecessary indirection. While we're removing the stored config as a short-term fix (see #XXXX), an enum-based design would provide a more robust long-term solution.

Current Design

struct Node {
    pub name: String,
    pub node_type: NodeType,  // Redundant with config type
    pub inputs: Vec<TensorOrScalar>,
    pub outputs: Vec<Argument>,
    // (config field being removed in current PR)
}

Issues:

  • node_type field is redundant - stores type information separately
  • Generic operations require matching on node_type discriminant
  • No compile-time guarantee about config/node type correspondence

Proposed Design

enum Node {
    ArgMax {
        name: String,
        inputs: Vec<TensorOrScalar>,
        outputs: Vec<Argument>,
        config: ArgMaxConfig,
    },
    Conv {
        name: String,
        inputs: Vec<TensorOrScalar>,
        outputs: Vec<Argument>,
        config: ConvConfig,
    },
    Add {
        name: String,
        inputs: Vec<TensorOrScalar>,
        outputs: Vec<Argument>,
        // No config needed
    },
    // ... one variant per supported op
}

impl Node {
    // Helper methods for common fields
    pub fn name(&self) -> &str {
        match self {
            Node::ArgMax { name, .. } |
            Node::Conv { name, .. } |
            Node::Add { name, .. } => name,
        }
    }

    pub fn inputs(&self) -> &[TensorOrScalar] {
        match self {
            Node::ArgMax { inputs, .. } |
            Node::Conv { inputs, .. } |
            Node::Add { inputs, .. } => inputs,
        }
    }

    // Similar for outputs
}

Benefits

  1. Eliminates Type Redundancy
  • The enum variant discriminant IS the node type
  • No separate node_type field needed
  • Single source of truth for what operation this is
  1. Compile-Time Type Safety
  • No trait objects (Box<dyn NodeConfig>)
  • No runtime downcasting with as_any() and downcast_ref()
  • Config type is guaranteed to match operation type
  • Impossible to have mismatched config/operation pairs
  1. Zero Runtime Overhead
  • No vtable indirection
  • No heap allocation for configs
  • Better cache locality (all data inline)
  • Optimal memory layout
  1. Exhaustive Matching
  • Compiler forces handling all operation types
  • Can't accidentally forget to implement a case
  • Refactoring is safer (adding/removing ops triggers compiler errors)
  1. Cleaner IR Semantics
  • Downstream consumers (burn-import) get fully typed, processed nodes
  • No need to know about processors or extraction logic
  • IR is truly "intermediate" - fully resolved and ready to consume

antimora avatar Nov 06 '25 21:11 antimora

There is also a drawback when all enum variants have the same fields: there is a lot of duplication in that case, so things like name, inputs, and outputs can be defined globally.

nathanielsimard avatar Nov 14 '25 00:11 nathanielsimard

There is also a drawback when all enum variants have the same fields: there is a lot of duplication in that case, so things like name, inputs, and outputs can be defined globally.

I minimized definitions via a single macro. It worked out to be clean in the end. In burn-import, I don't even use accessors.

antimora avatar Nov 14 '25 00:11 antimora

OK. Refactoring is done and all tests are passing.

Here is an example of ONNX-IR with typed config:

OnnxGraph {
    nodes: [
        Conv2d {
            name: "conv2d1",
            inputs: [
                Argument {
                    name: "input",
                    ty: Tensor(
                        TensorType {
                            dtype: F32,
                            rank: 4,
                            static_shape: Some(
                                [
                                    2,
                                    4,
                                    10,
                                    15,
                                ],
                            ),
                        },
                    ),
                    value_source: Dynamic,
                },
                Argument {
                    name: "",
                    ty: Tensor(
                        TensorType {
                            dtype: F32,
                            rank: 4,
                            static_shape: Some(
                                [
                                    6,
                                    2,
                                    3,
                                    5,
                                ],
                            ),
                        },
                    ),
                    value_source: Static(
                        0,
                    ),
                },
                Argument {
                    name: "",
                    ty: Tensor(
                        TensorType {
                            dtype: F32,
                            rank: 1,
                            static_shape: Some(
                                [
                                    6,
                                ],
                            ),
                        },
                    ),
                    value_source: Static(
                        1,
                    ),
                },
            ],
            outputs: [
                Argument {
                    name: "conv2d1_out1",
                    ty: Tensor(
                        TensorType {
                            dtype: F32,
                            rank: 4,
                            static_shape: None,
                        },
                    ),
                    value_source: Dynamic,
                },
            ],
            config: Conv2dConfig {
                channels: [
                    4,
                    6,
                ],
                kernel_size: [
                    3,
                    5,
                ],
                stride: [
                    2,
                    1,
                ],
                padding: Explicit(
                    4,
                    2,
                ),
                dilation: [
                    3,
                    1,
                ],
                groups: 2,
                bias: true,
            },
        },
    ],
    inputs: [
        Argument {
            name: "input",
            ty: Tensor(
                TensorType {
                    dtype: F32,
                    rank: 4,
                    static_shape: Some(
                        [
                            2,
                            4,
                            10,
                            15,
                        ],
                    ),
                },
            ),
            value_source: Dynamic,
        },
    ],
    outputs: [
        Argument {
            name: "conv2d1_out1",
            ty: Tensor(
                TensorType {
                    dtype: F32,
                    rank: 4,
                    static_shape: None,
                },
            ),
            value_source: Dynamic,
        },
    ],
}

Code generated:

#[derive(Module, Debug)]
pub struct Model<B: Backend> {
    conv2d1: Conv2d<B>,
    phantom: core::marker::PhantomData<B>,
    device: burn::module::Ignored<B::Device>,
}


impl<B: Backend> Default for Model<B> {
    fn default() -> Self {
        Self::from_file("./out/conv2d", &Default::default())
    }
}

impl<B: Backend> Model<B> {
    pub fn from_file(file: &str, device: &B::Device) -> Self {
        let record = burn::record::PrettyJsonFileRecorder::<FullPrecisionSettings>::new()
            .load(file.into(), device)
            .expect("Record file to exist.");
        Self::new(device).load_record(record)
    }
}

impl<B: Backend> Model<B> {
    #[allow(unused_variables)]
    pub fn new(device: &B::Device) -> Self {
        let conv2d1 = Conv2dConfig::new([4, 6], [3, 5])
            .with_stride([2, 1])
            .with_padding(PaddingConfig2d::Explicit(4, 2))
            .with_dilation([3, 1])
            .with_groups(2)
            .with_bias(true)
            .init(device);
        Self {
            conv2d1,
            phantom: core::marker::PhantomData,
            device: burn::module::Ignored(device.clone()),
        }
    }

    #[allow(clippy::let_and_return, clippy::approx_constant)]
    pub fn forward(&self, input: Tensor<B, 4>) -> Tensor<B, 4> {
        let conv2d1_out1 = self.conv2d1.forward(input);
        conv2d1_out1
    }
}

antimora avatar Nov 14 '25 01:11 antimora