1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/tensorflow/utils/export_utils.h"
17
18 #include <utility>
19 #include <vector>
20
21 #include "absl/container/flat_hash_set.h"
22 #include "absl/memory/memory.h"
23 #include "absl/strings/str_split.h"
24 #include "absl/strings/string_view.h"
25 #include "llvm/ADT/StringRef.h"
26 #include "llvm/ADT/TypeSwitch.h"
27 #include "llvm/Support/Casting.h"
28 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
29 #include "mlir/IR/Attributes.h" // from @llvm-project
30 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
31 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
32 #include "mlir/IR/Location.h" // from @llvm-project
33 #include "mlir/IR/Operation.h" // from @llvm-project
34 #include "mlir/IR/OperationSupport.h" // from @llvm-project
35 #include "mlir/IR/TypeUtilities.h" // from @llvm-project
36 #include "mlir/Support/DebugStringHelper.h" // from @llvm-project
37 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h"
38 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
39 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
40 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
41 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
42 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
43 #include "tensorflow/compiler/mlir/tensorflow/utils/location_utils.h"
44 #include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h"
45 #include "tensorflow/compiler/xla/status_macros.h"
46 #include "tensorflow/core/framework/attr_value.pb.h"
47 #include "tensorflow/core/framework/graph.pb.h"
48 #include "tensorflow/core/framework/graph_to_functiondef.h"
49 #include "tensorflow/core/framework/node_def.pb.h"
50 #include "tensorflow/core/framework/node_def_util.h"
51 #include "tensorflow/core/framework/op.h"
52 #include "tensorflow/core/framework/tensor.pb.h"
53 #include "tensorflow/core/framework/tensor_shape.pb.h"
54 #include "tensorflow/core/framework/types.pb.h"
55 #include "tensorflow/core/graph/algorithm.h"
56 #include "tensorflow/core/graph/graph.h"
57 #include "tensorflow/core/lib/core/errors.h"
58 #include "tensorflow/core/platform/protobuf.h"
59
60 namespace tensorflow {
61 namespace {
62 // static TensorFlow op prefix set.
GlobalOpPrefixes()63 std::set<std::string>* GlobalOpPrefixes() {
64 static std::set<std::string>* global_op_prefixes = [] {
65 std::set<std::string>* result = new std::set<std::string>;
66 result->insert("tf.");
67 result->insert("tf_executor.");
68 return result;
69 }();
70 return global_op_prefixes;
71 }
72
73 // Converts a location to the debug information for the node def.
ConvertLocation(mlir::Location inst_loc,llvm::StringRef node_name,NodeDef::ExperimentalDebugInfo * debug_info)74 Status ConvertLocation(mlir::Location inst_loc, llvm::StringRef node_name,
75 NodeDef::ExperimentalDebugInfo* debug_info) {
76 mlir::Location unwrapped_inst_loc = GetLocationWithoutOpType(inst_loc);
77
78 if (auto call_site = unwrapped_inst_loc.dyn_cast<mlir::CallSiteLoc>()) {
79 if (auto name_loc = GetLocationWithoutOpType(call_site.getCallee())
80 .dyn_cast<mlir::NameLoc>()) {
81 llvm::StringRef original_node_name, original_func_name;
82 std::tie(original_node_name, original_func_name) =
83 name_loc.getName().strref().split('@');
84 // The location points to the current node def.
85 if (node_name == original_node_name && original_func_name.empty()) {
86 return OkStatus();
87 }
88 debug_info->add_original_node_names(original_node_name.str());
89 if (!original_func_name.empty()) {
90 debug_info->add_original_func_names(original_func_name.str());
91 }
92 }
93 } else if (auto fused = unwrapped_inst_loc.dyn_cast<mlir::FusedLoc>()) {
94 auto locations = fused.getLocations();
95 if (locations.size() <= 1)
96 return errors::InvalidArgument("expected experimental debuf info.");
97 // skip the first one, which is the name of the node_def.
98 for (int i = 0, end = locations.size() - 1; i < end; ++i) {
99 TF_RETURN_IF_ERROR(ConvertLocation(locations[i], node_name, debug_info));
100 }
101 }
102 return OkStatus();
103 }
104
ConvertAttribute(const mlir::BoolAttr & attr,AttrValue * value)105 Status ConvertAttribute(const mlir::BoolAttr& attr, AttrValue* value) {
106 value->set_b(attr.getValue());
107 return OkStatus();
108 }
109
ConvertAttribute(const mlir::IntegerAttr & attr,AttrValue * value)110 Status ConvertAttribute(const mlir::IntegerAttr& attr, AttrValue* value) {
111 value->set_i(attr.getInt());
112 return OkStatus();
113 }
114
ConvertAttribute(const mlir::FloatAttr & attr,AttrValue * value)115 Status ConvertAttribute(const mlir::FloatAttr& attr, AttrValue* value) {
116 value->set_f(attr.getValueAsDouble());
117 return OkStatus();
118 }
119
ConvertAttribute(const mlir::ElementsAttr & attr,AttrValue * value)120 Status ConvertAttribute(const mlir::ElementsAttr& attr, AttrValue* value) {
121 return ConvertToTensorProto(attr, value->mutable_tensor());
122 }
123
ConvertAttribute(const mlir::TF::PlaceholderAttr & attr,AttrValue * value)124 Status ConvertAttribute(const mlir::TF::PlaceholderAttr& attr,
125 AttrValue* value) {
126 value->set_placeholder(attr.getValue().str());
127 return OkStatus();
128 }
129
ConvertAttribute(const mlir::TF::ShapeAttr & attr,AttrValue * value)130 Status ConvertAttribute(const mlir::TF::ShapeAttr& attr, AttrValue* value) {
131 SetTensorShapeProto(attr, value->mutable_shape());
132 return OkStatus();
133 }
134
ConvertAttribute(const mlir::FlatSymbolRefAttr & attr,AttrValue * value)135 Status ConvertAttribute(const mlir::FlatSymbolRefAttr& attr, AttrValue* value) {
136 value->mutable_func()->set_name(attr.getValue().str());
137 return OkStatus();
138 }
139
ConvertAttribute(const mlir::TF::FuncAttr & attr,bool remove_ref_type,AttrValue * value)140 Status ConvertAttribute(const mlir::TF::FuncAttr& attr, bool remove_ref_type,
141 AttrValue* value) {
142 TF_RETURN_IF_ERROR(
143 ConvertAttribute(attr.getName().cast<mlir::FlatSymbolRefAttr>(), value));
144 TF_RETURN_IF_ERROR(ConvertAttributes(attr.getAttrs().getValue(),
145 /*attrs_to_ignore=*/{}, remove_ref_type,
146 value->mutable_func()->mutable_attr()));
147 return OkStatus();
148 }
149
ConvertAttribute(const mlir::StringAttr & attr,AttrValue * value)150 Status ConvertAttribute(const mlir::StringAttr& attr, AttrValue* value) {
151 absl::string_view attr_value(attr.getValue().data(), attr.getValue().size());
152 switch (mangling_util::GetMangledKind(attr_value)) {
153 case mangling_util::MangledKind::kUnknown: {
154 value->set_s(std::string(attr_value));
155 return OkStatus();
156 }
157 case mangling_util::MangledKind::kDataType: {
158 DataType dtype;
159 TF_RETURN_IF_ERROR(mangling_util::DemangleDataType(attr_value, &dtype));
160 value->set_type(dtype);
161 return OkStatus();
162 }
163 case mangling_util::MangledKind::kTensorShape:
164 TF_RETURN_IF_ERROR(
165 mangling_util::DemangleShape(attr_value, value->mutable_shape()));
166 return OkStatus();
167 default:
168 return errors::Unimplemented("Mangled string couldn't be handled!");
169 }
170 return OkStatus();
171 }
172
ConvertAttribute(mlir::Type type,bool remove_ref_type,AttrValue * value)173 Status ConvertAttribute(mlir::Type type, bool remove_ref_type,
174 AttrValue* value) {
175 DataType dtype;
176 TF_RETURN_IF_ERROR(ConvertToDataType(type, &dtype));
177 if (tensorflow::IsRefType(dtype)) dtype = tensorflow::RemoveRefType(dtype);
178 value->set_type(dtype);
179 return OkStatus();
180 }
181
ConvertAttribute(const mlir::TypeAttr & type,bool remove_ref_type,AttrValue * value)182 Status ConvertAttribute(const mlir::TypeAttr& type, bool remove_ref_type,
183 AttrValue* value) {
184 return ConvertAttribute(type.getValue(), remove_ref_type, value);
185 }
186
ConvertAttribute(const mlir::UnitAttr & attr,AttrValue * value)187 Status ConvertAttribute(const mlir::UnitAttr& attr, AttrValue* value) {
188 value->clear_value();
189 return OkStatus();
190 }
191
ConvertAttribute(const mlir::ArrayAttr & attr,bool remove_ref_type,AttrValue * value)192 Status ConvertAttribute(const mlir::ArrayAttr& attr, bool remove_ref_type,
193 AttrValue* value) {
194 auto* list = value->mutable_list();
195 for (mlir::Attribute a : attr.getValue()) {
196 if (auto attr = a.dyn_cast<mlir::BoolAttr>()) {
197 list->add_b(attr.getValue());
198 } else if (auto attr = a.dyn_cast<mlir::IntegerAttr>()) {
199 list->add_i(attr.getInt());
200 } else if (auto attr = a.dyn_cast<mlir::FloatAttr>()) {
201 list->add_f(attr.getValueAsDouble());
202 } else if (auto attr = a.dyn_cast<mlir::StringAttr>()) {
203 AttrValue nested_value;
204 TF_RETURN_IF_ERROR(ConvertAttribute(attr, &nested_value));
205 switch (nested_value.value_case()) {
206 case AttrValue::kS:
207 list->add_s(nested_value.s());
208 break;
209 case AttrValue::kType:
210 list->add_type(nested_value.type());
211 break;
212 case AttrValue::kShape:
213 *list->add_shape() = nested_value.shape();
214 break;
215 default:
216 return errors::Unimplemented("Unhandled nested attribute!");
217 }
218 } else if (auto attr = a.dyn_cast<mlir::ElementsAttr>()) {
219 TensorProto tensor;
220 TF_RETURN_IF_ERROR(ConvertToTensorProto(attr, &tensor));
221 *list->add_tensor() = tensor;
222 } else if (auto attr = a.dyn_cast<mlir::FlatSymbolRefAttr>()) {
223 AttrValue attr_val;
224 TF_RETURN_IF_ERROR(ConvertAttribute(attr, &attr_val));
225 *list->add_func() = attr_val.func();
226 } else if (auto attr = a.dyn_cast<mlir::TypeAttr>()) {
227 AttrValue attr_val;
228 // For type attributes, we only propagate the element type.
229 mlir::Type elt_type = attr.getValue();
230 if (auto shaped_type = elt_type.dyn_cast<mlir::ShapedType>()) {
231 elt_type = shaped_type.getElementType();
232 }
233 TF_RETURN_IF_ERROR(
234 ConvertAttribute(elt_type, remove_ref_type, &attr_val));
235 list->add_type(attr_val.type());
236 } else if (auto attr = a.dyn_cast<mlir::TF::ShapeAttr>()) {
237 AttrValue attr_val;
238 TF_RETURN_IF_ERROR(ConvertAttribute(attr, &attr_val));
239 *list->add_shape() = attr_val.shape();
240 } else if (auto attr = a.dyn_cast<mlir::ArrayAttr>()) {
241 std::vector<int64_t> vals;
242 for (mlir::Attribute a : attr.getValue()) {
243 auto i = a.dyn_cast<mlir::IntegerAttr>();
244 if (!i)
245 return errors::Unimplemented(
246 "Expected 64-bit integer array attributes!");
247 vals.push_back(i.getInt());
248 }
249 mlir::OpBuilder builder(attr.getContext());
250 mlir::TensorType ty =
251 mlir::RankedTensorType::get(vals.size(), builder.getIntegerType(64));
252 TensorProto tensor;
253 TF_RETURN_IF_ERROR(ConvertToTensorProto(
254 mlir::DenseIntElementsAttr::get(ty, vals), &tensor));
255 *list->add_tensor() = tensor;
256 } else {
257 return errors::Unimplemented("Unhandled attribute!");
258 }
259 }
260 return OkStatus();
261 }
262
263 // Returns true if the executor/control dialect op should map to Ref node in
264 // TensorFlow Graph. For control dialect NextIteration it uses the 1st operand
265 // type. For executor dialect NextIteration it uses the 2nd operand type. For
266 // all others (Enter/Exit/Merge/Switch), if the output type is ref, they
267 // correspond to the Ref equivalent op in TF Graph.
IsRefTypeControlOp(mlir::Operation * op)268 static bool IsRefTypeControlOp(mlir::Operation* op) {
269 if (auto next_iter_sink =
270 llvm::dyn_cast<mlir::tf_executor::NextIterationSinkOp>(op))
271 return mlir::getElementTypeOrSelf(next_iter_sink.input().getType())
272 .isa<mlir::TF::TensorFlowRefType>();
273
274 auto op_name_or_status = GetTensorFlowOpName(op->getName().getStringRef());
275 if (!op_name_or_status.ok()) return false;
276
277 auto op_name = std::move(op_name_or_status).value();
278 if (op_name.equals("NextIteration"))
279 return mlir::getElementTypeOrSelf(op->getOperand(0).getType())
280 .isa<mlir::TF::TensorFlowRefType>();
281
282 if (op_name.equals("Enter") || op_name.equals("Exit") ||
283 op_name.equals("Switch") || op_name.equals("Merge")) {
284 return getElementTypeOrSelf(op->getResult(0).getType())
285 .isa<mlir::TF::TensorFlowRefType>();
286 }
287 return false;
288 }
289
290 } // anonymous namespace
291
GetTensorFlowOpName(llvm::StringRef op_name)292 StatusOr<llvm::StringRef> GetTensorFlowOpName(llvm::StringRef op_name) {
293 // When being converted to MLIR, some prefixes and suffixes are added to the
294 // operation types, and we have to remove them when converting the
295 // operations back to a graph:
296 // - "tf." or "tf_executor." : every operation type has this prefix.
297 // - ".sink" or ".Sink": only the NextIteration operation has this suffix. We
298 // don't need to consider ".source"/".Source" because the nodes with this
299 // suffix are skipped by the caller and will not be added to the graph.
300 auto prefixes = GlobalOpPrefixes();
301 if (std::none_of(prefixes->begin(), prefixes->end(), [&](std::string prefix) {
302 return op_name.consume_front(prefix);
303 })) {
304 return errors::FailedPrecondition("op node '", op_name.str(),
305 "' was not a TF op!");
306 }
307 // Control dialect NextIteration sink ends with ".sink" and Executor dialect
308 // NextIteration sink ends with ".Sink".
309 if (!op_name.consume_back(".sink")) op_name.consume_back(".Sink");
310 return op_name;
311 }
312
GetOperationNodeDef(mlir::Operation * inst,llvm::StringRef name)313 StatusOr<std::unique_ptr<NodeDef>> GetOperationNodeDef(
314 mlir::Operation* inst, llvm::StringRef name) {
315 auto node_def = std::make_unique<NodeDef>();
316 // Note: we do not use NodeBuilder or NodeDefBuilder as that would require
317 // mapping back from the inputs to the input arguments.
318
319 llvm::SmallString<64> op_name;
320 if (IsLegacyCallInstruction(inst)) {
321 // The op_name is the name of the function.
322 op_name.append(inst->getAttrOfType<mlir::SymbolRefAttr>("f")
323 .getLeafReference()
324 .getValue());
325 // Remove the attribute from the instruction as it is already converted to
326 // op_name.
327 auto attr_id = mlir::StringAttr::get(inst->getContext(), "f");
328 inst->removeAttr(attr_id);
329 } else {
330 // Some control flow ops in TensorFlow Graph have their respective "Ref" ops
331 // as well. For example there is Enter and RefEnter op. RefEnter forwards
332 // the input ref buffer to output. However both Enter and RefEnter are
333 // mapped to tf_executor::EnterOp during import. Check if it is a Ref op to
334 // correctly map to the TensorFlow Graph op.
335 if (IsRefTypeControlOp(inst)) op_name = "Ref";
336 TF_ASSIGN_OR_RETURN(auto tf_name,
337 GetTensorFlowOpName(inst->getName().getStringRef()));
338 op_name.append(tf_name);
339 }
340
341 node_def->set_name(name.str());
342 node_def->set_op(std::string(op_name.str()));
343
344 // Update NodeDef constructed out of an MLIR Case/If/While op to map it to
345 // either TensorFlow StatelessX or X op depending on the additional attribute.
346 if (llvm::isa<mlir::TF::CaseOp, mlir::TF::IfOp, mlir::TF::WhileOp>(inst)) {
347 auto stateless = inst->getAttrOfType<mlir::BoolAttr>("is_stateless");
348 if (stateless && stateless.getValue())
349 *node_def->mutable_op() = "Stateless" + node_def->op();
350 }
351
352 // Add inputs to the NodeDef based on the number of operands. This is required
353 // as later when edges are added to the Node using Graph::AddEdge the
354 // associated NodeDef is not updated.
355 for (int i = 0, e = inst->getNumOperands(); i < e; ++i) {
356 node_def->add_input();
357 }
358 if (auto attr = inst->getAttrOfType<mlir::StringAttr>("device")) {
359 node_def->set_device(std::string(attr.getValue()));
360 }
361
362 // Add the node debug info.
363 TF_RETURN_IF_ERROR(ConvertLocation(
364 inst->getLoc(), name, node_def->mutable_experimental_debug_info()));
365
366 return node_def;
367 }
368
ConvertAttributes(const llvm::ArrayRef<mlir::NamedAttribute> attrs,const absl::flat_hash_set<absl::string_view> & attrs_to_ignore,bool remove_ref_type,AttrValueMap * values)369 Status ConvertAttributes(
370 const llvm::ArrayRef<mlir::NamedAttribute> attrs,
371 const absl::flat_hash_set<absl::string_view>& attrs_to_ignore,
372 bool remove_ref_type, AttrValueMap* values) {
373 AttrValueMap func_call_attrs;
374 for (const mlir::NamedAttribute& named_attr : attrs) {
375 auto name_strref = named_attr.getName().str();
376 auto attr = named_attr.getValue();
377 absl::string_view name(name_strref.data(), name_strref.size());
378 if (name == "name" || name == "device" || attrs_to_ignore.contains(name)) {
379 // The name, device spec of a TF op or function are not stored as
380 // AttrValue inside NodeDef, but we model them using attribute inside
381 // MLIR. So we need to ignore them when going back to AttrValue here.
382 continue;
383 }
384 if (mangling_util::IsMangledAttributeName(name)) {
385 // In MLIR, attributes for functions requires dialect prefix. We need to
386 // remove TF dialect prefix before converting to AttrValue.
387 name = mangling_util::DemangleAttributeName(name);
388 }
389 AttrValue value;
390 if (auto symbol_ref = attr.dyn_cast<mlir::SymbolRefAttr>()) {
391 TF_RETURN_IF_ERROR(
392 ConvertAttribute(symbol_ref.cast<mlir::FlatSymbolRefAttr>(), &value));
393 func_call_attrs[string(name)] = value;
394 continue;
395 }
396 if (auto func_attr = attr.dyn_cast<mlir::TF::FuncAttr>()) {
397 TF_RETURN_IF_ERROR(ConvertAttribute(func_attr, remove_ref_type, &value));
398 func_call_attrs[string(name)] = value;
399 continue;
400 }
401 if (attr.isa<mlir::AffineMapAttr>()) {
402 // AffineMapAttr is not implemented.
403 return errors::Unimplemented("AffineMap attribute (needed for '",
404 name_strref, "') unimplemented");
405 }
406 TF_RETURN_IF_ERROR(
407 llvm::TypeSwitch<mlir::Attribute, Status>(attr)
408 .Case<mlir::BoolAttr, mlir::IntegerAttr, mlir::FloatAttr,
409 mlir::StringAttr, mlir::ElementsAttr, mlir::UnitAttr,
410 mlir::TF::ShapeAttr, mlir::TF::PlaceholderAttr>(
411 [&](auto derived_attr) {
412 return ConvertAttribute(derived_attr, &value);
413 })
414 .Case<mlir::ArrayAttr, mlir::TypeAttr>([&](auto derived_attr) {
415 return ConvertAttribute(derived_attr, remove_ref_type, &value);
416 })
417 .Default([&](mlir::Attribute) {
418 return errors::Unimplemented(
419 "Unhandled attribute kind for attribute '", name_strref,
420 '\'');
421 }));
422
423 // According to the NodeDef proto definition, an attribute name from the
424 // input TensorFlow GraphDef shouldn't contain '.'. If it does appear in
425 // the attribute from MLIR, it is treated as an attribute from function
426 // calls.
427 std::vector<string> name_tokens =
428 absl::StrSplit(name, '.', absl::SkipEmpty());
429 TF_RET_CHECK(name_tokens.size() <= 2);
430 auto it = func_call_attrs.find(name_tokens[0]);
431 if (it == func_call_attrs.end()) {
432 (*values)[string(name)] = value;
433 } else {
434 (*it->second.mutable_func()->mutable_attr())[name_tokens[1]] = value;
435 }
436 }
437 for (const auto& it : func_call_attrs) {
438 (*values)[it.first] = it.second;
439 }
440 return OkStatus();
441 }
442
SetShapeAttribute(absl::string_view name,mlir::ShapedType shaped_type,AttrValueMap * values)443 Status SetShapeAttribute(absl::string_view name, mlir::ShapedType shaped_type,
444 AttrValueMap* values) {
445 AttrValue value;
446 SetTensorShapeProto(shaped_type, value.mutable_list()->add_shape());
447
448 auto result = values->insert({string(name), value});
449 if (!result.second) {
450 // This should be extremely rare as it means we are adding the same
451 // attribute multiple times/have some redundancy in representing this
452 // attribute.
453 TensorShapeProto actual_shape = result.first->second.shape();
454 // Just check via string output as we shouldn't get here and if we do they
455 // should be trivially the same, else fail.
456 std::string new_shape_string = value.list().shape(0).ShortDebugString();
457 if (actual_shape.ShortDebugString() != new_shape_string) {
458 return errors::InvalidArgument("Expected ", new_shape_string, " '", name,
459 "' attribute but found ",
460 actual_shape.ShortDebugString());
461 }
462 }
463 return OkStatus();
464 }
465
IsLegacyCallInstruction(mlir::Operation * inst)466 bool IsLegacyCallInstruction(mlir::Operation* inst) {
467 return llvm::dyn_cast<mlir::TF::LegacyCallOp>(inst);
468 }
469
AddTensorFlowOpPrefix(std::string prefix)470 Status AddTensorFlowOpPrefix(std::string prefix) {
471 GlobalOpPrefixes()->insert(prefix);
472 return OkStatus();
473 }
474
475 } // namespace tensorflow
476