1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/lite/utils/nms_utils.h"
17
18 #include <string>
19
20 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
21 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
22
23 namespace mlir {
24 namespace TFL {
25
26 namespace {
27
28 // TODO(b/162842801): Consolidate all util definitions of kTFImplements.
29 constexpr char kTFImplements[] = "tf._implements";
30 constexpr char kCustomSSDPostprocessing[] = "TFLite_Detection_PostProcess";
31 constexpr char kTfNMSPadded[] = "non_max_suppression_padded_v2";
32
CustomOption(OpBuilder * builder,const std::string & content)33 inline ConstBytesAttr CustomOption(OpBuilder* builder,
34 const std::string& content) {
35 return ConstBytesAttr::get(builder->getContext(),
36 StringRef(content.data(), content.size()));
37 }
38
39 } // namespace
40
RewriteFunc()41 void ConvertNMSPaddedFunc::RewriteFunc() {
42 func_->setAttr(kTFImplements,
43 StringAttr::get(func_.getContext(), kTfNMSPadded));
44 Value boxes = func_.getArgument(0);
45 Value scores = func_.getArgument(1);
46 Value max_output_size = func_.getArgument(2);
47 Value iou_threshold = func_.getArgument(3);
48 Value score_threshold = func_.getArgument(4);
49 auto output_type0 = func_.getFunctionType().getResult(0);
50 auto output_type1 = func_.getFunctionType().getResult(1);
51
52 OpBuilder builder(func_.getBody());
53 auto op = builder.create<mlir::TFL::NonMaxSuppressionV4Op>(
54 func_.getLoc(), output_type0, output_type1, boxes, scores,
55 max_output_size, iou_threshold, score_threshold);
56
57 builder.create<mlir::func::ReturnOp>(func_.getLoc(), op.getResults());
58 }
59
VerifySignature()60 LogicalResult ConvertNMSPaddedFunc::VerifySignature() {
61 // Verify high-level function signature.
62 // Relevant argument characteristics are checked by the TFL op definition.
63 if (func_.getNumArguments() < 5) {
64 return func_.emitWarning()
65 << "Invalid number of arguments to "
66 "non_max_suppression_padded_v2 (need at least 5): "
67 << func_.getNumArguments();
68 }
69 if (func_.getFunctionType().getNumResults() != 2) {
70 return func_.emitWarning() << "Invalid number of results from "
71 "non_max_suppression_padded_v2 (need 2): "
72 << func_.getFunctionType().getNumResults();
73 }
74 // The TFLite fused op does not support batching yet.
75 // TODO(b/158709815): Add support for batches with padded NMS.
76 auto boxes_type =
77 func_.getFunctionType().getInput(0).dyn_cast<RankedTensorType>();
78 if (boxes_type == nullptr || !boxes_type.hasRank() ||
79 boxes_type.getRank() != 2) {
80 return func_.emitWarning() << "TFLite does not support batched input for "
81 "non_max_suppression_padded";
82 }
83 return success();
84 }
85
RewriteFunc()86 LogicalResult ConvertSSDPostProcessFunc::RewriteFunc() {
87 func_.eraseBody();
88 func_.addEntryBlock();
89 func_->setAttr(kTFImplements,
90 StringAttr::get(func_.getContext(), kCustomSSDPostprocessing));
91
92 OpBuilder builder(func_.getBody());
93 std::string custom_option_buffer;
94 if (failed(CreateNMSCustomOptions(func_, attr_.getAttrs(),
95 custom_option_buffer))) {
96 return failure();
97 }
98 auto op = builder.create<CustomOp>(
99 func_.getLoc(), func_.getFunctionType().getResults(),
100 func_.getArguments(), kCustomSSDPostprocessing,
101 CustomOption(&builder, custom_option_buffer));
102 builder.create<func::ReturnOp>(func_.getLoc(), op.getResults());
103
104 return success();
105 }
106
CreateNMSCustomOptions(func::FuncOp func,DictionaryAttr attrs,std::string & custom_option_buffer)107 LogicalResult ConvertSSDPostProcessFunc::CreateNMSCustomOptions(
108 func::FuncOp func, DictionaryAttr attrs,
109 std::string& custom_option_buffer) {
110 flexbuffers::Builder fbb;
111 size_t start_map = fbb.StartMap();
112
113 if (failed(AddIntAttr(func, attrs, "max_detections", &fbb)) ||
114 failed(AddIntAttr(func, attrs, "max_classes_per_detection", &fbb)) ||
115 failed(AddIntAttr(func, attrs, "num_classes", &fbb)) ||
116 failed(AddFloatAttr(func, attrs, "nms_score_threshold", &fbb)) ||
117 failed(AddFloatAttr(func, attrs, "nms_iou_threshold", &fbb)) ||
118 failed(AddFloatAttr(func, attrs, "y_scale", &fbb)) ||
119 failed(AddFloatAttr(func, attrs, "x_scale", &fbb)) ||
120 failed(AddFloatAttr(func, attrs, "h_scale", &fbb)) ||
121 failed(AddFloatAttr(func, attrs, "w_scale", &fbb)))
122 return failure();
123 auto use_regular_nms =
124 attrs.get("use_regular_nms").dyn_cast_or_null<BoolAttr>();
125 if (!use_regular_nms) {
126 return func.emitError()
127 << "use_regular_nms attribute is not set or not a bool";
128 }
129 fbb.Int("use_regular_nms", use_regular_nms.getValue());
130
131 fbb.EndMap(start_map);
132 fbb.Finish();
133 custom_option_buffer.assign(fbb.GetBuffer().begin(), fbb.GetBuffer().end());
134 return success();
135 }
136
AddIntAttr(func::FuncOp func,DictionaryAttr attrs,const std::string & attribute,flexbuffers::Builder * builder)137 LogicalResult ConvertSSDPostProcessFunc::AddIntAttr(
138 func::FuncOp func, DictionaryAttr attrs, const std::string& attribute,
139 flexbuffers::Builder* builder) {
140 auto int_attr = attrs.get(attribute).dyn_cast_or_null<IntegerAttr>();
141 if (!int_attr) {
142 return func.emitError()
143 << attribute.c_str() << " attribute is not set or not an integer";
144 }
145 builder->Int(attribute.c_str(), int_attr.getInt());
146 return success();
147 }
148
AddFloatAttr(func::FuncOp func,DictionaryAttr attrs,const std::string & attribute,flexbuffers::Builder * builder)149 LogicalResult ConvertSSDPostProcessFunc::AddFloatAttr(
150 func::FuncOp func, DictionaryAttr attrs, const std::string& attribute,
151 flexbuffers::Builder* builder) {
152 auto float_attr = attrs.get(attribute).dyn_cast_or_null<FloatAttr>();
153 if (!float_attr) {
154 return func.emitError()
155 << attribute.c_str() << " attribute is not set or not a float";
156 }
157 builder->Float(attribute.c_str(), float_attr.getValue().convertToFloat());
158 return success();
159 }
160
HasIntAttr(func::FuncOp func,DictionaryAttr attrs,const std::string & attribute)161 LogicalResult ConvertSSDPostProcessFunc::HasIntAttr(
162 func::FuncOp func, DictionaryAttr attrs, const std::string& attribute) {
163 auto int_attr = attrs.get(attribute).dyn_cast_or_null<IntegerAttr>();
164 if (!int_attr) {
165 return func.emitWarning()
166 << attribute.c_str() << " attribute is not set or not an integer";
167 }
168 return success();
169 }
170
HasFloatAttr(func::FuncOp func,DictionaryAttr attrs,const std::string & attribute)171 LogicalResult ConvertSSDPostProcessFunc::HasFloatAttr(
172 func::FuncOp func, DictionaryAttr attrs, const std::string& attribute) {
173 auto float_attr = attrs.get(attribute).dyn_cast_or_null<FloatAttr>();
174 if (!float_attr) {
175 return func.emitWarning()
176 << attribute.c_str() << " attribute is not set or not a float";
177 }
178 return success();
179 }
180
VerifySignature()181 LogicalResult ConvertSSDPostProcessFunc::VerifySignature() {
182 // Verify high-level function signature.
183 if (func_.getNumArguments() != 3) {
184 return func_.emitWarning()
185 << "Invalid number of arguments to " << kCustomSSDPostprocessing
186 << ": " << func_.getNumArguments();
187 }
188 if (func_.getFunctionType().getNumResults() != 4) {
189 return func_.emitWarning()
190 << "Invalid number of results from " << kCustomSSDPostprocessing
191 << ": " << func_.getFunctionType().getNumResults();
192 }
193
194 auto attrs = attr_.getAttrs();
195 if (failed(HasIntAttr(func_, attrs, "max_detections")) ||
196 failed(HasIntAttr(func_, attrs, "max_classes_per_detection")) ||
197 failed(HasIntAttr(func_, attrs, "num_classes")) ||
198 failed(HasFloatAttr(func_, attrs, "nms_score_threshold")) ||
199 failed(HasFloatAttr(func_, attrs, "nms_iou_threshold")) ||
200 failed(HasFloatAttr(func_, attrs, "y_scale")) ||
201 failed(HasFloatAttr(func_, attrs, "x_scale")) ||
202 failed(HasFloatAttr(func_, attrs, "h_scale")) ||
203 failed(HasFloatAttr(func_, attrs, "w_scale"))) {
204 return failure();
205 }
206 return success();
207 }
208
209 } // namespace TFL
210 } // namespace mlir
211