xref: /aosp_15_r20/external/tensorflow/tensorflow/core/transforms/remapper/pass.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/transforms/remapper/pass.h"
17 
18 #include <memory>
19 #include <string>
20 #include <type_traits>
21 #include <utility>
22 
23 #include "mlir/Dialect/PDL/IR/PDL.h"  // from @llvm-project
24 #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"  // from @llvm-project
25 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
26 #include "mlir/IR/OperationSupport.h"  // from @llvm-project
27 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
28 #include "mlir/Parser/Parser.h"  // from @llvm-project
29 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"  // from @llvm-project
30 #include "tensorflow/core/framework/types.h"
31 #include "tensorflow/core/ir/dialect.h"
32 #include "tensorflow/core/ir/tf_op_wrapper.h"
33 #include "tensorflow/core/transforms/pass_detail.h"
34 #include "tensorflow/core/transforms/remapper/remapping_helper.h"
35 #include "tensorflow/core/transforms/utils/pdll/utils.h"
36 #include "tensorflow/core/transforms/utils/utils.h"
37 #include "tensorflow/core/util/util.h"
38 
39 namespace mlir {
40 namespace tfg {
41 namespace mkl {
42 #include "tensorflow/core/transforms/remapper/pdll/MklPDLLPatterns.h.inc"
43 }  // namespace mkl
44 
45 // Convert Sigmoid+Mul to Swish
46 // Mul(x, Sigmoid(x)) --> _MklSwish(x)
47 class MatchMulSigmoid : public RewritePattern {
48  public:
MatchMulSigmoid(MLIRContext * context)49   explicit MatchMulSigmoid(MLIRContext *context)
50       : RewritePattern("tfg.Mul", PatternBenefit(/*benefit=*/1), context),
51         sigmoid_name_("tfg.Sigmoid", context) {}
52 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const53   LogicalResult matchAndRewrite(Operation *op,
54                                 PatternRewriter &rewriter) const override {
55     TypeAttr dtype_attr = op->getAttrOfType<TypeAttr>("T");
56     if (!dtype_attr.getValue().isa<Float32Type>() &&
57         !dtype_attr.getValue().isa<BFloat16Type>()) {
58       return failure();
59     }
60 
61     if (!util::OpHasDevice(op, tensorflow::DEVICE_CPU)) return failure();
62 
63     TFOp mul_wrapper(op);
64 
65     Value sigmoid = op->getOperand(0);
66     Value x = op->getOperand(1);
67 
68     auto sigmoidOperandEqToX = [&](Value sigmoid, Value x) {
69       Operation *op = sigmoid.getDefiningOp();
70       return op && op->getName() == sigmoid_name_ && op->getOperand(0) == x;
71     };
72 
73     if (!sigmoidOperandEqToX(sigmoid, x)) {
74       // The operands are commutative and it may have both sigmoid operands.
75       // Swap them then check it again.
76       std::swap(sigmoid, x);
77       if (!sigmoidOperandEqToX(sigmoid, x)) return failure();
78     }
79 
80     SmallVector<Value> operands;
81     // Set up non-control operand.
82     operands.push_back(x);
83     // Control operands come after regular operands.
84     llvm::append_range(operands, mul_wrapper.getControlOperands());
85 
86     Operation *new_op =
87         rewriter.create(op->getLoc(), rewriter.getStringAttr("tfg._MklSwish"),
88                         operands, op->getResultTypes(), op->getAttrs());
89     rewriter.replaceOp(op, new_op->getResults());
90 
91     return success();
92   }
93 
94  private:
95   // This is used to eliminate the string comparison by caching the handle of an
96   // operation name.
97   OperationName sigmoid_name_;
98 };
99 
100 // This enum class is used as a template parameter and meant for alias to tfg op
101 // name.
102 // TODO(intel-tf): Add more items as needed.
103 enum class OpKind { Relu, Relu6, Elu, LeakyRelu, Tanh, Sigmoid };
104 
GetTfgOpName(OpKind op_kind)105 inline std::string GetTfgOpName(OpKind op_kind) {
106   switch (op_kind) {
107     case OpKind::Relu:
108       return "tfg.Relu";
109     case OpKind::Relu6:
110       return "tfg.Relu6";
111     case OpKind::Elu:
112       return "tfg.Elu";
113     case OpKind::LeakyRelu:
114       return "tfg.LeakyRelu";
115     case OpKind::Tanh:
116       return "tfg.Tanh";
117     case OpKind::Sigmoid:
118       return "tfg.Sigmoid";
119     default:
120       return "tfg.NoOp";
121   }
122 }
123 
124 class RemapperPatternBase : public RewritePattern {
125  public:
RemapperPatternBase(StringRef opName,OpPropertyHelper & helper,PatternBenefit benefit=PatternBenefit (1))126   RemapperPatternBase(StringRef opName, OpPropertyHelper &helper,
127                       PatternBenefit benefit = PatternBenefit(1))
128       : RewritePattern(opName, benefit, helper.getDialect()->getContext()),
129         helper_(helper) {}
RemapperPatternBase(MatchAnyOpTypeTag tag,OpPropertyHelper & helper,PatternBenefit benefit=PatternBenefit (1))130   RemapperPatternBase(MatchAnyOpTypeTag tag, OpPropertyHelper &helper,
131                       PatternBenefit benefit = PatternBenefit(1))
132       : RewritePattern(tag, benefit, helper.getDialect()->getContext()),
133         helper_(helper) {}
134 
135  protected:
136   OpPropertyHelper helper_;
137 };
138 
GetContractionBiasAddOpState(OpBuilder & builder,const OpPropertyHelper & helper,Operation * contraction_op,Operation * bias_add_op)139 static std::unique_ptr<OperationState> GetContractionBiasAddOpState(
140     OpBuilder &builder, const OpPropertyHelper &helper,
141     Operation *contraction_op, Operation *bias_add_op) {
142   // Fused op name dependes on original contraction op name.
143   std::string fused_op_name;
144   if (helper.getDialect()->IsConv2D(contraction_op)) {
145     fused_op_name = "tfg._FusedConv2D";
146   } else if (helper.getDialect()->IsMatMul(contraction_op)) {
147     fused_op_name = "tfg._FusedMatMul";
148   } else if (helper.getDialect()->IsDepthwiseConv2dNative(contraction_op)) {
149     fused_op_name = "tfg._FusedDepthwiseConv2dNative";
150   } else if (helper.getDialect()->IsConv3D(contraction_op)) {
151     fused_op_name = "tfg._FusedConv3D";
152   } else {
153     return nullptr;
154   }
155 
156   SmallVector<Location> fused_locs{contraction_op->getLoc(),
157                                    bias_add_op->getLoc()};
158   auto state = std::make_unique<OperationState>(builder.getFusedLoc(fused_locs),
159                                                 fused_op_name);
160   SmallVector<Value> operands;
161   Value input = contraction_op->getOperand(0);
162   Value filter = contraction_op->getOperand(1);
163   Value bias = bias_add_op->getOperand(1);
164   operands.push_back(input);
165   operands.push_back(filter);
166   operands.push_back(bias);
167   state->addOperands(operands);
168   state->addOperands(TFOp(contraction_op).getControlOperands());
169   state->addOperands(TFOp(bias_add_op).getControlOperands());
170   state->addTypes(bias_add_op->getResultTypes());
171   state->attributes = contraction_op->getAttrs();
172   state->attributes.set("fused_ops", builder.getStrArrayAttr({"BiasAdd"}));
173   state->attributes.set("num_args", builder.getI32IntegerAttr(1));
174   // Default values for epsilon and leakyrelu_alpha
175   state->attributes.set("epsilon", builder.getF32FloatAttr(0.0001));
176   state->attributes.set("leakyrelu_alpha", builder.getF32FloatAttr(0.2));
177   return state;
178 }
179 
180 // Contraction + BiasAdd
181 // TODO(intel-tf): Support Contraction + {Add, AddV2} fusion in the case it has
182 // similar semantic of contraction + BiasAdd
183 class ContractionBiasAddRewriter : public RemapperPatternBase {
184  public:
ContractionBiasAddRewriter(OpPropertyHelper & helper)185   explicit ContractionBiasAddRewriter(OpPropertyHelper &helper)
186       : RemapperPatternBase("tfg.BiasAdd", helper, PatternBenefit(1)) {}
187 
188   // Constructor used by derived pattern rewritter class that may have
189   // different root operation name. Currently, pattern is
190   // matched from root op to its inputs.
ContractionBiasAddRewriter(StringRef op_name,OpPropertyHelper & helper,PatternBenefit benefit)191   explicit ContractionBiasAddRewriter(StringRef op_name,
192                                       OpPropertyHelper &helper,
193                                       PatternBenefit benefit)
194       : RemapperPatternBase(op_name, helper, benefit) {}
195 
196   using Pattern = ContractionBiasAdd;
197 
matchPattern(Operation * op,Pattern & pattern) const198   bool matchPattern(Operation *op, Pattern &pattern) const {
199     // Not allowing control flow on BiasAdd
200     if (helper_.HasControlOperandsOrResultUsers(op)) return false;
201     Operation *contraction_op = op->getOperand(0).getDefiningOp();
202     if (!helper_.IsContraction(contraction_op) ||
203         helper_.HasControlOperandsOrResultUsers(contraction_op) ||
204         !helper_.HaveSameDataType(op, contraction_op) ||
205         !helper_.HasAtMostOneUserOfResult0(contraction_op)) {
206       return false;
207     }
208     pattern.contraction = contraction_op;
209     pattern.bias_add = op;
210     return true;
211   }
212 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const213   LogicalResult matchAndRewrite(Operation *op,
214                                 PatternRewriter &rewriter) const override {
215     Pattern pattern;
216     if (!matchPattern(op, pattern)) return failure();
217     if (!helper_.IsDeviceCompatible(pattern)) return failure();
218     std::unique_ptr<OperationState> state = GetContractionBiasAddOpState(
219         rewriter, helper_, pattern.contraction, pattern.bias_add);
220     Operation *fused_op = rewriter.create(*state);
221     TFOp(fused_op).setName(TFOp(op).nameAttr());
222     rewriter.replaceOp(op, fused_op->getResults());
223     return success();
224   }
225 };
226 
227 // BasePattern + Activation
228 template <typename BasePatternRewriter, OpKind activation>
229 class BasePatternActivationRewriter : public BasePatternRewriter {
230  public:
BasePatternActivationRewriter(OpPropertyHelper & helper)231   explicit BasePatternActivationRewriter(OpPropertyHelper &helper)
232       : BasePatternRewriter(GetTfgOpName(activation), helper,
233                             PatternBenefit(1)) {}
234 
235   using BasePattern = typename BasePatternRewriter::Pattern;
236   using Pattern = std::conditional_t<
237       std::is_same<BasePatternRewriter, ContractionBiasAddRewriter>::value,
238       ContractionBiasAddActivation, void>;
239 
matchPattern(Operation * op,BasePattern & base_pattern,Pattern & pattern) const240   bool matchPattern(Operation *op, BasePattern &base_pattern,
241                     Pattern &pattern) const {
242     // Although template instantiation guarantuees that only valid activation is
243     // set as the root operation, a sanity check is added here.
244     if (this->helper_.getDialect()->IsNoOp(op)) return false;
245     if (this->helper_.HasControlOperandsOrResultUsers(op)) return false;
246 
247     // TODO(intel-tf): Add support for more patterns.
248     if constexpr (std::is_same<BasePattern, ContractionBiasAdd>::value &&
249                   std::is_same<Pattern, ContractionBiasAddActivation>::value) {
250       Operation *bias_add_op = op->getOperand(0).getDefiningOp();
251       if (!this->helper_.getDialect()->IsBiasAdd(bias_add_op) ||
252           !this->helper_.HaveSameDataType(op, bias_add_op) ||
253           !this->helper_.HasAtMostOneUserOfResult0(bias_add_op)) {
254         return false;
255       }
256       if (!BasePatternRewriter::matchPattern(bias_add_op, base_pattern)) {
257         return false;
258       }
259       pattern.contraction = base_pattern.contraction;
260       pattern.bias_add = base_pattern.bias_add;
261       pattern.activation = op;
262       return true;
263     }
264 
265     return false;
266   }
267 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const268   LogicalResult matchAndRewrite(Operation *op,
269                                 PatternRewriter &rewriter) const override {
270     BasePattern base_pattern;
271     Pattern pattern;
272     if (!matchPattern(op, base_pattern, pattern)) return failure();
273     if constexpr (!std::is_same<BasePatternRewriter,
274                                 ContractionBiasAddRewriter>::value) {
275       return failure();
276     }
277     if (!this->helper_.IsDeviceCompatible(pattern)) return failure();
278     Operation *&contraction_op = pattern.contraction;
279     Operation *&bias_add_op = pattern.bias_add;
280     Operation *&activation_op = pattern.activation;
281     const std::string activation_op_name =
282         activation_op->getName().stripDialect().str();
283     // Currently, supported activations are:
284     //    _FusedMatMul: Relu, Relu6, Elu, LeakyRelu, Tanh, and Sigmoid
285     //    _Fused*Conv*: Relu, Relu6, Elu and LeakyRelu
286     if ((activation_op_name == "Tanh" || activation_op_name == "Sigmoid") &&
287         !this->helper_.getDialect()->IsMatMul(contraction_op)) {
288       return failure();
289     }
290 
291     std::unique_ptr<OperationState> state = GetContractionBiasAddOpState(
292         rewriter, this->helper_, contraction_op, bias_add_op);
293     SmallVector<Location> fused_locs{state->location, activation_op->getLoc()};
294     state->location = rewriter.getFusedLoc(fused_locs);
295     state->attributes.set(
296         "fused_ops", rewriter.getStrArrayAttr({"BiasAdd", activation_op_name}));
297     if (this->helper_.getDialect()->IsLeakyRelu(activation_op)) {
298       state->attributes.set("leakyrelu_alpha", activation_op->getAttr("alpha"));
299     }
300     Operation *fused_op = rewriter.create(*state);
301     TFOp(fused_op).setName(TFOp(op).nameAttr());
302     rewriter.replaceOp(op, fused_op->getResults());
303     return success();
304   }
305 };
306 
307 template <template <OpKind> class PatternT, OpKind... op_kinds,
308           typename... Args>
InsertPatterns(RewritePatternSet & patterns,Args &&...args)309 static void InsertPatterns(RewritePatternSet &patterns, Args &&...args) {
310   patterns.insert<PatternT<op_kinds>...>(std::forward<Args>(args)...);
311 }
312 
313 template <OpKind activation>
314 using ContractionBiasAddActivationRewriter =
315     BasePatternActivationRewriter<ContractionBiasAddRewriter, activation>;
316 
317 class Remapper : public RemapperBase<Remapper> {
318  public:
319   Remapper() = default;
Remapper(bool enable_onednn_patterns,bool xla_auto_clustering)320   explicit Remapper(bool enable_onednn_patterns, bool xla_auto_clustering) {
321     enable_onednn_patterns_ = enable_onednn_patterns;
322     xla_auto_clustering_ = xla_auto_clustering;
323   }
324 
getDependentDialects(DialectRegistry & registry) const325   void getDependentDialects(DialectRegistry &registry) const override {
326     registry.insert<pdl::PDLDialect, pdl_interp::PDLInterpDialect>();
327   }
328 
initialize(MLIRContext * context)329   LogicalResult initialize(MLIRContext *context) override {
330     helper_ = OpPropertyHelper(context->getOrLoadDialect<TFGraphDialect>(),
331                                enable_onednn_patterns_, xla_auto_clustering_);
332     RewritePatternSet patterns(context);
333     populateRemapperPatterns(context, patterns);
334     RegisterPDLLUtils(patterns);
335     final_patterns_ = std::move(patterns);
336     return success();
337   }
338 
339   void runOnOperation() override;
340 
341  private:
populateRemapperPatterns(MLIRContext * context,RewritePatternSet & patterns)342   void populateRemapperPatterns(MLIRContext *context,
343                                 RewritePatternSet &patterns) {
344     if (verify_pdll_patterns_only_) {
345       populateRemapperPDLLPatterns(patterns);
346       return;
347     }
348     if (enable_onednn_patterns_) {
349       patterns.insert<MatchMulSigmoid>(context);
350       // TODO(chiahungduan): Currently, the only pattern implemented in PDLL is
351       // the same one as `MatchMulSigmoid`. Remove the one of them when there's
352       // a decision that which one is preferred.
353       populateRemapperPDLLPatterns(patterns);
354     }
355     patterns.insert<ContractionBiasAddRewriter>(helper_);
356     // Insert multiple pattern rewriters from template instantiations by
357     // activation ops.
358     InsertPatterns<ContractionBiasAddActivationRewriter, OpKind::Relu,
359                    OpKind::Relu6, OpKind::Elu, OpKind::LeakyRelu, OpKind::Tanh,
360                    OpKind::Sigmoid>(patterns, helper_);
361   }
362 
populateRemapperPDLLPatterns(RewritePatternSet & patterns)363   void populateRemapperPDLLPatterns(RewritePatternSet &patterns) {
364     mkl::populateGeneratedPDLLPatterns(patterns);
365   }
366 
367   FrozenRewritePatternSet final_patterns_;
368   OpPropertyHelper helper_;
369 };
370 
runOnOperation()371 void Remapper::runOnOperation() {
372   if (failed(applyPatternsAndFoldGreedily(getOperation(), final_patterns_))) {
373     signalPassFailure();
374   }
375 }
376 
CreateRemapperPass(bool enable_onednn_patterns,bool xla_auto_clustering)377 std::unique_ptr<Pass> CreateRemapperPass(bool enable_onednn_patterns,
378                                          bool xla_auto_clustering) {
379   return std::make_unique<Remapper>(enable_onednn_patterns,
380                                     xla_auto_clustering);
381 }
382 
383 }  // namespace tfg
384 }  // namespace mlir
385