1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h"
17
18 #include <memory>
19
20 #include "absl/container/flat_hash_set.h"
21 #include "absl/strings/string_view.h"
22 #include "llvm/ADT/DenseSet.h"
23 #include "llvm/ADT/StringRef.h"
24 #include "llvm/Support/Casting.h"
25 #include "mlir/IR/Attributes.h" // from @llvm-project
26 #include "mlir/Interfaces/DerivedAttributeOpInterface.h" // from @llvm-project
27 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
28 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
29 #include "tensorflow/compiler/mlir/tensorflow/utils/export_utils.h"
30 #include "tensorflow/compiler/mlir/utils/string_container_utils.h"
31 #include "tensorflow/compiler/xla/status_macros.h"
32 #include "tensorflow/core/framework/node_def_util.h"
33 #include "tensorflow/core/framework/tensor_shape.pb.h"
34 #include "tensorflow/core/lib/core/errors.h"
35 #include "tensorflow/core/lib/core/status.h"
36
37 namespace tensorflow {
38
39 namespace {
40
41 // Sets type list attribute with the given `name` to the given `types`. If the
42 // attribute already exists with a different value, returns an error.
43 template <typename ContainerT,
44 typename = typename std::enable_if<
45 std::is_same<mlir::Type, decltype(*std::declval<ContainerT>()
46 .begin())>::value>::type>
SetTypeAttribute(absl::string_view name,ContainerT types,AttrValueMap * values)47 Status SetTypeAttribute(absl::string_view name, ContainerT types,
48 AttrValueMap* values) {
49 AttrValue value;
50 auto& type_list = *value.mutable_list();
51 for (auto type : types) {
52 DataType dtype;
53 TF_RETURN_IF_ERROR(ConvertScalarTypeToDataType(type, &dtype));
54 type_list.add_type(dtype);
55 }
56
57 auto result = values->insert({string(name), value});
58 assert(result.second && "cannot have multiple attributes with the same name");
59 (void)result;
60
61 return OkStatus();
62 }
63
64 // Sets shape list attribute with the given `name` to the given `shapes`. If the
65 // attribute already exists then this will just retain the set value.
66 template <typename ContainerT,
67 typename = typename std::enable_if<std::is_same<
68 llvm::Optional<llvm::ArrayRef<int64_t>>,
69 decltype(*std::declval<ContainerT>().begin())>::value>::type>
SetShapeAttribute(absl::string_view name,ContainerT shapes,AttrValueMap * values)70 void SetShapeAttribute(absl::string_view name, ContainerT shapes,
71 AttrValueMap* values) {
72 AttrValue value;
73 auto& shape_list = *value.mutable_list();
74 for (const llvm::Optional<llvm::ArrayRef<int64_t>>& shape : shapes) {
75 TensorShapeProto& tshape = *shape_list.add_shape();
76 if (shape.has_value()) {
77 for (int64_t dim : *shape) tshape.add_dim()->set_size(dim);
78 } else {
79 tshape.set_unknown_rank(true);
80 }
81 }
82
83 // If shape is already set, override it. This can happen if we import
84 // without shape inference enabled and so couldn't be removed on import and
85 // are not explicitly dropped later.
86 (*values)[string(name)] = value;
87 }
88
89 // Collects all the unregistered attributes for an TF dialect operation.
90 // Attributes "name" and "device" are not included because they are not part
91 // of an TF op attributes.
GetUnregisteredAttrs(mlir::Operation * inst,const tensorflow::OpRegistrationData * op_reg_data,absl::flat_hash_set<absl::string_view> * attrs_to_ignore)92 Status GetUnregisteredAttrs(
93 mlir::Operation* inst, const tensorflow::OpRegistrationData* op_reg_data,
94 absl::flat_hash_set<absl::string_view>* attrs_to_ignore) {
95 if (!op_reg_data) {
96 // This is likely a function call node, so we should continue.
97 return OkStatus();
98 }
99
100 // Collect all the registered attributes.
101 llvm::DenseSet<llvm::StringRef> registered_attrs;
102 registered_attrs.insert("name");
103 registered_attrs.insert("device");
104 for (const auto& attr_def : op_reg_data->op_def.attr()) {
105 registered_attrs.insert(attr_def.name());
106 }
107 // Attributes are not in the registered attributes set will be ignored.
108 for (auto& attr : inst->getAttrs()) {
109 if (registered_attrs.find(attr.getName()) == registered_attrs.end()) {
110 attrs_to_ignore->insert(
111 absl::string_view(attr.getName().data(), attr.getName().size()));
112 }
113 }
114 return OkStatus();
115 }
116
117 // Collects all attribute names to ignore in an MLIR operation when exporting to
118 // a TensorFlow NodeDef.
GetAttributesToIgnore(mlir::Operation * inst,mlir::DictionaryAttr derived_attrs,const tensorflow::OpRegistrationData * op_reg_data,bool ignore_unregistered_attrs)119 StatusOr<absl::flat_hash_set<absl::string_view>> GetAttributesToIgnore(
120 mlir::Operation* inst, mlir::DictionaryAttr derived_attrs,
121 const tensorflow::OpRegistrationData* op_reg_data,
122 bool ignore_unregistered_attrs) {
123 // The elements are owned by the MLIRContext.
124 absl::flat_hash_set<absl::string_view> attrs_to_ignore;
125
126 // We ignore attributes attached to the operation when there is already a
127 // derived attribute defined in ODS.
128 if (derived_attrs) {
129 for (auto derived_attr : derived_attrs) {
130 attrs_to_ignore.insert(
131 mlir::StringRefToView(derived_attr.getName().strref()));
132 }
133 }
134
135 if (ignore_unregistered_attrs) {
136 TF_RETURN_IF_ERROR(
137 GetUnregisteredAttrs(inst, op_reg_data, &attrs_to_ignore));
138 }
139
140 if (inst->hasTrait<mlir::OpTrait::AttrSizedOperandSegments>()) {
141 // TODO(b/146937733): Don't use <void> here.
142 llvm::StringRef attr_name = mlir::OpTrait::AttrSizedOperandSegments<
143 void>::getOperandSegmentSizeAttr();
144 attrs_to_ignore.insert(attr_name.data());
145 }
146
147 if (inst->hasTrait<mlir::OpTrait::AttrSizedResultSegments>()) {
148 // TODO(b/146937733): Don't use <void> here.
149 llvm::StringRef attr_name = mlir::OpTrait::AttrSizedResultSegments<
150 void>::getResultSegmentSizeAttr();
151 attrs_to_ignore.insert(attr_name.data());
152 }
153
154 if (llvm::isa<mlir::TF::CaseOp, mlir::TF::IfOp, mlir::TF::WhileOp>(inst))
155 attrs_to_ignore.insert("is_stateless");
156
157 if (llvm::isa<mlir::TF::WhileOp>(inst))
158 attrs_to_ignore.insert("shape_invariant");
159
160 return attrs_to_ignore;
161 }
162
163 // Populates all derived attributes of a MLIR operation in a proto
164 // map<string, AttrValue>.
PopulateDerivedAttributes(mlir::Operation * inst,llvm::StringRef name,mlir::DictionaryAttr derived_attrs,bool ignore_unregistered_attrs,AttrValueMap * attributes)165 Status PopulateDerivedAttributes(mlir::Operation* inst, llvm::StringRef name,
166 mlir::DictionaryAttr derived_attrs,
167 bool ignore_unregistered_attrs,
168 AttrValueMap* attributes) {
169 if (derived_attrs) {
170 TF_RETURN_WITH_CONTEXT_IF_ERROR(
171 ConvertAttributes(derived_attrs.getValue(), /*attrs_to_ignore=*/{},
172 /*remove_ref_type=*/true, attributes),
173 "while converting derived attributes for node: ",
174 mlir::StringRefToView(name));
175 }
176
177 // Here we only add the shapes for the leading values with ShapedType,
178 // assuming values with non-ShapedType are put at the end of the result.
179 if (!ignore_unregistered_attrs && inst->getNumResults() > 0) {
180 auto values = inst->getResults();
181 auto begin = values.begin();
182 auto end = values.begin();
183 while (end != values.end() && (*end).getType().isa<mlir::ShapedType>())
184 end++;
185 if (begin != end) {
186 mlir::TF::ResultShapeRange output_shapes = {
187 mlir::TF::ResultShapeIterator(begin),
188 mlir::TF::ResultShapeIterator(end)};
189 SetShapeAttribute("_output_shapes", output_shapes, attributes);
190 }
191 }
192
193 return OkStatus();
194 }
195
196 } // namespace
197
GetAttrValuesFromOperation(mlir::Operation * inst,llvm::StringRef name,const tensorflow::OpRegistrationData * op_reg_data,bool ignore_unregistered_attrs,AttrValueMap * attributes)198 Status GetAttrValuesFromOperation(
199 mlir::Operation* inst, llvm::StringRef name,
200 const tensorflow::OpRegistrationData* op_reg_data,
201 bool ignore_unregistered_attrs, AttrValueMap* attributes) {
202 mlir::DictionaryAttr derived_attrs = nullptr;
203 if (auto interface = llvm::dyn_cast<mlir::DerivedAttributeOpInterface>(inst))
204 derived_attrs = interface.materializeDerivedAttributes();
205 TF_ASSIGN_OR_RETURN(auto attrs_to_ignore,
206 GetAttributesToIgnore(inst, derived_attrs, op_reg_data,
207 ignore_unregistered_attrs));
208 TF_RETURN_WITH_CONTEXT_IF_ERROR(
209 ConvertAttributes(inst->getAttrs(), attrs_to_ignore,
210 /*remove_ref_type=*/false, attributes),
211 "while converting attributes for node: ", mlir::StringRefToView(name));
212 TF_RETURN_IF_ERROR(PopulateDerivedAttributes(
213 inst, name, derived_attrs, ignore_unregistered_attrs, attributes));
214
215 // Explicitly handle XlaHostCompute op which has required function attribute
216 // in TensorFlow op def but it could have an empty value to represent missing
217 // functions. This value can't be represented using MLIR SymbolRefAttr and
218 // instead uses optional symbol ref attribute.
219 //
220 // TODO(b/182315488): Remove custom handling by finding a better
221 // representation in MLIR for empty function names. One option could be to use
222 // TensorFlow op defs to figure out function attributes that are missing in
223 // MLIR. This will also require some trait to identify optional attributes in
224 // MLIR.
225 constexpr char kShapeInferenceGraph[] = "shape_inference_graph";
226 if (mlir::isa<mlir::TF::XlaHostComputeOp>(inst) &&
227 !inst->hasAttr(kShapeInferenceGraph) &&
228 !attrs_to_ignore.contains(kShapeInferenceGraph)) {
229 AttrValue value;
230 value.mutable_func()->set_name("");
231 (*attributes)[kShapeInferenceGraph] = value;
232 }
233 return OkStatus();
234 }
235
ConvertTFDialectOpToNodeDef(mlir::Operation * inst,llvm::StringRef name,bool ignore_unregistered_attrs)236 StatusOr<std::unique_ptr<NodeDef>> ConvertTFDialectOpToNodeDef(
237 mlir::Operation* inst, llvm::StringRef name,
238 bool ignore_unregistered_attrs) {
239 TF_ASSIGN_OR_RETURN(auto node_def, GetOperationNodeDef(inst, name));
240 TF_ASSIGN_OR_RETURN(auto op_name,
241 GetTensorFlowOpName(inst->getName().getStringRef()));
242 const tensorflow::OpRegistrationData* op_reg_data =
243 tensorflow::OpRegistry::Global()->LookUp(op_name.str());
244 TF_RETURN_IF_ERROR(GetAttrValuesFromOperation(inst, name, op_reg_data,
245 ignore_unregistered_attrs,
246 node_def->mutable_attr()));
247 return node_def;
248 }
249
250 } // namespace tensorflow
251