1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/transforms/remapper/pass.h"
17
18 #include <memory>
19 #include <string>
20 #include <type_traits>
21 #include <utility>
22
23 #include "mlir/Dialect/PDL/IR/PDL.h" // from @llvm-project
24 #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h" // from @llvm-project
25 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
26 #include "mlir/IR/OperationSupport.h" // from @llvm-project
27 #include "mlir/IR/PatternMatch.h" // from @llvm-project
28 #include "mlir/Parser/Parser.h" // from @llvm-project
29 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project
30 #include "tensorflow/core/framework/types.h"
31 #include "tensorflow/core/ir/dialect.h"
32 #include "tensorflow/core/ir/tf_op_wrapper.h"
33 #include "tensorflow/core/transforms/pass_detail.h"
34 #include "tensorflow/core/transforms/remapper/remapping_helper.h"
35 #include "tensorflow/core/transforms/utils/pdll/utils.h"
36 #include "tensorflow/core/transforms/utils/utils.h"
37 #include "tensorflow/core/util/util.h"
38
39 namespace mlir {
40 namespace tfg {
41 namespace mkl {
42 #include "tensorflow/core/transforms/remapper/pdll/MklPDLLPatterns.h.inc"
43 } // namespace mkl
44
45 // Convert Sigmoid+Mul to Swish
46 // Mul(x, Sigmoid(x)) --> _MklSwish(x)
47 class MatchMulSigmoid : public RewritePattern {
48 public:
MatchMulSigmoid(MLIRContext * context)49 explicit MatchMulSigmoid(MLIRContext *context)
50 : RewritePattern("tfg.Mul", PatternBenefit(/*benefit=*/1), context),
51 sigmoid_name_("tfg.Sigmoid", context) {}
52
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const53 LogicalResult matchAndRewrite(Operation *op,
54 PatternRewriter &rewriter) const override {
55 TypeAttr dtype_attr = op->getAttrOfType<TypeAttr>("T");
56 if (!dtype_attr.getValue().isa<Float32Type>() &&
57 !dtype_attr.getValue().isa<BFloat16Type>()) {
58 return failure();
59 }
60
61 if (!util::OpHasDevice(op, tensorflow::DEVICE_CPU)) return failure();
62
63 TFOp mul_wrapper(op);
64
65 Value sigmoid = op->getOperand(0);
66 Value x = op->getOperand(1);
67
68 auto sigmoidOperandEqToX = [&](Value sigmoid, Value x) {
69 Operation *op = sigmoid.getDefiningOp();
70 return op && op->getName() == sigmoid_name_ && op->getOperand(0) == x;
71 };
72
73 if (!sigmoidOperandEqToX(sigmoid, x)) {
74 // The operands are commutative and it may have both sigmoid operands.
75 // Swap them then check it again.
76 std::swap(sigmoid, x);
77 if (!sigmoidOperandEqToX(sigmoid, x)) return failure();
78 }
79
80 SmallVector<Value> operands;
81 // Set up non-control operand.
82 operands.push_back(x);
83 // Control operands come after regular operands.
84 llvm::append_range(operands, mul_wrapper.getControlOperands());
85
86 Operation *new_op =
87 rewriter.create(op->getLoc(), rewriter.getStringAttr("tfg._MklSwish"),
88 operands, op->getResultTypes(), op->getAttrs());
89 rewriter.replaceOp(op, new_op->getResults());
90
91 return success();
92 }
93
94 private:
95 // This is used to eliminate the string comparison by caching the handle of an
96 // operation name.
97 OperationName sigmoid_name_;
98 };
99
100 // This enum class is used as a template parameter and meant for alias to tfg op
101 // name.
102 // TODO(intel-tf): Add more items as needed.
103 enum class OpKind { Relu, Relu6, Elu, LeakyRelu, Tanh, Sigmoid };
104
GetTfgOpName(OpKind op_kind)105 inline std::string GetTfgOpName(OpKind op_kind) {
106 switch (op_kind) {
107 case OpKind::Relu:
108 return "tfg.Relu";
109 case OpKind::Relu6:
110 return "tfg.Relu6";
111 case OpKind::Elu:
112 return "tfg.Elu";
113 case OpKind::LeakyRelu:
114 return "tfg.LeakyRelu";
115 case OpKind::Tanh:
116 return "tfg.Tanh";
117 case OpKind::Sigmoid:
118 return "tfg.Sigmoid";
119 default:
120 return "tfg.NoOp";
121 }
122 }
123
124 class RemapperPatternBase : public RewritePattern {
125 public:
RemapperPatternBase(StringRef opName,OpPropertyHelper & helper,PatternBenefit benefit=PatternBenefit (1))126 RemapperPatternBase(StringRef opName, OpPropertyHelper &helper,
127 PatternBenefit benefit = PatternBenefit(1))
128 : RewritePattern(opName, benefit, helper.getDialect()->getContext()),
129 helper_(helper) {}
RemapperPatternBase(MatchAnyOpTypeTag tag,OpPropertyHelper & helper,PatternBenefit benefit=PatternBenefit (1))130 RemapperPatternBase(MatchAnyOpTypeTag tag, OpPropertyHelper &helper,
131 PatternBenefit benefit = PatternBenefit(1))
132 : RewritePattern(tag, benefit, helper.getDialect()->getContext()),
133 helper_(helper) {}
134
135 protected:
136 OpPropertyHelper helper_;
137 };
138
GetContractionBiasAddOpState(OpBuilder & builder,const OpPropertyHelper & helper,Operation * contraction_op,Operation * bias_add_op)139 static std::unique_ptr<OperationState> GetContractionBiasAddOpState(
140 OpBuilder &builder, const OpPropertyHelper &helper,
141 Operation *contraction_op, Operation *bias_add_op) {
142 // Fused op name dependes on original contraction op name.
143 std::string fused_op_name;
144 if (helper.getDialect()->IsConv2D(contraction_op)) {
145 fused_op_name = "tfg._FusedConv2D";
146 } else if (helper.getDialect()->IsMatMul(contraction_op)) {
147 fused_op_name = "tfg._FusedMatMul";
148 } else if (helper.getDialect()->IsDepthwiseConv2dNative(contraction_op)) {
149 fused_op_name = "tfg._FusedDepthwiseConv2dNative";
150 } else if (helper.getDialect()->IsConv3D(contraction_op)) {
151 fused_op_name = "tfg._FusedConv3D";
152 } else {
153 return nullptr;
154 }
155
156 SmallVector<Location> fused_locs{contraction_op->getLoc(),
157 bias_add_op->getLoc()};
158 auto state = std::make_unique<OperationState>(builder.getFusedLoc(fused_locs),
159 fused_op_name);
160 SmallVector<Value> operands;
161 Value input = contraction_op->getOperand(0);
162 Value filter = contraction_op->getOperand(1);
163 Value bias = bias_add_op->getOperand(1);
164 operands.push_back(input);
165 operands.push_back(filter);
166 operands.push_back(bias);
167 state->addOperands(operands);
168 state->addOperands(TFOp(contraction_op).getControlOperands());
169 state->addOperands(TFOp(bias_add_op).getControlOperands());
170 state->addTypes(bias_add_op->getResultTypes());
171 state->attributes = contraction_op->getAttrs();
172 state->attributes.set("fused_ops", builder.getStrArrayAttr({"BiasAdd"}));
173 state->attributes.set("num_args", builder.getI32IntegerAttr(1));
174 // Default values for epsilon and leakyrelu_alpha
175 state->attributes.set("epsilon", builder.getF32FloatAttr(0.0001));
176 state->attributes.set("leakyrelu_alpha", builder.getF32FloatAttr(0.2));
177 return state;
178 }
179
180 // Contraction + BiasAdd
181 // TODO(intel-tf): Support Contraction + {Add, AddV2} fusion in the case it has
182 // similar semantic of contraction + BiasAdd
183 class ContractionBiasAddRewriter : public RemapperPatternBase {
184 public:
ContractionBiasAddRewriter(OpPropertyHelper & helper)185 explicit ContractionBiasAddRewriter(OpPropertyHelper &helper)
186 : RemapperPatternBase("tfg.BiasAdd", helper, PatternBenefit(1)) {}
187
188 // Constructor used by derived pattern rewritter class that may have
189 // different root operation name. Currently, pattern is
190 // matched from root op to its inputs.
ContractionBiasAddRewriter(StringRef op_name,OpPropertyHelper & helper,PatternBenefit benefit)191 explicit ContractionBiasAddRewriter(StringRef op_name,
192 OpPropertyHelper &helper,
193 PatternBenefit benefit)
194 : RemapperPatternBase(op_name, helper, benefit) {}
195
196 using Pattern = ContractionBiasAdd;
197
matchPattern(Operation * op,Pattern & pattern) const198 bool matchPattern(Operation *op, Pattern &pattern) const {
199 // Not allowing control flow on BiasAdd
200 if (helper_.HasControlOperandsOrResultUsers(op)) return false;
201 Operation *contraction_op = op->getOperand(0).getDefiningOp();
202 if (!helper_.IsContraction(contraction_op) ||
203 helper_.HasControlOperandsOrResultUsers(contraction_op) ||
204 !helper_.HaveSameDataType(op, contraction_op) ||
205 !helper_.HasAtMostOneUserOfResult0(contraction_op)) {
206 return false;
207 }
208 pattern.contraction = contraction_op;
209 pattern.bias_add = op;
210 return true;
211 }
212
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const213 LogicalResult matchAndRewrite(Operation *op,
214 PatternRewriter &rewriter) const override {
215 Pattern pattern;
216 if (!matchPattern(op, pattern)) return failure();
217 if (!helper_.IsDeviceCompatible(pattern)) return failure();
218 std::unique_ptr<OperationState> state = GetContractionBiasAddOpState(
219 rewriter, helper_, pattern.contraction, pattern.bias_add);
220 Operation *fused_op = rewriter.create(*state);
221 TFOp(fused_op).setName(TFOp(op).nameAttr());
222 rewriter.replaceOp(op, fused_op->getResults());
223 return success();
224 }
225 };
226
227 // BasePattern + Activation
228 template <typename BasePatternRewriter, OpKind activation>
229 class BasePatternActivationRewriter : public BasePatternRewriter {
230 public:
BasePatternActivationRewriter(OpPropertyHelper & helper)231 explicit BasePatternActivationRewriter(OpPropertyHelper &helper)
232 : BasePatternRewriter(GetTfgOpName(activation), helper,
233 PatternBenefit(1)) {}
234
235 using BasePattern = typename BasePatternRewriter::Pattern;
236 using Pattern = std::conditional_t<
237 std::is_same<BasePatternRewriter, ContractionBiasAddRewriter>::value,
238 ContractionBiasAddActivation, void>;
239
matchPattern(Operation * op,BasePattern & base_pattern,Pattern & pattern) const240 bool matchPattern(Operation *op, BasePattern &base_pattern,
241 Pattern &pattern) const {
242 // Although template instantiation guarantuees that only valid activation is
243 // set as the root operation, a sanity check is added here.
244 if (this->helper_.getDialect()->IsNoOp(op)) return false;
245 if (this->helper_.HasControlOperandsOrResultUsers(op)) return false;
246
247 // TODO(intel-tf): Add support for more patterns.
248 if constexpr (std::is_same<BasePattern, ContractionBiasAdd>::value &&
249 std::is_same<Pattern, ContractionBiasAddActivation>::value) {
250 Operation *bias_add_op = op->getOperand(0).getDefiningOp();
251 if (!this->helper_.getDialect()->IsBiasAdd(bias_add_op) ||
252 !this->helper_.HaveSameDataType(op, bias_add_op) ||
253 !this->helper_.HasAtMostOneUserOfResult0(bias_add_op)) {
254 return false;
255 }
256 if (!BasePatternRewriter::matchPattern(bias_add_op, base_pattern)) {
257 return false;
258 }
259 pattern.contraction = base_pattern.contraction;
260 pattern.bias_add = base_pattern.bias_add;
261 pattern.activation = op;
262 return true;
263 }
264
265 return false;
266 }
267
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const268 LogicalResult matchAndRewrite(Operation *op,
269 PatternRewriter &rewriter) const override {
270 BasePattern base_pattern;
271 Pattern pattern;
272 if (!matchPattern(op, base_pattern, pattern)) return failure();
273 if constexpr (!std::is_same<BasePatternRewriter,
274 ContractionBiasAddRewriter>::value) {
275 return failure();
276 }
277 if (!this->helper_.IsDeviceCompatible(pattern)) return failure();
278 Operation *&contraction_op = pattern.contraction;
279 Operation *&bias_add_op = pattern.bias_add;
280 Operation *&activation_op = pattern.activation;
281 const std::string activation_op_name =
282 activation_op->getName().stripDialect().str();
283 // Currently, supported activations are:
284 // _FusedMatMul: Relu, Relu6, Elu, LeakyRelu, Tanh, and Sigmoid
285 // _Fused*Conv*: Relu, Relu6, Elu and LeakyRelu
286 if ((activation_op_name == "Tanh" || activation_op_name == "Sigmoid") &&
287 !this->helper_.getDialect()->IsMatMul(contraction_op)) {
288 return failure();
289 }
290
291 std::unique_ptr<OperationState> state = GetContractionBiasAddOpState(
292 rewriter, this->helper_, contraction_op, bias_add_op);
293 SmallVector<Location> fused_locs{state->location, activation_op->getLoc()};
294 state->location = rewriter.getFusedLoc(fused_locs);
295 state->attributes.set(
296 "fused_ops", rewriter.getStrArrayAttr({"BiasAdd", activation_op_name}));
297 if (this->helper_.getDialect()->IsLeakyRelu(activation_op)) {
298 state->attributes.set("leakyrelu_alpha", activation_op->getAttr("alpha"));
299 }
300 Operation *fused_op = rewriter.create(*state);
301 TFOp(fused_op).setName(TFOp(op).nameAttr());
302 rewriter.replaceOp(op, fused_op->getResults());
303 return success();
304 }
305 };
306
307 template <template <OpKind> class PatternT, OpKind... op_kinds,
308 typename... Args>
InsertPatterns(RewritePatternSet & patterns,Args &&...args)309 static void InsertPatterns(RewritePatternSet &patterns, Args &&...args) {
310 patterns.insert<PatternT<op_kinds>...>(std::forward<Args>(args)...);
311 }
312
313 template <OpKind activation>
314 using ContractionBiasAddActivationRewriter =
315 BasePatternActivationRewriter<ContractionBiasAddRewriter, activation>;
316
317 class Remapper : public RemapperBase<Remapper> {
318 public:
319 Remapper() = default;
Remapper(bool enable_onednn_patterns,bool xla_auto_clustering)320 explicit Remapper(bool enable_onednn_patterns, bool xla_auto_clustering) {
321 enable_onednn_patterns_ = enable_onednn_patterns;
322 xla_auto_clustering_ = xla_auto_clustering;
323 }
324
getDependentDialects(DialectRegistry & registry) const325 void getDependentDialects(DialectRegistry ®istry) const override {
326 registry.insert<pdl::PDLDialect, pdl_interp::PDLInterpDialect>();
327 }
328
initialize(MLIRContext * context)329 LogicalResult initialize(MLIRContext *context) override {
330 helper_ = OpPropertyHelper(context->getOrLoadDialect<TFGraphDialect>(),
331 enable_onednn_patterns_, xla_auto_clustering_);
332 RewritePatternSet patterns(context);
333 populateRemapperPatterns(context, patterns);
334 RegisterPDLLUtils(patterns);
335 final_patterns_ = std::move(patterns);
336 return success();
337 }
338
339 void runOnOperation() override;
340
341 private:
populateRemapperPatterns(MLIRContext * context,RewritePatternSet & patterns)342 void populateRemapperPatterns(MLIRContext *context,
343 RewritePatternSet &patterns) {
344 if (verify_pdll_patterns_only_) {
345 populateRemapperPDLLPatterns(patterns);
346 return;
347 }
348 if (enable_onednn_patterns_) {
349 patterns.insert<MatchMulSigmoid>(context);
350 // TODO(chiahungduan): Currently, the only pattern implemented in PDLL is
351 // the same one as `MatchMulSigmoid`. Remove the one of them when there's
352 // a decision that which one is preferred.
353 populateRemapperPDLLPatterns(patterns);
354 }
355 patterns.insert<ContractionBiasAddRewriter>(helper_);
356 // Insert multiple pattern rewriters from template instantiations by
357 // activation ops.
358 InsertPatterns<ContractionBiasAddActivationRewriter, OpKind::Relu,
359 OpKind::Relu6, OpKind::Elu, OpKind::LeakyRelu, OpKind::Tanh,
360 OpKind::Sigmoid>(patterns, helper_);
361 }
362
populateRemapperPDLLPatterns(RewritePatternSet & patterns)363 void populateRemapperPDLLPatterns(RewritePatternSet &patterns) {
364 mkl::populateGeneratedPDLLPatterns(patterns);
365 }
366
367 FrozenRewritePatternSet final_patterns_;
368 OpPropertyHelper helper_;
369 };
370
runOnOperation()371 void Remapper::runOnOperation() {
372 if (failed(applyPatternsAndFoldGreedily(getOperation(), final_patterns_))) {
373 signalPassFailure();
374 }
375 }
376
CreateRemapperPass(bool enable_onednn_patterns,bool xla_auto_clustering)377 std::unique_ptr<Pass> CreateRemapperPass(bool enable_onednn_patterns,
378 bool xla_auto_clustering) {
379 return std::make_unique<Remapper>(enable_onednn_patterns,
380 xla_auto_clustering);
381 }
382
383 } // namespace tfg
384 } // namespace mlir
385