xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/lite/utils/nms_utils.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/lite/utils/nms_utils.h"
17 
18 #include <string>
19 
20 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
21 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
22 
23 namespace mlir {
24 namespace TFL {
25 
26 namespace {
27 
28 // TODO(b/162842801): Consolidate all util definitions of kTFImplements.
29 constexpr char kTFImplements[] = "tf._implements";
30 constexpr char kCustomSSDPostprocessing[] = "TFLite_Detection_PostProcess";
31 constexpr char kTfNMSPadded[] = "non_max_suppression_padded_v2";
32 
CustomOption(OpBuilder * builder,const std::string & content)33 inline ConstBytesAttr CustomOption(OpBuilder* builder,
34                                    const std::string& content) {
35   return ConstBytesAttr::get(builder->getContext(),
36                              StringRef(content.data(), content.size()));
37 }
38 
39 }  // namespace
40 
RewriteFunc()41 void ConvertNMSPaddedFunc::RewriteFunc() {
42   func_->setAttr(kTFImplements,
43                  StringAttr::get(func_.getContext(), kTfNMSPadded));
44   Value boxes = func_.getArgument(0);
45   Value scores = func_.getArgument(1);
46   Value max_output_size = func_.getArgument(2);
47   Value iou_threshold = func_.getArgument(3);
48   Value score_threshold = func_.getArgument(4);
49   auto output_type0 = func_.getFunctionType().getResult(0);
50   auto output_type1 = func_.getFunctionType().getResult(1);
51 
52   OpBuilder builder(func_.getBody());
53   auto op = builder.create<mlir::TFL::NonMaxSuppressionV4Op>(
54       func_.getLoc(), output_type0, output_type1, boxes, scores,
55       max_output_size, iou_threshold, score_threshold);
56 
57   builder.create<mlir::func::ReturnOp>(func_.getLoc(), op.getResults());
58 }
59 
VerifySignature()60 LogicalResult ConvertNMSPaddedFunc::VerifySignature() {
61   // Verify high-level function signature.
62   // Relevant argument characteristics are checked by the TFL op definition.
63   if (func_.getNumArguments() < 5) {
64     return func_.emitWarning()
65            << "Invalid number of arguments to "
66               "non_max_suppression_padded_v2 (need at least 5): "
67            << func_.getNumArguments();
68   }
69   if (func_.getFunctionType().getNumResults() != 2) {
70     return func_.emitWarning() << "Invalid number of results from "
71                                   "non_max_suppression_padded_v2 (need 2): "
72                                << func_.getFunctionType().getNumResults();
73   }
74   // The TFLite fused op does not support batching yet.
75   // TODO(b/158709815): Add support for batches with padded NMS.
76   auto boxes_type =
77       func_.getFunctionType().getInput(0).dyn_cast<RankedTensorType>();
78   if (boxes_type == nullptr || !boxes_type.hasRank() ||
79       boxes_type.getRank() != 2) {
80     return func_.emitWarning() << "TFLite does not support batched input for "
81                                   "non_max_suppression_padded";
82   }
83   return success();
84 }
85 
RewriteFunc()86 LogicalResult ConvertSSDPostProcessFunc::RewriteFunc() {
87   func_.eraseBody();
88   func_.addEntryBlock();
89   func_->setAttr(kTFImplements,
90                  StringAttr::get(func_.getContext(), kCustomSSDPostprocessing));
91 
92   OpBuilder builder(func_.getBody());
93   std::string custom_option_buffer;
94   if (failed(CreateNMSCustomOptions(func_, attr_.getAttrs(),
95                                     custom_option_buffer))) {
96     return failure();
97   }
98   auto op = builder.create<CustomOp>(
99       func_.getLoc(), func_.getFunctionType().getResults(),
100       func_.getArguments(), kCustomSSDPostprocessing,
101       CustomOption(&builder, custom_option_buffer));
102   builder.create<func::ReturnOp>(func_.getLoc(), op.getResults());
103 
104   return success();
105 }
106 
CreateNMSCustomOptions(func::FuncOp func,DictionaryAttr attrs,std::string & custom_option_buffer)107 LogicalResult ConvertSSDPostProcessFunc::CreateNMSCustomOptions(
108     func::FuncOp func, DictionaryAttr attrs,
109     std::string& custom_option_buffer) {
110   flexbuffers::Builder fbb;
111   size_t start_map = fbb.StartMap();
112 
113   if (failed(AddIntAttr(func, attrs, "max_detections", &fbb)) ||
114       failed(AddIntAttr(func, attrs, "max_classes_per_detection", &fbb)) ||
115       failed(AddIntAttr(func, attrs, "num_classes", &fbb)) ||
116       failed(AddFloatAttr(func, attrs, "nms_score_threshold", &fbb)) ||
117       failed(AddFloatAttr(func, attrs, "nms_iou_threshold", &fbb)) ||
118       failed(AddFloatAttr(func, attrs, "y_scale", &fbb)) ||
119       failed(AddFloatAttr(func, attrs, "x_scale", &fbb)) ||
120       failed(AddFloatAttr(func, attrs, "h_scale", &fbb)) ||
121       failed(AddFloatAttr(func, attrs, "w_scale", &fbb)))
122     return failure();
123   auto use_regular_nms =
124       attrs.get("use_regular_nms").dyn_cast_or_null<BoolAttr>();
125   if (!use_regular_nms) {
126     return func.emitError()
127            << "use_regular_nms attribute is not set or not a bool";
128   }
129   fbb.Int("use_regular_nms", use_regular_nms.getValue());
130 
131   fbb.EndMap(start_map);
132   fbb.Finish();
133   custom_option_buffer.assign(fbb.GetBuffer().begin(), fbb.GetBuffer().end());
134   return success();
135 }
136 
AddIntAttr(func::FuncOp func,DictionaryAttr attrs,const std::string & attribute,flexbuffers::Builder * builder)137 LogicalResult ConvertSSDPostProcessFunc::AddIntAttr(
138     func::FuncOp func, DictionaryAttr attrs, const std::string& attribute,
139     flexbuffers::Builder* builder) {
140   auto int_attr = attrs.get(attribute).dyn_cast_or_null<IntegerAttr>();
141   if (!int_attr) {
142     return func.emitError()
143            << attribute.c_str() << " attribute is not set or not an integer";
144   }
145   builder->Int(attribute.c_str(), int_attr.getInt());
146   return success();
147 }
148 
AddFloatAttr(func::FuncOp func,DictionaryAttr attrs,const std::string & attribute,flexbuffers::Builder * builder)149 LogicalResult ConvertSSDPostProcessFunc::AddFloatAttr(
150     func::FuncOp func, DictionaryAttr attrs, const std::string& attribute,
151     flexbuffers::Builder* builder) {
152   auto float_attr = attrs.get(attribute).dyn_cast_or_null<FloatAttr>();
153   if (!float_attr) {
154     return func.emitError()
155            << attribute.c_str() << " attribute is not set or not a float";
156   }
157   builder->Float(attribute.c_str(), float_attr.getValue().convertToFloat());
158   return success();
159 }
160 
HasIntAttr(func::FuncOp func,DictionaryAttr attrs,const std::string & attribute)161 LogicalResult ConvertSSDPostProcessFunc::HasIntAttr(
162     func::FuncOp func, DictionaryAttr attrs, const std::string& attribute) {
163   auto int_attr = attrs.get(attribute).dyn_cast_or_null<IntegerAttr>();
164   if (!int_attr) {
165     return func.emitWarning()
166            << attribute.c_str() << " attribute is not set or not an integer";
167   }
168   return success();
169 }
170 
HasFloatAttr(func::FuncOp func,DictionaryAttr attrs,const std::string & attribute)171 LogicalResult ConvertSSDPostProcessFunc::HasFloatAttr(
172     func::FuncOp func, DictionaryAttr attrs, const std::string& attribute) {
173   auto float_attr = attrs.get(attribute).dyn_cast_or_null<FloatAttr>();
174   if (!float_attr) {
175     return func.emitWarning()
176            << attribute.c_str() << " attribute is not set or not a float";
177   }
178   return success();
179 }
180 
VerifySignature()181 LogicalResult ConvertSSDPostProcessFunc::VerifySignature() {
182   // Verify high-level function signature.
183   if (func_.getNumArguments() != 3) {
184     return func_.emitWarning()
185            << "Invalid number of arguments to " << kCustomSSDPostprocessing
186            << ": " << func_.getNumArguments();
187   }
188   if (func_.getFunctionType().getNumResults() != 4) {
189     return func_.emitWarning()
190            << "Invalid number of results from " << kCustomSSDPostprocessing
191            << ": " << func_.getFunctionType().getNumResults();
192   }
193 
194   auto attrs = attr_.getAttrs();
195   if (failed(HasIntAttr(func_, attrs, "max_detections")) ||
196       failed(HasIntAttr(func_, attrs, "max_classes_per_detection")) ||
197       failed(HasIntAttr(func_, attrs, "num_classes")) ||
198       failed(HasFloatAttr(func_, attrs, "nms_score_threshold")) ||
199       failed(HasFloatAttr(func_, attrs, "nms_iou_threshold")) ||
200       failed(HasFloatAttr(func_, attrs, "y_scale")) ||
201       failed(HasFloatAttr(func_, attrs, "x_scale")) ||
202       failed(HasFloatAttr(func_, attrs, "h_scale")) ||
203       failed(HasFloatAttr(func_, attrs, "w_scale"))) {
204     return failure();
205   }
206   return success();
207 }
208 
209 }  // namespace TFL
210 }  // namespace mlir
211