1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <cstdio>
17 #include <iostream>
18 #include <string>
19 
20 #include "llvm/ADT/StringRef.h"
21 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
22 #include "mlir/IR/Attributes.h"  // from @llvm-project
23 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
24 #include "mlir/IR/Diagnostics.h"  // from @llvm-project
25 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
26 #include "mlir/Pass/Pass.h"  // from @llvm-project
27 #include "mlir/Support/LLVM.h"  // from @llvm-project
28 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
29 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"  // from @llvm-project
30 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
31 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
32 
33 namespace mlir {
34 
35 namespace TF {
36 
37 namespace {
38 
39 // Note: This implements the fusions performed in the old Remapper Grappler
40 // pass. That pass has specific cases for GPU and based on different
41 // target configurations on both CPU and GPU (Intel MKL, ROCm, etc.). This MLIR
42 // pass covers (some of) the general CPU case and at the moment does not account
43 // for any target-specific configurations.
44 
45 // This pass is being ported over from the Grappler Remapper pass based on
46 // need/usage. File a bug to request porting over additional fusions.
47 
48 // TODO(b/158265178): Support GPU-specific fusions.
49 // TODO(b/158266710): Support CPU MKL configurations.
50 
51 // Optimizes TF computations by fusing subgraphs/nodes onto more efficient
52 // implementations to decrease the number of operations needed to perform a
53 // computation.
54 struct FusedKernelMatcherPass
55     : public FusedKernelMatcherPassBase<FusedKernelMatcherPass> {
56   void runOnOperation() override;
57 };
58 
IsActivationFunction(Operation * op)59 bool IsActivationFunction(Operation *op) {
60   return isa<EluOp, ReluOp, Relu6Op>(op);
61 }
62 
63 // Finds and returns an activation op that uses the result of `op`. If there are
64 // multiple such activations, one is returned (with no guarantee as to which
65 // one). If there are no activation functions that use the output, returns
66 // nullptr.
GetActivation(Value op)67 Operation *GetActivation(Value op) {
68   for (auto &use : op.getUses()) {
69     if (IsActivationFunction(use.getOwner())) return use.getOwner();
70   }
71   return nullptr;
72 }
73 
74 // Finds and returns a BiasAdd that uses the result of `op` as the `value`
75 // input. If there are multiple such BiasAdds, one is returned (with no
76 // guarantee as to which one). If there are no BiasAdds that use the output,
77 // returns a null BiasAddOp.
GetBiasAdd(Value op)78 BiasAddOp GetBiasAdd(Value op) {
79   for (auto &use : op.getUses()) {
80     auto bias_add = dyn_cast_or_null<BiasAddOp>(use.getOwner());
81     // If it's a BiasAdd, check that the conv op is the first input.
82     if (bias_add && bias_add.value() == op) return bias_add;
83   }
84   // No BiasAddOps found among uses.
85   return BiasAddOp();
86 }
87 
88 // Performs a fusion of the following pattern(s), if possible:
89 //   <Contraction> + BiasAdd + <Activation> -> <FusedContraction>
90 //
91 // Note that fusion with activation is preferred, but a contraction and BiasAdd
92 // can also be replaced by a _FusedConv2D if there is no other activation
93 // function.
94 // i.e., this class also supports the following fusion:
95 //   <Contraction> + BiasAdd -> <FusedContraction>
96 //
97 // TODO(b/158266331): Support fusing activation chains of arbitrary length.
98 template <typename SrcOpT, typename FusedOpT>
99 class FuseContractionWithBiasAdd : public OpRewritePattern<SrcOpT> {
100  public:
101   using OpRewritePattern<SrcOpT>::OpRewritePattern;
102   // Class users should override this method if there are any op-specific
103   // compatibility requirements between the contraction op and the BiasAdd op.
AreFuseCompatible(SrcOpT contraction_op,BiasAddOp bias_add,PatternRewriter & rewriter) const104   virtual bool AreFuseCompatible(SrcOpT contraction_op, BiasAddOp bias_add,
105                                  PatternRewriter &rewriter) const {
106     return true;
107   }
108 
109   // Class users should override this method if there are any op-specific
110   // compatibility requirements for devices.
IsDeviceCompatible(SrcOpT contraction_op,BiasAddOp bias_add,PatternRewriter & rewriter) const111   virtual bool IsDeviceCompatible(SrcOpT contraction_op, BiasAddOp bias_add,
112                                   PatternRewriter &rewriter) const {
113     return true;
114   }
115 
matchAndRewrite(SrcOpT contraction,PatternRewriter & rewriter) const116   LogicalResult matchAndRewrite(SrcOpT contraction,
117                                 PatternRewriter &rewriter) const override {
118     auto context = rewriter.getContext();
119 
120     // We do support fusion only if the contraction operation is inside one of
121     // the expected operations with regions. Other operations can have semantics
122     // that is not compatible with fusion (e.g. region compilation).
123     if (!isa<func::FuncOp, IfOp, WhileOp>(contraction->getParentOp())) {
124       return rewriter.notifyMatchFailure(
125           contraction,
126           "fused operation must be nested inside a function, If or While");
127     }
128 
129     // If the contraction is used in multiple places, fusing it will only create
130     // more contraction nodes, which is slower.
131     if (!contraction.getResult().hasOneUse())
132       return rewriter.notifyMatchFailure(contraction,
133                                          "result is used by multiple ops");
134 
135     BiasAddOp bias_add = GetBiasAdd(contraction.getResult());
136     if (!bias_add) {
137       return rewriter.notifyMatchFailure(
138           contraction, "does not feed into a tf.BiasAdd/tf.BiasAddV1 op");
139     }
140 
141     if (!AreFuseCompatible(contraction, bias_add, rewriter)) {
142       return rewriter.notifyMatchFailure(
143           contraction, "cannot fuse with the subsequent BiasAdd op");
144     }
145 
146     if (!IsDeviceCompatible(contraction, bias_add, rewriter)) {
147       return rewriter.notifyMatchFailure(
148           contraction,
149           "cannot fuse with the subsequent op as it's not supported by the "
150           "target device.");
151     }
152 
153     SmallVector<Location, 3> locations{contraction.getLoc(), bias_add.getLoc()};
154     SmallVector<Attribute, 2> fused_ops{StringAttr::get(
155         context, bias_add.getOperation()->getName().stripDialect())};
156 
157     // BiasAdd may or may not feed into an activation function.
158     auto activation = GetActivation(bias_add);
159 
160     // If there is an activation, only fuse it if this is the only op to use the
161     // result of the BiasAdd.
162     bool fuse_activation = activation && bias_add.output().hasOneUse();
163     Type result_type;
164 
165     // Include info about the activation function if applicable.
166     if (fuse_activation) {
167       locations.push_back(activation->getLoc());
168       fused_ops.push_back(
169           StringAttr::get(context, activation->getName().stripDialect()));
170       result_type = activation->getResultTypes().front();
171     } else {
172       result_type = bias_add.getResult().getType();
173     }
174 
175     auto fused_loc = rewriter.getFusedLoc(locations);
176 
177     // The fused contraction has the same operands as the original contraction
178     // with `bias` from the BiasAddOp appended.
179     SmallVector<Value, 4> operands(contraction.operand_begin(),
180                                    contraction.operand_end());
181     operands.push_back(bias_add.bias());
182 
183     // The fused contraction has the same attributes as the original
184     // contraction, with two additions: the list of ops which have been fused
185     // together; epsilon (only with FusedBatchNorm).
186     std::vector<NamedAttribute> attrs = contraction->getAttrs();
187     ArrayAttr fused_ops_attr = ArrayAttr::get(context, fused_ops);
188     attrs.push_back(
189         NamedAttribute(StringAttr::get(context, "fused_ops"), fused_ops_attr));
190     // Epsilon is used only in fusions with the FusedBatchNorm op, so we zero it
191     // here.
192     Attribute epsilon = rewriter.getF32FloatAttr(0);
193     attrs.push_back(
194         NamedAttribute(StringAttr::get(context, "epsilon"), epsilon));
195 
196     // Insert fused operation right before the BiasAdd operation to guarantee
197     // that bias value dominates the fused operation. We already verified that
198     // original operation has a single use, so this is safe to do.
199     auto *bias_add_op = bias_add.getOperation();
200     if (bias_add_op) rewriter.setInsertionPoint(bias_add_op);
201 
202     Value fused_op = rewriter.create<FusedOpT>(fused_loc, result_type,
203                                                ValueRange(operands), attrs);
204     auto op_to_replace = fuse_activation ? activation : bias_add;
205     rewriter.replaceOp(op_to_replace, ValueRange({fused_op}));
206     return success();
207   }
208 };
209 
210 const char kDeviceAttr[] = "device";
211 const char kDeviceGpu[] = "GPU";
212 
GetDevice(mlir::Operation * op)213 llvm::Optional<std::string> GetDevice(mlir::Operation *op) {
214   mlir::StringAttr device = op->getAttrOfType<mlir::StringAttr>(kDeviceAttr);
215   if (!device || device.getValue().empty()) {
216     return llvm::None;
217   }
218   const std::string device_name = device.str();
219   tensorflow::DeviceNameUtils::ParsedName parsed_name;
220   if (!tensorflow::DeviceNameUtils::ParseFullName(device_name, &parsed_name)) {
221     return llvm::None;
222   }
223   if (!parsed_name.has_type) {
224     return llvm::None;
225   }
226   return parsed_name.type;
227 }
228 
IsGpuDevice(mlir::Operation * op)229 bool IsGpuDevice(mlir::Operation *op) {
230   llvm::Optional<std::string> device = GetDevice(op);
231   if (!device) return false;
232   return *device == kDeviceGpu;
233 }
234 
235 // Performs a fusion of the following pattern(s), if possible:
236 //   Conv2D + BiasAdd + <Activation> -> _FusedConv2D
237 class FuseConv2DBiasAdd
238     : public FuseContractionWithBiasAdd<Conv2DOp, _FusedConv2DOp> {
239  public:
240   using FuseContractionWithBiasAdd<Conv2DOp,
241                                    _FusedConv2DOp>::FuseContractionWithBiasAdd;
242   // Verify that the Conv2D and BiasAdd data formats match. This is necessary
243   // for the ops to fuse correctly, the fused Conv2D op has one data format
244   // attribute which is shared.
AreFuseCompatible(Conv2DOp conv,BiasAddOp bias_add,PatternRewriter & rewriter) const245   bool AreFuseCompatible(Conv2DOp conv, BiasAddOp bias_add,
246                          PatternRewriter &rewriter) const override {
247     // Verify that the data formats match and are valid for fusion.
248     if (conv.data_format() != bias_add.data_format()) {
249       (void)rewriter.notifyMatchFailure(conv, [&](Diagnostic &diag) {
250         diag << "data format does not match Conv2D data format ("
251              << bias_add.data_format() << " vs " << conv.data_format() << ")";
252       });
253       return false;
254     }
255     // Verify the data type is supported.
256     if (!conv.T().isF32() && !conv.T().isF64()) {
257       (void)rewriter.notifyMatchFailure(conv, [&](Diagnostic &diag) {
258         diag << "supported data types for _FusedConv2D are float and double, "
259              << " but got " << conv.T();
260       });
261       return false;
262     }
263     return true;
264   }
265 
IsDeviceCompatible(Conv2DOp conv,BiasAddOp bias_add,PatternRewriter & rewriter) const266   bool IsDeviceCompatible(Conv2DOp conv, BiasAddOp bias_add,
267                           PatternRewriter &rewriter) const override {
268     // Currently, GPU only supports Conv2D+BiasAdd+Relu fusion.
269     if (IsGpuDevice(conv)) {
270       auto activation = GetActivation(bias_add);
271       if (!activation || activation->getName().stripDialect() != "Relu" ||
272           !bias_add.output().hasOneUse()) {
273         (void)rewriter.notifyMatchFailure(conv, [&](Diagnostic &diag) {
274           diag << "GPU only supports Conv2D+BiasAdd+Relu fusion";
275         });
276         return false;
277       }
278     }
279     return true;
280   }
281 };
282 
283 // Performs a fusion of the following pattern(s), if possible:
284 //   MatMulOp + BiasAdd + <Activation> -> _FusedMatMulOp
285 class FuseMatMulBiasAdd
286     : public FuseContractionWithBiasAdd<MatMulOp, _FusedMatMulOp> {
287   using FuseContractionWithBiasAdd<MatMulOp,
288                                    _FusedMatMulOp>::FuseContractionWithBiasAdd;
289 
AreFuseCompatible(MatMulOp matmul,BiasAddOp bias_add,PatternRewriter & rewriter) const290   bool AreFuseCompatible(MatMulOp matmul, BiasAddOp bias_add,
291                          PatternRewriter &rewriter) const override {
292     // FusedMatMul kernel supports limited set of data types.
293     if (!matmul.T().isF32() && !matmul.T().isBF16()) {
294       (void)rewriter.notifyMatchFailure(matmul, [&](Diagnostic &diag) {
295         diag << "supported data types for _FusedMatMul are float and bfloat16, "
296              << " but got " << matmul.T();
297       });
298       return false;
299     }
300     return true;
301   }
302 
IsDeviceCompatible(MatMulOp matmul,BiasAddOp bias_add,PatternRewriter & rewriter) const303   bool IsDeviceCompatible(MatMulOp matmul, BiasAddOp bias_add,
304                           PatternRewriter &rewriter) const override {
305     if (IsGpuDevice(matmul)) {
306       (void)rewriter.notifyMatchFailure(matmul, [&](Diagnostic &diag) {
307         diag << "_FusedMatMul is not supported by GPU";
308       });
309       return false;
310     }
311     return true;
312   }
313 };
314 
runOnOperation()315 void FusedKernelMatcherPass::runOnOperation() {
316   RewritePatternSet patterns(&getContext());
317   auto func = getOperation();
318   patterns.add<FuseConv2DBiasAdd, FuseMatMulBiasAdd>(&getContext());
319 
320   (void)applyPatternsAndFoldGreedily(func, std::move(patterns));
321 }
322 
323 }  // namespace
324 
CreateFusedKernelMatcherPass()325 std::unique_ptr<OperationPass<func::FuncOp>> CreateFusedKernelMatcherPass() {
326   return std::make_unique<FusedKernelMatcherPass>();
327 }
328 
329 }  // namespace TF
330 
331 }  // namespace mlir
332