onnx-mlir
onnx-mlir copied to clipboard
Op fusion needs special support from ONNXBroadcastOpShapeHelper
ONNXBroadcastOpShapeHelper is used to generate access expression for load, similar to the usage in the Element-wise op lowering. But they are different in two cases.
- Constructor. In element-wise op lowering, ONNXBroadcastOpShapeHelper takes all the remapped inputs (the hook to the converted inputs with type of MemRef, provided as operands or OpAdaptor parameter for matchAndRewrite) of the current Op as parameters for initialization. For the fusible op, we have to call getRemappedValue() or getRemappedValues() on rewriter to get the converted inputs. We can do that for all inputs, except the one coming from the defining op in our fusion list. It will cause a runtime error on the SSA. In fact, we do not need that input because the converted value is just generated. I guess that's the reason. But MLIR did not generate a compile error even though the result of getRemappedValues() was checked. Anyway, what should be put there? In Op fusion, we only have the Value inside the loop nest, not the whole MemRef Value. Can a null Value or the output Value be put there if they have the same shape? The output Value may have different element type, and different tensor shape, which will be discussed in the second bullet.
- Allowed shape for op fusion. Ops can be fused only when their loop iteration space are the same. But there is a corner case that their shape may be different with extended dimension of size 1. For example, onnx.Sqrt and onnx.Add are fused:
%3 = "onnx.Constant dense<[...]> : ()->tensor<1x24xf32>
%1 = "onnx.Sqrt"(%0) : (tensor<24xf32>) -> tensor<24xf32>
%2 = "onnx.Add"(%1, %3) : tensor<24xf32>, tensor<1x24xf32>) -> tensor<1x24xf32>
If such corner case is allowed, we have to notice that:
- The shape of the alloc Value (the output Value) comes from the root Op. It will be MemRef<24xf32>. It requires some work to make the shape to output of the last Op in the fusion chain because we do not have the converted input for the fused Op for ShapeHelper to handle the dynamic shape.
- The access expression may need to add extra 0s to the loop index vector. I am afraid that this is a new case not considered in the previous code.
Should we consider this corner case for fusion or not?
It requires some work to make the shape to output of the last Op
Approach 1
memref.subview
can be used for this. Quoting from the specs:
A subview operation may additionally reduce the rank of the resulting view by removing dimensions that are statically known to be of size 1.
subview are basically a software construct that let you access a subset (or all) of a memref. They are later eliminated by simply adding the offsets/strides/missing 1 dims in a later normalization pass. They cost nothing.
Approach 2 (preferred as it also handle the broadcast situations)
Scan all the fused ops in turn. For each op, generate the memory load (using broadcasting support), except for the "fused dependence". Store all the values in an array. When emitting the code, just pick up the values from the array.
Approach 3 (maybe even better).
Generate a single fused loop; gen the code as normal (including memref alloc before the loop, load / store into the temporary arrays). You should know that all but the last alloc should be kept (but also every other side results reused later outside of the fused loop). Then perform a cleanup pass where all unnecessary alloc are removed, as well as the load store chains.
The reason I like approach 3 is that it has minimum impact on the code gen, and having a cleanup pass that eliminate all stores, and replace the load by substituting a value where the load is used... seems pretty robust and feasible. Load/store could even be decorated with a special attribute, so that a cleanup pass does not even need to be aware of what was fused...