1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/lite/transforms/lift_tflite_flex_ops.h"
17
18 #include <string>
19 #include <utility>
20
21 #include "flatbuffers/flexbuffers.h" // from @flatbuffers
22 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
23 #include "mlir/IR/Attributes.h" // from @llvm-project
24 #include "mlir/IR/Builders.h" // from @llvm-project
25 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
26 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
27 #include "mlir/IR/Diagnostics.h" // from @llvm-project
28 #include "mlir/IR/Location.h" // from @llvm-project
29 #include "mlir/IR/MLIRContext.h" // from @llvm-project
30 #include "mlir/IR/Matchers.h" // from @llvm-project
31 #include "mlir/IR/Operation.h" // from @llvm-project
32 #include "mlir/IR/PatternMatch.h" // from @llvm-project
33 #include "mlir/Pass/Pass.h" // from @llvm-project
34 #include "mlir/Pass/PassRegistry.h" // from @llvm-project
35 #include "mlir/Support/LLVM.h" // from @llvm-project
36 #include "mlir/Support/LogicalResult.h" // from @llvm-project
37 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project
38 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
39 #include "tensorflow/compiler/mlir/lite/transforms/passes.h"
40 #include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h"
41 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h"
42 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
43 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
44 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_attr.h"
45 #include "tensorflow/core/framework/node_def.pb.h"
46 #include "tensorflow/core/framework/node_def_util.h"
47 #include "tensorflow/core/framework/op.h"
48 #include "tensorflow/core/framework/op_def.pb.h"
49 #include "tensorflow/core/platform/errors.h"
50 #include "tensorflow/core/platform/protobuf.h"
51
52 // The pass lifts TFLite Flex custom ops into TF dialect operations.
53 // Note: this pass is experimental, so not guaranteed to work with all Flex ops.
54
55 namespace mlir {
56 namespace TFL {
57 namespace {
58 #define GEN_PASS_CLASSES
59 #include "tensorflow/compiler/mlir/lite/transforms/passes.h.inc"
60
61 using ::tensorflow::StatusOr;
62
63 constexpr StringRef kFlexOpNamePrefix = "Flex";
64
65 // Pattern that converts TFL::CustomOp that encodes a Flex op into a TF dialect
66 // operation.
67 class LiftFlexCustomOp : public OpRewritePattern<TFL::CustomOp> {
68 public:
69 using OpRewritePattern<TFL::CustomOp>::OpRewritePattern;
70
matchAndRewrite(TFL::CustomOp op,PatternRewriter & rewriter) const71 LogicalResult matchAndRewrite(TFL::CustomOp op,
72 PatternRewriter& rewriter) const override {
73 if (!op.custom_code().startswith(kFlexOpNamePrefix)) {
74 return failure();
75 }
76
77 llvm::StringRef tf_op_name =
78 op.custom_code().substr(kFlexOpNamePrefix.size());
79 const std::string tf_op_full_name = llvm::Twine("tf.", tf_op_name).str();
80
81 // Create the TF op
82 OperationState op_state(op.getLoc(), tf_op_full_name);
83 op_state.addOperands(op.getOperands());
84 op_state.addTypes(op.getResultTypes());
85
86 SmallVector<NamedAttribute, 2> attrs;
87 std::string parsed_op_name;
88 tensorflow::NodeDef node_def;
89 if (failed(ParseCustomOption(op.custom_option().getValue(), op.getLoc(),
90 parsed_op_name, attrs, node_def))) {
91 return failure();
92 }
93 if (parsed_op_name != tf_op_name) {
94 return op.emitOpError(
95 "TF op names in 'custom_code' and 'custom_option' don't match");
96 }
97 const tensorflow::OpDef* op_def;
98
99 // This will fail only if the op is not a registered TensorFlow op.
100 if (!tensorflow::OpRegistry::Global()
101 ->LookUpOpDef(parsed_op_name, &op_def)
102 .ok()) {
103 op->emitError() << "can't find registered TF op for " << parsed_op_name
104 << ". Please make sure the op library for "
105 << parsed_op_name << " is linked properly";
106 return failure();
107 }
108 op_state.addAttributes(attrs);
109
110 Operation* tf_op = rewriter.create(op_state);
111 rewriter.replaceOp(op, tf_op->getResults());
112
113 // Special type fixes for TF Resource Tensors that are casted to
114 // Int32 tensor during MLIR->TFLite flatbuffer conversion.
115 // TODO(b/146131919): correct handling of resource type
116 if (auto tensor_array_v3_op = dyn_cast<TF::TensorArrayV3Op>(tf_op)) {
117 Value handle = tensor_array_v3_op.handle();
118 auto handle_type = handle.getType().cast<TensorType>();
119 if (handle_type.getElementType().isInteger(/*width=*/32)) {
120 Type resource_tensor_type =
121 handle_type.clone(TF::ResourceType::get(rewriter.getContext()));
122 handle.setType(resource_tensor_type);
123 }
124 }
125
126 // Special type fixes for scalar tensor types.
127 // TFLite flatbuffer schema doesn't distinguish scalar tensor shapes
128 // and unranked tensor shapes (i.e. they are both represented as an empty
129 // INT32 list), see b/138865275. MLIR importer conservatively treats them as
130 // unranked tensor types. Here we set them to scalar tensor types when it is
131 // safe.
132 if (auto tensor_array_v3_op = dyn_cast<TF::TensorArrayV3Op>(tf_op)) {
133 // The "flow" in TensorArrayV3 is always a scalar float tensor.
134 // https://www.tensorflow.org/api_docs/python/tf/raw_ops/TensorArrayWriteV3
135 Value flow = tensor_array_v3_op.flow();
136 Type scalar_f32_tensor_type =
137 RankedTensorType::get(/*shape=*/{}, rewriter.getF32Type());
138 flow.setType(scalar_f32_tensor_type);
139 }
140
141 // Sets operand_segment_sizes or result_segment_sizes attribute to the op.
142 // Referenced from tensorflow::ImporterBase::CreateOperation
143
144 const auto set_segment_sizes_attr =
145 [&](const tensorflow::NameRangeMap& arg_ranges,
146 const tensorflow::protobuf::RepeatedPtrField<
147 tensorflow::OpDef::ArgDef>& args,
148 llvm::StringRef attr_name) {
149 std::vector<int32_t> values;
150 values.reserve(args.size());
151 for (const auto& arg : args) {
152 auto range = arg_ranges.at(arg.name());
153 values.push_back(
154 range.second - range.first);
155 }
156 auto attr_value = mlir::DenseI32ArrayAttr::get(tf_op->getContext(), values);
157 tf_op->setAttr(attr_name, attr_value);
158 };
159 if (tf_op->hasTrait<mlir::OpTrait::AttrSizedOperandSegments>() ||
160 tf_op->hasTrait<mlir::OpTrait::AttrSizedResultSegments>()) {
161 // The op has multiple variadic operands or results.
162 // Calculate operand and result segment sizes using the OpDef.
163 tensorflow::NameRangeMap input_ranges, output_ranges;
164 // This will fail only if the OpDef is syntactically invalid.
165 if (!NameRangesForNode(node_def, *op_def, &input_ranges, &output_ranges)
166 .ok()) {
167 tf_op->emitError("malformed opdef");
168 return failure();
169 }
170 if (tf_op->hasTrait<mlir::OpTrait::AttrSizedOperandSegments>()) {
171 // Add derived "operand_segment_sizes" attr to the created operation.
172 // TODO(b/146937733): Don't use <void> here.
173 set_segment_sizes_attr(input_ranges, op_def->input_arg(),
174 mlir::OpTrait::AttrSizedOperandSegments<
175 void>::getOperandSegmentSizeAttr());
176 }
177
178 if (tf_op->hasTrait<mlir::OpTrait::AttrSizedResultSegments>()) {
179 // Add derived "result_segment_sizes" attr to the created operation.
180 // TODO(b/146937733): Don't use <void> here.
181 set_segment_sizes_attr(output_ranges, op_def->output_arg(),
182 mlir::OpTrait::AttrSizedResultSegments<
183 void>::getResultSegmentSizeAttr());
184 }
185 }
186
187 return success();
188 }
189
190 private:
191 // Parses TFLite Flex op's `custom_options` and returns the TF
192 // `op_name` and TF dialect MLIR op attributes.
ParseCustomOption(StringRef custom_options,Location loc,std::string & op_name,SmallVectorImpl<NamedAttribute> & attributes,tensorflow::NodeDef & node_def)193 static LogicalResult ParseCustomOption(
194 StringRef custom_options, Location loc, std::string& op_name,
195 SmallVectorImpl<NamedAttribute>& attributes,
196 tensorflow::NodeDef& node_def) {
197 // The flexbuffer contains a vector where the first elements is the
198 // op name and the second is a serialized NodeDef.
199 const flexbuffers::Vector& v =
200 flexbuffers::GetRoot(
201 reinterpret_cast<const uint8_t*>(custom_options.data()),
202 custom_options.size())
203 .AsVector();
204
205 op_name = v[0].AsString().str();
206
207 if (!node_def.ParseFromString(v[1].AsString().str())) {
208 return emitError(
209 loc, "failed to parse 'custom_options' data into a valid NodeDef");
210 }
211
212 OpBuilder builder(loc.getContext());
213 for (const auto& name_and_value : node_def.attr()) {
214 const std::string& attr_name = name_and_value.first;
215 const tensorflow::AttrValue& attr_value = name_and_value.second;
216 StatusOr<Attribute> mlir_attr =
217 tensorflow::ConvertAttributeValue(attr_value, &builder);
218 if (!mlir_attr.ok()) {
219 return emitError(loc, mlir_attr.status().error_message());
220 }
221 attributes.push_back(builder.getNamedAttr(attr_name, *mlir_attr));
222 }
223 return success();
224 }
225 };
226
227 class LiftTfliteFlexOpsPass
228 : public LiftTfliteFlexOpsPassBase<LiftTfliteFlexOpsPass> {
getDependentDialects(DialectRegistry & registry) const229 void getDependentDialects(DialectRegistry& registry) const override {
230 registry.insert<TF::TensorFlowDialect>();
231 }
232
233 public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LiftTfliteFlexOpsPass)234 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LiftTfliteFlexOpsPass)
235
236 void runOnOperation() override {
237 MLIRContext* context = &getContext();
238 func::FuncOp func = getOperation();
239
240 mlir::RewritePatternSet patterns(context);
241 patterns.add<LiftFlexCustomOp>(context);
242 if (failed(applyPatternsAndFoldGreedily(func, std::move(patterns)))) {
243 signalPassFailure();
244 return;
245 }
246 }
247 };
248
249 } // namespace
250
CreateLiftTfliteFlexOpsPass()251 std::unique_ptr<OperationPass<func::FuncOp>> CreateLiftTfliteFlexOpsPass() {
252 return std::make_unique<LiftTfliteFlexOpsPass>();
253 }
254 } // namespace TFL
255 } // namespace mlir
256