onnx-mlir
onnx-mlir copied to clipboard
Hybrid ONNX-to-ONNX transformation pass
Includes a new pattern-based shape inference implementation.
So far only combines shape inference and canonicalization but the plan is to include constant propagation and decomposition so that a single pass with all the patterns can cascade shape inference and all the transforms in a single pass.
I encountered a model with a long dependency chain between shape inference, canonicalization, and constant propagation which requires an incredible number of repetitive passes with the current pass infrastructure.
Tweaked a bunch of lit tests to work with the new shape inference. The tweaks fall into two categories:
-
The new pattern based shape inference removes dead code, namely ops whose results are not used. See e.g.
test_lstm_no_results
in onnx_shape_inference.mlir. This is a behavior that comes withapplyPatternsAndFoldGreedily
. -
Shape errors that were ignored by the old shape inference which skipped ops that already had static shape. E.g.
test_if_simple
in onnx_lowering.mlir.
Signed-off-by: Soren Lassen [email protected]
The special problem with shape inference is that there is a binding for the block arguments and the merging for the joined control flow. Not sure about whether MLIR can reach the fixed point for transformation for the whole graph if we just put all the rewriting rules together and apply them greedily on function op. Will constant propagation have the similar issue if constant propagation exceeds the boundary of block? That's another extension we can consider.
The special problem with shape inference is that there is a binding for the block arguments and the merging for the joined control flow. Not sure about whether MLIR can reach the fixed point for transformation for the whole graph if we just put all the rewriting rules together and apply them greedily on function op.
I was assuming that we could always do shape inference in a single top-down pass
please let me know if you have a counter example
Will constant propagation have the similar issue if constant propagation exceeds the boundary of block?
are you thinking of if/loop/scan subregions - propagating a constant into or out of a subregion, or inlining the subregion if the conditional is constant? (I haven't thought through those situations)
Shape inference for regions
Regions could can from IfOp, LoopOp or functionOp. There be be more than one control flow going into a region. For example, a functionOp may have more than one call site, and the Loop body region has one from outside of the loop and one from the loop body return. Let's focus on how Shape inference is performed for LoopOp:
- Bind the region arguments with the loop inputs. S_input => S_arguments
- Perform Shape inference of the loop body based on S_arguments
- At the end of loop body, we get the shape of the return, denoted as S_return
- Merge argument shape and S_return => S_arguments
- repeat step 2 until S_arguments reaches fixed point
The semi-lattice structure of shape info on the merge operation guarantees the stop of the loop. If constant propagation is done with shape inference, could it happen that some constant from the unstable stage got propagated and cannot be recovered? For example, there is a constant shaped tensor in the input but the final result for the corresponding argument is not a constant shaped tensor. In the first iteration of shape inference on loop body, we may replace the ShapeOp with a constant and propagate it all the way. At the second iteration, we have no way to recover the ShapeOp, at least with current implementation.
- repeat step 2 until S_arguments reaches fixed point
this repeat is done by the loop in ONNXOpTransformPass::runOnOperation()
, right?
9. repeat step 2 until S_arguments reaches fixed point
this repeat is done by the loop in
ONNXOpTransformPass::runOnOperation()
, right?
No, we do not have that part yet. For tensor type, the input type and the return data have to be the same. The loop may be only needed for onnx.Seq type. A LoopOp may take an empty Seq as initial input. After one iteration, the return may be Seq<Tensor<1xf32>>. The second iteration may return Seq<Tensor<?xf32>> if the inserted element is a 1D tensor other than Tensor<1xf32>. However, I do not know whether we should do shape inference on Seq in such an aggressive way. The motivation is to avoid Seq<Tensor<*xf32>> in onnx-mlir. Perhaps we can assume pre-order traversal is enough for us. What we need to do is to use the dynamic pass on the subregion region for FunctionOp, IfOp or LoopOp .
@sorenlassen I went back to review this PR. I agree on the basic idea to define the hybrid pass to replace most of the onnx-to-onnx transformations. Several questions
- When fold is defined to an op, it is invoked automatically in canonicalization pass. Will the same happen when you add the pattern of canonicalization into the hybrid pattern?
- The fold may be used separately with createOrFold. The current constant propagation code (which will be moved into fold) assumes the shape inference has been done. To avoid duplicating shape inference for fold condition, I plan to directly add the shape inference into fold function. Is this a valid decision?
- What's the impact of fold containing shape inference on the hybrid pass?
- How would we proceed? Finish the hybrid pass first and then move the constant propagation into fold?
- The design of shape inference, such as general interface, and handling control flow, is not fully clear to me. I am working on it.
@sorenlassen I thought over the problem further. The main obstacle is from LoopOp. The shape inference for loop body starts with assigning the type of loop initial values to the loop body arguments. These types may not be the same as the final types of loop body argument yet. Therefore, the flag, inferShapeOnly
, has to be true.
Theoretically, the the output of loop body should be merged with the initial value and update the loop body argument for further iterating until the shape remains the same. Currently, onnx-mlir assumes that body output of one iteration is the ultimate result. I am afraid It may be wrong. I ran into problem with example of loop with sequence. For this PR, we can keep the current implementation.
For hybrid pass, we need to create a nested pass of hybrid transformation pass for the loop body with inferShapeOnly=true for iteration process (the first pass of the loop body in current implementation). After it becomes stable, we can run another nested pass of hybrid transformation pass for the loop body with infereShapeOnly=false. I guess that the inferShape of LoopOp does this by traversing the ops in its region for shape inference only.
The invocation of the nested pass should be in the shape inference of LoopOp. That's counter part of the current parameter for inferShape. To my understanding, the use that parameter is cause by that MLIR pass can not be defined on region which is not a FunctionOp, though there is a return before terminator. We need to figure out a nice solution for this.
There are several related questions, which are unclear to me and not urgent:
- To me, it seems that shape inference shall further update the shape info of loop body after the type of loop arguements becomes stable (after one iteration). But we did not see any issue without that step in current implement. Why?
- With transformations, such as constant propagation, the loop body could be transformed. Could the shape for loop outputs be further improved? Should we iterate the loop body for shape inference once more?
- When shape inference is applied to loop body, inferShapeOnly is ON. Since some transformation may help shape inference. Will it help to copy the loop body and then perform hybrid pass with transformation on the copied body to get just the type of the loop body output?
Comment on code: The major change is to decompose the current inferShape function for control flow ops, IfOp and LoopOp,into two pieces. One passes the shape info into the region and the other compute the result type when the control flows join.
- When inferShape is called on IfOp, nothing need to be done.
- When inferShape of return of THEN or ELSE branch of IfOp is encountered, call the "join" operation for IfOp. The join operation for IfOp is the current inferShape except the two calls doShapeInference on THEN and ELSE. I assume that the pass manager will traverse into regions. I think that the return Pattern in this PR is for this purpose.
- When inferShape is called on LoopOp, the type of initial input is propagated to loop carried variables and traverse the region to apply inferShape only. This the initialization and doShapeInference part of current inferShape function for LoopOp.
- When inferShape of return of LoopOp is called, call the "join" Operation for LoopOp. If fixed point is not reached, traverse the region to apply inferShape only. This should be the rest part of the current inferShape function for LoopOp.
Other parts of the PR is OK. Is it right that inferShapeOnly is for debug only? Similar control for other pattern could be added too, couldn't them?
@chentong319 if you want to discuss in person I can meet Friday morning March 17 anytime between 8-11am PT if you're available then
@chentong319 if you want to discuss in person I can meet Friday morning March 17 anytime between 8-11am PT if you're available then
Sorry, I am not available this Friday. Does any time next Monday work for you?
Does any time next Monday work for you?
yes, I'm available anytime after 1 PT on Monday
Does any time next Monday work for you?
yes, I'm available anytime after 1 PT on Monday 1 pt Monday https://ibm.webex.com/meet/chentong90
this PR was completed in PR #2098