1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This file implements the lowering for trigonometric standard ops to
17 // approximations.
18 
19 #include <utility>
20 
21 #include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
22 #include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
23 #include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
24 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
25 #include "mlir/Dialect/Func/IR/FuncOps.h"
26 #include "mlir/Dialect/Math/IR/Math.h"
27 #include "mlir/IR/BuiltinOps.h"
28 #include "mlir/Pass/Pass.h"
29 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
30 
31 namespace mlir {
32 namespace mhlo {
33 namespace {
34 
35 template <typename OpTy>
36 class ApproximateOnExtendedF32Lowering : public OpRewritePattern<OpTy> {
37  public:
ApproximateOnExtendedF32Lowering(MLIRContext * ctx)38   explicit ApproximateOnExtendedF32Lowering(MLIRContext *ctx)
39       : OpRewritePattern<OpTy>(ctx, /*benefit=*/100) {}
40 
41   virtual Value emitApproximation(ValueRange, Location,
42                                   PatternRewriter &) const = 0;
43 
matchAndRewrite(OpTy op,PatternRewriter & rewriter) const44   LogicalResult matchAndRewrite(OpTy op,
45                                 PatternRewriter &rewriter) const override {
46     Location loc = op.getLoc();
47     auto rawArgs = op.getOperation()->getOperands();
48 
49     // Supports only f16 and f32 for now.
50     if (!op.getType().isF16() && !op.getType().isF32()) return failure();
51 
52     // Extend operands to f32 if needed and possible.
53     SmallVector<Value, 2> f32Args;
54     f32Args.reserve(rawArgs.size());
55     for (Value arg : rawArgs) {
56       // Similar to XLA, do not rewrite f64 as precision might matter.
57       Type argTy = arg.getType();
58       if (argTy.isF64()) return failure();
59 
60       if (argTy.isF16())
61         arg = rewriter.create<arith::ExtFOp>(loc, rewriter.getF32Type(), arg);
62 
63       // If we still do not have f32, fail.
64       if (!arg.getType().isF32()) return failure();
65 
66       f32Args.push_back(arg);
67     }
68 
69     Value result = emitApproximation(f32Args, loc, rewriter);
70     assert(result.getType().isF32() && "Expect f32 intermediate result.");
71 
72     // Truncate back if needed.
73     if (op.getType().isF16())
74       result =
75           rewriter.create<arith::TruncFOp>(loc, rewriter.getF16Type(), result);
76 
77     rewriter.replaceOp(op, {result});
78     return success();
79   }
80 };
81 
82 // This approximation resembles Eigen and realizes a constant approximation for
83 // the +/-1 limits on top.
84 // https://gitlab.com/libeigen/eigen/-/blob/master/Eigen/src/Core/MathFunctionsImpl.h
85 class ApproximateTanhLowering
86     : public ApproximateOnExtendedF32Lowering<math::TanhOp> {
87  public:
ApproximateTanhLowering(MLIRContext * ctx)88   explicit ApproximateTanhLowering(MLIRContext *ctx)
89       : ApproximateOnExtendedF32Lowering<math::TanhOp>(ctx) {}
90 
91   // Emits the fast tanh approximation that is also used by XLA.
emitApproximation(ValueRange args,Location loc,PatternRewriter & rewriter) const92   Value emitApproximation(ValueRange args, Location loc,
93                           PatternRewriter &rewriter) const override {
94     Value input = args.front();
95     assert(input.getType().isF32());
96     static constexpr std::array<float, 7> numeratorCoeffs{
97         -2.76076847742355e-16f, 2.00018790482477e-13f, -8.60467152213735e-11f,
98         5.12229709037114e-08f,  1.48572235717979e-05f, 6.37261928875436e-04f,
99         4.89352455891786e-03f};
100     static constexpr std::array<float, 4> denominatorCoeffs{
101         1.19825839466702e-06f, 1.18534705686654e-04f, 2.26843463243900e-03f,
102         4.89352518554385e-03f};
103 
104     // Materialize polynomial approximation.
105     Value inputSquared = rewriter.create<arith::MulFOp>(loc, input, input);
106     Value numerator = rewriter.create<arith::ConstantOp>(
107         loc, rewriter.getF32FloatAttr(numeratorCoeffs[0]));
108     for (int64_t i = 1; i < static_cast<int64_t>(numeratorCoeffs.size()); i++) {
109       numerator = rewriter.create<arith::AddFOp>(
110           loc, rewriter.create<arith::MulFOp>(loc, inputSquared, numerator),
111           rewriter.create<arith::ConstantOp>(
112               loc, rewriter.getF32FloatAttr(numeratorCoeffs[i])));
113     }
114     numerator = rewriter.create<arith::MulFOp>(loc, input, numerator);
115     Value denominator = rewriter.create<arith::ConstantOp>(
116         loc, rewriter.getF32FloatAttr(denominatorCoeffs[0]));
117     for (int64_t i = 1; i < static_cast<int64_t>(denominatorCoeffs.size());
118          i++) {
119       denominator = rewriter.create<arith::AddFOp>(
120           loc, rewriter.create<arith::MulFOp>(loc, inputSquared, denominator),
121           rewriter.create<arith::ConstantOp>(
122               loc, rewriter.getF32FloatAttr(denominatorCoeffs[i])));
123     }
124     Value approx = rewriter.create<arith::DivFOp>(loc, numerator, denominator);
125 
126     // For small values of |x|, we can approximate tanh(x) = x. For extremely
127     // small values of x (|x| < 1e-37), the other approximation would evaluate
128     // tanh(x) = 0.
129     constexpr float kUseIdentityApprox = 0.0004;
130     Value absInput = rewriter.create<math::AbsFOp>(loc, input);
131     Value useIdentityApprox = rewriter.create<arith::CmpFOp>(
132         loc, arith::CmpFPredicate::OLT, absInput,
133         rewriter.create<arith::ConstantOp>(
134             loc, rewriter.getF32FloatAttr(kUseIdentityApprox)));
135     approx =
136         rewriter.create<arith::SelectOp>(loc, useIdentityApprox, input, approx);
137 
138     // For very small/large values, use a constant approximation -1/1.
139     Value tooLargeInput = rewriter.create<arith::CmpFOp>(
140         loc, arith::CmpFPredicate::UGT, input,
141         rewriter.create<arith::ConstantOp>(
142             loc, rewriter.getF32FloatAttr(7.90531110763549805f)));
143     Value tooSmallInput = rewriter.create<arith::CmpFOp>(
144         loc, arith::CmpFPredicate::ULT, input,
145         rewriter.create<arith::ConstantOp>(
146             loc, rewriter.getF32FloatAttr(-7.90531110763549805f)));
147     Value inputIsNan = rewriter.create<arith::CmpFOp>(
148         loc, arith::CmpFPredicate::UNE, input, input);
149     approx = rewriter.create<arith::SelectOp>(
150         loc, tooLargeInput,
151         rewriter.create<arith::ConstantOp>(loc, rewriter.getF32FloatAttr(1.0)),
152         approx);
153     approx = rewriter.create<arith::SelectOp>(
154         loc, tooSmallInput,
155         rewriter.create<arith::ConstantOp>(loc, rewriter.getF32FloatAttr(-1.0)),
156         approx);
157     approx = rewriter.create<arith::SelectOp>(loc, inputIsNan, input, approx);
158 
159     return approx;
160   }
161 };
162 
163 struct LegalizeTrigonometricToApproximationPass
164     : public LegalizeTanhToApproximationPassBase<
165           LegalizeTrigonometricToApproximationPass> {
166   /// Perform the lowering of standard dialect operations to approximations.
runOnOperationmlir::mhlo::__anon0dfd1be80111::LegalizeTrigonometricToApproximationPass167   void runOnOperation() override {
168     RewritePatternSet patterns(&getContext());
169     populateTrigonometricToApproximationPatterns(&getContext(), &patterns);
170     if (failed(applyPatternsAndFoldGreedily(getOperation(),
171                                             std::move(patterns)))) {
172       return signalPassFailure();
173     }
174   }
175 };
176 
177 }  // anonymous namespace
178 
179 std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>>
createLegalizeTrigonometricToApproximationPass()180 createLegalizeTrigonometricToApproximationPass() {
181   return std::make_unique<LegalizeTrigonometricToApproximationPass>();
182 }
183 
populateTrigonometricToApproximationPatterns(mlir::MLIRContext * context,RewritePatternSet * patterns)184 void populateTrigonometricToApproximationPatterns(mlir::MLIRContext *context,
185                                                   RewritePatternSet *patterns) {
186   // clang-format off
187   patterns->add<ApproximateTanhLowering>(context);
188   // clang-format on
189 }
190 
191 }  // namespace mhlo
192 }  // namespace mlir
193