-
Notifications
You must be signed in to change notification settings - Fork 10.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MLIR][Math] add canonicalize-f32-promotion pass #92482
base: main
Are you sure you want to change the base?
[MLIR][Math] add canonicalize-f32-promotion pass #92482
Conversation
Thank you for submitting a Pull Request (PR) to the LLVM Project! This PR will be automatically labeled and the relevant teams will be If you wish to, you can add reviewers by using the "Reviewers" section on this page. If this is not working for you, it is probably because you do not have write If you have received no comments on your PR for a week, you can request a review If you have further questions, they may be answered by the LLVM GitHub User Guide. You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums. |
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-math Author: Ivy Zhang (crazydemo) ChangesThe current This pass is to eliminate the redundant truncf/extf pairs. Full diff: https://github.com/llvm/llvm-project/pull/92482.diff 5 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
index e2c513047c77a..f150ff6f944d2 100644
--- a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
@@ -17,6 +17,7 @@ namespace math {
#include "mlir/Dialect/Math/Transforms/Passes.h.inc"
#define GEN_PASS_DECL_MATHUPLIFTTOFMA
#define GEN_PASS_DECL_MATHLEGALIZETOF32
+#define GEN_PASS_DECL_MATHCANONICALIZEF32PROMOTION
#include "mlir/Dialect/Math/Transforms/Passes.h.inc"
#define GEN_PASS_REGISTRATION
#include "mlir/Dialect/Math/Transforms/Passes.h.inc"
diff --git a/mlir/include/mlir/Dialect/Math/Transforms/Passes.td b/mlir/include/mlir/Dialect/Math/Transforms/Passes.td
index e870e714bfda5..5bf5eb45f921a 100644
--- a/mlir/include/mlir/Dialect/Math/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Math/Transforms/Passes.td
@@ -36,4 +36,51 @@ def MathLegalizeToF32 : Pass<"math-legalize-to-f32"> {
let dependentDialects = ["math::MathDialect", "arith::ArithDialect"];
}
+def MathCanonicalizeF32Promotion : Pass<"math-canonicalize-f32-promotion"> {
+ let summary = "Eliminate redundant truncf/extf pairs";
+ let description = [{
+ `legalize-to-f32` pass does f32 promotion for every op belonging to the
+ illegal op list. Once there are some consecutive illegal ops, `legalize-to-f32`
+ will insert redundant `arith.truncf` and `arith.extf` pairs between the illegal
+ ops.
+
+ This pass is to eliminate the redundant truncf/extf pairs to improve
+ performance.
+
+ However, this pass may introduce numerical difference as the `f32->bf16` rounding
+ is eliminated.
+
+ Example:
+
+ ```mlir
+ // the initial func
+ func.func @bf16_sin_vector(%arg0: vector<32xbf16>) -> vector<32xbf16> {
+ %0 = math.absf %arg0 : vector<32xbf16>
+ %1 = math.sin %0 : vector<32xbf16>
+ return %1 : vector<32xbf16>
+ }
+ // after legalize-to-f32
+ func.func @bf16_sin_vector(%arg0: vector<32xbf16>) -> vector<32xbf16> {
+ %0 = arith.extf %arg0 : vector<32xbf16> to vector<32xf32>
+ %1 = math.absf %0 : vector<32xf32>
+ %2 = arith.truncf %1 : vector<32xf32> to vector<32xbf16>
+ %3 = arith.extf %2 : vector<32xbf16> to vector<32xf32>
+ %4 = math.sin %3 : vector<32xf32>
+ %5 = arith.truncf %4 : vector<32xf32> to vector<32xbf16>
+ return %5 : vector<32xbf16>
+ }
+ // after canonicalize-f32-promotion
+ func.func @bf16_sin_vector(%arg0: vector<32xbf16>) -> vector<32xbf16> {
+ %0 = arith.extf %arg0 : vector<32xbf16> to vector<32xf32>
+ %1 = math.absf %0 : vector<32xf32>
+ %2 = math.sin %1 : vector<32xf32>
+ %3 = arith.truncf %2 : vector<32xf32> to vector<32xbf16>
+ return %3 : vector<32xbf16>
+ }
+ ```
+
+ }];
+ let dependentDialects = ["math::MathDialect", "arith::ArithDialect"];
+}
+
#endif // MLIR_DIALECT_MATH_TRANSFORMS_PASSES
diff --git a/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt
index 2a5b4fbcb5271..0d39d14925d23 100644
--- a/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt
@@ -2,6 +2,7 @@ add_mlir_dialect_library(MLIRMathTransforms
AlgebraicSimplification.cpp
ExpandPatterns.cpp
LegalizeToF32.cpp
+ CanonicalizeF32Promotion.cpp
PolynomialApproximation.cpp
UpliftToFMA.cpp
diff --git a/mlir/lib/Dialect/Math/Transforms/CanonicalizeF32Promotion.cpp b/mlir/lib/Dialect/Math/Transforms/CanonicalizeF32Promotion.cpp
new file mode 100644
index 0000000000000..b9b43a0887f14
--- /dev/null
+++ b/mlir/lib/Dialect/Math/Transforms/CanonicalizeF32Promotion.cpp
@@ -0,0 +1,72 @@
+//===- CanonicalizeF32Promotion.cpp - Remove redundant extf/truncf pairs -===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements removing redundant extf/truncf pairs inserted from
+// LegalizeToF32.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/Dialect/Math/Transforms/Passes.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/TypeUtilities.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+namespace mlir::math {
+#define GEN_PASS_DEF_MATHCANONICALIZEF32PROMOTION
+#include "mlir/Dialect/Math/Transforms/Passes.h.inc"
+} // namespace mlir::math
+
+using namespace mlir;
+
+namespace {
+
+struct CanonicalizeF32PromotionRewritePattern final
+ : OpRewritePattern<arith::ExtFOp> {
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(arith::ExtFOp op,
+ PatternRewriter &rewriter) const final {
+ if (auto innertruncop = op.getOperand().getDefiningOp<arith::TruncFOp>()) {
+ if (auto truncinput = innertruncop.getOperand()) {
+ auto outter_type = op.getType();
+ auto intermediate_type = innertruncop.getType();
+ auto inner_type = truncinput.getType();
+ if (outter_type.isa<ShapedType>()) {
+ outter_type = op.getType().cast<ShapedType>().getElementType();
+ intermediate_type =
+ innertruncop.getType().cast<ShapedType>().getElementType();
+ inner_type = truncinput.getType().cast<ShapedType>().getElementType();
+ }
+ if (outter_type.isF32() &&
+ (intermediate_type.isF16() || intermediate_type.isBF16()) &&
+ inner_type.isF32()) {
+ rewriter.replaceOp(op, {truncinput});
+ }
+ } else
+ return failure();
+ } else
+ return failure();
+ return success();
+ }
+};
+
+struct MathCanonicalizeF32Promotion final
+ : math::impl::MathCanonicalizeF32PromotionBase<
+ MathCanonicalizeF32Promotion> {
+ using MathCanonicalizeF32PromotionBase::MathCanonicalizeF32PromotionBase;
+ void runOnOperation() override {
+ RewritePatternSet patterns(&getContext());
+ patterns.insert<CanonicalizeF32PromotionRewritePattern>(&getContext());
+ FrozenRewritePatternSet patternSet(std::move(patterns));
+ if (failed(applyPatternsAndFoldGreedily(getOperation(), patternSet)))
+ signalPassFailure();
+ }
+};
+
+} // namespace
diff --git a/mlir/test/Dialect/Math/canonicalize-f32-promotion.mlir b/mlir/test/Dialect/Math/canonicalize-f32-promotion.mlir
new file mode 100644
index 0000000000000..127eece98cf79
--- /dev/null
+++ b/mlir/test/Dialect/Math/canonicalize-f32-promotion.mlir
@@ -0,0 +1,74 @@
+// RUN: mlir-opt %s --split-input-file -math-legalize-to-f32 -math-canonicalize-f32-promotion | FileCheck %s
+
+// CHECK-LABEL: @sequences
+// CHECK-SAME: ([[ARG0:%.+]]: bf16)
+// CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]]
+// CHECK: [[ABSF:%.+]] = math.absf [[EXTF]]
+// CHECK: [[SIN:%.+]] = math.sin [[ABSF]]
+// CHECK: [[TRUNCF:%.+]] = arith.truncf [[SIN]]
+// CHECK: return [[TRUNCF]] : bf16
+func.func @sequences(%arg0: bf16) -> bf16 {
+ %0 = math.absf %arg0 : bf16
+ %1 = math.sin %0 : bf16
+ return %1 : bf16
+}
+
+// CHECK-LABEL: @eliminatecastoncastf16
+// CHECK: return [[arg0:%.+]] : f32
+func.func @eliminatecastoncastf16(%arg0: f32) -> f32 {
+ %0 = arith.truncf %arg0 : f32 to f16
+ %1 = arith.extf %0 : f16 to f32
+ return %1 : f32
+}
+
+// CHECK-LABEL: @eliminatecastoncastbf16
+// CHECK: return [[arg0:%.+]] : f32
+func.func @eliminatecastoncastbf16(%arg0: f32) -> f32 {
+ %0 = arith.truncf %arg0 : f32 to bf16
+ %1 = arith.extf %0 : bf16 to f32
+ return %1 : f32
+}
+
+// CHECK-LABEL: @bf16_sin_vector
+// CHECK-SAME: ([[ARG0:%.+]]: vector<32x32x32xbf16>)
+// CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]]
+// CHECK: [[ABSF:%.+]] = math.absf [[EXTF]]
+// CHECK: [[SIN:%.+]] = math.sin [[ABSF]]
+// CHECK: [[TRUNCF:%.+]] = arith.truncf [[SIN]]
+// CHECK: return [[TRUNCF]] : vector<32x32x32xbf16>
+func.func @bf16_sin_vector(%arg0: vector<32x32x32xbf16>) -> vector<32x32x32xbf16> {
+ %0 = math.absf %arg0 : vector<32x32x32xbf16>
+ %1 = math.sin %0 : vector<32x32x32xbf16>
+ return %1 : vector<32x32x32xbf16>
+}
+
+// CHECK-LABEL: @f16_sin_vector
+// CHECK-SAME: ([[ARG0:%.+]]: vector<32x32x32xf16>)
+// CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]]
+// CHECK: [[ABSF:%.+]] = math.absf [[EXTF]]
+// CHECK: [[SIN:%.+]] = math.sin [[ABSF]]
+// CHECK: [[TRUNCF:%.+]] = arith.truncf [[SIN]]
+// CHECK: return [[TRUNCF]] : vector<32x32x32xf16>
+func.func @f16_sin_vector(%arg0: vector<32x32x32xf16>) -> vector<32x32x32xf16> {
+ %0 = math.absf %arg0 : vector<32x32x32xf16>
+ %1 = math.sin %0 : vector<32x32x32xf16>
+ return %1 : vector<32x32x32xf16>
+}
+
+// CHECK-LABEL: @bf16_branch_vector
+// CHECK-SAME: ([[ARG0:%.+]]: vector<32x32x32xbf16>)
+// CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]]
+// CHECK: [[ABSF:%.+]] = math.absf [[EXTF]]
+// CHECK: [[SIN:%.+]] = math.sin [[ABSF]]
+// CHECK: [[TRUNCF0:%.+]] = arith.truncf [[SIN]]
+// CHECK: [[COS:%.+]] = math.cos [[ABSF]]
+// CHECK: [[TRUNCF1:%.+]] = arith.truncf [[COS]]
+// CHECK: [[ADDF:%.+]] = arith.addf
+// CHECK: return [[ADDF]] : vector<32x32x32xbf16>
+func.func @bf16_branch_vector(%arg0: vector<32x32x32xbf16>) -> vector<32x32x32xbf16> {
+ %0 = math.absf %arg0 : vector<32x32x32xbf16>
+ %1 = math.sin %0 : vector<32x32x32xbf16>
+ %2 = math.cos %0 : vector<32x32x32xbf16>
+ %3 = arith.addf %1, %2 : vector<32x32x32xbf16>
+ return %3 : vector<32x32x32xbf16>
+}
|
FYI @krzysz00 |
This pass is to eliminate the redundant truncf/extf pairs to improve | ||
performance. | ||
|
||
However, this pass may introduce numerical difference as the `f32->bf16` rounding |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would these be valid patterns as canonicalization when fast-math is enabled?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This pass is independent from fast-math. If the fast-math is enabled, and we have such redundant truncf/extf pairs, they will still be removed from the IR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you missed my point: I was asking about THE canonicalization patterns, not a separate pass.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think they're valid under fastmath. See discussion at 3cf8535 for a particular case of this in a different context
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Given the context of how that ultimately got resolved, what I'd want to see is updating the math legalizer to strip out trunc/ext pairs that it creates at time of creation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think they're valid under fastmath.
Can you elaborate? I don't follow.
Assuming you meant to send a link to #88486 (comment) instead, there you wrote:
This rewrite has caused per-element result errors around 1e-2 (if I remember right 16.25 vs 16.3125 or the like)
But numerical differences is actually in scope for fast-math...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, I did mean to send a link to the commit - the discussion moved there.
What I mean to say is that
%y = arith.truncf %arg0 : float to half
%z =arith.extf %y : half to float
return %z : float
can't be simplified to return %arg0
, even under fastmath, because that is an explicit, user-specified desire to lose precision.
However, rewrites (like arith-emulate-unsupported-floats
, which probably should also get this intermediate preservation treatment) which would introduce such a colliding truncf/extf pair are allowed to not do that and keep a higher intermediate precision. This is allowed without fastmath.
These are, to my current understanding, the C, and thus the LLVM, semantics.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we didn't have this wider ecosystem that knows what exactly the fastmath flags mean, I'd buy that you could do this rewrite everywhere under contract
... but I'm pretty sure that flag doesn't allow eliminating arbitrary trunc/ext pairs from input
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can't be simplified to return %arg0, even under fastmath, because that is an explicit, user-specified desire to lose precision.
Sorry but I don't quite see why fast-math does not allow this in LLVM IR?
I don't really buy the "user-specified desired" because patterns of IR are happening in LLVM after quite a churn of optimization (possibly only through inlining, etc.).
If the user really want something, they shouldn't decorate these with fast math! (which is the point of expressing the express opposite desire than what you're saying actually)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the main thing with LLVM IR is hat fptrunc
and fpext
can't carry fastmath flags, last I checked.
Now, if they could, I'd absolutely permit
func.func @f(%arg0: f32) -> f32 {
%0 = arith.truncf contract %arg0 : f32 to f16
%1 = arith.extf contract %0 : f16 to f32
return %1 : f32
}
to fold to
func.func @f(%arg0: f32) -> f32 {
return %arg0 : f32
}
So, from an MLIR perspective, if we give extend and truncate the ability to fastmath (if they don't already have it) it'd make a lot of sense to
- Add that folding and
- Change the various unsupported float emulation/legalization passes to (optionally - or maybe by default) stick a
contract
on their truncates and extensions
Could you make this an edit to |
Thanks for your comment. I add a |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So, I've got quibbles with this approach too. I've put in a suggestion for the code I think you want to use here.
@@ -109,4 +135,14 @@ void LegalizeToF32Pass::runOnOperation() { | |||
math::populateLegalizeToF32Patterns(patterns, typeConverter); | |||
if (failed(applyPartialConversion(op, target, std::move(patterns)))) | |||
return signalPassFailure(); | |||
|
|||
if (useCanonicalizeF32Promotion) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't really like this approach.
How about going up to matchAndRewrite()
and doing this:
SmallVector<Value> extendedWidthOperands(operands);
for (auto [extended, original] : llvm::zip_equal(extendedWidthOperands, op->getOperands()) {
// match trunc/ext pair. The inelegant version is.
if (auto short = extended.getDefiningOp<arith::TruncFOp>()) {
auto maybeOriginal = extended.getIn().getDefiningOp<arith;:ExtFOp>());
if (maybeOriginal && maybeOriginal.getIn() == original)
extended = original;
}
convertOpResultTypes(..., extendedWidthOperands, ...);
Now, you don't need a pass option, and all you're doing is "if this is the extension of the truncation of my original argument, use that original argument instead".
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- The pass
option
is for users who concerns about the numerical difference. With the option, they can easily switch on / off the optimization. - I really appreciate your
one stage
approach, which directly modifies thematchAndRewrite()
to determine whether to insertextf / truncf
at the time of the current op is hit. However, if users createsextf / truncf
explicitly in the IR, then the op pairs cannot be optimized in this way. I think thetwo stage
approach can handle such case way more easily.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I propose the one stage approach because it doesn't optimize explicit truncf / extf pairs
Explicitly rewriting away all truncf/extf pairs shouldn't be hiding in a type legalization. The legalization can, using the one stage approach, refrain from creating such pairs to improve numerical precision, but it should not eliminate existing ones.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we can apply the one stage approach in legalization pass, and create another pass for something like graph simplification use. @ZhennanQin
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@krzysz00 May I know what's the difference between the existing truncf / extf and auto-generated ones? Why we can only eliminate the truncf / extf generated from legalize-to-f32
, but not from any other passes? Would you please provide a use scenario?
BTW, I think we could have different conversion target for |
Re bf16 ... one change I've wanted to make is to (factoring some logic out of arith-emulate-unsuoported-floats) let callers of this pass allowlist some short types. That is, I'd want to let people legalize all the short types except f16 ... or all the short types except f16 and b16 and so on Maybe that'd help your usecase? I'm opposed to trying to hardcore a list of f16 supported ops because that's target-dependent |
Currently, I am focusing on And as you are focusing on shorter types, I think we have no conflicts in implementing. I think hard code a list of f16 supported ops is reasonable after the target description is ready in MLIR. Here's an RFC about adding target description into MLIR. |
For bf16 legalization in |
It's entirely possible that we should combine those two passes into one and put it ... I haven't the faintest idea where. Maybe just move emulate-unsupported-floats into |
There will be passes other than |
So we have
I think the second option meets our need and has smaller changes. |
Re your most recent comment, I'd want both: And as to the truncf/extf pairs, the reason I don't want f32 legalization to delete existing ones is that I have exactly one truncf/extf pair I want to keep around for its effect. The code conceptually goes like this) func @test_kernel(%a: tensor<...xbf16>, %b: tensor<...xbf16>) -> tensor<...xbf16>
func @ref_kernel(%a: tensor<...xf32>, %: tensor<...xf32>) -> tensor<...xf32>
func @main() {
%cTest = launch @test_kernel(...) : tensor<... x bf16>
%cRef = launch @ref_kernel(...) : tensor<... x f32>
%cTestF32 = arith.extf %cTest : tensor<... x bf16> to tensor<... x f32>
// This is the truncf/extf pair I want to preserve. Note that, though I write it on tensors, by the time any sort of
// math legalization is running, this is all scalar elementwise code.
// The truncate/extend replicates the fact that @test_kernel accumulates in f32 and rounds down to bf16
%cRefMatchingTestRounding = arith.truncf %cRef : tensor<...xf32> to tensor<...xbf16>
%cRefF32 = arith.extf %cRefMatchingTestRounding : tensor<...xbf16> to tensor<...xf32>
call @verify_nearness(%cTestF32, %cRefF32)
} |
The current
legalize-to-f32
pass does f32 promotion for every op belonging to the illegal op list. Once there are some consecutive illegal ops,legalize-to-f32
will insert redundantarith.truncf
andarith.extf
pairs between the illegal ops.This pass is to eliminate the redundant truncf/extf pairs to improve performance. However, this pass may introduce numerical difference as the f32->bf16 rounding is eliminated.