1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <cstdio>
17 #include <iostream>
18 #include <string>
19
20 #include "llvm/ADT/StringRef.h"
21 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
22 #include "mlir/IR/Attributes.h" // from @llvm-project
23 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
24 #include "mlir/IR/Diagnostics.h" // from @llvm-project
25 #include "mlir/IR/PatternMatch.h" // from @llvm-project
26 #include "mlir/Pass/Pass.h" // from @llvm-project
27 #include "mlir/Support/LLVM.h" // from @llvm-project
28 #include "mlir/Support/LogicalResult.h" // from @llvm-project
29 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project
30 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
31 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
32
33 namespace mlir {
34
35 namespace TF {
36
37 namespace {
38
39 // Note: This implements the fusions performed in the old Remapper Grappler
40 // pass. That pass has specific cases for GPU and based on different
41 // target configurations on both CPU and GPU (Intel MKL, ROCm, etc.). This MLIR
42 // pass covers (some of) the general CPU case and at the moment does not account
43 // for any target-specific configurations.
44
45 // This pass is being ported over from the Grappler Remapper pass based on
46 // need/usage. File a bug to request porting over additional fusions.
47
48 // TODO(b/158265178): Support GPU-specific fusions.
49 // TODO(b/158266710): Support CPU MKL configurations.
50
51 // Optimizes TF computations by fusing subgraphs/nodes onto more efficient
52 // implementations to decrease the number of operations needed to perform a
53 // computation.
54 struct FusedKernelMatcherPass
55 : public FusedKernelMatcherPassBase<FusedKernelMatcherPass> {
56 void runOnOperation() override;
57 };
58
IsActivationFunction(Operation * op)59 bool IsActivationFunction(Operation *op) {
60 return isa<EluOp, ReluOp, Relu6Op>(op);
61 }
62
63 // Finds and returns an activation op that uses the result of `op`. If there are
64 // multiple such activations, one is returned (with no guarantee as to which
65 // one). If there are no activation functions that use the output, returns
66 // nullptr.
GetActivation(Value op)67 Operation *GetActivation(Value op) {
68 for (auto &use : op.getUses()) {
69 if (IsActivationFunction(use.getOwner())) return use.getOwner();
70 }
71 return nullptr;
72 }
73
74 // Finds and returns a BiasAdd that uses the result of `op` as the `value`
75 // input. If there are multiple such BiasAdds, one is returned (with no
76 // guarantee as to which one). If there are no BiasAdds that use the output,
77 // returns a null BiasAddOp.
GetBiasAdd(Value op)78 BiasAddOp GetBiasAdd(Value op) {
79 for (auto &use : op.getUses()) {
80 auto bias_add = dyn_cast_or_null<BiasAddOp>(use.getOwner());
81 // If it's a BiasAdd, check that the conv op is the first input.
82 if (bias_add && bias_add.value() == op) return bias_add;
83 }
84 // No BiasAddOps found among uses.
85 return BiasAddOp();
86 }
87
88 // Performs a fusion of the following pattern(s), if possible:
89 // <Contraction> + BiasAdd + <Activation> -> <FusedContraction>
90 //
91 // Note that fusion with activation is preferred, but a contraction and BiasAdd
92 // can also be replaced by a _FusedConv2D if there is no other activation
93 // function.
94 // i.e., this class also supports the following fusion:
95 // <Contraction> + BiasAdd -> <FusedContraction>
96 //
97 // TODO(b/158266331): Support fusing activation chains of arbitrary length.
98 template <typename SrcOpT, typename FusedOpT>
99 class FuseContractionWithBiasAdd : public OpRewritePattern<SrcOpT> {
100 public:
101 using OpRewritePattern<SrcOpT>::OpRewritePattern;
102 // Class users should override this method if there are any op-specific
103 // compatibility requirements between the contraction op and the BiasAdd op.
AreFuseCompatible(SrcOpT contraction_op,BiasAddOp bias_add,PatternRewriter & rewriter) const104 virtual bool AreFuseCompatible(SrcOpT contraction_op, BiasAddOp bias_add,
105 PatternRewriter &rewriter) const {
106 return true;
107 }
108
109 // Class users should override this method if there are any op-specific
110 // compatibility requirements for devices.
IsDeviceCompatible(SrcOpT contraction_op,BiasAddOp bias_add,PatternRewriter & rewriter) const111 virtual bool IsDeviceCompatible(SrcOpT contraction_op, BiasAddOp bias_add,
112 PatternRewriter &rewriter) const {
113 return true;
114 }
115
matchAndRewrite(SrcOpT contraction,PatternRewriter & rewriter) const116 LogicalResult matchAndRewrite(SrcOpT contraction,
117 PatternRewriter &rewriter) const override {
118 auto context = rewriter.getContext();
119
120 // We do support fusion only if the contraction operation is inside one of
121 // the expected operations with regions. Other operations can have semantics
122 // that is not compatible with fusion (e.g. region compilation).
123 if (!isa<func::FuncOp, IfOp, WhileOp>(contraction->getParentOp())) {
124 return rewriter.notifyMatchFailure(
125 contraction,
126 "fused operation must be nested inside a function, If or While");
127 }
128
129 // If the contraction is used in multiple places, fusing it will only create
130 // more contraction nodes, which is slower.
131 if (!contraction.getResult().hasOneUse())
132 return rewriter.notifyMatchFailure(contraction,
133 "result is used by multiple ops");
134
135 BiasAddOp bias_add = GetBiasAdd(contraction.getResult());
136 if (!bias_add) {
137 return rewriter.notifyMatchFailure(
138 contraction, "does not feed into a tf.BiasAdd/tf.BiasAddV1 op");
139 }
140
141 if (!AreFuseCompatible(contraction, bias_add, rewriter)) {
142 return rewriter.notifyMatchFailure(
143 contraction, "cannot fuse with the subsequent BiasAdd op");
144 }
145
146 if (!IsDeviceCompatible(contraction, bias_add, rewriter)) {
147 return rewriter.notifyMatchFailure(
148 contraction,
149 "cannot fuse with the subsequent op as it's not supported by the "
150 "target device.");
151 }
152
153 SmallVector<Location, 3> locations{contraction.getLoc(), bias_add.getLoc()};
154 SmallVector<Attribute, 2> fused_ops{StringAttr::get(
155 context, bias_add.getOperation()->getName().stripDialect())};
156
157 // BiasAdd may or may not feed into an activation function.
158 auto activation = GetActivation(bias_add);
159
160 // If there is an activation, only fuse it if this is the only op to use the
161 // result of the BiasAdd.
162 bool fuse_activation = activation && bias_add.output().hasOneUse();
163 Type result_type;
164
165 // Include info about the activation function if applicable.
166 if (fuse_activation) {
167 locations.push_back(activation->getLoc());
168 fused_ops.push_back(
169 StringAttr::get(context, activation->getName().stripDialect()));
170 result_type = activation->getResultTypes().front();
171 } else {
172 result_type = bias_add.getResult().getType();
173 }
174
175 auto fused_loc = rewriter.getFusedLoc(locations);
176
177 // The fused contraction has the same operands as the original contraction
178 // with `bias` from the BiasAddOp appended.
179 SmallVector<Value, 4> operands(contraction.operand_begin(),
180 contraction.operand_end());
181 operands.push_back(bias_add.bias());
182
183 // The fused contraction has the same attributes as the original
184 // contraction, with two additions: the list of ops which have been fused
185 // together; epsilon (only with FusedBatchNorm).
186 std::vector<NamedAttribute> attrs = contraction->getAttrs();
187 ArrayAttr fused_ops_attr = ArrayAttr::get(context, fused_ops);
188 attrs.push_back(
189 NamedAttribute(StringAttr::get(context, "fused_ops"), fused_ops_attr));
190 // Epsilon is used only in fusions with the FusedBatchNorm op, so we zero it
191 // here.
192 Attribute epsilon = rewriter.getF32FloatAttr(0);
193 attrs.push_back(
194 NamedAttribute(StringAttr::get(context, "epsilon"), epsilon));
195
196 // Insert fused operation right before the BiasAdd operation to guarantee
197 // that bias value dominates the fused operation. We already verified that
198 // original operation has a single use, so this is safe to do.
199 auto *bias_add_op = bias_add.getOperation();
200 if (bias_add_op) rewriter.setInsertionPoint(bias_add_op);
201
202 Value fused_op = rewriter.create<FusedOpT>(fused_loc, result_type,
203 ValueRange(operands), attrs);
204 auto op_to_replace = fuse_activation ? activation : bias_add;
205 rewriter.replaceOp(op_to_replace, ValueRange({fused_op}));
206 return success();
207 }
208 };
209
210 const char kDeviceAttr[] = "device";
211 const char kDeviceGpu[] = "GPU";
212
GetDevice(mlir::Operation * op)213 llvm::Optional<std::string> GetDevice(mlir::Operation *op) {
214 mlir::StringAttr device = op->getAttrOfType<mlir::StringAttr>(kDeviceAttr);
215 if (!device || device.getValue().empty()) {
216 return llvm::None;
217 }
218 const std::string device_name = device.str();
219 tensorflow::DeviceNameUtils::ParsedName parsed_name;
220 if (!tensorflow::DeviceNameUtils::ParseFullName(device_name, &parsed_name)) {
221 return llvm::None;
222 }
223 if (!parsed_name.has_type) {
224 return llvm::None;
225 }
226 return parsed_name.type;
227 }
228
IsGpuDevice(mlir::Operation * op)229 bool IsGpuDevice(mlir::Operation *op) {
230 llvm::Optional<std::string> device = GetDevice(op);
231 if (!device) return false;
232 return *device == kDeviceGpu;
233 }
234
235 // Performs a fusion of the following pattern(s), if possible:
236 // Conv2D + BiasAdd + <Activation> -> _FusedConv2D
237 class FuseConv2DBiasAdd
238 : public FuseContractionWithBiasAdd<Conv2DOp, _FusedConv2DOp> {
239 public:
240 using FuseContractionWithBiasAdd<Conv2DOp,
241 _FusedConv2DOp>::FuseContractionWithBiasAdd;
242 // Verify that the Conv2D and BiasAdd data formats match. This is necessary
243 // for the ops to fuse correctly, the fused Conv2D op has one data format
244 // attribute which is shared.
AreFuseCompatible(Conv2DOp conv,BiasAddOp bias_add,PatternRewriter & rewriter) const245 bool AreFuseCompatible(Conv2DOp conv, BiasAddOp bias_add,
246 PatternRewriter &rewriter) const override {
247 // Verify that the data formats match and are valid for fusion.
248 if (conv.data_format() != bias_add.data_format()) {
249 (void)rewriter.notifyMatchFailure(conv, [&](Diagnostic &diag) {
250 diag << "data format does not match Conv2D data format ("
251 << bias_add.data_format() << " vs " << conv.data_format() << ")";
252 });
253 return false;
254 }
255 // Verify the data type is supported.
256 if (!conv.T().isF32() && !conv.T().isF64()) {
257 (void)rewriter.notifyMatchFailure(conv, [&](Diagnostic &diag) {
258 diag << "supported data types for _FusedConv2D are float and double, "
259 << " but got " << conv.T();
260 });
261 return false;
262 }
263 return true;
264 }
265
IsDeviceCompatible(Conv2DOp conv,BiasAddOp bias_add,PatternRewriter & rewriter) const266 bool IsDeviceCompatible(Conv2DOp conv, BiasAddOp bias_add,
267 PatternRewriter &rewriter) const override {
268 // Currently, GPU only supports Conv2D+BiasAdd+Relu fusion.
269 if (IsGpuDevice(conv)) {
270 auto activation = GetActivation(bias_add);
271 if (!activation || activation->getName().stripDialect() != "Relu" ||
272 !bias_add.output().hasOneUse()) {
273 (void)rewriter.notifyMatchFailure(conv, [&](Diagnostic &diag) {
274 diag << "GPU only supports Conv2D+BiasAdd+Relu fusion";
275 });
276 return false;
277 }
278 }
279 return true;
280 }
281 };
282
283 // Performs a fusion of the following pattern(s), if possible:
284 // MatMulOp + BiasAdd + <Activation> -> _FusedMatMulOp
285 class FuseMatMulBiasAdd
286 : public FuseContractionWithBiasAdd<MatMulOp, _FusedMatMulOp> {
287 using FuseContractionWithBiasAdd<MatMulOp,
288 _FusedMatMulOp>::FuseContractionWithBiasAdd;
289
AreFuseCompatible(MatMulOp matmul,BiasAddOp bias_add,PatternRewriter & rewriter) const290 bool AreFuseCompatible(MatMulOp matmul, BiasAddOp bias_add,
291 PatternRewriter &rewriter) const override {
292 // FusedMatMul kernel supports limited set of data types.
293 if (!matmul.T().isF32() && !matmul.T().isBF16()) {
294 (void)rewriter.notifyMatchFailure(matmul, [&](Diagnostic &diag) {
295 diag << "supported data types for _FusedMatMul are float and bfloat16, "
296 << " but got " << matmul.T();
297 });
298 return false;
299 }
300 return true;
301 }
302
IsDeviceCompatible(MatMulOp matmul,BiasAddOp bias_add,PatternRewriter & rewriter) const303 bool IsDeviceCompatible(MatMulOp matmul, BiasAddOp bias_add,
304 PatternRewriter &rewriter) const override {
305 if (IsGpuDevice(matmul)) {
306 (void)rewriter.notifyMatchFailure(matmul, [&](Diagnostic &diag) {
307 diag << "_FusedMatMul is not supported by GPU";
308 });
309 return false;
310 }
311 return true;
312 }
313 };
314
runOnOperation()315 void FusedKernelMatcherPass::runOnOperation() {
316 RewritePatternSet patterns(&getContext());
317 auto func = getOperation();
318 patterns.add<FuseConv2DBiasAdd, FuseMatMulBiasAdd>(&getContext());
319
320 (void)applyPatternsAndFoldGreedily(func, std::move(patterns));
321 }
322
323 } // namespace
324
CreateFusedKernelMatcherPass()325 std::unique_ptr<OperationPass<func::FuncOp>> CreateFusedKernelMatcherPass() {
326 return std::make_unique<FusedKernelMatcherPass>();
327 }
328
329 } // namespace TF
330
331 } // namespace mlir
332