xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h"
17 
18 #include "llvm/ADT/DenseMap.h"
19 #include "llvm/ADT/DenseSet.h"
20 #include "llvm/ADT/STLExtras.h"
21 #include "llvm/ADT/StringRef.h"
22 #include "llvm/ADT/Twine.h"
23 #include "llvm/Support/Casting.h"
24 #include "llvm/Support/raw_ostream.h"
25 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
26 #include "mlir/IR/Attributes.h"  // from @llvm-project
27 #include "mlir/IR/Builders.h"  // from @llvm-project
28 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
29 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
30 #include "mlir/IR/OpImplementation.h"  // from @llvm-project
31 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
32 #include "mlir/IR/SymbolTable.h"  // from @llvm-project
33 #include "mlir/IR/TypeUtilities.h"  // from @llvm-project
34 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
35 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
36 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
37 
38 namespace mlir {
39 namespace tf_saved_model {
40 
41 //===----------------------------------------------------------------------===//
42 // Utilities
43 //===----------------------------------------------------------------------===//
44 
IsStrArrayAttr(Attribute attr)45 static bool IsStrArrayAttr(Attribute attr) {
46   auto array = attr.dyn_cast<ArrayAttr>();
47   if (!array) return false;
48 
49   return llvm::all_of(array,
50                       [](Attribute attr) { return attr.isa<StringAttr>(); });
51 }
52 
53 //===----------------------------------------------------------------------===//
54 // TensorFlowSavedModelDialect Op's
55 //===----------------------------------------------------------------------===//
56 
VerifyTensorTypesCompatible(Type t1,Type t2)57 LogicalResult VerifyTensorTypesCompatible(Type t1, Type t2) {
58   if (!t1.isa<TensorType>() || !t2.isa<TensorType>()) {
59     return failure();
60   }
61   return verifyCompatibleShape(t1.cast<TensorType>(), t2.cast<TensorType>());
62 }
63 
verify()64 LogicalResult GlobalTensorOp::verify() {
65   GlobalTensorOp global_tensor = *this;
66   if (failed(VerifyTensorTypesCompatible(global_tensor.type(),
67                                          global_tensor.value().getType()))) {
68     return global_tensor.emitError() << "'type' and 'value' attributes should "
69                                         "have compatible tensor types";
70   }
71   if (!global_tensor.is_mutable()) {
72     if (!global_tensor.type().cast<TensorType>().hasStaticShape()) {
73       return global_tensor.emitError()
74              << "'type' attribute for immutable 'tf_saved_model.global_tensor' "
75                 "should have a static shape";
76     }
77   }
78   return success();
79 }
80 
verify()81 LogicalResult SessionInitializerOp::verify() {
82   SessionInitializerOp session_initializer = *this;
83   mlir::SymbolTable symbol_table(
84       session_initializer->getParentOfType<ModuleOp>());
85 
86   for (auto sym_ref : session_initializer.initializers()) {
87     auto init_func_op = symbol_table.lookup<mlir::func::FuncOp>(
88         sym_ref.cast<FlatSymbolRefAttr>().getValue());
89 
90     if (!init_func_op)
91       return session_initializer.emitOpError()
92              << "the initializer function does not exist";
93 
94     if (!init_func_op.getFunctionType().getResults().empty())
95       return session_initializer.emitOpError()
96              << "the initializer function should have no output";
97 
98     auto exported_names = GetExportedNames(init_func_op);
99 
100     if (exported_names.empty())
101       return session_initializer.emitOpError()
102              << "the initializer function should be exported";
103 
104     if (exported_names.size() != 1)
105       return session_initializer.emitOpError()
106              << "the initializer function should have only one exported names";
107   }
108 
109   return success();
110 }
111 
112 }  // namespace tf_saved_model
113 }  // namespace mlir
114 
115 #define GET_OP_CLASSES
116 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.cc.inc"
117 
118 namespace mlir {
119 namespace tf_saved_model {
120 
121 //===----------------------------------------------------------------------===//
122 // TensorFlowSavedModelDialect Dialect
123 //===----------------------------------------------------------------------===//
124 
TensorFlowSavedModelDialect(MLIRContext * context)125 TensorFlowSavedModelDialect::TensorFlowSavedModelDialect(MLIRContext *context)
126     : Dialect(/*name=*/"tf_saved_model", context,
127               TypeID::get<TensorFlowSavedModelDialect>()) {
128   // The TensorFlow Dialect is needed in the verifier and other routines
129   // associated to this dialect. It makes little sense anyway to use the
130   // SavedModel dialect without the TensorFlow Dialect.
131   context->loadDialect<TF::TensorFlowDialect>();
132 
133   addOperations<
134 #define GET_OP_LIST
135 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.cc.inc"
136       >();
137 }
138 
VerifyIndexPath(Operation * op,NamedAttribute named_attr)139 static LogicalResult VerifyIndexPath(Operation *op, NamedAttribute named_attr) {
140   auto attr = named_attr.getValue().dyn_cast<ArrayAttr>();
141   if (!attr) {
142     return op->emitError()
143            << "'tf_saved_model.index_path' attribute should be an ArrayAttr";
144   }
145   for (auto element : attr) {
146     if (element.isa<StringAttr>()) {
147       continue;
148     }
149     if (auto integer = element.dyn_cast<IntegerAttr>()) {
150       if (integer.getValue().getBitWidth() == 64) {
151         continue;
152       }
153     }
154     return op->emitError() << "'tf_saved_model.index_path' elements should "
155                               "be strings or 64-bit integers";
156   }
157   return mlir::success();
158 }
159 
GetBoundInputArgTypeFor(mlir::Operation * op)160 Type GetBoundInputArgTypeFor(mlir::Operation *op) {
161   if (auto global_tensor = llvm::dyn_cast<GlobalTensorOp>(op)) {
162     auto type = global_tensor.type().cast<TensorType>();
163     return RankedTensorType::get(
164         {}, TF::ResourceType::get({type}, type.getContext()));
165   }
166 
167   if (auto asset = llvm::dyn_cast<AssetOp>(op)) {
168     return RankedTensorType::get({}, TF::StringType::get(asset.getContext()));
169   }
170 
171   op->emitError() << "unknown symbol operation";
172   return {};
173 }
174 
VerifyBoundInputArgType(Operation * op_for_diagnostics,Type arg_type,mlir::Operation * symbol_op)175 static LogicalResult VerifyBoundInputArgType(Operation *op_for_diagnostics,
176                                              Type arg_type,
177                                              mlir::Operation *symbol_op) {
178   auto expected_type = GetBoundInputArgTypeFor(symbol_op);
179   if (!expected_type) return failure();
180 
181   if (arg_type != expected_type) {
182     return op_for_diagnostics->emitError()
183            << "bound input with type " << arg_type << " expected to have type "
184            << expected_type;
185   }
186   return success();
187 }
188 
verifyRegionArgAttribute(Operation * op,unsigned region_index,unsigned arg_index,NamedAttribute named_attr)189 LogicalResult TensorFlowSavedModelDialect::verifyRegionArgAttribute(
190     Operation *op, unsigned region_index, unsigned arg_index,
191     NamedAttribute named_attr) {
192   if (named_attr.getName() == "tf_saved_model.bound_input") {
193     if (!named_attr.getValue().isa<FlatSymbolRefAttr>()) {
194       return op->emitError() << "'tf_saved_model.bound_input' attribute should "
195                                 "be a FlatSymbolRefAttr";
196     }
197     auto symbol_name =
198         named_attr.getValue().cast<FlatSymbolRefAttr>().getValue();
199     auto module = op->getParentOfType<ModuleOp>();
200     mlir::Operation *symbol_op = module.lookupSymbol(symbol_name);
201     if (!symbol_op) {
202       return op->emitError() << "'tf_saved_model.bound_input' attribute must "
203                                 "reference a valid symbol, got invalid symbol '"
204                              << symbol_name << "'";
205     }
206     auto arg_type = cast<func::FuncOp>(op).getArgument(arg_index).getType();
207     return VerifyBoundInputArgType(op, arg_type, symbol_op);
208   }
209   if (named_attr.getName() == "tf_saved_model.index_path") {
210     return VerifyIndexPath(op, named_attr);
211   }
212 
213   return op->emitError() << "unknown tf_saved_model dialect arg attribute '"
214                          << named_attr.getName().getValue() << "'";
215 }
216 
verifyRegionResultAttribute(Operation * op,unsigned region_index,unsigned result_index,NamedAttribute named_attr)217 LogicalResult TensorFlowSavedModelDialect::verifyRegionResultAttribute(
218     Operation *op, unsigned region_index, unsigned result_index,
219     NamedAttribute named_attr) {
220   if (named_attr.getName() == "tf_saved_model.index_path") {
221     return VerifyIndexPath(op, named_attr);
222   }
223 
224   return op->emitError() << "unknown tf_saved_model dialect result attribute '"
225                          << named_attr.getName().getValue() << "'";
226 }
227 
HasAnyTfSavedModelArgAttr(func::FuncOp func)228 static bool HasAnyTfSavedModelArgAttr(func::FuncOp func) {
229   for (int i = 0, e = func.getNumArguments(); i < e; i++) {
230     if (func.getArgAttr(i, "tf_saved_model.index_path") ||
231         func.getArgAttr(i, "tf_saved_model.bound_input")) {
232       return true;
233     }
234   }
235   for (int i = 0, e = func.getNumResults(); i < e; i++) {
236     if (func.getResultAttr(i, "tf_saved_model.index_path") ||
237         func.getResultAttr(i, "tf_saved_model.bound_input")) {
238       return true;
239     }
240   }
241   return false;
242 }
243 
VerifySavedModelModule(ModuleOp module,TensorFlowSavedModelDialect * dialect)244 static LogicalResult VerifySavedModelModule(
245     ModuleOp module, TensorFlowSavedModelDialect *dialect) {
246   auto exported_names_ident =
247       StringAttr::get(dialect->getContext(), "tf_saved_model.exported_names");
248   // Check that there are no duplicated exported_names.
249   DenseMap<StringRef, Operation *> exported_name_to_op;
250   for (auto &op : module) {
251     auto attr = op.getAttr(exported_names_ident);
252     if (!attr) continue;
253     // If this verifier is called before we verify the
254     // 'tf_saved_model.exported_names' attribute, then it might be invalid.
255     // Forward to the dialect's verification to establish that precondition.
256     if (failed(dialect->verifyOperationAttribute(
257             &op, {exported_names_ident, attr}))) {
258       return failure();
259     }
260     for (auto str : attr.cast<ArrayAttr>()) {
261       auto exported_name = str.cast<StringAttr>().getValue();
262       auto p = exported_name_to_op.insert({exported_name, &op});
263       if (!p.second) {
264         return op.emitError()
265             .append("duplicate exported name '", exported_name, "'")
266             .attachNote(p.first->getSecond()->getLoc())
267             .append("previously seen here");
268       }
269     }
270   }
271   for (auto func : module.getOps<func::FuncOp>()) {
272     const bool is_exported = IsExported(func);
273 
274     if (is_exported && !func.isPublic()) {
275       return func.emitError()
276              << "exported function @" << func.getName() << " should be public";
277     }
278 
279     if (!is_exported && func.isPublic()) {
280       return func.emitError() << "non-exported function @" << func.getName()
281                               << " should be private";
282     }
283     if (!is_exported && HasAnyTfSavedModelArgAttr(func)) {
284       return func.emitError() << "can only apply 'tf_saved_model' argument "
285                                  "attributes to exported functions";
286     }
287   }
288 
289   auto session_initializers = module.getOps<SessionInitializerOp>();
290   if (!session_initializers.empty() &&
291       !llvm::hasSingleElement(session_initializers)) {
292     return (*++session_initializers.begin()).emitError()
293            << "there must be no more than one session_initializer op";
294   }
295 
296   auto is_init = [&session_initializers](mlir::func::FuncOp func) {
297     if (session_initializers.empty()) return false;
298     auto init_syms = (*session_initializers.begin()).initializers();
299     return std::any_of(
300         init_syms.begin(), init_syms.end(), [&](Attribute sym_ref) {
301           return sym_ref.cast<FlatSymbolRefAttr>().getValue() == func.getName();
302         });
303   };
304 
305   SymbolTable symbol_table(module);
306   auto symbol_uses = SymbolTable::getSymbolUses(&module.getBodyRegion());
307   if (!symbol_uses.has_value()) {
308     return module.emitError() << "modules with 'tf_saved_model.semantics' must "
309                                  "have analyzable symbol uses";
310   }
311   for (auto symbol_use : *symbol_uses) {
312     auto func = symbol_table.lookupNearestSymbolFrom<func::FuncOp>(
313         symbol_use.getUser(), symbol_use.getSymbolRef());
314     if (func && IsExported(func)) {
315       // If it is an init function, then it can be used by the unique
316       // session_initializer op.
317       if (is_init(func) &&
318           llvm::isa<SessionInitializerOp>(symbol_use.getUser()))
319         continue;
320 
321       return symbol_use.getUser()
322           ->emitError("exported function cannot be internally referenced")
323           .attachNote(func.getLoc())
324           .append("references this exported function");
325     }
326   }
327   return success();
328 }
329 
VerifyExportedFunc(func::FuncOp func)330 LogicalResult VerifyExportedFunc(func::FuncOp func) {
331   bool reached_bound_inputs = false;
332   auto module = func->getParentOfType<ModuleOp>();
333   for (int i = 0, e = func.getNumArguments(); i < e; i++) {
334     if (func.getArgAttr(i, "tf_saved_model.bound_input")) {
335       reached_bound_inputs = true;
336       continue;
337     }
338     if (func.getArgAttr(i, "tf_saved_model.index_path")) {
339       if (reached_bound_inputs) {
340         return func.emitError()
341                << "all 'tf_saved_model.index_path' arg attributes should "
342                   "precede all 'tf_saved_model.bound_input' arg attributes";
343       }
344       continue;
345     }
346     if (func.getArgAttr(i, "tf.resource_name")) {
347       if (module->getAttr("tf_saved_model.under_construction")) continue;
348       return func.emitError() << "'tf.resource_name' attribute is not allowed "
349                                  "unless it is being under construction";
350     }
351     return func.emitError()
352            << "all arguments should have 'tf_saved_model.index_path', "
353               "'tf_saved_model.bound_input' or 'tf.resource_name' attributes";
354   }
355   llvm::SmallDenseSet<StringRef, 8> unique_bound_inputs;
356   for (int i = 0, e = func.getNumArguments(); i < e; i++) {
357     if (auto attr = func.getArgAttrOfType<FlatSymbolRefAttr>(
358             i, "tf_saved_model.bound_input")) {
359       if (!unique_bound_inputs.insert(attr.getValue()).second) {
360         if (module->getAttr("tf_saved_model.under_construction")) continue;
361         return func.emitError()
362                << "duplicate 'tf_saved_model.bound_input' binding";
363       }
364     }
365   }
366 
367   for (int i = 0, e = func.getNumResults(); i < e; i++) {
368     if (!func.getResultAttr(i, "tf_saved_model.index_path")) {
369       return func.emitError() << "all results should have "
370                                  "'tf_saved_model.index_path' attributes";
371     }
372   }
373 
374   return success();
375 }
376 
verifyOperationAttribute(Operation * op,NamedAttribute named_attr)377 LogicalResult TensorFlowSavedModelDialect::verifyOperationAttribute(
378     Operation *op, NamedAttribute named_attr) {
379   if (named_attr.getName() == "tf_saved_model.exported_names") {
380     if (!isa<func::FuncOp, GlobalTensorOp>(op)) {
381       return op->emitError() << "'tf_saved_model.exported_names' must be on a "
382                                 "'func' or 'tf_saved_model.global_tensor' op";
383     }
384     if (!IsStrArrayAttr(named_attr.getValue())) {
385       return op->emitError()
386              << "'tf_saved_model.exported_names' must be an array of strings";
387     }
388     if (!op->getParentOp()->getAttr("tf_saved_model.semantics")) {
389       return op->emitError()
390              << "'tf_saved_model.exported_names' must be on an op "
391                 "whose immediate parent has attribute "
392                 "'tf_saved_model.semantics'";
393     }
394     if (auto func = dyn_cast<func::FuncOp>(op)) {
395       if (failed(VerifyExportedFunc(func))) {
396         return failure();
397       }
398     }
399     return success();
400   }
401   if (named_attr.getName() == "tf_saved_model.semantics") {
402     auto module = dyn_cast<ModuleOp>(op);
403     if (!module) {
404       return op->emitError() << "'tf_saved_model.semantics' must "
405                                 "be on a module op";
406     }
407     return VerifySavedModelModule(module, this);
408   }
409   if (named_attr.getName() == "tf_saved_model.under_construction") {
410     return success();
411   }
412 
413   return op->emitError() << "unknown tf_saved_model dialect attribute '"
414                          << named_attr.getName().getValue() << "'";
415 }
416 
GetExportedNames(Operation * op)417 SmallVector<StringRef, 2> GetExportedNames(Operation *op) {
418   SmallVector<StringRef, 2> ret;
419   auto exported_names =
420       op->getAttrOfType<ArrayAttr>("tf_saved_model.exported_names");
421   if (exported_names) {
422     for (auto name : exported_names) {
423       ret.push_back(name.cast<StringAttr>().getValue());
424     }
425   }
426   return ret;
427 }
428 
IsExported(Operation * op)429 bool IsExported(Operation *op) {
430   auto exported_names =
431       op->getAttrOfType<ArrayAttr>("tf_saved_model.exported_names");
432   return exported_names && !exported_names.empty();
433 }
434 
HasTfSavedModelSemantics(ModuleOp module)435 bool HasTfSavedModelSemantics(ModuleOp module) {
436   return module->getAttr("tf_saved_model.semantics") != nullptr;
437 }
438 
LookupBoundInput(func::FuncOp func,int arg_index,const SymbolTable & symbol_table)439 Operation *LookupBoundInput(func::FuncOp func, int arg_index,
440                             const SymbolTable &symbol_table) {
441   auto attr = func.getArgAttrOfType<FlatSymbolRefAttr>(
442       arg_index, "tf_saved_model.bound_input");
443   if (!attr) return nullptr;
444   return symbol_table.lookup(attr.getValue());
445 }
446 
GetSessionInitializerOp(mlir::ModuleOp op)447 SessionInitializerOp GetSessionInitializerOp(mlir::ModuleOp op) {
448   auto initializers = op.getOps<SessionInitializerOp>();
449   if (initializers.empty()) return {};
450   return *initializers.begin();
451 }
452 
453 class OptimizeSessionInitializerPattern
454     : public OpRewritePattern<SessionInitializerOp> {
455  public:
456   using OpRewritePattern::OpRewritePattern;
457 
matchAndRewrite(SessionInitializerOp op,PatternRewriter & rewriter) const458   LogicalResult matchAndRewrite(SessionInitializerOp op,
459                                 PatternRewriter &rewriter) const override {
460     SymbolTable symbol_table(op->getParentOfType<ModuleOp>());
461 
462     SmallVector<func::FuncOp, 2> to_remove;
463     SmallVector<mlir::Attribute, 2> to_keep;
464     for (auto sym_ref : op.initializers()) {
465       auto init_func_op = symbol_table.lookup<mlir::func::FuncOp>(
466           sym_ref.cast<FlatSymbolRefAttr>().getValue());
467 
468       // The init function can only be referenced from the SessionInitializerOp.
469       // And there is at most one SessionInitializerOp in the module. So if both
470       // ops have no other uses or have one NoOp only, they can be simply
471       // erased.
472       auto &operations = init_func_op.front().getOperations();
473       if ((operations.size() == 1 &&
474            operations.front().hasTrait<OpTrait::IsTerminator>()) ||
475           (operations.size() == 2 &&
476            dyn_cast<mlir::TF::NoOp>(operations.front()) &&
477            operations.back().hasTrait<OpTrait::IsTerminator>())) {
478         to_remove.push_back(init_func_op);
479       } else {
480         to_keep.push_back(sym_ref);
481       }
482     }
483 
484     for (auto func_op : to_remove) rewriter.eraseOp(func_op);
485 
486     if (to_keep.empty())
487       rewriter.eraseOp(op);
488     else
489       op->setAttr("initializers", rewriter.getArrayAttr(to_keep));
490 
491     return success();
492   }
493 };
494 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)495 void SessionInitializerOp::getCanonicalizationPatterns(
496     RewritePatternSet &results, MLIRContext *context) {
497   results.add<OptimizeSessionInitializerPattern>(context);
498 }
499 
GetSessionInitializerExportedName(ModuleOp op)500 SmallVector<StringRef, 2> GetSessionInitializerExportedName(ModuleOp op) {
501   auto session_initializer_op = GetSessionInitializerOp(op);
502   if (!session_initializer_op) return {};
503 
504   SymbolTable symbol_table(op);
505 
506   SmallVector<StringRef, 2> results;
507   for (auto sym_ref : session_initializer_op.initializers()) {
508     auto init_func_op = symbol_table.lookup<mlir::func::FuncOp>(
509         sym_ref.cast<FlatSymbolRefAttr>().getValue());
510     auto exported_names = GetExportedNames(init_func_op);
511     assert(exported_names.size() == 1);
512     results.push_back(exported_names[0]);
513   }
514 
515   return results;
516 }
517 
518 }  // namespace tf_saved_model
519 }  // namespace mlir
520