1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This file implements the lowering for trigonometric standard ops to
17 // approximations.
18
19 #include <utility>
20
21 #include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
22 #include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
23 #include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
24 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
25 #include "mlir/Dialect/Func/IR/FuncOps.h"
26 #include "mlir/Dialect/Math/IR/Math.h"
27 #include "mlir/IR/BuiltinOps.h"
28 #include "mlir/Pass/Pass.h"
29 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
30
31 namespace mlir {
32 namespace mhlo {
33 namespace {
34
35 template <typename OpTy>
36 class ApproximateOnExtendedF32Lowering : public OpRewritePattern<OpTy> {
37 public:
ApproximateOnExtendedF32Lowering(MLIRContext * ctx)38 explicit ApproximateOnExtendedF32Lowering(MLIRContext *ctx)
39 : OpRewritePattern<OpTy>(ctx, /*benefit=*/100) {}
40
41 virtual Value emitApproximation(ValueRange, Location,
42 PatternRewriter &) const = 0;
43
matchAndRewrite(OpTy op,PatternRewriter & rewriter) const44 LogicalResult matchAndRewrite(OpTy op,
45 PatternRewriter &rewriter) const override {
46 Location loc = op.getLoc();
47 auto rawArgs = op.getOperation()->getOperands();
48
49 // Supports only f16 and f32 for now.
50 if (!op.getType().isF16() && !op.getType().isF32()) return failure();
51
52 // Extend operands to f32 if needed and possible.
53 SmallVector<Value, 2> f32Args;
54 f32Args.reserve(rawArgs.size());
55 for (Value arg : rawArgs) {
56 // Similar to XLA, do not rewrite f64 as precision might matter.
57 Type argTy = arg.getType();
58 if (argTy.isF64()) return failure();
59
60 if (argTy.isF16())
61 arg = rewriter.create<arith::ExtFOp>(loc, rewriter.getF32Type(), arg);
62
63 // If we still do not have f32, fail.
64 if (!arg.getType().isF32()) return failure();
65
66 f32Args.push_back(arg);
67 }
68
69 Value result = emitApproximation(f32Args, loc, rewriter);
70 assert(result.getType().isF32() && "Expect f32 intermediate result.");
71
72 // Truncate back if needed.
73 if (op.getType().isF16())
74 result =
75 rewriter.create<arith::TruncFOp>(loc, rewriter.getF16Type(), result);
76
77 rewriter.replaceOp(op, {result});
78 return success();
79 }
80 };
81
82 // This approximation resembles Eigen and realizes a constant approximation for
83 // the +/-1 limits on top.
84 // https://gitlab.com/libeigen/eigen/-/blob/master/Eigen/src/Core/MathFunctionsImpl.h
85 class ApproximateTanhLowering
86 : public ApproximateOnExtendedF32Lowering<math::TanhOp> {
87 public:
ApproximateTanhLowering(MLIRContext * ctx)88 explicit ApproximateTanhLowering(MLIRContext *ctx)
89 : ApproximateOnExtendedF32Lowering<math::TanhOp>(ctx) {}
90
91 // Emits the fast tanh approximation that is also used by XLA.
emitApproximation(ValueRange args,Location loc,PatternRewriter & rewriter) const92 Value emitApproximation(ValueRange args, Location loc,
93 PatternRewriter &rewriter) const override {
94 Value input = args.front();
95 assert(input.getType().isF32());
96 static constexpr std::array<float, 7> numeratorCoeffs{
97 -2.76076847742355e-16f, 2.00018790482477e-13f, -8.60467152213735e-11f,
98 5.12229709037114e-08f, 1.48572235717979e-05f, 6.37261928875436e-04f,
99 4.89352455891786e-03f};
100 static constexpr std::array<float, 4> denominatorCoeffs{
101 1.19825839466702e-06f, 1.18534705686654e-04f, 2.26843463243900e-03f,
102 4.89352518554385e-03f};
103
104 // Materialize polynomial approximation.
105 Value inputSquared = rewriter.create<arith::MulFOp>(loc, input, input);
106 Value numerator = rewriter.create<arith::ConstantOp>(
107 loc, rewriter.getF32FloatAttr(numeratorCoeffs[0]));
108 for (int64_t i = 1; i < static_cast<int64_t>(numeratorCoeffs.size()); i++) {
109 numerator = rewriter.create<arith::AddFOp>(
110 loc, rewriter.create<arith::MulFOp>(loc, inputSquared, numerator),
111 rewriter.create<arith::ConstantOp>(
112 loc, rewriter.getF32FloatAttr(numeratorCoeffs[i])));
113 }
114 numerator = rewriter.create<arith::MulFOp>(loc, input, numerator);
115 Value denominator = rewriter.create<arith::ConstantOp>(
116 loc, rewriter.getF32FloatAttr(denominatorCoeffs[0]));
117 for (int64_t i = 1; i < static_cast<int64_t>(denominatorCoeffs.size());
118 i++) {
119 denominator = rewriter.create<arith::AddFOp>(
120 loc, rewriter.create<arith::MulFOp>(loc, inputSquared, denominator),
121 rewriter.create<arith::ConstantOp>(
122 loc, rewriter.getF32FloatAttr(denominatorCoeffs[i])));
123 }
124 Value approx = rewriter.create<arith::DivFOp>(loc, numerator, denominator);
125
126 // For small values of |x|, we can approximate tanh(x) = x. For extremely
127 // small values of x (|x| < 1e-37), the other approximation would evaluate
128 // tanh(x) = 0.
129 constexpr float kUseIdentityApprox = 0.0004;
130 Value absInput = rewriter.create<math::AbsFOp>(loc, input);
131 Value useIdentityApprox = rewriter.create<arith::CmpFOp>(
132 loc, arith::CmpFPredicate::OLT, absInput,
133 rewriter.create<arith::ConstantOp>(
134 loc, rewriter.getF32FloatAttr(kUseIdentityApprox)));
135 approx =
136 rewriter.create<arith::SelectOp>(loc, useIdentityApprox, input, approx);
137
138 // For very small/large values, use a constant approximation -1/1.
139 Value tooLargeInput = rewriter.create<arith::CmpFOp>(
140 loc, arith::CmpFPredicate::UGT, input,
141 rewriter.create<arith::ConstantOp>(
142 loc, rewriter.getF32FloatAttr(7.90531110763549805f)));
143 Value tooSmallInput = rewriter.create<arith::CmpFOp>(
144 loc, arith::CmpFPredicate::ULT, input,
145 rewriter.create<arith::ConstantOp>(
146 loc, rewriter.getF32FloatAttr(-7.90531110763549805f)));
147 Value inputIsNan = rewriter.create<arith::CmpFOp>(
148 loc, arith::CmpFPredicate::UNE, input, input);
149 approx = rewriter.create<arith::SelectOp>(
150 loc, tooLargeInput,
151 rewriter.create<arith::ConstantOp>(loc, rewriter.getF32FloatAttr(1.0)),
152 approx);
153 approx = rewriter.create<arith::SelectOp>(
154 loc, tooSmallInput,
155 rewriter.create<arith::ConstantOp>(loc, rewriter.getF32FloatAttr(-1.0)),
156 approx);
157 approx = rewriter.create<arith::SelectOp>(loc, inputIsNan, input, approx);
158
159 return approx;
160 }
161 };
162
163 struct LegalizeTrigonometricToApproximationPass
164 : public LegalizeTanhToApproximationPassBase<
165 LegalizeTrigonometricToApproximationPass> {
166 /// Perform the lowering of standard dialect operations to approximations.
runOnOperationmlir::mhlo::__anon0dfd1be80111::LegalizeTrigonometricToApproximationPass167 void runOnOperation() override {
168 RewritePatternSet patterns(&getContext());
169 populateTrigonometricToApproximationPatterns(&getContext(), &patterns);
170 if (failed(applyPatternsAndFoldGreedily(getOperation(),
171 std::move(patterns)))) {
172 return signalPassFailure();
173 }
174 }
175 };
176
177 } // anonymous namespace
178
179 std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>>
createLegalizeTrigonometricToApproximationPass()180 createLegalizeTrigonometricToApproximationPass() {
181 return std::make_unique<LegalizeTrigonometricToApproximationPass>();
182 }
183
populateTrigonometricToApproximationPatterns(mlir::MLIRContext * context,RewritePatternSet * patterns)184 void populateTrigonometricToApproximationPatterns(mlir::MLIRContext *context,
185 RewritePatternSet *patterns) {
186 // clang-format off
187 patterns->add<ApproximateTanhLowering>(context);
188 // clang-format on
189 }
190
191 } // namespace mhlo
192 } // namespace mlir
193