Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MLIR][Math] add canonicalize-f32-promotion pass #92482

Open
wants to merge 6 commits into
base: main
Choose a base branch
from

Conversation

crazydemo
Copy link

@crazydemo crazydemo commented May 17, 2024

The current legalize-to-f32 pass does f32 promotion for every op belonging to the illegal op list. Once there are some consecutive illegal ops, legalize-to-f32 will insert redundant arith.truncf and arith.extf pairs between the illegal ops.

This pass is to eliminate the redundant truncf/extf pairs to improve performance. However, this pass may introduce numerical difference as the f32->bf16 rounding is eliminated.

Copy link

Thank you for submitting a Pull Request (PR) to the LLVM Project!

This PR will be automatically labeled and the relevant teams will be
notified.

If you wish to, you can add reviewers by using the "Reviewers" section on this page.

If this is not working for you, it is probably because you do not have write
permissions for the repository. In which case you can instead tag reviewers by
name in a comment by using @ followed by their GitHub username.

If you have received no comments on your PR for a week, you can request a review
by "ping"ing the PR by adding a comment “Ping”. The common courtesy "ping" rate
is once a week. Please remember that you are asking for valuable time from other developers.

If you have further questions, they may be answered by the LLVM GitHub User Guide.

You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums.

@llvmbot
Copy link
Collaborator

llvmbot commented May 17, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-math

Author: Ivy Zhang (crazydemo)

Changes

The current legalize-to-f32 pass does f32 promotion for every op belonging to the illegal op list. Once there are some consecutive illegal ops, legalize-to-f32 will insert redundant arith.truncf and arith.extf pairs between the illegal ops.

This pass is to eliminate the redundant truncf/extf pairs.


Full diff: https://github.com/llvm/llvm-project/pull/92482.diff

5 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Math/Transforms/Passes.h (+1)
  • (modified) mlir/include/mlir/Dialect/Math/Transforms/Passes.td (+47)
  • (modified) mlir/lib/Dialect/Math/Transforms/CMakeLists.txt (+1)
  • (added) mlir/lib/Dialect/Math/Transforms/CanonicalizeF32Promotion.cpp (+72)
  • (added) mlir/test/Dialect/Math/canonicalize-f32-promotion.mlir (+74)
diff --git a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
index e2c513047c77a..f150ff6f944d2 100644
--- a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
@@ -17,6 +17,7 @@ namespace math {
 #include "mlir/Dialect/Math/Transforms/Passes.h.inc"
 #define GEN_PASS_DECL_MATHUPLIFTTOFMA
 #define GEN_PASS_DECL_MATHLEGALIZETOF32
+#define GEN_PASS_DECL_MATHCANONICALIZEF32PROMOTION
 #include "mlir/Dialect/Math/Transforms/Passes.h.inc"
 #define GEN_PASS_REGISTRATION
 #include "mlir/Dialect/Math/Transforms/Passes.h.inc"
diff --git a/mlir/include/mlir/Dialect/Math/Transforms/Passes.td b/mlir/include/mlir/Dialect/Math/Transforms/Passes.td
index e870e714bfda5..5bf5eb45f921a 100644
--- a/mlir/include/mlir/Dialect/Math/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Math/Transforms/Passes.td
@@ -36,4 +36,51 @@ def MathLegalizeToF32 : Pass<"math-legalize-to-f32"> {
   let dependentDialects = ["math::MathDialect", "arith::ArithDialect"];
 }
 
+def MathCanonicalizeF32Promotion : Pass<"math-canonicalize-f32-promotion"> {
+  let summary = "Eliminate redundant truncf/extf pairs";
+  let description = [{
+    `legalize-to-f32` pass does f32 promotion for every op belonging to the
+    illegal op list. Once there are some consecutive illegal ops, `legalize-to-f32`
+    will insert redundant `arith.truncf` and `arith.extf` pairs between the illegal
+    ops.
+    
+    This pass is to eliminate the redundant truncf/extf pairs to improve
+    performance.
+
+    However, this pass may introduce numerical difference as the `f32->bf16` rounding
+    is eliminated.
+
+    Example:
+
+    ```mlir
+    // the initial func
+    func.func @bf16_sin_vector(%arg0: vector<32xbf16>) -> vector<32xbf16> {
+        %0 = math.absf %arg0 : vector<32xbf16>
+        %1 = math.sin %0 : vector<32xbf16>
+        return %1 : vector<32xbf16>
+      }
+    // after legalize-to-f32
+    func.func @bf16_sin_vector(%arg0: vector<32xbf16>) -> vector<32xbf16> {
+        %0 = arith.extf %arg0 : vector<32xbf16> to vector<32xf32>
+        %1 = math.absf %0 : vector<32xf32>
+        %2 = arith.truncf %1 : vector<32xf32> to vector<32xbf16>
+        %3 = arith.extf %2 : vector<32xbf16> to vector<32xf32>
+        %4 = math.sin %3 : vector<32xf32>
+        %5 = arith.truncf %4 : vector<32xf32> to vector<32xbf16>
+        return %5 : vector<32xbf16>
+      }
+    // after canonicalize-f32-promotion
+    func.func @bf16_sin_vector(%arg0: vector<32xbf16>) -> vector<32xbf16> {
+        %0 = arith.extf %arg0 : vector<32xbf16> to vector<32xf32>
+        %1 = math.absf %0 : vector<32xf32>
+        %2 = math.sin %1 : vector<32xf32>
+        %3 = arith.truncf %2 : vector<32xf32> to vector<32xbf16>
+        return %3 : vector<32xbf16>
+      }
+    ```
+
+  }];
+  let dependentDialects = ["math::MathDialect", "arith::ArithDialect"];
+}
+
 #endif // MLIR_DIALECT_MATH_TRANSFORMS_PASSES
diff --git a/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt
index 2a5b4fbcb5271..0d39d14925d23 100644
--- a/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt
@@ -2,6 +2,7 @@ add_mlir_dialect_library(MLIRMathTransforms
   AlgebraicSimplification.cpp
   ExpandPatterns.cpp
   LegalizeToF32.cpp
+  CanonicalizeF32Promotion.cpp
   PolynomialApproximation.cpp
   UpliftToFMA.cpp
 
diff --git a/mlir/lib/Dialect/Math/Transforms/CanonicalizeF32Promotion.cpp b/mlir/lib/Dialect/Math/Transforms/CanonicalizeF32Promotion.cpp
new file mode 100644
index 0000000000000..b9b43a0887f14
--- /dev/null
+++ b/mlir/lib/Dialect/Math/Transforms/CanonicalizeF32Promotion.cpp
@@ -0,0 +1,72 @@
+//===- CanonicalizeF32Promotion.cpp - Remove redundant extf/truncf pairs -===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements removing redundant extf/truncf pairs inserted from
+// LegalizeToF32.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/Dialect/Math/Transforms/Passes.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/TypeUtilities.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+namespace mlir::math {
+#define GEN_PASS_DEF_MATHCANONICALIZEF32PROMOTION
+#include "mlir/Dialect/Math/Transforms/Passes.h.inc"
+} // namespace mlir::math
+
+using namespace mlir;
+
+namespace {
+
+struct CanonicalizeF32PromotionRewritePattern final
+    : OpRewritePattern<arith::ExtFOp> {
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(arith::ExtFOp op,
+                                PatternRewriter &rewriter) const final {
+    if (auto innertruncop = op.getOperand().getDefiningOp<arith::TruncFOp>()) {
+      if (auto truncinput = innertruncop.getOperand()) {
+        auto outter_type = op.getType();
+        auto intermediate_type = innertruncop.getType();
+        auto inner_type = truncinput.getType();
+        if (outter_type.isa<ShapedType>()) {
+          outter_type = op.getType().cast<ShapedType>().getElementType();
+          intermediate_type =
+              innertruncop.getType().cast<ShapedType>().getElementType();
+          inner_type = truncinput.getType().cast<ShapedType>().getElementType();
+        }
+        if (outter_type.isF32() &&
+            (intermediate_type.isF16() || intermediate_type.isBF16()) &&
+            inner_type.isF32()) {
+          rewriter.replaceOp(op, {truncinput});
+        }
+      } else
+        return failure();
+    } else
+      return failure();
+    return success();
+  }
+};
+
+struct MathCanonicalizeF32Promotion final
+    : math::impl::MathCanonicalizeF32PromotionBase<
+          MathCanonicalizeF32Promotion> {
+  using MathCanonicalizeF32PromotionBase::MathCanonicalizeF32PromotionBase;
+  void runOnOperation() override {
+    RewritePatternSet patterns(&getContext());
+    patterns.insert<CanonicalizeF32PromotionRewritePattern>(&getContext());
+    FrozenRewritePatternSet patternSet(std::move(patterns));
+    if (failed(applyPatternsAndFoldGreedily(getOperation(), patternSet)))
+      signalPassFailure();
+  }
+};
+
+} // namespace
diff --git a/mlir/test/Dialect/Math/canonicalize-f32-promotion.mlir b/mlir/test/Dialect/Math/canonicalize-f32-promotion.mlir
new file mode 100644
index 0000000000000..127eece98cf79
--- /dev/null
+++ b/mlir/test/Dialect/Math/canonicalize-f32-promotion.mlir
@@ -0,0 +1,74 @@
+// RUN: mlir-opt %s --split-input-file -math-legalize-to-f32 -math-canonicalize-f32-promotion | FileCheck %s
+
+// CHECK-LABEL: @sequences
+// CHECK-SAME: ([[ARG0:%.+]]: bf16)
+// CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]]
+// CHECK: [[ABSF:%.+]] = math.absf [[EXTF]]
+// CHECK: [[SIN:%.+]] = math.sin [[ABSF]]
+// CHECK: [[TRUNCF:%.+]] = arith.truncf [[SIN]]
+// CHECK: return [[TRUNCF]] : bf16
+func.func @sequences(%arg0: bf16) -> bf16 {
+  %0 = math.absf %arg0 : bf16
+  %1 = math.sin %0 : bf16
+  return %1 : bf16
+}
+
+// CHECK-LABEL: @eliminatecastoncastf16
+// CHECK: return [[arg0:%.+]] : f32
+func.func @eliminatecastoncastf16(%arg0: f32) -> f32 {
+  %0 = arith.truncf %arg0 : f32 to f16
+  %1 = arith.extf %0 : f16 to f32
+  return %1 : f32
+}
+
+// CHECK-LABEL: @eliminatecastoncastbf16
+// CHECK: return [[arg0:%.+]] : f32
+func.func @eliminatecastoncastbf16(%arg0: f32) -> f32 {
+  %0 = arith.truncf %arg0 : f32 to bf16
+  %1 = arith.extf %0 : bf16 to f32
+  return %1 : f32
+}
+
+// CHECK-LABEL: @bf16_sin_vector
+// CHECK-SAME: ([[ARG0:%.+]]: vector<32x32x32xbf16>)
+// CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]]
+// CHECK: [[ABSF:%.+]] = math.absf [[EXTF]]
+// CHECK: [[SIN:%.+]] = math.sin [[ABSF]]
+// CHECK: [[TRUNCF:%.+]] = arith.truncf [[SIN]]
+// CHECK: return [[TRUNCF]] : vector<32x32x32xbf16>
+func.func @bf16_sin_vector(%arg0: vector<32x32x32xbf16>) -> vector<32x32x32xbf16> {
+  %0 = math.absf %arg0 : vector<32x32x32xbf16>
+  %1 = math.sin %0 : vector<32x32x32xbf16>
+  return %1 : vector<32x32x32xbf16>
+}
+
+// CHECK-LABEL: @f16_sin_vector
+// CHECK-SAME: ([[ARG0:%.+]]: vector<32x32x32xf16>)
+// CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]]
+// CHECK: [[ABSF:%.+]] = math.absf [[EXTF]]
+// CHECK: [[SIN:%.+]] = math.sin [[ABSF]]
+// CHECK: [[TRUNCF:%.+]] = arith.truncf [[SIN]]
+// CHECK: return [[TRUNCF]] : vector<32x32x32xf16>
+func.func @f16_sin_vector(%arg0: vector<32x32x32xf16>) -> vector<32x32x32xf16> {
+  %0 = math.absf %arg0 : vector<32x32x32xf16>
+  %1 = math.sin %0 : vector<32x32x32xf16>
+  return %1 : vector<32x32x32xf16>
+}
+
+// CHECK-LABEL: @bf16_branch_vector
+// CHECK-SAME: ([[ARG0:%.+]]: vector<32x32x32xbf16>)
+// CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]]
+// CHECK: [[ABSF:%.+]] = math.absf [[EXTF]]
+// CHECK: [[SIN:%.+]] = math.sin [[ABSF]]
+// CHECK: [[TRUNCF0:%.+]] = arith.truncf [[SIN]]
+// CHECK: [[COS:%.+]] = math.cos [[ABSF]]
+// CHECK: [[TRUNCF1:%.+]] = arith.truncf [[COS]]
+// CHECK: [[ADDF:%.+]] = arith.addf
+// CHECK: return [[ADDF]] : vector<32x32x32xbf16>
+func.func @bf16_branch_vector(%arg0: vector<32x32x32xbf16>) -> vector<32x32x32xbf16> {
+  %0 = math.absf %arg0 : vector<32x32x32xbf16>
+	%1 = math.sin %0 : vector<32x32x32xbf16>
+	%2 = math.cos %0 : vector<32x32x32xbf16>
+	%3 = arith.addf %1, %2 : vector<32x32x32xbf16>
+  return %3 : vector<32x32x32xbf16>
+}

@crazydemo
Copy link
Author

FYI @krzysz00

This pass is to eliminate the redundant truncf/extf pairs to improve
performance.

However, this pass may introduce numerical difference as the `f32->bf16` rounding
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would these be valid patterns as canonicalization when fast-math is enabled?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This pass is independent from fast-math. If the fast-math is enabled, and we have such redundant truncf/extf pairs, they will still be removed from the IR.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you missed my point: I was asking about THE canonicalization patterns, not a separate pass.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think they're valid under fastmath. See discussion at 3cf8535 for a particular case of this in a different context

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given the context of how that ultimately got resolved, what I'd want to see is updating the math legalizer to strip out trunc/ext pairs that it creates at time of creation.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think they're valid under fastmath.

Can you elaborate? I don't follow.
Assuming you meant to send a link to #88486 (comment) instead, there you wrote:

This rewrite has caused per-element result errors around 1e-2 (if I remember right 16.25 vs 16.3125 or the like)

But numerical differences is actually in scope for fast-math...

Copy link
Contributor

@krzysz00 krzysz00 May 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, I did mean to send a link to the commit - the discussion moved there.

What I mean to say is that

%y = arith.truncf %arg0 : float to half
%z =arith.extf %y : half to float
return %z : float

can't be simplified to return %arg0, even under fastmath, because that is an explicit, user-specified desire to lose precision.

However, rewrites (like arith-emulate-unsupported-floats, which probably should also get this intermediate preservation treatment) which would introduce such a colliding truncf/extf pair are allowed to not do that and keep a higher intermediate precision. This is allowed without fastmath.

These are, to my current understanding, the C, and thus the LLVM, semantics.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we didn't have this wider ecosystem that knows what exactly the fastmath flags mean, I'd buy that you could do this rewrite everywhere under contract ... but I'm pretty sure that flag doesn't allow eliminating arbitrary trunc/ext pairs from input

Copy link
Collaborator

@joker-eph joker-eph May 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can't be simplified to return %arg0, even under fastmath, because that is an explicit, user-specified desire to lose precision.

Sorry but I don't quite see why fast-math does not allow this in LLVM IR?
I don't really buy the "user-specified desired" because patterns of IR are happening in LLVM after quite a churn of optimization (possibly only through inlining, etc.).
If the user really want something, they shouldn't decorate these with fast math! (which is the point of expressing the express opposite desire than what you're saying actually)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the main thing with LLVM IR is hat fptrunc and fpext can't carry fastmath flags, last I checked.

Now, if they could, I'd absolutely permit

func.func @f(%arg0: f32) -> f32 {
    %0 = arith.truncf contract %arg0 : f32 to f16
    %1 = arith.extf contract %0 : f16 to f32
    return %1 : f32
}

to fold to

func.func @f(%arg0: f32) -> f32 {
    return %arg0 : f32
}

So, from an MLIR perspective, if we give extend and truncate the ability to fastmath (if they don't already have it) it'd make a lot of sense to

  1. Add that folding and
  2. Change the various unsupported float emulation/legalization passes to (optionally - or maybe by default) stick a contract on their truncates and extensions

@krzysz00
Copy link
Contributor

Could you make this an edit to legalize-to-f32 instead? I think reasonable people can agree that (as seen over with the X86 backend changes) that removing redundant implicit rounding is fine.

@crazydemo
Copy link
Author

Could you make this an edit to legalize-to-f32 instead? I think reasonable people can agree that (as seen over with the X86 backend changes) that removing redundant implicit rounding is fine.

Thanks for your comment. I add a use-canonicalize-f32-promotion option to legalize-to-f32 instread.

Copy link
Contributor

@krzysz00 krzysz00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, I've got quibbles with this approach too. I've put in a suggestion for the code I think you want to use here.

@@ -109,4 +135,14 @@ void LegalizeToF32Pass::runOnOperation() {
math::populateLegalizeToF32Patterns(patterns, typeConverter);
if (failed(applyPartialConversion(op, target, std::move(patterns))))
return signalPassFailure();

if (useCanonicalizeF32Promotion) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't really like this approach.

How about going up to matchAndRewrite() and doing this:

SmallVector<Value> extendedWidthOperands(operands);
for (auto [extended, original] : llvm::zip_equal(extendedWidthOperands, op->getOperands()) {
  // match trunc/ext pair. The inelegant version is.
  if (auto short = extended.getDefiningOp<arith::TruncFOp>()) {
    auto maybeOriginal = extended.getIn().getDefiningOp<arith;:ExtFOp>());
    if (maybeOriginal && maybeOriginal.getIn() == original)
      extended = original;
  }
  convertOpResultTypes(..., extendedWidthOperands, ...);

Now, you don't need a pass option, and all you're doing is "if this is the extension of the truncation of my original argument, use that original argument instead".

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. The pass option is for users who concerns about the numerical difference. With the option, they can easily switch on / off the optimization.
  2. I really appreciate your one stage approach, which directly modifies the matchAndRewrite() to determine whether to insert extf / truncf at the time of the current op is hit. However, if users creates extf / truncf explicitly in the IR, then the op pairs cannot be optimized in this way. I think the two stage approach can handle such case way more easily.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I propose the one stage approach because it doesn't optimize explicit truncf / extf pairs

Explicitly rewriting away all truncf/extf pairs shouldn't be hiding in a type legalization. The legalization can, using the one stage approach, refrain from creating such pairs to improve numerical precision, but it should not eliminate existing ones.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can apply the one stage approach in legalization pass, and create another pass for something like graph simplification use. @ZhennanQin

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@krzysz00 May I know what's the difference between the existing truncf / extf and auto-generated ones? Why we can only eliminate the truncf / extf generated from legalize-to-f32, but not from any other passes? Would you please provide a use scenario?

@crazydemo
Copy link
Author

BTW, I think we could have different conversion target for bf16 and fp16. Like arith::AddFOp, arith::SubFOp, arith::MulFOp, arith::DivFOp, math::AbsFOp, math::CeilOp, math::FloorOp, math::RoundOp, math::SqrtOp, math::RsqrtOp, math::Exp2Op, these ops should be legalized to f32 when their original data type is bf16, while they will not need to be legalized when their original datatype is fp16 and is compiled on machine with fp16 instruction support. What's your opinion on this? Maybe I can open another PR to discuss this issue. @krzysz00

@krzysz00
Copy link
Contributor

Re bf16 ... one change I've wanted to make is to (factoring some logic out of arith-emulate-unsuoported-floats) let callers of this pass allowlist some short types. That is, I'd want to let people legalize all the short types except f16 ... or all the short types except f16 and b16 and so on

Maybe that'd help your usecase?

I'm opposed to trying to hardcore a list of f16 supported ops because that's target-dependent

@crazydemo
Copy link
Author

Re bf16 ... one change I've wanted to make is to (factoring some logic out of arith-emulate-unsuoported-floats) let callers of this pass allowlist some short types. That is, I'd want to let people legalize all the short types except f16 ... or all the short types except f16 and b16 and so on

Maybe that'd help your usecase?

I'm opposed to trying to hardcore a list of f16 supported ops because that's target-dependent

Currently, I am focusing on bf16 legalization, and may extend to fp16 legalization in future. With the current legalize-to-f32 pass, only math ops can be legalized to f32, while arith ops, like arith::AddFOp will not be legalized, but it needs to be promoted for computing. I think we can at least expand the current op list to support my usecase.

And as you are focusing on shorter types, I think we have no conflicts in implementing. I think hard code a list of f16 supported ops is reasonable after the target description is ready in MLIR. Here's an RFC about adding target description into MLIR.

@ZhennanQin

@krzysz00
Copy link
Contributor

For bf16 legalization in arith, you want the companion of this pass, emulate-unsupported-floats

@krzysz00
Copy link
Contributor

It's entirely possible that we should combine those two passes into one and put it ... I haven't the faintest idea where.

Maybe just move emulate-unsupported-floats into math and have it support both Arith and Math ... and then do the one stage change to avoid extra extf and truncf pairs

@ZhennanQin
Copy link

ZhennanQin commented May 23, 2024

It's entirely possible that we should combine those two passes into one and put it ... I haven't the faintest idea where.

Maybe just move emulate-unsupported-floats into math and have it support both Arith and Math ... and then do the one stage change to avoid extra extf and truncf pairs

There will be passes other than legalize-to-f32 and emulate-unsupported-floats that can insert extf and truncf ops. Take matmul_bf16 + bias_add_bf16 as example, because we want to use fp32 as the accumulator data type, so in matmul lowering pass, matmul_bf16 will be replaced by scf.for with x86vector.avx.intr.dot + truncf. Then in legalize-to-f32 pass, bias_add_bf16 will be replaced by extf + bias_add_fp32 + truncf, and we will want to eliminate the extf / truncf pairs between x86vector.avx.intr.dot and bias_add_fp32.

@crazydemo
Copy link
Author

It's entirely possible that we should combine those two passes into one and put it ... I haven't the faintest idea where.

Maybe just move emulate-unsupported-floats into math and have it support both Arith and Math ... and then do the one stage change to avoid extra extf and truncf pairs

So we have legalize-to-f32 and emulate-unsupported-floats to insert extf/trunc pairs for math and arith respectively. All we need is a removement of redundant extf/trunc pairs. We have two choices:

  1. combining legalize-to-f32 and emulate-unsupported-floats into one pass, and also do removement in this pass.
  2. Create a independent pass, which only does removement. And this pass can be located in Transforms folder.

I think the second option meets our need and has smaller changes.

@krzysz00
Copy link
Contributor

Re your most recent comment, I'd want both:

And as to the truncf/extf pairs, the reason I don't want f32 legalization to delete existing ones is that I have exactly one truncf/extf pair I want to keep around for its effect. The code conceptually goes like this)

func @test_kernel(%a: tensor<...xbf16>, %b: tensor<...xbf16>) -> tensor<...xbf16>
func @ref_kernel(%a: tensor<...xf32>, %: tensor<...xf32>) -> tensor<...xf32>

func @main() {
  %cTest = launch @test_kernel(...) : tensor<... x bf16>
  %cRef = launch @ref_kernel(...) : tensor<... x f32>
  %cTestF32 = arith.extf %cTest : tensor<... x bf16> to tensor<... x f32>
  // This is the truncf/extf pair I want to preserve. Note that, though I write it on tensors, by the time any sort of
  // math legalization is running, this is all scalar elementwise code.
  // The truncate/extend replicates the fact that @test_kernel accumulates in f32 and rounds down to bf16
  %cRefMatchingTestRounding = arith.truncf %cRef : tensor<...xf32> to tensor<...xbf16>
  %cRefF32 = arith.extf %cRefMatchingTestRounding : tensor<...xbf16> to tensor<...xf32>
   call @verify_nearness(%cTestF32, %cRefF32)
}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants