diff options
author | Hanhan Wang <hanchung@google.com> | 2021-04-20 07:34:32 -0700 |
---|---|---|
committer | Hanhan Wang <hanchung@google.com> | 2021-04-20 07:35:20 -0700 |
commit | 7b7df8e85eec445389e4b07915f16aa18332719d (patch) | |
tree | 7f28e4dbeae6d44b94fa436ac3e2221622728617 | |
parent | [gn build] reformat all gn files (diff) | |
download | llvm-project-main.tar.gz llvm-project-main.tar.bz2 llvm-project-main.zip |
[mlir][StandardToSPIRV] Add support for lowering std.xor on bool to SPIR-Vmain
std.xor ops on bool are lowered to spv.LogicalNotEqual. For Boolean values, xor
and not-equal are the same thing.
Reviewed By: antiagainst
Differential Revision: https://reviews.llvm.org/D100817
-rw-r--r-- | mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp | 29 | ||||
-rw-r--r-- | mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir | 4 |
2 files changed, 32 insertions, 1 deletions
diff --git a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp index 0196a21f4a69..2a6e7f281860 100644 --- a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp +++ b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp @@ -663,6 +663,17 @@ public: ConversionPatternRewriter &rewriter) const override; }; +/// Converts std.xor to SPIR-V operations if the type of source is i1 or vector +/// of i1. +class BoolXOrOpPattern final : public OpConversionPattern<XOrOp> { +public: + using OpConversionPattern<XOrOp>::OpConversionPattern; + + LogicalResult + matchAndRewrite(XOrOp xorOp, ArrayRef<Value> operands, + ConversionPatternRewriter &rewriter) const override; +}; + } // namespace //===----------------------------------------------------------------------===// @@ -1250,6 +1261,22 @@ XOrOpPattern::matchAndRewrite(XOrOp xorOp, ArrayRef<Value> operands, return success(); } +LogicalResult +BoolXOrOpPattern::matchAndRewrite(XOrOp xorOp, ArrayRef<Value> operands, + ConversionPatternRewriter &rewriter) const { + assert(operands.size() == 2); + + if (!isBoolScalarOrVector(operands.front().getType())) + return failure(); + + auto dstType = getTypeConverter()->convertType(xorOp.getType()); + if (!dstType) + return failure(); + rewriter.replaceOpWithNewOp<spirv::LogicalNotEqualOp>(xorOp, dstType, + operands); + return success(); +} + //===----------------------------------------------------------------------===// // Pattern population //===----------------------------------------------------------------------===// @@ -1293,7 +1320,7 @@ void populateStandardToSPIRVPatterns(SPIRVTypeConverter &typeConverter, UnaryAndBinaryOpPattern<UnsignedDivIOp, spirv::UDivOp>, UnaryAndBinaryOpPattern<UnsignedRemIOp, spirv::UModOp>, UnaryAndBinaryOpPattern<UnsignedShiftRightOp, spirv::ShiftRightLogicalOp>, - SignedRemIOpPattern, XOrOpPattern, + SignedRemIOpPattern, XOrOpPattern, BoolXOrOpPattern, // Comparison patterns BoolCmpIOpPattern, CmpFOpPattern, CmpFOpNanNonePattern, CmpIOpPattern, diff --git a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir index 0148a0731dc9..fe769482c787 100644 --- a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir +++ b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir @@ -224,6 +224,8 @@ func @logical_scalar(%arg0 : i1, %arg1 : i1) { %0 = and %arg0, %arg1 : i1 // CHECK: spv.LogicalOr %1 = or %arg0, %arg1 : i1 + // CHECK: spv.LogicalNotEqual + %2 = xor %arg0, %arg1 : i1 return } @@ -233,6 +235,8 @@ func @logical_vector(%arg0 : vector<4xi1>, %arg1 : vector<4xi1>) { %0 = and %arg0, %arg1 : vector<4xi1> // CHECK: spv.LogicalOr %1 = or %arg0, %arg1 : vector<4xi1> + // CHECK: spv.LogicalNotEqual + %2 = xor %arg0, %arg1 : vector<4xi1> return } |