burn
burn copied to clipboard
Implement memory-efficient ONNX weight loading (lazy/protobuf streaming)
Feature description
Currently, burn-import's ONNX loader (onnx_ir) reads all ONNX weights (initializers) into memory up front. For large models, this approach is memory intensive and can cause scalability issues or OOM crashes. Instead, we should move to a strategy where weights and tensors are read from the ONNX/protobuf file only as needed during model loading or code generation, possibly using streaming/protobuf lazy parsing.
Current state
onnx_ir/burn-importloads all weight tensors at once, regardless of actual need.- This behavior is particularly problematic for large models or resource-constrained environments.
- No option exists for memory-mapped or streaming/lazy loading.
Proposal
- Refactor ONNX loading in
burn-import/onnx-irto support on-demand reading of weights/tensors from the ONNX file. - Use protobuf's streaming API, or a similar mechanism, to avoid loading unnecessary data into memory.
- Consider providing both eager (current) and lazy/streaming modes for backwards compatibility.
- Update codegen and all ONNX operator implementations in
burn-importto work with on-demand tensor access. - Document any new APIs or usage considerations for downstream users.
Feature motivation
- Support importing and working with extremely large ONNX models without OOM errors.
- Reduce the memory footprint for typical ONNX import scenarios.
- Enable use of
burn-importin environments with constrained memory (e.g., embedded, wasm, CI/CD, cloud). - Bring burn-import's ONNX handling up to par with best practices in other frameworks (cf. PyTorch, TensorFlow, ONNX Runtime).
(Optional) Suggest a Solution
- Investigate the protobuf parsing used in
onnx-ir, and refactor to support iterators or readers for initializers/weights. - Use streaming reads for large tensor data blocks, and only decode weights as needed for each node/operator.
- Consider a trait or abstraction for weight access that can be implemented for both eager and lazy backends.
- Profile and benchmark memory usage before/after.
- Add regression tests for large ONNX models to ensure memory use stays low.
Context
- Related to scalability and performance limitations in current
burn-importONNX support. - Not directly addressed by any existing open tickets; this is a new proposal.
- For overlapping concerns, see existing issues on ONNX import scalability, async, and backend support.
This would substantially help with https://github.com/tracel-ai/burn/issues/2871. I've been doing some investigations into memory usage in the past few days, and I've found that nearly all of the memory comes from just loading and keeping all of the tensors loaded in memory at the same time.
Although not streaming, my new PR (#3615) creates the groundwork for loading in a single layer at a time, and might help contribute to solving this in the long run.
It is still being refined, but feedback would be appreciated on its interface.
I am currently working on this.
After https://github.com/tracel-ai/burn/pull/3872 is merged refactor Rc<RefCell<GraphState>> lifetime hack.