llvm-project icon indicating copy to clipboard operation
llvm-project copied to clipboard

missed fold, x fmul C1 > C2 => x > C2

Open zhengyang92 opened this issue 11 months ago • 6 comments

https://alive2.llvm.org/ce/z/UmK7vQ https://godbolt.org/z/a6Gqfvxv8

define i1 @src(float %0) {
%t1 = fmul float %0, 0x3FF0CCCCC0000000
%t2 = fcmp olt float %t1, 0x3FE20418A0000000
ret i1 %t2
}

define i1 @tgt(float %0) {
%ole = fcmp ole float %0, 0x3FE12878E0000000
ret i1 %ole
}

generalization may be a little tricky for this one.

zhengyang92 avatar Mar 14 '24 17:03 zhengyang92

@jayfoad @regehr

zhengyang92 avatar Mar 14 '24 17:03 zhengyang92

generalization may be a little tricky for this one.

What is the generalized pattern?

dtcxzyw avatar Mar 14 '24 17:03 dtcxzyw

Mathematically it would be X * C1 > C2 --> X > C3 where C3 = C2 / C1 (I'm assuming that C1 is positive for simplicity but it can be tweaked to work for negative C1 too). Two problems:

  • The division might overflow or underflow.
  • Rounding might mean that the equivalence is not exact. I am not a numerical analyst but I would suggest detecting that by plugging values of X back into the original expression X * C1 > C2. With X = C3 it should return false and with X = nextafter(C3, +inf) it should return true.

jayfoad avatar Mar 14 '24 18:03 jayfoad

And this optimization should respect #pragma STDC FENV_ROUND, by the way. (The optimization should be disabled for #pragma STDC FENV_ROUND FE_DYNAMIC.)

Explorer09 avatar Mar 15 '24 08:03 Explorer09

And this optimization should respect #pragma STDC FENV_ROUND, by the way. (The optimization should be disabled for #pragma STDC FENV_ROUND FE_DYNAMIC.)

Some cases are probably safe regardless of rounding mode, like X * 2 > 8 --> X > 4. Perhaps all cases where the result of the division C2 / C1 is exact?

jayfoad avatar Mar 15 '24 09:03 jayfoad

By the way, there's a typo in the title. It should be "x fmul C1 > C2 => x > C3".

Explorer09 avatar Mar 24 '24 14:03 Explorer09

Rounding might mean that the equivalence is not exact. I am not a numerical analyst but I would suggest detecting that by plugging values of X back into the original expression X * C1 > C2. With X = C3 it should return false and with X = nextafter(C3, +inf) it should return true.

I thought this would be simple to hold true with a single check in some way (maybe with weird rounding), but knowing floats I threw it into Z3 and this does not work in the general case :O (for a few reasons), but this seems to be pretty close to be the working solution :> (despite having tried some far too complicated things in the process :,D)

Simplifying this problem first to find all solutions x to x * C1 = C2, we can at least prove that for any given (non-subnormal!!) float there are a maximum of 2 solutions You can show this by showing there's no x such that x * y = (x + 2) * y For some number with n bits in the mantissa (including the 1 bit): Let's have some s be 2^n We can represent floating point multiplication without the exponent (due to the lack of subnormals) as mul(x, y) = round(x*y/2^(floor(log2(x*y))-n+1)) Due to the restriction with the mantissa, we know s/2 <= (x, y) < s Getting the divisor just requires finding log2(x)+log2(y) first -- bits - 1 <= log2((x, y)) < bits so 2bits - 2 <= log2((x, y)) < 2bits So floor(log2(x)+log2(y)) is either 2bits - 2 or 2bits - 1 If it's 2bits - 2, then the divisor = 2^(n-1) = s/2 If it's 2bits - 1, then the divisor = 2^n = s For some x*y, we know s^2/4 <= x*y < s^2 We can then find s^2/4+s <= (x+2)*y < s^2+s This offset of s is enough to at least increment the new mantissa by 1 with the same rounding, even if the divisor was equal to s Because the divisors are guaranteed to be >= than s, this means that with a value of x+2, there's no normal (x, y) which can have more than 2 solutions!

This is all to say, you actually just need to check next(C2 / C1) and prev(C2 / C1) for normal numbers ^^ (I checked this with Z3 as well and given how long it took (before crashing from OOM) it probably works, esp thanks to how the rounding will ensure it's never 2 away from the exact result, but it feels like that's not a very rigorous proof ^^;) In pseudocode for clarity on what I'm saying bc I feel like I always keep things ambiguous on accident:

// x * C1 > C2 -> x > C3
C3 = C2 / C1
if (C3 * C1 > C2)
  C3 = prev(C3)
else if (!(next(C3) * C1 > C2))
  C3 = next(C3)

With subnormals you can probably do some more trickery for this, like finding the smallest/largest value with the smallest/largest exponent such that it does/doesn't round, but I think I'll leave this here for now at least... I hope this helps at least a little even if just for reassurance ^^

Geotale avatar Apr 28 '24 19:04 Geotale