1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <string>
17 
18 #include "llvm/ADT/ArrayRef.h"
19 #include "llvm/ADT/None.h"
20 #include "llvm/ADT/STLExtras.h"
21 #include "llvm/ADT/SetVector.h"
22 #include "llvm/ADT/SmallVector.h"
23 #include "llvm/Support/Casting.h"
24 #include "llvm/Support/CommandLine.h"
25 #include "llvm/Support/raw_ostream.h"
26 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
27 #include "mlir/IR/Attributes.h"  // from @llvm-project
28 #include "mlir/IR/Builders.h"  // from @llvm-project
29 #include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
30 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
31 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
32 #include "mlir/IR/Location.h"  // from @llvm-project
33 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
34 #include "mlir/IR/Operation.h"  // from @llvm-project
35 #include "mlir/IR/SymbolTable.h"  // from @llvm-project
36 #include "mlir/IR/Visitors.h"  // from @llvm-project
37 #include "mlir/Interfaces/CallInterfaces.h"  // from @llvm-project
38 #include "mlir/Pass/Pass.h"  // from @llvm-project
39 #include "mlir/Support/LLVM.h"  // from @llvm-project
40 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
41 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
42 #include "tensorflow/compiler/mlir/lite/transforms/passes.h"
43 #include "tensorflow/compiler/mlir/lite/utils/lstm_utils.h"
44 #include "tensorflow/compiler/mlir/lite/utils/nms_utils.h"
45 #include "tensorflow/compiler/mlir/lite/utils/perception_ops_utils.h"
46 #include "tensorflow/compiler/mlir/lite/utils/tftext_utils.h"
47 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h"
48 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
49 
50 namespace mlir {
51 namespace TFL {
52 namespace {
53 #define GEN_PASS_CLASSES
54 #include "tensorflow/compiler/mlir/lite/transforms/passes.h.inc"
55 
56 constexpr char kTFAPIImplements[] = "tf.api_implements";
57 constexpr char kTFTextAPIPrefix[] = "tftext:";
58 constexpr char kCustomSSDPostprocessing[] = "TFLite_Detection_PostProcess";
59 constexpr char kTfNMSPadded[] = "non_max_suppression_padded_v2";
60 constexpr char kCustomMaxUnpooling[] = "addons:MaxUnpooling2D";
61 constexpr char kCustomDenseImageWarp[] = "addons:DenseImageWarp";
62 constexpr char kTFLFusableOp[] = "tfl_fusable_op";
63 
64 using mlir::TF::FuncAttr;
65 
CustomOption(OpBuilder * builder,const std::string & content)66 inline ConstBytesAttr CustomOption(OpBuilder* builder,
67                                    const std::string& content) {
68   return ConstBytesAttr::get(builder->getContext(),
69                              StringRef(content.data(), content.size()));
70 }
71 
CreateTflFusableOpCustomOptions(ArrayRef<std::pair<StringRef,Attribute>> attrs,OpBuilder * builder,std::string & custom_option_buffer)72 LogicalResult CreateTflFusableOpCustomOptions(
73     ArrayRef<std::pair<StringRef, Attribute>> attrs, OpBuilder* builder,
74     std::string& custom_option_buffer) {
75   // There is something worth noting in the ordering of the custom op option:
76   // At the MLIR level, all the option is ordered alphabetcially, so there is
77   // no way for us to retrieve the original order, so please make sure you are
78   // reading custom option from dictionary rather than depending on the order.
79   flexbuffers::Builder fbb;
80   size_t start_map = fbb.StartMap();
81 
82   for (auto attr : attrs) {
83     if (auto float_attr = attr.second.dyn_cast_or_null<FloatAttr>()) {
84       fbb.Float(attr.first.data(), float_attr.getValue().convertToFloat());
85     } else if (auto int_attr = attr.second.dyn_cast_or_null<IntegerAttr>()) {
86       fbb.Int(attr.first.data(), int_attr.getInt());
87     } else if (auto bool_attr = attr.second.dyn_cast_or_null<BoolAttr>()) {
88       fbb.Bool(attr.first.data(), bool_attr.getValue());
89     } else {
90       // TODO(b/201482289): support other data types.
91       return failure();
92     }
93   }
94 
95   fbb.EndMap(start_map);
96   fbb.Finish();
97   custom_option_buffer.assign(fbb.GetBuffer().begin(), fbb.GetBuffer().end());
98   return success();
99 }
100 
101 // Convert func annotated with `tfl_fusable_op` attribute to tfl custom op.
ConvertTflFusableOp(func::FuncOp func,StringRef custom_op_name,ArrayRef<std::pair<StringRef,Attribute>> attrs)102 LogicalResult ConvertTflFusableOp(
103     func::FuncOp func, StringRef custom_op_name,
104     ArrayRef<std::pair<StringRef, Attribute>> attrs) {
105   func.eraseBody();
106   func.addEntryBlock();
107 
108   OpBuilder builder(func.getBody());
109   std::string custom_option_buffer;
110   if (failed(CreateTflFusableOpCustomOptions(attrs, &builder,
111                                              custom_option_buffer))) {
112     return failure();
113   }
114 
115   auto tfl_fusable_op = builder.create<TFL::CustomOp>(
116       func->getLoc(), func.getFunctionType().getResults(), func.getArguments(),
117       custom_op_name, CustomOption(&builder, custom_option_buffer));
118   builder.create<func::ReturnOp>(func->getLoc(), tfl_fusable_op.getResults());
119   return success();
120 }
121 
122 // Abstracts the conversion of the embedded lookup composite function.
123 class ConvertEmbeddedLookupFunc {
124  public:
ConvertEmbeddedLookupFunc(func::FuncOp func)125   explicit ConvertEmbeddedLookupFunc(func::FuncOp func) : func_(func) {}
126 
RewriteFunc()127   void RewriteFunc() {
128     func_->setAttr(kTFImplements,
129                    StringAttr::get(func_.getContext(), "embedding_lookup"));
130     Value lookup = func_.getArgument(1);
131     Value value = func_.getArgument(0);
132     auto output_type = func_.getFunctionType().getResult(0);
133 
134     OpBuilder builder(func_.getBody());
135     auto op = builder.create<mlir::TFL::EmbeddingLookupOp>(
136         func_.getLoc(), output_type, lookup, value);
137 
138     builder.create<mlir::func::ReturnOp>(func_.getLoc(), op.getResult());
139   }
140 
VerifySignature()141   LogicalResult VerifySignature() {
142     if (func_.getNumArguments() != 2) {
143       return func_.emitWarning()
144              << "Invalid number of arguments in the embedding "
145                 "matmul composite function";
146     }
147     if (func_.getFunctionType().getNumResults() != 1) {
148       return func_.emitWarning() << "Invalid number of results in the "
149                                     "embedding matmul composite function";
150     }
151     return success();
152   }
153 
154  private:
155   func::FuncOp func_;
156 };
157 
158 class PrepareCompositeFunctionsPass
159     : public PrepareCompositeFunctionsPassBase<PrepareCompositeFunctionsPass> {
getDependentDialects(DialectRegistry & registry) const160   void getDependentDialects(DialectRegistry& registry) const override {
161     registry.insert<TFL::TensorFlowLiteDialect>();
162   }
163 
164  public:
165   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PrepareCompositeFunctionsPass)
166 
PrepareCompositeFunctionsPass()167   explicit PrepareCompositeFunctionsPass() {}
168 
169  private:
170   // TODO(b/160915525): Consolidate FuncAttr and StringAttr into one.
171   void ConvertTFImplements(func::FuncOp func, StringAttr attr);
172   void ConvertTFImplementsWithAttributes(func::FuncOp func, FuncAttr attr);
173   void ConvertTFAPIImplements(func::FuncOp func, StringAttr attr,
174                               ModuleOp module);
175   void runOnOperation() override;
176 };
177 
CheckFusableLayerNormalizedLstmCellSimple(func::FuncOp lstm_func)178 LogicalResult CheckFusableLayerNormalizedLstmCellSimple(
179     func::FuncOp lstm_func) {
180   for (int i = 0; i < 5; ++i) {
181     auto input = lstm_func.getArgument(i);
182     auto input_type = input.getType().dyn_cast_or_null<RankedTensorType>();
183     if (!input_type) {
184       lstm_func.emitWarning(
185           "we cannot fuse this lstm func because all the inputs have not "
186           "ranked tensor type.");
187       return failure();
188     }
189   }
190 
191   return success();
192 }
193 
CheckFusableLstmCellSimple(func::FuncOp lstm_func)194 LogicalResult CheckFusableLstmCellSimple(func::FuncOp lstm_func) {
195   for (int i = 0; i < 4; ++i) {
196     auto input = lstm_func.getArgument(i);
197     auto input_type = input.getType().dyn_cast_or_null<RankedTensorType>();
198     if (!input_type) {
199       lstm_func.emitWarning(
200           "we cannot fuse this lstm func because all the inputs have not "
201           "ranked tensor type.");
202       return failure();
203     }
204   }
205 
206   return success();
207 }
208 
CheckOutputConsumer(Operation * call_op,int expected_num_outputs,llvm::DenseSet<int> expected_consumer_indices)209 LogicalResult CheckOutputConsumer(
210     Operation* call_op, int expected_num_outputs,
211     llvm::DenseSet<int> expected_consumer_indices) {
212   const int num_results = call_op->getNumResults();
213   if (num_results != expected_num_outputs) return failure();
214 
215   for (int i = 0; i < expected_num_outputs; ++i) {
216     auto it = expected_consumer_indices.find(i);
217     if (it == expected_consumer_indices.end()) {
218       // Unexpected consumer.
219       if (!call_op->getResult(i).use_empty()) return failure();
220     }
221   }
222   return success();
223 }
224 
CheckFusableKerasLstm(func::FuncOp lstm_func,ModuleOp module)225 LogicalResult CheckFusableKerasLstm(func::FuncOp lstm_func, ModuleOp module) {
226   for (auto func : module.getOps<func::FuncOp>()) {
227     if (func == lstm_func) continue;
228     auto result = func.walk([&](CallOpInterface op) {
229       if (dyn_cast<func::FuncOp>(op.resolveCallable()) == lstm_func) {
230         // Keras LSTM have 5 outputs.
231         // We should make sure only the first or the second output are
232         // consumed.
233         if (failed(CheckOutputConsumer(op.getOperation(), 5, {0, 1})))
234           return WalkResult::interrupt();
235       }
236       return WalkResult::advance();
237     });
238 
239     if (result.wasInterrupted()) return failure();
240   }
241   // Current UnidirectionalSequenceLSTMOp doesn't support mask input.
242   if (lstm_func.getNumArguments() == 7) return failure();
243 
244   // We should know the batch size in advance for the lstm fusion.
245   // A good indicator of batch size is both cell state and input state (indices
246   // 1 & 2) have fixed shape and other input tenors should have ranked tensor
247   // types.
248   for (int i = 0; i < 6; ++i) {
249     auto input = lstm_func.getArgument(i);
250     auto input_type = input.getType().dyn_cast_or_null<RankedTensorType>();
251     if (!input_type) {
252       lstm_func.emitWarning(
253           "we cannot fuse this lstm func because all the inputs have not "
254           "ranked tensor type.");
255       return failure();
256     }
257     switch (i) {
258       case 1:  // output_init_state
259       case 2:  // hidden_init_state
260         if (!input_type.hasStaticShape()) {
261           lstm_func.emitWarning(
262               "we cannot fuse this lstm func because the batch size is not "
263               "fixed, please consider setting fixed batch size like "
264               "https://github.com/tensorflow/tensorflow/blob/master/tensorflow/"
265               "lite/examples/experimental_new_converter/"
266               "Keras_LSTM_fusion_Codelab.ipynb");
267           return failure();
268         }
269         break;
270       case 3:  // wiehgt
271       case 4:  // recurrent_kernel
272       case 5:  // bias
273         if (!input_type.hasStaticShape()) {
274           lstm_func.emitWarning(
275               "we cannot fuse this lstm func because the weight & bias are not "
276               "fixed, please consider setting fixed batch size like "
277               "https://github.com/tensorflow/tensorflow/blob/master/tensorflow/"
278               "lite/examples/experimental_new_converter/"
279               "Keras_LSTM_fusion_Codelab.ipynb");
280           return failure();
281         }
282         break;
283       default:
284         // No op.
285         break;
286     }
287   }
288 
289   return success();
290 }
291 
ConvertTFImplements(func::FuncOp func,StringAttr attr)292 void PrepareCompositeFunctionsPass::ConvertTFImplements(func::FuncOp func,
293                                                         StringAttr attr) {
294   if (attr.getValue() == "embedding_matmul") {
295     // Convert the composite embedding_matmul function body to a
296     // TFLite fused embedding_lookup op.
297     ConvertEmbeddedLookupFunc convert_embedded_lookup(func);
298     if (failed(convert_embedded_lookup.VerifySignature())) return;
299     func.eraseBody();
300     func.addEntryBlock();
301     convert_embedded_lookup.RewriteFunc();
302   } else if (attr.getValue() == mlir::TFL::kLstmCellSimple) {
303     // Check if the lstm cell simple can be fused, if not, we just don't do
304     // anything.
305     if (failed(CheckFusableLstmCellSimple(func))) return;
306     func.eraseBody();
307     func.addEntryBlock();
308     ConvertLSTMCellSimpleToFusedLSTM convert_lstm_cell_simple(func);
309     if (failed(convert_lstm_cell_simple.RewriteFunc())) {
310       return signalPassFailure();
311     }
312   } else if (attr.getValue() == mlir::TFL::kLayerNormalizedLstmCellSimple) {
313     // Check if the layer normalized lstm cell simple can be fused, if not, we
314     // just don't do anything.
315     if (failed(CheckFusableLayerNormalizedLstmCellSimple(func))) return;
316     func.eraseBody();
317     func.addEntryBlock();
318     ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM
319         convert_layer_norm_lstm_cell_simple(func);
320     if (failed(convert_layer_norm_lstm_cell_simple.RewriteFunc())) {
321       return signalPassFailure();
322     }
323   } else if (attr.getValue() == kTfNMSPadded) {
324     ConvertNMSPaddedFunc convert_nms_padded(func);
325     if (failed(convert_nms_padded.VerifySignature())) return;
326     func.eraseBody();
327     func.addEntryBlock();
328     convert_nms_padded.RewriteFunc();
329   } else if (attr.getValue() == kCustomDenseImageWarp) {
330     ConvertDenseImageWarpFunc image_warping(func);
331     if (failed(image_warping.VerifySignature())) return;
332     if (failed(image_warping.RewriteFunc())) {
333       return signalPassFailure();
334     }
335   }
336 }
337 
ConvertTFImplementsWithAttributes(func::FuncOp func,FuncAttr attr)338 void PrepareCompositeFunctionsPass::ConvertTFImplementsWithAttributes(
339     func::FuncOp func, FuncAttr attr) {
340   StringRef api_name = attr.getName().getLeafReference().getValue();
341   bool enable_fuse_tftext =
342       tfl_fuse_tftext_ || IsTFTextRegistered(tensorflow::OpRegistry::Global());
343   if (api_name.startswith(kTFTextAPIPrefix) && enable_fuse_tftext) {
344     if (failed(ConvertTFTextAPI(func, api_name, attr))) {
345       return signalPassFailure();
346     }
347   } else if (api_name == kCustomSSDPostprocessing) {
348     ConvertSSDPostProcessFunc convert_ssd_postprocess(func, attr);
349     if (failed(convert_ssd_postprocess.VerifySignature())) return;
350     if (failed(convert_ssd_postprocess.RewriteFunc())) {
351       return signalPassFailure();
352     }
353   } else if (api_name == kCustomMaxUnpooling) {
354     ConvertMaxUnpoolingFunc max_unpooling(func, attr);
355     if (failed(max_unpooling.VerifySignature())) return;
356     if (failed(max_unpooling.RewriteFunc())) {
357       return signalPassFailure();
358     }
359   } else {
360     // We will look for the `tfl_fusable_op` attribute and fuse as a custom op.
361     DictionaryAttr dict_attr = attr.getAttrs();
362 
363     SmallVector<std::pair<StringRef, Attribute>, 4> attributes;
364     bool tfl_fusable_op = false;
365     for (auto attr_item : dict_attr) {
366       // Push other attributes except the TFLFusableOp.
367       if (attr_item.getName() == kTFLFusableOp &&
368           attr_item.getValue().dyn_cast<BoolAttr>().getValue()) {
369         tfl_fusable_op = true;
370       } else {
371         attributes.push_back({attr_item.getName(), attr_item.getValue()});
372       }
373     }
374 
375     if (!tfl_fusable_op) return;
376 
377     if (failed(ConvertTflFusableOp(func, api_name, attributes))) {
378       func->emitError(absl::StrCat("failed to fuse for op: ", api_name.str()));
379       return signalPassFailure();
380     }
381   }
382 }
383 
ConvertTFAPIImplements(func::FuncOp func,StringAttr attr,ModuleOp module)384 void PrepareCompositeFunctionsPass::ConvertTFAPIImplements(func::FuncOp func,
385                                                            StringAttr attr,
386                                                            ModuleOp module) {
387   // Keras lstm tf.api_implements usually has attribute like "lstm_abcde91...".
388   // TODO(b/147436982): we need to make sure that only the
389   // outputs(full sequence) is used, not the last_output, not the new_states.
390   // We will discard everything except the outputs.
391   // And the outputs is in the shape of [batch, time, units].
392   if (attr.getValue().startswith("lstm_")) {
393     // Check if the keras lstm can be fused, if not, we just don't do anything.
394     if (failed(CheckFusableKerasLstm(func, module))) return;
395     func.eraseBody();
396     func.addEntryBlock();
397     OpBuilder builder(func.getBody());
398     if (failed(ConvertKerasLSTMLayer(func, &builder)))
399       return signalPassFailure();
400   }
401 }
402 
runOnOperation()403 void PrepareCompositeFunctionsPass::runOnOperation() {
404   auto module = getOperation();
405   for (auto func : module.getOps<func::FuncOp>()) {
406     // We have three kinds of implements:
407     // 1) tf._implements, with string attributes.
408     // 2) tf._implements, with proto attributes.
409     // 3) tf.api_implements.
410     // We need to handle them separately.
411     auto tf_implements_attr_str =
412         func->getAttrOfType<StringAttr>(kTFImplements);
413     if (tf_implements_attr_str) {
414       ConvertTFImplements(func, tf_implements_attr_str);
415       continue;
416     }
417 
418     auto tf_implements_attr = func->getAttrOfType<FuncAttr>(kTFImplements);
419     if (tf_implements_attr) {
420       ConvertTFImplementsWithAttributes(func, tf_implements_attr);
421       continue;
422     }
423 
424     auto tf_api_implements_attr =
425         func->getAttrOfType<StringAttr>(kTFAPIImplements);
426     if (tf_api_implements_attr) {
427       // TODO(b/147536816): Keras lstm should set up the correct attributes.
428       ConvertTFAPIImplements(func, tf_api_implements_attr, module);
429     }
430   }
431 }
432 }  // namespace
433 
CreatePrepareCompositeFunctionsPass()434 std::unique_ptr<OperationPass<ModuleOp>> CreatePrepareCompositeFunctionsPass() {
435   return std::make_unique<PrepareCompositeFunctionsPass>();
436 }
437 
438 }  // namespace TFL
439 }  // namespace mlir
440