xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/lite/transforms/lift_tflite_flex_ops.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/lite/transforms/lift_tflite_flex_ops.h"
17 
18 #include <string>
19 #include <utility>
20 
21 #include "flatbuffers/flexbuffers.h"  // from @flatbuffers
22 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
23 #include "mlir/IR/Attributes.h"  // from @llvm-project
24 #include "mlir/IR/Builders.h"  // from @llvm-project
25 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
26 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
27 #include "mlir/IR/Diagnostics.h"  // from @llvm-project
28 #include "mlir/IR/Location.h"  // from @llvm-project
29 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
30 #include "mlir/IR/Matchers.h"  // from @llvm-project
31 #include "mlir/IR/Operation.h"  // from @llvm-project
32 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
33 #include "mlir/Pass/Pass.h"  // from @llvm-project
34 #include "mlir/Pass/PassRegistry.h"  // from @llvm-project
35 #include "mlir/Support/LLVM.h"  // from @llvm-project
36 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
37 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"  // from @llvm-project
38 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
39 #include "tensorflow/compiler/mlir/lite/transforms/passes.h"
40 #include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h"
41 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h"
42 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
43 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
44 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_attr.h"
45 #include "tensorflow/core/framework/node_def.pb.h"
46 #include "tensorflow/core/framework/node_def_util.h"
47 #include "tensorflow/core/framework/op.h"
48 #include "tensorflow/core/framework/op_def.pb.h"
49 #include "tensorflow/core/platform/errors.h"
50 #include "tensorflow/core/platform/protobuf.h"
51 
52 // The pass lifts TFLite Flex custom ops into TF dialect operations.
53 // Note: this pass is experimental, so not guaranteed to work with all Flex ops.
54 
55 namespace mlir {
56 namespace TFL {
57 namespace {
58 #define GEN_PASS_CLASSES
59 #include "tensorflow/compiler/mlir/lite/transforms/passes.h.inc"
60 
61 using ::tensorflow::StatusOr;
62 
63 constexpr StringRef kFlexOpNamePrefix = "Flex";
64 
65 // Pattern that converts TFL::CustomOp that encodes a Flex op into a TF dialect
66 // operation.
67 class LiftFlexCustomOp : public OpRewritePattern<TFL::CustomOp> {
68  public:
69   using OpRewritePattern<TFL::CustomOp>::OpRewritePattern;
70 
matchAndRewrite(TFL::CustomOp op,PatternRewriter & rewriter) const71   LogicalResult matchAndRewrite(TFL::CustomOp op,
72                                 PatternRewriter& rewriter) const override {
73     if (!op.custom_code().startswith(kFlexOpNamePrefix)) {
74       return failure();
75     }
76 
77     llvm::StringRef tf_op_name =
78         op.custom_code().substr(kFlexOpNamePrefix.size());
79     const std::string tf_op_full_name = llvm::Twine("tf.", tf_op_name).str();
80 
81     // Create the TF op
82     OperationState op_state(op.getLoc(), tf_op_full_name);
83     op_state.addOperands(op.getOperands());
84     op_state.addTypes(op.getResultTypes());
85 
86     SmallVector<NamedAttribute, 2> attrs;
87     std::string parsed_op_name;
88     tensorflow::NodeDef node_def;
89     if (failed(ParseCustomOption(op.custom_option().getValue(), op.getLoc(),
90                                  parsed_op_name, attrs, node_def))) {
91       return failure();
92     }
93     if (parsed_op_name != tf_op_name) {
94       return op.emitOpError(
95           "TF op names in 'custom_code' and 'custom_option' don't match");
96     }
97     const tensorflow::OpDef* op_def;
98 
99     // This will fail only if the op is not a registered TensorFlow op.
100     if (!tensorflow::OpRegistry::Global()
101              ->LookUpOpDef(parsed_op_name, &op_def)
102              .ok()) {
103       op->emitError() << "can't find registered TF op for " << parsed_op_name
104                       << ". Please make sure the op library for "
105                       << parsed_op_name << " is linked properly";
106       return failure();
107     }
108     op_state.addAttributes(attrs);
109 
110     Operation* tf_op = rewriter.create(op_state);
111     rewriter.replaceOp(op, tf_op->getResults());
112 
113     // Special type fixes for TF Resource Tensors that are casted to
114     // Int32 tensor during MLIR->TFLite flatbuffer conversion.
115     // TODO(b/146131919): correct handling of resource type
116     if (auto tensor_array_v3_op = dyn_cast<TF::TensorArrayV3Op>(tf_op)) {
117       Value handle = tensor_array_v3_op.handle();
118       auto handle_type = handle.getType().cast<TensorType>();
119       if (handle_type.getElementType().isInteger(/*width=*/32)) {
120         Type resource_tensor_type =
121             handle_type.clone(TF::ResourceType::get(rewriter.getContext()));
122         handle.setType(resource_tensor_type);
123       }
124     }
125 
126     // Special type fixes for scalar tensor types.
127     // TFLite flatbuffer schema doesn't distinguish scalar tensor shapes
128     // and unranked tensor shapes (i.e. they are both represented as an empty
129     // INT32 list), see b/138865275. MLIR importer conservatively treats them as
130     // unranked tensor types. Here we set them to scalar tensor types when it is
131     // safe.
132     if (auto tensor_array_v3_op = dyn_cast<TF::TensorArrayV3Op>(tf_op)) {
133       // The "flow" in TensorArrayV3 is always a scalar float tensor.
134       // https://www.tensorflow.org/api_docs/python/tf/raw_ops/TensorArrayWriteV3
135       Value flow = tensor_array_v3_op.flow();
136       Type scalar_f32_tensor_type =
137           RankedTensorType::get(/*shape=*/{}, rewriter.getF32Type());
138       flow.setType(scalar_f32_tensor_type);
139     }
140 
141     // Sets operand_segment_sizes or result_segment_sizes attribute to the op.
142     // Referenced from tensorflow::ImporterBase::CreateOperation
143 
144     const auto set_segment_sizes_attr =
145         [&](const tensorflow::NameRangeMap& arg_ranges,
146             const tensorflow::protobuf::RepeatedPtrField<
147                 tensorflow::OpDef::ArgDef>& args,
148             llvm::StringRef attr_name) {
149           std::vector<int32_t> values;
150           values.reserve(args.size());
151           for (const auto& arg : args) {
152             auto range = arg_ranges.at(arg.name());
153             values.push_back(
154                 range.second - range.first);
155           }
156           auto attr_value = mlir::DenseI32ArrayAttr::get(tf_op->getContext(), values);
157           tf_op->setAttr(attr_name, attr_value);
158         };
159     if (tf_op->hasTrait<mlir::OpTrait::AttrSizedOperandSegments>() ||
160         tf_op->hasTrait<mlir::OpTrait::AttrSizedResultSegments>()) {
161       // The op has multiple variadic operands or results.
162       // Calculate operand and result segment sizes using the OpDef.
163       tensorflow::NameRangeMap input_ranges, output_ranges;
164       // This will fail only if the OpDef is syntactically invalid.
165       if (!NameRangesForNode(node_def, *op_def, &input_ranges, &output_ranges)
166                .ok()) {
167         tf_op->emitError("malformed opdef");
168         return failure();
169       }
170       if (tf_op->hasTrait<mlir::OpTrait::AttrSizedOperandSegments>()) {
171         // Add derived "operand_segment_sizes" attr to the created operation.
172         // TODO(b/146937733): Don't use <void> here.
173         set_segment_sizes_attr(input_ranges, op_def->input_arg(),
174                                mlir::OpTrait::AttrSizedOperandSegments<
175                                    void>::getOperandSegmentSizeAttr());
176       }
177 
178       if (tf_op->hasTrait<mlir::OpTrait::AttrSizedResultSegments>()) {
179         // Add derived "result_segment_sizes" attr to the created operation.
180         // TODO(b/146937733): Don't use <void> here.
181         set_segment_sizes_attr(output_ranges, op_def->output_arg(),
182                                mlir::OpTrait::AttrSizedResultSegments<
183                                    void>::getResultSegmentSizeAttr());
184       }
185     }
186 
187     return success();
188   }
189 
190  private:
191   // Parses TFLite Flex op's `custom_options` and returns the TF
192   // `op_name` and TF dialect MLIR op attributes.
ParseCustomOption(StringRef custom_options,Location loc,std::string & op_name,SmallVectorImpl<NamedAttribute> & attributes,tensorflow::NodeDef & node_def)193   static LogicalResult ParseCustomOption(
194       StringRef custom_options, Location loc, std::string& op_name,
195       SmallVectorImpl<NamedAttribute>& attributes,
196       tensorflow::NodeDef& node_def) {
197     // The flexbuffer contains a vector where the first elements is the
198     // op name and the second is a serialized NodeDef.
199     const flexbuffers::Vector& v =
200         flexbuffers::GetRoot(
201             reinterpret_cast<const uint8_t*>(custom_options.data()),
202             custom_options.size())
203             .AsVector();
204 
205     op_name = v[0].AsString().str();
206 
207     if (!node_def.ParseFromString(v[1].AsString().str())) {
208       return emitError(
209           loc, "failed to parse 'custom_options' data into a valid NodeDef");
210     }
211 
212     OpBuilder builder(loc.getContext());
213     for (const auto& name_and_value : node_def.attr()) {
214       const std::string& attr_name = name_and_value.first;
215       const tensorflow::AttrValue& attr_value = name_and_value.second;
216       StatusOr<Attribute> mlir_attr =
217           tensorflow::ConvertAttributeValue(attr_value, &builder);
218       if (!mlir_attr.ok()) {
219         return emitError(loc, mlir_attr.status().error_message());
220       }
221       attributes.push_back(builder.getNamedAttr(attr_name, *mlir_attr));
222     }
223     return success();
224   }
225 };
226 
227 class LiftTfliteFlexOpsPass
228     : public LiftTfliteFlexOpsPassBase<LiftTfliteFlexOpsPass> {
getDependentDialects(DialectRegistry & registry) const229   void getDependentDialects(DialectRegistry& registry) const override {
230     registry.insert<TF::TensorFlowDialect>();
231   }
232 
233  public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LiftTfliteFlexOpsPass)234   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LiftTfliteFlexOpsPass)
235 
236   void runOnOperation() override {
237     MLIRContext* context = &getContext();
238     func::FuncOp func = getOperation();
239 
240     mlir::RewritePatternSet patterns(context);
241     patterns.add<LiftFlexCustomOp>(context);
242     if (failed(applyPatternsAndFoldGreedily(func, std::move(patterns)))) {
243       signalPassFailure();
244       return;
245     }
246   }
247 };
248 
249 }  // namespace
250 
CreateLiftTfliteFlexOpsPass()251 std::unique_ptr<OperationPass<func::FuncOp>> CreateLiftTfliteFlexOpsPass() {
252   return std::make_unique<LiftTfliteFlexOpsPass>();
253 }
254 }  // namespace TFL
255 }  // namespace mlir
256