xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h"
17 
18 #include <memory>
19 
20 #include "absl/container/flat_hash_set.h"
21 #include "absl/strings/string_view.h"
22 #include "llvm/ADT/DenseSet.h"
23 #include "llvm/ADT/StringRef.h"
24 #include "llvm/Support/Casting.h"
25 #include "mlir/IR/Attributes.h"  // from @llvm-project
26 #include "mlir/Interfaces/DerivedAttributeOpInterface.h"  // from @llvm-project
27 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
28 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
29 #include "tensorflow/compiler/mlir/tensorflow/utils/export_utils.h"
30 #include "tensorflow/compiler/mlir/utils/string_container_utils.h"
31 #include "tensorflow/compiler/xla/status_macros.h"
32 #include "tensorflow/core/framework/node_def_util.h"
33 #include "tensorflow/core/framework/tensor_shape.pb.h"
34 #include "tensorflow/core/lib/core/errors.h"
35 #include "tensorflow/core/lib/core/status.h"
36 
37 namespace tensorflow {
38 
39 namespace {
40 
41 // Sets type list attribute with the given `name` to the given `types`. If the
42 // attribute already exists with a different value, returns an error.
43 template <typename ContainerT,
44           typename = typename std::enable_if<
45               std::is_same<mlir::Type, decltype(*std::declval<ContainerT>()
46                                                      .begin())>::value>::type>
SetTypeAttribute(absl::string_view name,ContainerT types,AttrValueMap * values)47 Status SetTypeAttribute(absl::string_view name, ContainerT types,
48                         AttrValueMap* values) {
49   AttrValue value;
50   auto& type_list = *value.mutable_list();
51   for (auto type : types) {
52     DataType dtype;
53     TF_RETURN_IF_ERROR(ConvertScalarTypeToDataType(type, &dtype));
54     type_list.add_type(dtype);
55   }
56 
57   auto result = values->insert({string(name), value});
58   assert(result.second && "cannot have multiple attributes with the same name");
59   (void)result;
60 
61   return OkStatus();
62 }
63 
64 // Sets shape list attribute with the given `name` to the given `shapes`. If the
65 // attribute already exists then this will just retain the set value.
66 template <typename ContainerT,
67           typename = typename std::enable_if<std::is_same<
68               llvm::Optional<llvm::ArrayRef<int64_t>>,
69               decltype(*std::declval<ContainerT>().begin())>::value>::type>
SetShapeAttribute(absl::string_view name,ContainerT shapes,AttrValueMap * values)70 void SetShapeAttribute(absl::string_view name, ContainerT shapes,
71                        AttrValueMap* values) {
72   AttrValue value;
73   auto& shape_list = *value.mutable_list();
74   for (const llvm::Optional<llvm::ArrayRef<int64_t>>& shape : shapes) {
75     TensorShapeProto& tshape = *shape_list.add_shape();
76     if (shape.has_value()) {
77       for (int64_t dim : *shape) tshape.add_dim()->set_size(dim);
78     } else {
79       tshape.set_unknown_rank(true);
80     }
81   }
82 
83   // If shape is already set, override it. This can happen if we import
84   // without shape inference enabled and so couldn't be removed on import and
85   // are not explicitly dropped later.
86   (*values)[string(name)] = value;
87 }
88 
89 // Collects all the unregistered attributes for an TF dialect operation.
90 // Attributes "name" and "device" are not included because they are not part
91 // of an TF op attributes.
GetUnregisteredAttrs(mlir::Operation * inst,const tensorflow::OpRegistrationData * op_reg_data,absl::flat_hash_set<absl::string_view> * attrs_to_ignore)92 Status GetUnregisteredAttrs(
93     mlir::Operation* inst, const tensorflow::OpRegistrationData* op_reg_data,
94     absl::flat_hash_set<absl::string_view>* attrs_to_ignore) {
95   if (!op_reg_data) {
96     // This is likely a function call node, so we should continue.
97     return OkStatus();
98   }
99 
100   // Collect all the registered attributes.
101   llvm::DenseSet<llvm::StringRef> registered_attrs;
102   registered_attrs.insert("name");
103   registered_attrs.insert("device");
104   for (const auto& attr_def : op_reg_data->op_def.attr()) {
105     registered_attrs.insert(attr_def.name());
106   }
107   // Attributes are not in the registered attributes set will be ignored.
108   for (auto& attr : inst->getAttrs()) {
109     if (registered_attrs.find(attr.getName()) == registered_attrs.end()) {
110       attrs_to_ignore->insert(
111           absl::string_view(attr.getName().data(), attr.getName().size()));
112     }
113   }
114   return OkStatus();
115 }
116 
117 // Collects all attribute names to ignore in an MLIR operation when exporting to
118 // a TensorFlow NodeDef.
GetAttributesToIgnore(mlir::Operation * inst,mlir::DictionaryAttr derived_attrs,const tensorflow::OpRegistrationData * op_reg_data,bool ignore_unregistered_attrs)119 StatusOr<absl::flat_hash_set<absl::string_view>> GetAttributesToIgnore(
120     mlir::Operation* inst, mlir::DictionaryAttr derived_attrs,
121     const tensorflow::OpRegistrationData* op_reg_data,
122     bool ignore_unregistered_attrs) {
123   // The elements are owned by the MLIRContext.
124   absl::flat_hash_set<absl::string_view> attrs_to_ignore;
125 
126   // We ignore attributes attached to the operation when there is already a
127   // derived attribute defined in ODS.
128   if (derived_attrs) {
129     for (auto derived_attr : derived_attrs) {
130       attrs_to_ignore.insert(
131           mlir::StringRefToView(derived_attr.getName().strref()));
132     }
133   }
134 
135   if (ignore_unregistered_attrs) {
136     TF_RETURN_IF_ERROR(
137         GetUnregisteredAttrs(inst, op_reg_data, &attrs_to_ignore));
138   }
139 
140   if (inst->hasTrait<mlir::OpTrait::AttrSizedOperandSegments>()) {
141     // TODO(b/146937733): Don't use <void> here.
142     llvm::StringRef attr_name = mlir::OpTrait::AttrSizedOperandSegments<
143         void>::getOperandSegmentSizeAttr();
144     attrs_to_ignore.insert(attr_name.data());
145   }
146 
147   if (inst->hasTrait<mlir::OpTrait::AttrSizedResultSegments>()) {
148     // TODO(b/146937733): Don't use <void> here.
149     llvm::StringRef attr_name = mlir::OpTrait::AttrSizedResultSegments<
150         void>::getResultSegmentSizeAttr();
151     attrs_to_ignore.insert(attr_name.data());
152   }
153 
154   if (llvm::isa<mlir::TF::CaseOp, mlir::TF::IfOp, mlir::TF::WhileOp>(inst))
155     attrs_to_ignore.insert("is_stateless");
156 
157   if (llvm::isa<mlir::TF::WhileOp>(inst))
158     attrs_to_ignore.insert("shape_invariant");
159 
160   return attrs_to_ignore;
161 }
162 
163 // Populates all derived attributes of a MLIR operation in a proto
164 // map<string, AttrValue>.
PopulateDerivedAttributes(mlir::Operation * inst,llvm::StringRef name,mlir::DictionaryAttr derived_attrs,bool ignore_unregistered_attrs,AttrValueMap * attributes)165 Status PopulateDerivedAttributes(mlir::Operation* inst, llvm::StringRef name,
166                                  mlir::DictionaryAttr derived_attrs,
167                                  bool ignore_unregistered_attrs,
168                                  AttrValueMap* attributes) {
169   if (derived_attrs) {
170     TF_RETURN_WITH_CONTEXT_IF_ERROR(
171         ConvertAttributes(derived_attrs.getValue(), /*attrs_to_ignore=*/{},
172                           /*remove_ref_type=*/true, attributes),
173         "while converting derived attributes for node: ",
174         mlir::StringRefToView(name));
175   }
176 
177   // Here we only add the shapes for the leading values with ShapedType,
178   // assuming values with non-ShapedType are put at the end of the result.
179   if (!ignore_unregistered_attrs && inst->getNumResults() > 0) {
180     auto values = inst->getResults();
181     auto begin = values.begin();
182     auto end = values.begin();
183     while (end != values.end() && (*end).getType().isa<mlir::ShapedType>())
184       end++;
185     if (begin != end) {
186       mlir::TF::ResultShapeRange output_shapes = {
187           mlir::TF::ResultShapeIterator(begin),
188           mlir::TF::ResultShapeIterator(end)};
189       SetShapeAttribute("_output_shapes", output_shapes, attributes);
190     }
191   }
192 
193   return OkStatus();
194 }
195 
196 }  // namespace
197 
GetAttrValuesFromOperation(mlir::Operation * inst,llvm::StringRef name,const tensorflow::OpRegistrationData * op_reg_data,bool ignore_unregistered_attrs,AttrValueMap * attributes)198 Status GetAttrValuesFromOperation(
199     mlir::Operation* inst, llvm::StringRef name,
200     const tensorflow::OpRegistrationData* op_reg_data,
201     bool ignore_unregistered_attrs, AttrValueMap* attributes) {
202   mlir::DictionaryAttr derived_attrs = nullptr;
203   if (auto interface = llvm::dyn_cast<mlir::DerivedAttributeOpInterface>(inst))
204     derived_attrs = interface.materializeDerivedAttributes();
205   TF_ASSIGN_OR_RETURN(auto attrs_to_ignore,
206                       GetAttributesToIgnore(inst, derived_attrs, op_reg_data,
207                                             ignore_unregistered_attrs));
208   TF_RETURN_WITH_CONTEXT_IF_ERROR(
209       ConvertAttributes(inst->getAttrs(), attrs_to_ignore,
210                         /*remove_ref_type=*/false, attributes),
211       "while converting attributes for node: ", mlir::StringRefToView(name));
212   TF_RETURN_IF_ERROR(PopulateDerivedAttributes(
213       inst, name, derived_attrs, ignore_unregistered_attrs, attributes));
214 
215   //  Explicitly handle XlaHostCompute op which has required function attribute
216   //  in TensorFlow op def but it could have an empty value to represent missing
217   //  functions. This value can't be represented using MLIR SymbolRefAttr and
218   //  instead uses optional symbol ref attribute.
219   //
220   // TODO(b/182315488): Remove custom handling by finding a better
221   // representation in MLIR for empty function names. One option could be to use
222   // TensorFlow op defs to figure out function attributes that are missing in
223   // MLIR. This will also require some trait to identify optional attributes in
224   // MLIR.
225   constexpr char kShapeInferenceGraph[] = "shape_inference_graph";
226   if (mlir::isa<mlir::TF::XlaHostComputeOp>(inst) &&
227       !inst->hasAttr(kShapeInferenceGraph) &&
228       !attrs_to_ignore.contains(kShapeInferenceGraph)) {
229     AttrValue value;
230     value.mutable_func()->set_name("");
231     (*attributes)[kShapeInferenceGraph] = value;
232   }
233   return OkStatus();
234 }
235 
ConvertTFDialectOpToNodeDef(mlir::Operation * inst,llvm::StringRef name,bool ignore_unregistered_attrs)236 StatusOr<std::unique_ptr<NodeDef>> ConvertTFDialectOpToNodeDef(
237     mlir::Operation* inst, llvm::StringRef name,
238     bool ignore_unregistered_attrs) {
239   TF_ASSIGN_OR_RETURN(auto node_def, GetOperationNodeDef(inst, name));
240   TF_ASSIGN_OR_RETURN(auto op_name,
241                       GetTensorFlowOpName(inst->getName().getStringRef()));
242   const tensorflow::OpRegistrationData* op_reg_data =
243       tensorflow::OpRegistry::Global()->LookUp(op_name.str());
244   TF_RETURN_IF_ERROR(GetAttrValuesFromOperation(inst, name, op_reg_data,
245                                                 ignore_unregistered_attrs,
246                                                 node_def->mutable_attr()));
247   return node_def;
248 }
249 
250 }  // namespace tensorflow
251