llvm-project
llvm-project copied to clipboard
[PatternMatching] Add generic API for matching constants using custom conditions
- [PatternMatching] Add generic API for matching constants using custom conditions
- [InstCombine] Add example usage for new
Checkedmatcher API
The new API is:
m_CheckedInt(Lambda)/m_CheckedFp(Lambda)
- Matches non-undef constants s.t Lambda(ele) is true for all
elements.
m_CheckedIntAllowUndef(Lambda)/m_CheckedFpAllowUndef(Lambda)
- Matches constants/undef s.t Lambda(ele) is true for all
elements.
The goal with these is to be able to replace the common usage of:
match(X, m_APInt(C)) && CustomCheck(C)
with
match(X, m_CheckedInt(C, CustomChecks);
The rationale if we often ignore non-splat vectors because there are no good APIs to handle them with and its not worth increasing code complexity for such cases.
The hope is the API creates a common method handling scalars/splat-vecs/non-splat-vecs to essentially make this a non-issue.
@llvm/pr-subscribers-llvm-transforms
@llvm/pr-subscribers-llvm-ir
Author: None (goldsteinn)
Changes
- [PatternMatching] Add generic API for matching constants using custom conditions
- [InstCombine] Add example usage for new
Checkedmatcher API
Patch is 26.39 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/85676.diff
4 Files Affected:
- (modified) llvm/include/llvm/IR/PatternMatch.h (+80-11)
- (modified) llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp (+30-36)
- (modified) llvm/test/Transforms/InstCombine/signed-truncation-check.ll (+6-24)
- (modified) llvm/unittests/IR/PatternMatch.cpp (+240)
diff --git a/llvm/include/llvm/IR/PatternMatch.h b/llvm/include/llvm/IR/PatternMatch.h
index 382009d9df785d..4333d3e6e8da2a 100644
--- a/llvm/include/llvm/IR/PatternMatch.h
+++ b/llvm/include/llvm/IR/PatternMatch.h
@@ -346,7 +346,7 @@ template <int64_t Val> inline constantint_match<Val> m_ConstantInt() {
/// This helper class is used to match constant scalars, vector splats,
/// and fixed width vectors that satisfy a specified predicate.
/// For fixed width vector constants, undefined elements are ignored.
-template <typename Predicate, typename ConstantVal>
+template <typename Predicate, typename ConstantVal, bool AllowUndefs>
struct cstval_pred_ty : public Predicate {
template <typename ITy> bool match(ITy *V) {
if (const auto *CV = dyn_cast<ConstantVal>(V))
@@ -369,8 +369,11 @@ struct cstval_pred_ty : public Predicate {
Constant *Elt = C->getAggregateElement(i);
if (!Elt)
return false;
- if (isa<UndefValue>(Elt))
+ if (isa<UndefValue>(Elt)) {
+ if (!AllowUndefs)
+ return false;
continue;
+ }
auto *CV = dyn_cast<ConstantVal>(Elt);
if (!CV || !this->isValue(CV->getValue()))
return false;
@@ -384,16 +387,17 @@ struct cstval_pred_ty : public Predicate {
};
/// specialization of cstval_pred_ty for ConstantInt
-template <typename Predicate>
-using cst_pred_ty = cstval_pred_ty<Predicate, ConstantInt>;
+template <typename Predicate, bool AllowUndefs = true>
+using cst_pred_ty = cstval_pred_ty<Predicate, ConstantInt, AllowUndefs>;
/// specialization of cstval_pred_ty for ConstantFP
-template <typename Predicate>
-using cstfp_pred_ty = cstval_pred_ty<Predicate, ConstantFP>;
+template <typename Predicate, bool AllowUndefs = true>
+using cstfp_pred_ty = cstval_pred_ty<Predicate, ConstantFP, AllowUndefs>;
/// This helper class is used to match scalar and vector constants that
/// satisfy a specified predicate, and bind them to an APInt.
-template <typename Predicate> struct api_pred_ty : public Predicate {
+template <typename Predicate, bool AllowUndefs = true>
+struct api_pred_ty : public Predicate {
const APInt *&Res;
api_pred_ty(const APInt *&R) : Res(R) {}
@@ -406,7 +410,8 @@ template <typename Predicate> struct api_pred_ty : public Predicate {
}
if (V->getType()->isVectorTy())
if (const auto *C = dyn_cast<Constant>(V))
- if (auto *CI = dyn_cast_or_null<ConstantInt>(C->getSplatValue()))
+ if (auto *CI =
+ dyn_cast_or_null<ConstantInt>(C->getSplatValue(AllowUndefs)))
if (this->isValue(CI->getValue())) {
Res = &CI->getValue();
return true;
@@ -419,7 +424,8 @@ template <typename Predicate> struct api_pred_ty : public Predicate {
/// This helper class is used to match scalar and vector constants that
/// satisfy a specified predicate, and bind them to an APFloat.
/// Undefs are allowed in splat vector constants.
-template <typename Predicate> struct apf_pred_ty : public Predicate {
+template <typename Predicate, bool AllowUndefs = true>
+struct apf_pred_ty : public Predicate {
const APFloat *&Res;
apf_pred_ty(const APFloat *&R) : Res(R) {}
@@ -432,8 +438,8 @@ template <typename Predicate> struct apf_pred_ty : public Predicate {
}
if (V->getType()->isVectorTy())
if (const auto *C = dyn_cast<Constant>(V))
- if (auto *CI = dyn_cast_or_null<ConstantFP>(
- C->getSplatValue(/* AllowUndef */ true)))
+ if (auto *CI =
+ dyn_cast_or_null<ConstantFP>(C->getSplatValue(AllowUndefs)))
if (this->isValue(CI->getValue())) {
Res = &CI->getValue();
return true;
@@ -452,6 +458,69 @@ template <typename Predicate> struct apf_pred_ty : public Predicate {
//
///////////////////////////////////////////////////////////////////////////////
+template <typename APTy> struct custom_checkfn {
+ function_ref<bool(const APTy &)> CheckFn;
+ bool isValue(const APTy &C) { return CheckFn(C); }
+};
+
+// Match and integer or vector where CheckFn(ele) for each element is true.
+// For vectors, undefined elements are assumed NOT to match.
+inline cst_pred_ty<custom_checkfn<APInt>, false>
+m_CheckedInt(function_ref<bool(const APInt &)> CheckFn) {
+ return cst_pred_ty<custom_checkfn<APInt>, false>{CheckFn};
+}
+
+inline api_pred_ty<custom_checkfn<APInt>, false>
+m_CheckedInt(const APInt *&V, function_ref<bool(const APInt &)> CheckFn) {
+ api_pred_ty<custom_checkfn<APInt>, false> P(V);
+ P.CheckFn = CheckFn;
+ return P;
+}
+
+// Match and integer or vector where CheckFn(ele) for each element is true.
+// For vectors, undefined elements are assumed to match.
+inline cst_pred_ty<custom_checkfn<APInt>>
+m_CheckedIntAllowUndef(function_ref<bool(const APInt &)> CheckFn) {
+ return cst_pred_ty<custom_checkfn<APInt>>{CheckFn};
+}
+
+inline api_pred_ty<custom_checkfn<APInt>>
+m_CheckedIntAllowUndef(const APInt *&V,
+ function_ref<bool(const APInt &)> CheckFn) {
+ api_pred_ty<custom_checkfn<APInt>> P(V);
+ P.CheckFn = CheckFn;
+ return P;
+}
+
+// Match and float or vector where CheckFn(ele) for each element is true.
+// For vectors, undefined elements are assumed NOT to match.
+inline cstfp_pred_ty<custom_checkfn<APFloat>, false>
+m_CheckedFp(function_ref<bool(const APFloat &)> CheckFn) {
+ return cstfp_pred_ty<custom_checkfn<APFloat>, false>{CheckFn};
+}
+
+inline apf_pred_ty<custom_checkfn<APFloat>, false>
+m_CheckedFp(const APFloat *&V, function_ref<bool(const APFloat &)> CheckFn) {
+ apf_pred_ty<custom_checkfn<APFloat>, false> P(V);
+ P.CheckFn = CheckFn;
+ return P;
+}
+
+// Match and float or vector where CheckFn(ele) for each element is true.
+// For vectors, undefined elements are assumed to match.
+inline cstfp_pred_ty<custom_checkfn<APFloat>>
+m_CheckedFpAllowUndef(function_ref<bool(const APFloat &)> CheckFn) {
+ return cstfp_pred_ty<custom_checkfn<APFloat>>{CheckFn};
+}
+
+inline apf_pred_ty<custom_checkfn<APFloat>>
+m_CheckedFpAllowUndef(const APFloat *&V,
+ function_ref<bool(const APFloat &)> CheckFn) {
+ apf_pred_ty<custom_checkfn<APFloat>> P(V);
+ P.CheckFn = CheckFn;
+ return P;
+}
+
struct is_any_apint {
bool isValue(const APInt &C) { return true; }
};
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index 0dce0077bf1588..711294e4635579 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -6347,57 +6347,51 @@ Instruction *InstCombinerImpl::foldICmpUsingKnownBits(ICmpInst &I) {
case ICmpInst::ICMP_ULT: {
if (Op1Min == Op0Max) // A <u B -> A != B if max(A) == min(B)
return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1);
- const APInt *CmpC;
- if (match(Op1, m_APInt(CmpC))) {
- // A <u C -> A == C-1 if min(A)+1 == C
- if (*CmpC == Op0Min + 1)
- return new ICmpInst(ICmpInst::ICMP_EQ, Op0,
- ConstantInt::get(Op1->getType(), *CmpC - 1));
- // X <u C --> X == 0, if the number of zero bits in the bottom of X
- // exceeds the log2 of C.
- if (Op0Known.countMinTrailingZeros() >= CmpC->ceilLogBase2())
- return new ICmpInst(ICmpInst::ICMP_EQ, Op0,
- Constant::getNullValue(Op1->getType()));
- }
+ // A <u C -> A == C-1 if min(A)+1 == C
+ if (match(Op1, m_SpecificInt(Op0Min + 1)))
+ return new ICmpInst(ICmpInst::ICMP_EQ, Op0,
+ ConstantInt::get(Op1->getType(), Op0Min));
+ // X <u C --> X == 0, if the number of zero bits in the bottom of X
+ // exceeds the log2 of C.
+ if (match(Op1, m_CheckedInt([&Op0Known](const APInt &C) {
+ return Op0Known.countMinTrailingZeros() >= C.ceilLogBase2();
+ })))
+ return new ICmpInst(ICmpInst::ICMP_EQ, Op0,
+ Constant::getNullValue(Op1->getType()));
break;
}
case ICmpInst::ICMP_UGT: {
if (Op1Max == Op0Min) // A >u B -> A != B if min(A) == max(B)
return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1);
- const APInt *CmpC;
- if (match(Op1, m_APInt(CmpC))) {
- // A >u C -> A == C+1 if max(a)-1 == C
- if (*CmpC == Op0Max - 1)
- return new ICmpInst(ICmpInst::ICMP_EQ, Op0,
- ConstantInt::get(Op1->getType(), *CmpC + 1));
- // X >u C --> X != 0, if the number of zero bits in the bottom of X
- // exceeds the log2 of C.
- if (Op0Known.countMinTrailingZeros() >= CmpC->getActiveBits())
- return new ICmpInst(ICmpInst::ICMP_NE, Op0,
- Constant::getNullValue(Op1->getType()));
- }
+ // A >u C -> A == C+1 if max(a)-1 == C
+ if (match(Op1, m_SpecificInt(Op0Max - 1)))
+ return new ICmpInst(ICmpInst::ICMP_EQ, Op0,
+ ConstantInt::get(Op1->getType(), Op0Max));
+ // X >u C --> X != 0, if the number of zero bits in the bottom of X
+ // exceeds the log2 of C.
+ if (match(Op1, m_CheckedInt([&Op0Known](const APInt &C) {
+ return Op0Known.countMinTrailingZeros() >= C.getActiveBits();
+ })))
+ return new ICmpInst(ICmpInst::ICMP_NE, Op0,
+ Constant::getNullValue(Op1->getType()));
break;
}
case ICmpInst::ICMP_SLT: {
if (Op1Min == Op0Max) // A <s B -> A != B if max(A) == min(B)
return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1);
- const APInt *CmpC;
- if (match(Op1, m_APInt(CmpC))) {
- if (*CmpC == Op0Min + 1) // A <s C -> A == C-1 if min(A)+1 == C
- return new ICmpInst(ICmpInst::ICMP_EQ, Op0,
- ConstantInt::get(Op1->getType(), *CmpC - 1));
- }
+ // A <s C -> A == C-1 if min(A)+1 == C
+ if (match(Op1, m_SpecificInt(Op0Min + 1)))
+ return new ICmpInst(ICmpInst::ICMP_EQ, Op0,
+ ConstantInt::get(Op1->getType(), Op0Min));
break;
}
case ICmpInst::ICMP_SGT: {
if (Op1Max == Op0Min) // A >s B -> A != B if min(A) == max(B)
return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1);
- const APInt *CmpC;
- if (match(Op1, m_APInt(CmpC))) {
- if (*CmpC == Op0Max - 1) // A >s C -> A == C+1 if max(A)-1 == C
- return new ICmpInst(ICmpInst::ICMP_EQ, Op0,
- ConstantInt::get(Op1->getType(), *CmpC + 1));
- }
+ // A >s C -> A == C+1 if max(A)-1 == C
+ if (match(Op1, m_SpecificInt(Op0Max - 1)))
+ return new ICmpInst(ICmpInst::ICMP_EQ, Op0,
+ ConstantInt::get(Op1->getType(), Op0Max));
break;
}
}
diff --git a/llvm/test/Transforms/InstCombine/signed-truncation-check.ll b/llvm/test/Transforms/InstCombine/signed-truncation-check.ll
index 208e166b2c8760..465235fb08d383 100644
--- a/llvm/test/Transforms/InstCombine/signed-truncation-check.ll
+++ b/llvm/test/Transforms/InstCombine/signed-truncation-check.ll
@@ -212,10 +212,7 @@ define <3 x i1> @positive_vec_undef0(<3 x i32> %arg) {
define <3 x i1> @positive_vec_undef1(<3 x i32> %arg) {
; CHECK-LABEL: @positive_vec_undef1(
-; CHECK-NEXT: [[T1:%.*]] = icmp sgt <3 x i32> [[ARG:%.*]], <i32 -1, i32 -1, i32 -1>
-; CHECK-NEXT: [[T2:%.*]] = add <3 x i32> [[ARG]], <i32 128, i32 undef, i32 128>
-; CHECK-NEXT: [[T3:%.*]] = icmp ult <3 x i32> [[T2]], <i32 256, i32 256, i32 256>
-; CHECK-NEXT: [[T4:%.*]] = and <3 x i1> [[T1]], [[T3]]
+; CHECK-NEXT: [[T4:%.*]] = icmp ult <3 x i32> [[ARG:%.*]], <i32 128, i32 128, i32 128>
; CHECK-NEXT: ret <3 x i1> [[T4]]
;
%t1 = icmp sgt <3 x i32> %arg, <i32 -1, i32 -1, i32 -1>
@@ -227,10 +224,7 @@ define <3 x i1> @positive_vec_undef1(<3 x i32> %arg) {
define <3 x i1> @positive_vec_undef2(<3 x i32> %arg) {
; CHECK-LABEL: @positive_vec_undef2(
-; CHECK-NEXT: [[T1:%.*]] = icmp sgt <3 x i32> [[ARG:%.*]], <i32 -1, i32 -1, i32 -1>
-; CHECK-NEXT: [[T2:%.*]] = add <3 x i32> [[ARG]], <i32 128, i32 128, i32 128>
-; CHECK-NEXT: [[T3:%.*]] = icmp ult <3 x i32> [[T2]], <i32 256, i32 undef, i32 256>
-; CHECK-NEXT: [[T4:%.*]] = and <3 x i1> [[T1]], [[T3]]
+; CHECK-NEXT: [[T4:%.*]] = icmp ult <3 x i32> [[ARG:%.*]], <i32 128, i32 128, i32 128>
; CHECK-NEXT: ret <3 x i1> [[T4]]
;
%t1 = icmp sgt <3 x i32> %arg, <i32 -1, i32 -1, i32 -1>
@@ -242,10 +236,7 @@ define <3 x i1> @positive_vec_undef2(<3 x i32> %arg) {
define <3 x i1> @positive_vec_undef3(<3 x i32> %arg) {
; CHECK-LABEL: @positive_vec_undef3(
-; CHECK-NEXT: [[T1:%.*]] = icmp sgt <3 x i32> [[ARG:%.*]], <i32 -1, i32 undef, i32 -1>
-; CHECK-NEXT: [[T2:%.*]] = add <3 x i32> [[ARG]], <i32 128, i32 undef, i32 128>
-; CHECK-NEXT: [[T3:%.*]] = icmp ult <3 x i32> [[T2]], <i32 256, i32 256, i32 256>
-; CHECK-NEXT: [[T4:%.*]] = and <3 x i1> [[T1]], [[T3]]
+; CHECK-NEXT: [[T4:%.*]] = icmp ult <3 x i32> [[ARG:%.*]], <i32 128, i32 128, i32 128>
; CHECK-NEXT: ret <3 x i1> [[T4]]
;
%t1 = icmp sgt <3 x i32> %arg, <i32 -1, i32 undef, i32 -1>
@@ -257,10 +248,7 @@ define <3 x i1> @positive_vec_undef3(<3 x i32> %arg) {
define <3 x i1> @positive_vec_undef4(<3 x i32> %arg) {
; CHECK-LABEL: @positive_vec_undef4(
-; CHECK-NEXT: [[T1:%.*]] = icmp sgt <3 x i32> [[ARG:%.*]], <i32 -1, i32 undef, i32 -1>
-; CHECK-NEXT: [[T2:%.*]] = add <3 x i32> [[ARG]], <i32 128, i32 128, i32 128>
-; CHECK-NEXT: [[T3:%.*]] = icmp ult <3 x i32> [[T2]], <i32 256, i32 undef, i32 256>
-; CHECK-NEXT: [[T4:%.*]] = and <3 x i1> [[T1]], [[T3]]
+; CHECK-NEXT: [[T4:%.*]] = icmp ult <3 x i32> [[ARG:%.*]], <i32 128, i32 128, i32 128>
; CHECK-NEXT: ret <3 x i1> [[T4]]
;
%t1 = icmp sgt <3 x i32> %arg, <i32 -1, i32 undef, i32 -1>
@@ -272,10 +260,7 @@ define <3 x i1> @positive_vec_undef4(<3 x i32> %arg) {
define <3 x i1> @positive_vec_undef5(<3 x i32> %arg) {
; CHECK-LABEL: @positive_vec_undef5(
-; CHECK-NEXT: [[T1:%.*]] = icmp sgt <3 x i32> [[ARG:%.*]], <i32 -1, i32 -1, i32 -1>
-; CHECK-NEXT: [[T2:%.*]] = add <3 x i32> [[ARG]], <i32 128, i32 undef, i32 128>
-; CHECK-NEXT: [[T3:%.*]] = icmp ult <3 x i32> [[T2]], <i32 256, i32 undef, i32 256>
-; CHECK-NEXT: [[T4:%.*]] = and <3 x i1> [[T1]], [[T3]]
+; CHECK-NEXT: [[T4:%.*]] = icmp ult <3 x i32> [[ARG:%.*]], <i32 128, i32 128, i32 128>
; CHECK-NEXT: ret <3 x i1> [[T4]]
;
%t1 = icmp sgt <3 x i32> %arg, <i32 -1, i32 -1, i32 -1>
@@ -287,10 +272,7 @@ define <3 x i1> @positive_vec_undef5(<3 x i32> %arg) {
define <3 x i1> @positive_vec_undef6(<3 x i32> %arg) {
; CHECK-LABEL: @positive_vec_undef6(
-; CHECK-NEXT: [[T1:%.*]] = icmp sgt <3 x i32> [[ARG:%.*]], <i32 -1, i32 undef, i32 -1>
-; CHECK-NEXT: [[T2:%.*]] = add <3 x i32> [[ARG]], <i32 128, i32 undef, i32 128>
-; CHECK-NEXT: [[T3:%.*]] = icmp ult <3 x i32> [[T2]], <i32 256, i32 undef, i32 256>
-; CHECK-NEXT: [[T4:%.*]] = and <3 x i1> [[T1]], [[T3]]
+; CHECK-NEXT: [[T4:%.*]] = icmp ult <3 x i32> [[ARG:%.*]], <i32 128, i32 128, i32 128>
; CHECK-NEXT: ret <3 x i1> [[T4]]
;
%t1 = icmp sgt <3 x i32> %arg, <i32 -1, i32 undef, i32 -1>
diff --git a/llvm/unittests/IR/PatternMatch.cpp b/llvm/unittests/IR/PatternMatch.cpp
index 533a30bfba45dd..de361c70804c3e 100644
--- a/llvm/unittests/IR/PatternMatch.cpp
+++ b/llvm/unittests/IR/PatternMatch.cpp
@@ -572,6 +572,169 @@ TEST_F(PatternMatchTest, BitCast) {
EXPECT_FALSE(m_ElementWiseBitCast(m_Value()).match(NXV2I64ToNXV4I32));
}
+TEST_F(PatternMatchTest, CheckedInt) {
+ Type *I8Ty = IRB.getInt8Ty();
+ const APInt *Res = nullptr;
+
+ auto CheckUgt1 = [](const APInt &C) { return C.ugt(1); };
+ auto CheckTrue = [](const APInt &) { return true; };
+ auto CheckFalse = [](const APInt &) { return false; };
+ auto CheckNonZero = [](const APInt &C) { return !C.isZero(); };
+ auto CheckPow2 = [](const APInt &C) { return C.isPowerOf2(); };
+
+ auto DoScalarCheck = [&](int8_t Val) {
+ APInt APVal(8, Val);
+ Constant *C = ConstantInt::get(I8Ty, Val);
+
+ Res = nullptr;
+ EXPECT_TRUE(m_CheckedInt(CheckTrue).match(C));
+ EXPECT_TRUE(m_CheckedInt(Res, CheckTrue).match(C));
+ EXPECT_EQ(*Res, APVal);
+
+ Res = nullptr;
+ EXPECT_FALSE(m_CheckedInt(CheckFalse).match(C));
+ EXPECT_FALSE(m_CheckedInt(Res, CheckFalse).match(C));
+
+ Res = nullptr;
+ EXPECT_EQ(CheckUgt1(APVal), m_CheckedInt(CheckUgt1).match(C));
+ EXPECT_EQ(CheckUgt1(APVal), m_CheckedInt(Res, CheckUgt1).match(C));
+ if (CheckUgt1(APVal)) {
+ EXPECT_NE(Res, nullptr);
+ EXPECT_EQ(*Res, APVal);
+ }
+
+ Res = nullptr;
+ EXPECT_EQ(CheckUgt1(APVal), m_CheckedIntAllowUndef(CheckUgt1).match(C));
+ EXPECT_EQ(CheckUgt1(APVal),
+ m_CheckedIntAllowUndef(Res, CheckUgt1).match(C));
+ if (CheckUgt1(APVal)) {
+ EXPECT_NE(Res, nullptr);
+ EXPECT_EQ(*Res, APVal);
+ }
+
+ Res = nullptr;
+ EXPECT_EQ(CheckNonZero(APVal), m_CheckedInt(CheckNonZero).match(C));
+ EXPECT_EQ(CheckNonZero(APVal), m_CheckedInt(Res, CheckNonZero).match(C));
+ if (CheckNonZero(APVal)) {
+ EXPECT_NE(Res, nullptr);
+ EXPECT_EQ(*Res, APVal);
+ }
+
+ Res = nullptr;
+ EXPECT_EQ(CheckNonZero(APVal),
+ m_CheckedIntAllowUndef(CheckNonZero).match(C));
+ EXPECT_EQ(CheckNonZero(APVal),
+ m_CheckedIntAllowUndef(Res, CheckNonZero).match(C));
+ if (CheckNonZero(APVal)) {
+ EXPECT_NE(Res, nullptr);
+ EXPECT_EQ(*Res, APVal);
+ }
+
+ Res = nullptr;
+ EXPECT_EQ(CheckPow2(APVal), m_CheckedInt(CheckPow2).match(C));
+ EXPECT_EQ(CheckPow2(APVal), m_CheckedInt(Res, CheckPow2).match(C));
+ if (CheckPow2(APVal)) {
+ EXPECT_NE(Res, nullptr);
+ EXPECT_EQ(*Res, APVal);
+ }
+
+ Res = nullptr;
+ EXPECT_EQ(CheckPow2(APVal), m_CheckedIntAllowUndef(CheckPow2).match(C));
+ EXPECT_EQ(CheckPow2(APVal),
+ m_CheckedIntAllowUndef(Res, CheckPow2).match(C));
+ if (CheckPow2(APVal)) {
+ EXPECT_NE(Res, nullptr);
+ EXPECT_EQ(*Res, APVal);
+ }
+ };
+
+ DoScalarCheck(0);
+ DoScalarCheck(1);
+ DoScalarCheck(2);
+ DoScalarCheck(3);
+
+ EXPECT_FALSE(m_CheckedInt(CheckTrue).match(UndefValue::get(I8Ty)));
+ EXPECT_FALSE(m_CheckedInt(Res, CheckTrue).match(UndefValue::get(I8Ty)));
+ EXPECT_EQ(Res, nullptr);
+
+ EXPECT_FALSE(m_CheckedInt(CheckFalse).match(UndefValue::get(I8Ty)));
+ EXPECT_FALSE(m_CheckedInt(Res, CheckFalse).match(UndefValue::get(I8Ty)));
+ EXPECT_EQ(Res, nullptr);
+
+ EXPECT_FALSE(m_CheckedInt(CheckTrue).match(PoisonValue::get(I8Ty)));
+ EXPECT_FALSE(m_CheckedInt(Res, CheckTrue).match(PoisonValue::get(I8Ty)));
+ EXPECT_EQ(Res, nullptr);
+
+ EXPECT_FALSE(m_CheckedInt(CheckFalse).match(PoisonValue::get(I8Ty)));
+ EXPECT_FALSE(m_CheckedInt(Res, CheckFalse).match(PoisonValue::get(I8Ty)));
+ EXPECT_EQ(Res, nullptr);
+
+ auto DoVecCheckImpl = [&](ArrayRef<std::optional<int8_t>> Vals,
+ function_ref<bool(const APInt &)> CheckFn,
+ bool UndefAsPoison) {
+ SmallVector<Constant *> VecElems;
+ std::optional<bool> Okay;
+ bool AllSame = true;
+ bool HasUndef = false;
+ std::optional<APInt> First;
+ for (const std::optional<int8_t> &Val : Vals) {
+ if (!Val.has_value()) {
+ VecElems.push_back(UndefAsPoison ? PoisonValue::get(I8Ty)
+ : UndefValue::get(I8Ty));
+ HasUndef = true;
+ } else {
+ if (!Okay.has_value())
+ Okay = true;
+ APInt APVal(8, *Val);
+ if (!First.has_value())
+ First = APVal;
+ else
+ AllSame &= First->eq(APVal);
+ Okay = *Okay && CheckFn(APVal);
+ VecElems.push_back(ConstantInt::get(I8Ty, *Val));
+ }
+ }
+
+ Constant *C = ConstantVector::get(VecElems);
+ EXPECT_EQ(!HasUndef && Okay.value_or(false),
+ m_CheckedInt(CheckFn).match(C));
+ EXPECT_EQ(Okay.value_or(false), m_CheckedIntAllowUndef(CheckFn).match(C));
+
+ Res = nullptr;
+ bool Expec = !HasUndef && AllSame && Okay.value_or(false);
+ EXPECT_EQ(Expec, m_CheckedInt(Res, CheckFn).match(C));
+ if (Expec) {
+ EXPECT_NE(Res, nullptr);
+ EXPECT_EQ(*Res, *First);
+ }
+
+ Res = nullptr;
+ Expec = AllSame && Okay.value_or(f...
[truncated]
:warning: C/C++ code formatter, clang-format found issues in your code. :warning:
You can test this locally with the following command:
git-clang-format --diff a12622543de15df45fb9ad64e8ab723289d55169 436c5a2925171e6384e508893c96347c401b94ed -- llvm/include/llvm/IR/PatternMatch.h llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp llvm/unittests/IR/PatternMatch.cpp
View the diff from clang-format here.
diff --git a/llvm/unittests/IR/PatternMatch.cpp b/llvm/unittests/IR/PatternMatch.cpp
index d5a4a6a056..a436d19c0d 100644
--- a/llvm/unittests/IR/PatternMatch.cpp
+++ b/llvm/unittests/IR/PatternMatch.cpp
@@ -657,7 +657,6 @@ TEST_F(PatternMatchTest, CheckedInt) {
EXPECT_NE(Res, nullptr);
EXPECT_EQ(*Res, APVal);
}
-
};
DoScalarCheck(0);
ping
Please check the compile-time impact of this change.
Please check the compile-time impact of this change.
https://llvm-compile-time-tracker.com/compare.php?from=817f453aa576286aaca0a6b0244e6ab08516b80c&to=05871ee629407932bc6576224dcc2ae99db473fa&stat=instructions:u
Think basically no impact, maybe a slight regression stage2-O0 and slight improvement stage2-O3, but looks to be within err range and a bit scattered.
ping
ping
I'd like to defer this until after https://github.com/llvm/llvm-project/pull/88217, which would address my primary concern with this (the ever present undef footgun).
I'd like to defer this until after #88217, which would address my primary concern with this (the ever present undef footgun).
Actually any issue getting just the default APIs (not supporting undef) in now? Don't imagine it will get in your way. If you think it will be a problem though, this can ofc be deferred.
rebased