1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <string>
17
18 #include "llvm/ADT/ArrayRef.h"
19 #include "llvm/ADT/None.h"
20 #include "llvm/ADT/STLExtras.h"
21 #include "llvm/ADT/SetVector.h"
22 #include "llvm/ADT/SmallVector.h"
23 #include "llvm/Support/Casting.h"
24 #include "llvm/Support/CommandLine.h"
25 #include "llvm/Support/raw_ostream.h"
26 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
27 #include "mlir/IR/Attributes.h" // from @llvm-project
28 #include "mlir/IR/Builders.h" // from @llvm-project
29 #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
30 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
31 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
32 #include "mlir/IR/Location.h" // from @llvm-project
33 #include "mlir/IR/MLIRContext.h" // from @llvm-project
34 #include "mlir/IR/Operation.h" // from @llvm-project
35 #include "mlir/IR/SymbolTable.h" // from @llvm-project
36 #include "mlir/IR/Visitors.h" // from @llvm-project
37 #include "mlir/Interfaces/CallInterfaces.h" // from @llvm-project
38 #include "mlir/Pass/Pass.h" // from @llvm-project
39 #include "mlir/Support/LLVM.h" // from @llvm-project
40 #include "mlir/Support/LogicalResult.h" // from @llvm-project
41 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
42 #include "tensorflow/compiler/mlir/lite/transforms/passes.h"
43 #include "tensorflow/compiler/mlir/lite/utils/lstm_utils.h"
44 #include "tensorflow/compiler/mlir/lite/utils/nms_utils.h"
45 #include "tensorflow/compiler/mlir/lite/utils/perception_ops_utils.h"
46 #include "tensorflow/compiler/mlir/lite/utils/tftext_utils.h"
47 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h"
48 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
49
50 namespace mlir {
51 namespace TFL {
52 namespace {
53 #define GEN_PASS_CLASSES
54 #include "tensorflow/compiler/mlir/lite/transforms/passes.h.inc"
55
56 constexpr char kTFAPIImplements[] = "tf.api_implements";
57 constexpr char kTFTextAPIPrefix[] = "tftext:";
58 constexpr char kCustomSSDPostprocessing[] = "TFLite_Detection_PostProcess";
59 constexpr char kTfNMSPadded[] = "non_max_suppression_padded_v2";
60 constexpr char kCustomMaxUnpooling[] = "addons:MaxUnpooling2D";
61 constexpr char kCustomDenseImageWarp[] = "addons:DenseImageWarp";
62 constexpr char kTFLFusableOp[] = "tfl_fusable_op";
63
64 using mlir::TF::FuncAttr;
65
CustomOption(OpBuilder * builder,const std::string & content)66 inline ConstBytesAttr CustomOption(OpBuilder* builder,
67 const std::string& content) {
68 return ConstBytesAttr::get(builder->getContext(),
69 StringRef(content.data(), content.size()));
70 }
71
CreateTflFusableOpCustomOptions(ArrayRef<std::pair<StringRef,Attribute>> attrs,OpBuilder * builder,std::string & custom_option_buffer)72 LogicalResult CreateTflFusableOpCustomOptions(
73 ArrayRef<std::pair<StringRef, Attribute>> attrs, OpBuilder* builder,
74 std::string& custom_option_buffer) {
75 // There is something worth noting in the ordering of the custom op option:
76 // At the MLIR level, all the option is ordered alphabetcially, so there is
77 // no way for us to retrieve the original order, so please make sure you are
78 // reading custom option from dictionary rather than depending on the order.
79 flexbuffers::Builder fbb;
80 size_t start_map = fbb.StartMap();
81
82 for (auto attr : attrs) {
83 if (auto float_attr = attr.second.dyn_cast_or_null<FloatAttr>()) {
84 fbb.Float(attr.first.data(), float_attr.getValue().convertToFloat());
85 } else if (auto int_attr = attr.second.dyn_cast_or_null<IntegerAttr>()) {
86 fbb.Int(attr.first.data(), int_attr.getInt());
87 } else if (auto bool_attr = attr.second.dyn_cast_or_null<BoolAttr>()) {
88 fbb.Bool(attr.first.data(), bool_attr.getValue());
89 } else {
90 // TODO(b/201482289): support other data types.
91 return failure();
92 }
93 }
94
95 fbb.EndMap(start_map);
96 fbb.Finish();
97 custom_option_buffer.assign(fbb.GetBuffer().begin(), fbb.GetBuffer().end());
98 return success();
99 }
100
101 // Convert func annotated with `tfl_fusable_op` attribute to tfl custom op.
ConvertTflFusableOp(func::FuncOp func,StringRef custom_op_name,ArrayRef<std::pair<StringRef,Attribute>> attrs)102 LogicalResult ConvertTflFusableOp(
103 func::FuncOp func, StringRef custom_op_name,
104 ArrayRef<std::pair<StringRef, Attribute>> attrs) {
105 func.eraseBody();
106 func.addEntryBlock();
107
108 OpBuilder builder(func.getBody());
109 std::string custom_option_buffer;
110 if (failed(CreateTflFusableOpCustomOptions(attrs, &builder,
111 custom_option_buffer))) {
112 return failure();
113 }
114
115 auto tfl_fusable_op = builder.create<TFL::CustomOp>(
116 func->getLoc(), func.getFunctionType().getResults(), func.getArguments(),
117 custom_op_name, CustomOption(&builder, custom_option_buffer));
118 builder.create<func::ReturnOp>(func->getLoc(), tfl_fusable_op.getResults());
119 return success();
120 }
121
122 // Abstracts the conversion of the embedded lookup composite function.
123 class ConvertEmbeddedLookupFunc {
124 public:
ConvertEmbeddedLookupFunc(func::FuncOp func)125 explicit ConvertEmbeddedLookupFunc(func::FuncOp func) : func_(func) {}
126
RewriteFunc()127 void RewriteFunc() {
128 func_->setAttr(kTFImplements,
129 StringAttr::get(func_.getContext(), "embedding_lookup"));
130 Value lookup = func_.getArgument(1);
131 Value value = func_.getArgument(0);
132 auto output_type = func_.getFunctionType().getResult(0);
133
134 OpBuilder builder(func_.getBody());
135 auto op = builder.create<mlir::TFL::EmbeddingLookupOp>(
136 func_.getLoc(), output_type, lookup, value);
137
138 builder.create<mlir::func::ReturnOp>(func_.getLoc(), op.getResult());
139 }
140
VerifySignature()141 LogicalResult VerifySignature() {
142 if (func_.getNumArguments() != 2) {
143 return func_.emitWarning()
144 << "Invalid number of arguments in the embedding "
145 "matmul composite function";
146 }
147 if (func_.getFunctionType().getNumResults() != 1) {
148 return func_.emitWarning() << "Invalid number of results in the "
149 "embedding matmul composite function";
150 }
151 return success();
152 }
153
154 private:
155 func::FuncOp func_;
156 };
157
158 class PrepareCompositeFunctionsPass
159 : public PrepareCompositeFunctionsPassBase<PrepareCompositeFunctionsPass> {
getDependentDialects(DialectRegistry & registry) const160 void getDependentDialects(DialectRegistry& registry) const override {
161 registry.insert<TFL::TensorFlowLiteDialect>();
162 }
163
164 public:
165 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PrepareCompositeFunctionsPass)
166
PrepareCompositeFunctionsPass()167 explicit PrepareCompositeFunctionsPass() {}
168
169 private:
170 // TODO(b/160915525): Consolidate FuncAttr and StringAttr into one.
171 void ConvertTFImplements(func::FuncOp func, StringAttr attr);
172 void ConvertTFImplementsWithAttributes(func::FuncOp func, FuncAttr attr);
173 void ConvertTFAPIImplements(func::FuncOp func, StringAttr attr,
174 ModuleOp module);
175 void runOnOperation() override;
176 };
177
CheckFusableLayerNormalizedLstmCellSimple(func::FuncOp lstm_func)178 LogicalResult CheckFusableLayerNormalizedLstmCellSimple(
179 func::FuncOp lstm_func) {
180 for (int i = 0; i < 5; ++i) {
181 auto input = lstm_func.getArgument(i);
182 auto input_type = input.getType().dyn_cast_or_null<RankedTensorType>();
183 if (!input_type) {
184 lstm_func.emitWarning(
185 "we cannot fuse this lstm func because all the inputs have not "
186 "ranked tensor type.");
187 return failure();
188 }
189 }
190
191 return success();
192 }
193
CheckFusableLstmCellSimple(func::FuncOp lstm_func)194 LogicalResult CheckFusableLstmCellSimple(func::FuncOp lstm_func) {
195 for (int i = 0; i < 4; ++i) {
196 auto input = lstm_func.getArgument(i);
197 auto input_type = input.getType().dyn_cast_or_null<RankedTensorType>();
198 if (!input_type) {
199 lstm_func.emitWarning(
200 "we cannot fuse this lstm func because all the inputs have not "
201 "ranked tensor type.");
202 return failure();
203 }
204 }
205
206 return success();
207 }
208
CheckOutputConsumer(Operation * call_op,int expected_num_outputs,llvm::DenseSet<int> expected_consumer_indices)209 LogicalResult CheckOutputConsumer(
210 Operation* call_op, int expected_num_outputs,
211 llvm::DenseSet<int> expected_consumer_indices) {
212 const int num_results = call_op->getNumResults();
213 if (num_results != expected_num_outputs) return failure();
214
215 for (int i = 0; i < expected_num_outputs; ++i) {
216 auto it = expected_consumer_indices.find(i);
217 if (it == expected_consumer_indices.end()) {
218 // Unexpected consumer.
219 if (!call_op->getResult(i).use_empty()) return failure();
220 }
221 }
222 return success();
223 }
224
CheckFusableKerasLstm(func::FuncOp lstm_func,ModuleOp module)225 LogicalResult CheckFusableKerasLstm(func::FuncOp lstm_func, ModuleOp module) {
226 for (auto func : module.getOps<func::FuncOp>()) {
227 if (func == lstm_func) continue;
228 auto result = func.walk([&](CallOpInterface op) {
229 if (dyn_cast<func::FuncOp>(op.resolveCallable()) == lstm_func) {
230 // Keras LSTM have 5 outputs.
231 // We should make sure only the first or the second output are
232 // consumed.
233 if (failed(CheckOutputConsumer(op.getOperation(), 5, {0, 1})))
234 return WalkResult::interrupt();
235 }
236 return WalkResult::advance();
237 });
238
239 if (result.wasInterrupted()) return failure();
240 }
241 // Current UnidirectionalSequenceLSTMOp doesn't support mask input.
242 if (lstm_func.getNumArguments() == 7) return failure();
243
244 // We should know the batch size in advance for the lstm fusion.
245 // A good indicator of batch size is both cell state and input state (indices
246 // 1 & 2) have fixed shape and other input tenors should have ranked tensor
247 // types.
248 for (int i = 0; i < 6; ++i) {
249 auto input = lstm_func.getArgument(i);
250 auto input_type = input.getType().dyn_cast_or_null<RankedTensorType>();
251 if (!input_type) {
252 lstm_func.emitWarning(
253 "we cannot fuse this lstm func because all the inputs have not "
254 "ranked tensor type.");
255 return failure();
256 }
257 switch (i) {
258 case 1: // output_init_state
259 case 2: // hidden_init_state
260 if (!input_type.hasStaticShape()) {
261 lstm_func.emitWarning(
262 "we cannot fuse this lstm func because the batch size is not "
263 "fixed, please consider setting fixed batch size like "
264 "https://github.com/tensorflow/tensorflow/blob/master/tensorflow/"
265 "lite/examples/experimental_new_converter/"
266 "Keras_LSTM_fusion_Codelab.ipynb");
267 return failure();
268 }
269 break;
270 case 3: // wiehgt
271 case 4: // recurrent_kernel
272 case 5: // bias
273 if (!input_type.hasStaticShape()) {
274 lstm_func.emitWarning(
275 "we cannot fuse this lstm func because the weight & bias are not "
276 "fixed, please consider setting fixed batch size like "
277 "https://github.com/tensorflow/tensorflow/blob/master/tensorflow/"
278 "lite/examples/experimental_new_converter/"
279 "Keras_LSTM_fusion_Codelab.ipynb");
280 return failure();
281 }
282 break;
283 default:
284 // No op.
285 break;
286 }
287 }
288
289 return success();
290 }
291
ConvertTFImplements(func::FuncOp func,StringAttr attr)292 void PrepareCompositeFunctionsPass::ConvertTFImplements(func::FuncOp func,
293 StringAttr attr) {
294 if (attr.getValue() == "embedding_matmul") {
295 // Convert the composite embedding_matmul function body to a
296 // TFLite fused embedding_lookup op.
297 ConvertEmbeddedLookupFunc convert_embedded_lookup(func);
298 if (failed(convert_embedded_lookup.VerifySignature())) return;
299 func.eraseBody();
300 func.addEntryBlock();
301 convert_embedded_lookup.RewriteFunc();
302 } else if (attr.getValue() == mlir::TFL::kLstmCellSimple) {
303 // Check if the lstm cell simple can be fused, if not, we just don't do
304 // anything.
305 if (failed(CheckFusableLstmCellSimple(func))) return;
306 func.eraseBody();
307 func.addEntryBlock();
308 ConvertLSTMCellSimpleToFusedLSTM convert_lstm_cell_simple(func);
309 if (failed(convert_lstm_cell_simple.RewriteFunc())) {
310 return signalPassFailure();
311 }
312 } else if (attr.getValue() == mlir::TFL::kLayerNormalizedLstmCellSimple) {
313 // Check if the layer normalized lstm cell simple can be fused, if not, we
314 // just don't do anything.
315 if (failed(CheckFusableLayerNormalizedLstmCellSimple(func))) return;
316 func.eraseBody();
317 func.addEntryBlock();
318 ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM
319 convert_layer_norm_lstm_cell_simple(func);
320 if (failed(convert_layer_norm_lstm_cell_simple.RewriteFunc())) {
321 return signalPassFailure();
322 }
323 } else if (attr.getValue() == kTfNMSPadded) {
324 ConvertNMSPaddedFunc convert_nms_padded(func);
325 if (failed(convert_nms_padded.VerifySignature())) return;
326 func.eraseBody();
327 func.addEntryBlock();
328 convert_nms_padded.RewriteFunc();
329 } else if (attr.getValue() == kCustomDenseImageWarp) {
330 ConvertDenseImageWarpFunc image_warping(func);
331 if (failed(image_warping.VerifySignature())) return;
332 if (failed(image_warping.RewriteFunc())) {
333 return signalPassFailure();
334 }
335 }
336 }
337
ConvertTFImplementsWithAttributes(func::FuncOp func,FuncAttr attr)338 void PrepareCompositeFunctionsPass::ConvertTFImplementsWithAttributes(
339 func::FuncOp func, FuncAttr attr) {
340 StringRef api_name = attr.getName().getLeafReference().getValue();
341 bool enable_fuse_tftext =
342 tfl_fuse_tftext_ || IsTFTextRegistered(tensorflow::OpRegistry::Global());
343 if (api_name.startswith(kTFTextAPIPrefix) && enable_fuse_tftext) {
344 if (failed(ConvertTFTextAPI(func, api_name, attr))) {
345 return signalPassFailure();
346 }
347 } else if (api_name == kCustomSSDPostprocessing) {
348 ConvertSSDPostProcessFunc convert_ssd_postprocess(func, attr);
349 if (failed(convert_ssd_postprocess.VerifySignature())) return;
350 if (failed(convert_ssd_postprocess.RewriteFunc())) {
351 return signalPassFailure();
352 }
353 } else if (api_name == kCustomMaxUnpooling) {
354 ConvertMaxUnpoolingFunc max_unpooling(func, attr);
355 if (failed(max_unpooling.VerifySignature())) return;
356 if (failed(max_unpooling.RewriteFunc())) {
357 return signalPassFailure();
358 }
359 } else {
360 // We will look for the `tfl_fusable_op` attribute and fuse as a custom op.
361 DictionaryAttr dict_attr = attr.getAttrs();
362
363 SmallVector<std::pair<StringRef, Attribute>, 4> attributes;
364 bool tfl_fusable_op = false;
365 for (auto attr_item : dict_attr) {
366 // Push other attributes except the TFLFusableOp.
367 if (attr_item.getName() == kTFLFusableOp &&
368 attr_item.getValue().dyn_cast<BoolAttr>().getValue()) {
369 tfl_fusable_op = true;
370 } else {
371 attributes.push_back({attr_item.getName(), attr_item.getValue()});
372 }
373 }
374
375 if (!tfl_fusable_op) return;
376
377 if (failed(ConvertTflFusableOp(func, api_name, attributes))) {
378 func->emitError(absl::StrCat("failed to fuse for op: ", api_name.str()));
379 return signalPassFailure();
380 }
381 }
382 }
383
ConvertTFAPIImplements(func::FuncOp func,StringAttr attr,ModuleOp module)384 void PrepareCompositeFunctionsPass::ConvertTFAPIImplements(func::FuncOp func,
385 StringAttr attr,
386 ModuleOp module) {
387 // Keras lstm tf.api_implements usually has attribute like "lstm_abcde91...".
388 // TODO(b/147436982): we need to make sure that only the
389 // outputs(full sequence) is used, not the last_output, not the new_states.
390 // We will discard everything except the outputs.
391 // And the outputs is in the shape of [batch, time, units].
392 if (attr.getValue().startswith("lstm_")) {
393 // Check if the keras lstm can be fused, if not, we just don't do anything.
394 if (failed(CheckFusableKerasLstm(func, module))) return;
395 func.eraseBody();
396 func.addEntryBlock();
397 OpBuilder builder(func.getBody());
398 if (failed(ConvertKerasLSTMLayer(func, &builder)))
399 return signalPassFailure();
400 }
401 }
402
runOnOperation()403 void PrepareCompositeFunctionsPass::runOnOperation() {
404 auto module = getOperation();
405 for (auto func : module.getOps<func::FuncOp>()) {
406 // We have three kinds of implements:
407 // 1) tf._implements, with string attributes.
408 // 2) tf._implements, with proto attributes.
409 // 3) tf.api_implements.
410 // We need to handle them separately.
411 auto tf_implements_attr_str =
412 func->getAttrOfType<StringAttr>(kTFImplements);
413 if (tf_implements_attr_str) {
414 ConvertTFImplements(func, tf_implements_attr_str);
415 continue;
416 }
417
418 auto tf_implements_attr = func->getAttrOfType<FuncAttr>(kTFImplements);
419 if (tf_implements_attr) {
420 ConvertTFImplementsWithAttributes(func, tf_implements_attr);
421 continue;
422 }
423
424 auto tf_api_implements_attr =
425 func->getAttrOfType<StringAttr>(kTFAPIImplements);
426 if (tf_api_implements_attr) {
427 // TODO(b/147536816): Keras lstm should set up the correct attributes.
428 ConvertTFAPIImplements(func, tf_api_implements_attr, module);
429 }
430 }
431 }
432 } // namespace
433
CreatePrepareCompositeFunctionsPass()434 std::unique_ptr<OperationPass<ModuleOp>> CreatePrepareCompositeFunctionsPass() {
435 return std::make_unique<PrepareCompositeFunctionsPass>();
436 }
437
438 } // namespace TFL
439 } // namespace mlir
440