xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/tensorflow/utils/export_utils.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/tensorflow/utils/export_utils.h"
17 
18 #include <utility>
19 #include <vector>
20 
21 #include "absl/container/flat_hash_set.h"
22 #include "absl/memory/memory.h"
23 #include "absl/strings/str_split.h"
24 #include "absl/strings/string_view.h"
25 #include "llvm/ADT/StringRef.h"
26 #include "llvm/ADT/TypeSwitch.h"
27 #include "llvm/Support/Casting.h"
28 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
29 #include "mlir/IR/Attributes.h"  // from @llvm-project
30 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
31 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
32 #include "mlir/IR/Location.h"  // from @llvm-project
33 #include "mlir/IR/Operation.h"  // from @llvm-project
34 #include "mlir/IR/OperationSupport.h"  // from @llvm-project
35 #include "mlir/IR/TypeUtilities.h"  // from @llvm-project
36 #include "mlir/Support/DebugStringHelper.h"  // from @llvm-project
37 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h"
38 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
39 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
40 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
41 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
42 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
43 #include "tensorflow/compiler/mlir/tensorflow/utils/location_utils.h"
44 #include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h"
45 #include "tensorflow/compiler/xla/status_macros.h"
46 #include "tensorflow/core/framework/attr_value.pb.h"
47 #include "tensorflow/core/framework/graph.pb.h"
48 #include "tensorflow/core/framework/graph_to_functiondef.h"
49 #include "tensorflow/core/framework/node_def.pb.h"
50 #include "tensorflow/core/framework/node_def_util.h"
51 #include "tensorflow/core/framework/op.h"
52 #include "tensorflow/core/framework/tensor.pb.h"
53 #include "tensorflow/core/framework/tensor_shape.pb.h"
54 #include "tensorflow/core/framework/types.pb.h"
55 #include "tensorflow/core/graph/algorithm.h"
56 #include "tensorflow/core/graph/graph.h"
57 #include "tensorflow/core/lib/core/errors.h"
58 #include "tensorflow/core/platform/protobuf.h"
59 
60 namespace tensorflow {
61 namespace {
62 // static TensorFlow op prefix set.
GlobalOpPrefixes()63 std::set<std::string>* GlobalOpPrefixes() {
64   static std::set<std::string>* global_op_prefixes = [] {
65     std::set<std::string>* result = new std::set<std::string>;
66     result->insert("tf.");
67     result->insert("tf_executor.");
68     return result;
69   }();
70   return global_op_prefixes;
71 }
72 
73 // Converts a location to the debug information for the node def.
ConvertLocation(mlir::Location inst_loc,llvm::StringRef node_name,NodeDef::ExperimentalDebugInfo * debug_info)74 Status ConvertLocation(mlir::Location inst_loc, llvm::StringRef node_name,
75                        NodeDef::ExperimentalDebugInfo* debug_info) {
76   mlir::Location unwrapped_inst_loc = GetLocationWithoutOpType(inst_loc);
77 
78   if (auto call_site = unwrapped_inst_loc.dyn_cast<mlir::CallSiteLoc>()) {
79     if (auto name_loc = GetLocationWithoutOpType(call_site.getCallee())
80                             .dyn_cast<mlir::NameLoc>()) {
81       llvm::StringRef original_node_name, original_func_name;
82       std::tie(original_node_name, original_func_name) =
83           name_loc.getName().strref().split('@');
84       // The location points to the current node def.
85       if (node_name == original_node_name && original_func_name.empty()) {
86         return OkStatus();
87       }
88       debug_info->add_original_node_names(original_node_name.str());
89       if (!original_func_name.empty()) {
90         debug_info->add_original_func_names(original_func_name.str());
91       }
92     }
93   } else if (auto fused = unwrapped_inst_loc.dyn_cast<mlir::FusedLoc>()) {
94     auto locations = fused.getLocations();
95     if (locations.size() <= 1)
96       return errors::InvalidArgument("expected experimental debuf info.");
97     // skip the first one, which is the name of the node_def.
98     for (int i = 0, end = locations.size() - 1; i < end; ++i) {
99       TF_RETURN_IF_ERROR(ConvertLocation(locations[i], node_name, debug_info));
100     }
101   }
102   return OkStatus();
103 }
104 
ConvertAttribute(const mlir::BoolAttr & attr,AttrValue * value)105 Status ConvertAttribute(const mlir::BoolAttr& attr, AttrValue* value) {
106   value->set_b(attr.getValue());
107   return OkStatus();
108 }
109 
ConvertAttribute(const mlir::IntegerAttr & attr,AttrValue * value)110 Status ConvertAttribute(const mlir::IntegerAttr& attr, AttrValue* value) {
111   value->set_i(attr.getInt());
112   return OkStatus();
113 }
114 
ConvertAttribute(const mlir::FloatAttr & attr,AttrValue * value)115 Status ConvertAttribute(const mlir::FloatAttr& attr, AttrValue* value) {
116   value->set_f(attr.getValueAsDouble());
117   return OkStatus();
118 }
119 
ConvertAttribute(const mlir::ElementsAttr & attr,AttrValue * value)120 Status ConvertAttribute(const mlir::ElementsAttr& attr, AttrValue* value) {
121   return ConvertToTensorProto(attr, value->mutable_tensor());
122 }
123 
ConvertAttribute(const mlir::TF::PlaceholderAttr & attr,AttrValue * value)124 Status ConvertAttribute(const mlir::TF::PlaceholderAttr& attr,
125                         AttrValue* value) {
126   value->set_placeholder(attr.getValue().str());
127   return OkStatus();
128 }
129 
ConvertAttribute(const mlir::TF::ShapeAttr & attr,AttrValue * value)130 Status ConvertAttribute(const mlir::TF::ShapeAttr& attr, AttrValue* value) {
131   SetTensorShapeProto(attr, value->mutable_shape());
132   return OkStatus();
133 }
134 
ConvertAttribute(const mlir::FlatSymbolRefAttr & attr,AttrValue * value)135 Status ConvertAttribute(const mlir::FlatSymbolRefAttr& attr, AttrValue* value) {
136   value->mutable_func()->set_name(attr.getValue().str());
137   return OkStatus();
138 }
139 
ConvertAttribute(const mlir::TF::FuncAttr & attr,bool remove_ref_type,AttrValue * value)140 Status ConvertAttribute(const mlir::TF::FuncAttr& attr, bool remove_ref_type,
141                         AttrValue* value) {
142   TF_RETURN_IF_ERROR(
143       ConvertAttribute(attr.getName().cast<mlir::FlatSymbolRefAttr>(), value));
144   TF_RETURN_IF_ERROR(ConvertAttributes(attr.getAttrs().getValue(),
145                                        /*attrs_to_ignore=*/{}, remove_ref_type,
146                                        value->mutable_func()->mutable_attr()));
147   return OkStatus();
148 }
149 
ConvertAttribute(const mlir::StringAttr & attr,AttrValue * value)150 Status ConvertAttribute(const mlir::StringAttr& attr, AttrValue* value) {
151   absl::string_view attr_value(attr.getValue().data(), attr.getValue().size());
152   switch (mangling_util::GetMangledKind(attr_value)) {
153     case mangling_util::MangledKind::kUnknown: {
154       value->set_s(std::string(attr_value));
155       return OkStatus();
156     }
157     case mangling_util::MangledKind::kDataType: {
158       DataType dtype;
159       TF_RETURN_IF_ERROR(mangling_util::DemangleDataType(attr_value, &dtype));
160       value->set_type(dtype);
161       return OkStatus();
162     }
163     case mangling_util::MangledKind::kTensorShape:
164       TF_RETURN_IF_ERROR(
165           mangling_util::DemangleShape(attr_value, value->mutable_shape()));
166       return OkStatus();
167     default:
168       return errors::Unimplemented("Mangled string couldn't be handled!");
169   }
170   return OkStatus();
171 }
172 
ConvertAttribute(mlir::Type type,bool remove_ref_type,AttrValue * value)173 Status ConvertAttribute(mlir::Type type, bool remove_ref_type,
174                         AttrValue* value) {
175   DataType dtype;
176   TF_RETURN_IF_ERROR(ConvertToDataType(type, &dtype));
177   if (tensorflow::IsRefType(dtype)) dtype = tensorflow::RemoveRefType(dtype);
178   value->set_type(dtype);
179   return OkStatus();
180 }
181 
ConvertAttribute(const mlir::TypeAttr & type,bool remove_ref_type,AttrValue * value)182 Status ConvertAttribute(const mlir::TypeAttr& type, bool remove_ref_type,
183                         AttrValue* value) {
184   return ConvertAttribute(type.getValue(), remove_ref_type, value);
185 }
186 
ConvertAttribute(const mlir::UnitAttr & attr,AttrValue * value)187 Status ConvertAttribute(const mlir::UnitAttr& attr, AttrValue* value) {
188   value->clear_value();
189   return OkStatus();
190 }
191 
ConvertAttribute(const mlir::ArrayAttr & attr,bool remove_ref_type,AttrValue * value)192 Status ConvertAttribute(const mlir::ArrayAttr& attr, bool remove_ref_type,
193                         AttrValue* value) {
194   auto* list = value->mutable_list();
195   for (mlir::Attribute a : attr.getValue()) {
196     if (auto attr = a.dyn_cast<mlir::BoolAttr>()) {
197       list->add_b(attr.getValue());
198     } else if (auto attr = a.dyn_cast<mlir::IntegerAttr>()) {
199       list->add_i(attr.getInt());
200     } else if (auto attr = a.dyn_cast<mlir::FloatAttr>()) {
201       list->add_f(attr.getValueAsDouble());
202     } else if (auto attr = a.dyn_cast<mlir::StringAttr>()) {
203       AttrValue nested_value;
204       TF_RETURN_IF_ERROR(ConvertAttribute(attr, &nested_value));
205       switch (nested_value.value_case()) {
206         case AttrValue::kS:
207           list->add_s(nested_value.s());
208           break;
209         case AttrValue::kType:
210           list->add_type(nested_value.type());
211           break;
212         case AttrValue::kShape:
213           *list->add_shape() = nested_value.shape();
214           break;
215         default:
216           return errors::Unimplemented("Unhandled nested attribute!");
217       }
218     } else if (auto attr = a.dyn_cast<mlir::ElementsAttr>()) {
219       TensorProto tensor;
220       TF_RETURN_IF_ERROR(ConvertToTensorProto(attr, &tensor));
221       *list->add_tensor() = tensor;
222     } else if (auto attr = a.dyn_cast<mlir::FlatSymbolRefAttr>()) {
223       AttrValue attr_val;
224       TF_RETURN_IF_ERROR(ConvertAttribute(attr, &attr_val));
225       *list->add_func() = attr_val.func();
226     } else if (auto attr = a.dyn_cast<mlir::TypeAttr>()) {
227       AttrValue attr_val;
228       // For type attributes, we only propagate the element type.
229       mlir::Type elt_type = attr.getValue();
230       if (auto shaped_type = elt_type.dyn_cast<mlir::ShapedType>()) {
231         elt_type = shaped_type.getElementType();
232       }
233       TF_RETURN_IF_ERROR(
234           ConvertAttribute(elt_type, remove_ref_type, &attr_val));
235       list->add_type(attr_val.type());
236     } else if (auto attr = a.dyn_cast<mlir::TF::ShapeAttr>()) {
237       AttrValue attr_val;
238       TF_RETURN_IF_ERROR(ConvertAttribute(attr, &attr_val));
239       *list->add_shape() = attr_val.shape();
240     } else if (auto attr = a.dyn_cast<mlir::ArrayAttr>()) {
241       std::vector<int64_t> vals;
242       for (mlir::Attribute a : attr.getValue()) {
243         auto i = a.dyn_cast<mlir::IntegerAttr>();
244         if (!i)
245           return errors::Unimplemented(
246               "Expected 64-bit integer array attributes!");
247         vals.push_back(i.getInt());
248       }
249       mlir::OpBuilder builder(attr.getContext());
250       mlir::TensorType ty =
251           mlir::RankedTensorType::get(vals.size(), builder.getIntegerType(64));
252       TensorProto tensor;
253       TF_RETURN_IF_ERROR(ConvertToTensorProto(
254           mlir::DenseIntElementsAttr::get(ty, vals), &tensor));
255       *list->add_tensor() = tensor;
256     } else {
257       return errors::Unimplemented("Unhandled attribute!");
258     }
259   }
260   return OkStatus();
261 }
262 
263 // Returns true if the executor/control dialect op should map to Ref node in
264 // TensorFlow Graph. For control dialect NextIteration it uses the 1st operand
265 // type. For executor dialect NextIteration it uses the 2nd operand type. For
266 // all others (Enter/Exit/Merge/Switch), if the output type is ref, they
267 // correspond to the Ref equivalent op in TF Graph.
IsRefTypeControlOp(mlir::Operation * op)268 static bool IsRefTypeControlOp(mlir::Operation* op) {
269   if (auto next_iter_sink =
270           llvm::dyn_cast<mlir::tf_executor::NextIterationSinkOp>(op))
271     return mlir::getElementTypeOrSelf(next_iter_sink.input().getType())
272         .isa<mlir::TF::TensorFlowRefType>();
273 
274   auto op_name_or_status = GetTensorFlowOpName(op->getName().getStringRef());
275   if (!op_name_or_status.ok()) return false;
276 
277   auto op_name = std::move(op_name_or_status).value();
278   if (op_name.equals("NextIteration"))
279     return mlir::getElementTypeOrSelf(op->getOperand(0).getType())
280         .isa<mlir::TF::TensorFlowRefType>();
281 
282   if (op_name.equals("Enter") || op_name.equals("Exit") ||
283       op_name.equals("Switch") || op_name.equals("Merge")) {
284     return getElementTypeOrSelf(op->getResult(0).getType())
285         .isa<mlir::TF::TensorFlowRefType>();
286   }
287   return false;
288 }
289 
290 }  // anonymous namespace
291 
GetTensorFlowOpName(llvm::StringRef op_name)292 StatusOr<llvm::StringRef> GetTensorFlowOpName(llvm::StringRef op_name) {
293   // When being converted to MLIR, some prefixes and suffixes are added to the
294   // operation types, and we have to remove them when converting the
295   // operations back to a graph:
296   // - "tf." or "tf_executor." : every operation type has this prefix.
297   // - ".sink" or ".Sink": only the NextIteration operation has this suffix. We
298   // don't need to consider ".source"/".Source" because the nodes with this
299   // suffix are skipped by the caller and will not be added to the graph.
300   auto prefixes = GlobalOpPrefixes();
301   if (std::none_of(prefixes->begin(), prefixes->end(), [&](std::string prefix) {
302         return op_name.consume_front(prefix);
303       })) {
304     return errors::FailedPrecondition("op node '", op_name.str(),
305                                       "' was not a TF op!");
306   }
307   // Control dialect NextIteration sink ends with ".sink" and Executor dialect
308   // NextIteration sink ends with ".Sink".
309   if (!op_name.consume_back(".sink")) op_name.consume_back(".Sink");
310   return op_name;
311 }
312 
GetOperationNodeDef(mlir::Operation * inst,llvm::StringRef name)313 StatusOr<std::unique_ptr<NodeDef>> GetOperationNodeDef(
314     mlir::Operation* inst, llvm::StringRef name) {
315   auto node_def = std::make_unique<NodeDef>();
316   // Note: we do not use NodeBuilder or NodeDefBuilder as that would require
317   // mapping back from the inputs to the input arguments.
318 
319   llvm::SmallString<64> op_name;
320   if (IsLegacyCallInstruction(inst)) {
321     // The op_name is the name of the function.
322     op_name.append(inst->getAttrOfType<mlir::SymbolRefAttr>("f")
323                        .getLeafReference()
324                        .getValue());
325     // Remove the attribute from the instruction as it is already converted to
326     // op_name.
327     auto attr_id = mlir::StringAttr::get(inst->getContext(), "f");
328     inst->removeAttr(attr_id);
329   } else {
330     // Some control flow ops in TensorFlow Graph have their respective "Ref" ops
331     // as well. For example there is Enter and RefEnter op. RefEnter forwards
332     // the input ref buffer to output. However both Enter and RefEnter are
333     // mapped to tf_executor::EnterOp during import. Check if it is a Ref op to
334     // correctly map to the TensorFlow Graph op.
335     if (IsRefTypeControlOp(inst)) op_name = "Ref";
336     TF_ASSIGN_OR_RETURN(auto tf_name,
337                         GetTensorFlowOpName(inst->getName().getStringRef()));
338     op_name.append(tf_name);
339   }
340 
341   node_def->set_name(name.str());
342   node_def->set_op(std::string(op_name.str()));
343 
344   // Update NodeDef constructed out of an MLIR Case/If/While op to map it to
345   // either TensorFlow StatelessX or X op depending on the additional attribute.
346   if (llvm::isa<mlir::TF::CaseOp, mlir::TF::IfOp, mlir::TF::WhileOp>(inst)) {
347     auto stateless = inst->getAttrOfType<mlir::BoolAttr>("is_stateless");
348     if (stateless && stateless.getValue())
349       *node_def->mutable_op() = "Stateless" + node_def->op();
350   }
351 
352   // Add inputs to the NodeDef based on the number of operands. This is required
353   // as later when edges are added to the Node using Graph::AddEdge the
354   // associated NodeDef is not updated.
355   for (int i = 0, e = inst->getNumOperands(); i < e; ++i) {
356     node_def->add_input();
357   }
358   if (auto attr = inst->getAttrOfType<mlir::StringAttr>("device")) {
359     node_def->set_device(std::string(attr.getValue()));
360   }
361 
362   // Add the node debug info.
363   TF_RETURN_IF_ERROR(ConvertLocation(
364       inst->getLoc(), name, node_def->mutable_experimental_debug_info()));
365 
366   return node_def;
367 }
368 
ConvertAttributes(const llvm::ArrayRef<mlir::NamedAttribute> attrs,const absl::flat_hash_set<absl::string_view> & attrs_to_ignore,bool remove_ref_type,AttrValueMap * values)369 Status ConvertAttributes(
370     const llvm::ArrayRef<mlir::NamedAttribute> attrs,
371     const absl::flat_hash_set<absl::string_view>& attrs_to_ignore,
372     bool remove_ref_type, AttrValueMap* values) {
373   AttrValueMap func_call_attrs;
374   for (const mlir::NamedAttribute& named_attr : attrs) {
375     auto name_strref = named_attr.getName().str();
376     auto attr = named_attr.getValue();
377     absl::string_view name(name_strref.data(), name_strref.size());
378     if (name == "name" || name == "device" || attrs_to_ignore.contains(name)) {
379       // The name, device spec of a TF op or function are not stored as
380       // AttrValue inside NodeDef, but we model them using attribute inside
381       // MLIR. So we need to ignore them when going back to AttrValue here.
382       continue;
383     }
384     if (mangling_util::IsMangledAttributeName(name)) {
385       // In MLIR, attributes for functions requires dialect prefix. We need to
386       // remove TF dialect prefix before converting to AttrValue.
387       name = mangling_util::DemangleAttributeName(name);
388     }
389     AttrValue value;
390     if (auto symbol_ref = attr.dyn_cast<mlir::SymbolRefAttr>()) {
391       TF_RETURN_IF_ERROR(
392           ConvertAttribute(symbol_ref.cast<mlir::FlatSymbolRefAttr>(), &value));
393       func_call_attrs[string(name)] = value;
394       continue;
395     }
396     if (auto func_attr = attr.dyn_cast<mlir::TF::FuncAttr>()) {
397       TF_RETURN_IF_ERROR(ConvertAttribute(func_attr, remove_ref_type, &value));
398       func_call_attrs[string(name)] = value;
399       continue;
400     }
401     if (attr.isa<mlir::AffineMapAttr>()) {
402       // AffineMapAttr is not implemented.
403       return errors::Unimplemented("AffineMap attribute (needed for '",
404                                    name_strref, "') unimplemented");
405     }
406     TF_RETURN_IF_ERROR(
407         llvm::TypeSwitch<mlir::Attribute, Status>(attr)
408             .Case<mlir::BoolAttr, mlir::IntegerAttr, mlir::FloatAttr,
409                   mlir::StringAttr, mlir::ElementsAttr, mlir::UnitAttr,
410                   mlir::TF::ShapeAttr, mlir::TF::PlaceholderAttr>(
411                 [&](auto derived_attr) {
412                   return ConvertAttribute(derived_attr, &value);
413                 })
414             .Case<mlir::ArrayAttr, mlir::TypeAttr>([&](auto derived_attr) {
415               return ConvertAttribute(derived_attr, remove_ref_type, &value);
416             })
417             .Default([&](mlir::Attribute) {
418               return errors::Unimplemented(
419                   "Unhandled attribute kind for attribute '", name_strref,
420                   '\'');
421             }));
422 
423     // According to the NodeDef proto definition, an attribute name from the
424     // input TensorFlow GraphDef shouldn't contain '.'. If it does appear in
425     // the attribute from MLIR, it is treated as an attribute from function
426     // calls.
427     std::vector<string> name_tokens =
428         absl::StrSplit(name, '.', absl::SkipEmpty());
429     TF_RET_CHECK(name_tokens.size() <= 2);
430     auto it = func_call_attrs.find(name_tokens[0]);
431     if (it == func_call_attrs.end()) {
432       (*values)[string(name)] = value;
433     } else {
434       (*it->second.mutable_func()->mutable_attr())[name_tokens[1]] = value;
435     }
436   }
437   for (const auto& it : func_call_attrs) {
438     (*values)[it.first] = it.second;
439   }
440   return OkStatus();
441 }
442 
SetShapeAttribute(absl::string_view name,mlir::ShapedType shaped_type,AttrValueMap * values)443 Status SetShapeAttribute(absl::string_view name, mlir::ShapedType shaped_type,
444                          AttrValueMap* values) {
445   AttrValue value;
446   SetTensorShapeProto(shaped_type, value.mutable_list()->add_shape());
447 
448   auto result = values->insert({string(name), value});
449   if (!result.second) {
450     // This should be extremely rare as it means we are adding the same
451     // attribute multiple times/have some redundancy in representing this
452     // attribute.
453     TensorShapeProto actual_shape = result.first->second.shape();
454     // Just check via string output as we shouldn't get here and if we do they
455     // should be trivially the same, else fail.
456     std::string new_shape_string = value.list().shape(0).ShortDebugString();
457     if (actual_shape.ShortDebugString() != new_shape_string) {
458       return errors::InvalidArgument("Expected ", new_shape_string, " '", name,
459                                      "' attribute but found ",
460                                      actual_shape.ShortDebugString());
461     }
462   }
463   return OkStatus();
464 }
465 
IsLegacyCallInstruction(mlir::Operation * inst)466 bool IsLegacyCallInstruction(mlir::Operation* inst) {
467   return llvm::dyn_cast<mlir::TF::LegacyCallOp>(inst);
468 }
469 
AddTensorFlowOpPrefix(std::string prefix)470 Status AddTensorFlowOpPrefix(std::string prefix) {
471   GlobalOpPrefixes()->insert(prefix);
472   return OkStatus();
473 }
474 
475 }  // namespace tensorflow
476