tvm icon indicating copy to clipboard operation
tvm copied to clipboard

[Arith] Provide tighter ConstIntBounds for special cases

Open Lunderberg opened this issue 1 year ago • 1 comments

Expressions of the form (A+B)*C < (A*B)*D can occur occur when comparing the number of operations required for two different orderings in which matrix multiplications can be performed. Proving or disproving this conditional allows an optimal order of execution to be selected, even for dynamic argument shapes.

The default behavior of ConstIntBounds assumes that each term in an expression is independent. For example, the maximum value of (A+B)*C - (A*B)*D is determined by taking the maximum value of (A+B)*C and subtracting the minimum value of (A*B)*D. This algorithm can be applied in all cases, but can provide a bound that is looser than strictly required.

This commit adds a check for this case in ConstIntBounds, to provide a tighter bound of possible values. When A, B, C, and D are all positive values, as is the case for tensor shapes, the inequality can be written as 1/A + 1/B < D/C. If this inequality holds for the minimum values of A, B, and D, along with the maximum value of C, then it holds for all values.

Lunderberg avatar Feb 16 '24 17:02 Lunderberg

One thing that we would like to consider is the overall efficiency. ConstIntBound is one such case where we are not trying to do complicated pattern matching and relies on the independent property, as a result it can be called extensively in the inner loop of analysis.

It is useful to blance the overall efficiency as well as the readability and easy to reason(it is a good thing to keep ConstIntBound simple when its behavior is predictable) , where to place those detections.

I think it is helpful to support some of the more complicated patterns, usually we do that use another form of SubAnalyzer, or proves with stronger strength(https://github.com/apache/tvm/blob/main/include/tvm/arith/analyzer.h#L75). Some of our previous proves tries to introduce a different strength that is not being explicitly used in inner loops of other simplifiers. I believe this case is also similar

tqchen avatar Feb 16 '24 18:02 tqchen

Good point on checking the performance. I did a benchmark, with shown results shown in the plot below. The x-axis is the time required to run the analyzer with the BoundUsingReciprocal function disabled, and the y-axis is the time with it enabled. A y=x diagonal line is shown for comparison.

Click for benchmark settings
  • The benchmark measures the time to run the analyzer itself. The is the call to analyzer.const_int_bound(expr) for the tests in test_arith_const_int_bound.py, and the call to analyzer.rewrite_simplify(expr) for the tests in test_arith_rewrite_simplify.py.
  • Tests were performed using pytest-benchmark plugin using default settings. For each benchmark, the number of iterations was selected to take about 1 second total time, and the plot shows the average per-iteration time. The number of iterations was at least 5k for all benchmarks, typically around 30-40k.
  • When disabling PR-16588, the BoundUsingReciprocal function was never entered.
  • Order of each benchmark was randomized to avoid any systematic bias between groups.

image

In the majority of cases, there is effectively no performance difference. In cases where the improved bounds returned from ConstIntBound allow a simplification, the performance is improved as fewer recursive rewrites need to be applied.

Lunderberg avatar Feb 20 '24 20:02 Lunderberg

I still think in this case the overall constraining the behavior of ConstIntBound is more predictable and readable.

My main worry is we open a flood gate of introducing many recursive rewrite patterns to ConstIntBound itself.

  • Do you think if it is possible to introduce another subalayzer for this kind of prove given they are specific to LoRA, or have a stronger proof strength
  • In all cases, we should take a close look at the proves and make sure they are correct, since it is at center of what we are doing
  • If in the end we end up keep things in ConstIntBound, I think we should have a comment saying that this was really an exceptional case, and future changes to ConstIntBound should only be made in exceptional cases

tqchen avatar Feb 20 '24 22:02 tqchen

My main worry is we open a flood gate of introducing many recursive rewrite patterns to ConstIntBound itself.

Ah, I think I see where I may have miscommunicated. There isn't actually any recursive rewriting being performed. The pattern matching are used to extract information, and not to perform any rewrites or allocations. Stating that (A+B)*C < (A*B)*D could be rearranged into 1/A + 1/B < D/C was for the derivation of the tighter bounds, and isn't actually performed at runtime.

I absolutely agree that ConstIntBound should be a pure analysis pass, and should not perform any rewrites to the argument.

Do you think if it is possible to introduce another subalayzer for this kind of prove given they are specific to LoRA, or have a stronger proof strength

I think it could be moved to the TryProve* methods within the RewriteSimplifier, but hoisting it all the way out to a new subanalyzer probably wouldn't be feasible. It would need to simplify first, then check for the specific comparison being made, and then simplify again, which seems like a lot of overhead.

My worry with introducing arguments for the proof strength is that it adds code complexity, and requires callers to know an additional argument. For cases that have a significant computational overhead, that's worth it to avoid surprising a user. For cases that don't, and would follow the same API, I'd want to avoid having that requirement where possible.

Rather than enabling it by proof strength, what if we start with it it behind an extension flag (link). Where more computationally expensive simplifications would always require a user to opt-in with a higher proof strength, individually-enabled flags could allow new simplifications to be introduced as opt-in, and only later be enabled by default.

In all cases, we should take a close look at the proves and make sure they are correct, since it is at center of what we are doing

Absolutely agreed. For this case, I compared the range of output values for several np.arange test cases. Long-term, I wonder if it would be good to add fuzzing to the unit tests. As good as the unit tests are at the moment, I always like adding more steps that could catch an incorrect simplification.

If in the end we end up keep things in ConstIntBound, I think we should have a comment saying that this was really an exceptional case, and future changes to ConstIntBound should only be made in exceptional cases

That makes sense, and I agree that it would be good to comment on it.

Lunderberg avatar Feb 21 '24 04:02 Lunderberg

This PR is now updated to perform the checks for (A+B)*C < (A*B)*D patterns in RewriteSimplifer, gated behind the Extension::kComparisonOfProductAndSum extension flag. This flag is currently disabled by default, with unit tests explicitly enabling the extension.

The updated behavior is sufficient to unblock https://github.com/apache/tvm/pull/16589.

Lunderberg avatar Feb 27 '24 02:02 Lunderberg

Thank you @Lunderberg !

tqchen avatar Feb 27 '24 19:02 tqchen