1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow_lite_support/cc/task/vision/object_detector.h"
17
18 #include <algorithm>
19 #include <limits>
20 #include <vector>
21
22 #include "absl/memory/memory.h"
23 #include "absl/strings/str_format.h"
24 #include "absl/strings/string_view.h"
25 #include "tensorflow/lite/c/common.h"
26 #include "tensorflow/lite/interpreter.h"
27 #include "tensorflow_lite_support/cc/common.h"
28 #include "tensorflow_lite_support/cc/port/status_macros.h"
29 #include "tensorflow_lite_support/cc/task/core/task_api_factory.h"
30 #include "tensorflow_lite_support/cc/task/core/task_utils.h"
31 #include "tensorflow_lite_support/cc/task/core/tflite_engine.h"
32 #include "tensorflow_lite_support/cc/task/vision/core/label_map_item.h"
33 #include "tensorflow_lite_support/cc/task/vision/proto/bounding_box_proto_inc.h"
34 #include "tensorflow_lite_support/cc/task/vision/proto/class_proto_inc.h"
35 #include "tensorflow_lite_support/cc/task/vision/utils/frame_buffer_utils.h"
36 #include "tensorflow_lite_support/metadata/cc/metadata_extractor.h"
37 #include "tensorflow_lite_support/metadata/metadata_schema_generated.h"
38
39 namespace tflite {
40 namespace task {
41 namespace vision {
42
43 namespace {
44
45 using ::absl::StatusCode;
46 using ::tflite::BoundingBoxProperties;
47 using ::tflite::ContentProperties;
48 using ::tflite::ContentProperties_BoundingBoxProperties;
49 using ::tflite::EnumNameContentProperties;
50 using ::tflite::ProcessUnit;
51 using ::tflite::ProcessUnitOptions_ScoreThresholdingOptions;
52 using ::tflite::TensorMetadata;
53 using ::tflite::metadata::ModelMetadataExtractor;
54 using ::tflite::support::CreateStatusWithPayload;
55 using ::tflite::support::StatusOr;
56 using ::tflite::support::TfLiteSupportStatus;
57 using ::tflite::task::core::AssertAndReturnTypedTensor;
58 using ::tflite::task::core::TaskAPIFactory;
59 using ::tflite::task::core::TfLiteEngine;
60
61 // The expected number of dimensions of the 4 output tensors, representing in
62 // that order: locations, classes, scores, num_results.
63 static constexpr int kOutputTensorsExpectedDims[4] = {3, 2, 2, 1};
64
GetBoundingBoxProperties(const TensorMetadata & tensor_metadata)65 StatusOr<const BoundingBoxProperties*> GetBoundingBoxProperties(
66 const TensorMetadata& tensor_metadata) {
67 if (tensor_metadata.content() == nullptr ||
68 tensor_metadata.content()->content_properties() == nullptr) {
69 return CreateStatusWithPayload(
70 StatusCode::kInvalidArgument,
71 absl::StrFormat(
72 "Expected BoundingBoxProperties for tensor %s, found none.",
73 tensor_metadata.name() ? tensor_metadata.name()->str() : "#0"),
74 TfLiteSupportStatus::kMetadataInvalidContentPropertiesError);
75 }
76
77 ContentProperties type = tensor_metadata.content()->content_properties_type();
78 if (type != ContentProperties_BoundingBoxProperties) {
79 return CreateStatusWithPayload(
80 StatusCode::kInvalidArgument,
81 absl::StrFormat(
82 "Expected BoundingBoxProperties for tensor %s, found %s.",
83 tensor_metadata.name() ? tensor_metadata.name()->str() : "#0",
84 EnumNameContentProperties(type)),
85 TfLiteSupportStatus::kMetadataInvalidContentPropertiesError);
86 }
87
88 const BoundingBoxProperties* properties =
89 tensor_metadata.content()->content_properties_as_BoundingBoxProperties();
90
91 // Mobile SSD only supports "BOUNDARIES" bounding box type.
92 if (properties->type() != tflite::BoundingBoxType_BOUNDARIES) {
93 return CreateStatusWithPayload(
94 StatusCode::kInvalidArgument,
95 absl::StrFormat(
96 "Mobile SSD only supports BoundingBoxType BOUNDARIES, found %s",
97 tflite::EnumNameBoundingBoxType(properties->type())),
98 TfLiteSupportStatus::kMetadataInvalidContentPropertiesError);
99 }
100
101 // Mobile SSD only supports "RATIO" coordinates type.
102 if (properties->coordinate_type() != tflite::CoordinateType_RATIO) {
103 return CreateStatusWithPayload(
104 StatusCode::kInvalidArgument,
105 absl::StrFormat(
106 "Mobile SSD only supports CoordinateType RATIO, found %s",
107 tflite::EnumNameCoordinateType(properties->coordinate_type())),
108 TfLiteSupportStatus::kMetadataInvalidContentPropertiesError);
109 }
110
111 // Index is optional, but must contain 4 values if present.
112 if (properties->index() != nullptr && properties->index()->size() != 4) {
113 return CreateStatusWithPayload(
114 StatusCode::kInvalidArgument,
115 absl::StrFormat(
116 "Expected BoundingBoxProperties index to contain 4 values, found "
117 "%d",
118 properties->index()->size()),
119 TfLiteSupportStatus::kMetadataInvalidContentPropertiesError);
120 }
121
122 return properties;
123 }
124
GetLabelMapIfAny(const ModelMetadataExtractor & metadata_extractor,const TensorMetadata & tensor_metadata,absl::string_view locale)125 StatusOr<std::vector<LabelMapItem>> GetLabelMapIfAny(
126 const ModelMetadataExtractor& metadata_extractor,
127 const TensorMetadata& tensor_metadata, absl::string_view locale) {
128 const std::string labels_filename =
129 ModelMetadataExtractor::FindFirstAssociatedFileName(
130 tensor_metadata, tflite::AssociatedFileType_TENSOR_VALUE_LABELS);
131 if (labels_filename.empty()) {
132 return std::vector<LabelMapItem>();
133 }
134 ASSIGN_OR_RETURN(absl::string_view labels_file,
135 metadata_extractor.GetAssociatedFile(labels_filename));
136 const std::string display_names_filename =
137 ModelMetadataExtractor::FindFirstAssociatedFileName(
138 tensor_metadata, tflite::AssociatedFileType_TENSOR_VALUE_LABELS,
139 locale);
140 absl::string_view display_names_file = nullptr;
141 if (!display_names_filename.empty()) {
142 ASSIGN_OR_RETURN(display_names_file, metadata_extractor.GetAssociatedFile(
143 display_names_filename));
144 }
145 return BuildLabelMapFromFiles(labels_file, display_names_file);
146 }
147
GetScoreThreshold(const ModelMetadataExtractor & metadata_extractor,const TensorMetadata & tensor_metadata)148 StatusOr<float> GetScoreThreshold(
149 const ModelMetadataExtractor& metadata_extractor,
150 const TensorMetadata& tensor_metadata) {
151 ASSIGN_OR_RETURN(
152 const ProcessUnit* score_thresholding_process_unit,
153 metadata_extractor.FindFirstProcessUnit(
154 tensor_metadata, ProcessUnitOptions_ScoreThresholdingOptions));
155 if (score_thresholding_process_unit == nullptr) {
156 return std::numeric_limits<float>::lowest();
157 }
158 return score_thresholding_process_unit->options_as_ScoreThresholdingOptions()
159 ->global_score_threshold();
160 }
161
SanityCheckOutputTensors(const std::vector<const TfLiteTensor * > & output_tensors)162 absl::Status SanityCheckOutputTensors(
163 const std::vector<const TfLiteTensor*>& output_tensors) {
164 if (output_tensors.size() != 4) {
165 return CreateStatusWithPayload(
166 StatusCode::kInternal,
167 absl::StrFormat("Expected 4 output tensors, found %d",
168 output_tensors.size()));
169 }
170
171 // Get number of results.
172 if (output_tensors[3]->dims->data[0] != 1) {
173 return CreateStatusWithPayload(
174 StatusCode::kInternal,
175 absl::StrFormat(
176 "Expected tensor with dimensions [1] at index 3, found [%d]",
177 output_tensors[3]->dims->data[0]));
178 }
179 int num_results =
180 static_cast<int>(AssertAndReturnTypedTensor<float>(output_tensors[3])[0]);
181
182 // Check dimensions for the other tensors are correct.
183 if (output_tensors[0]->dims->data[0] != 1 ||
184 output_tensors[0]->dims->data[1] != num_results ||
185 output_tensors[0]->dims->data[2] != 4) {
186 return CreateStatusWithPayload(
187 StatusCode::kInternal,
188 absl::StrFormat(
189 "Expected locations tensor with dimensions [1,%d,4] at index 0, "
190 "found [%d,%d,%d].",
191 num_results, output_tensors[0]->dims->data[0],
192 output_tensors[0]->dims->data[1],
193 output_tensors[0]->dims->data[2]));
194 }
195 if (output_tensors[1]->dims->data[0] != 1 ||
196 output_tensors[1]->dims->data[1] != num_results) {
197 return CreateStatusWithPayload(
198 StatusCode::kInternal,
199 absl::StrFormat(
200 "Expected classes tensor with dimensions [1,%d] at index 1, "
201 "found [%d,%d].",
202 num_results, output_tensors[1]->dims->data[0],
203 output_tensors[1]->dims->data[1]));
204 }
205 if (output_tensors[2]->dims->data[0] != 1 ||
206 output_tensors[2]->dims->data[1] != num_results) {
207 return CreateStatusWithPayload(
208 StatusCode::kInternal,
209 absl::StrFormat(
210 "Expected scores tensor with dimensions [1,%d] at index 2, "
211 "found [%d,%d].",
212 num_results, output_tensors[2]->dims->data[0],
213 output_tensors[2]->dims->data[1]));
214 }
215
216 return absl::OkStatus();
217 }
218
219 } // namespace
220
221 /* static */
SanityCheckOptions(const ObjectDetectorOptions & options)222 absl::Status ObjectDetector::SanityCheckOptions(
223 const ObjectDetectorOptions& options) {
224 if (!options.has_model_file_with_metadata()) {
225 return CreateStatusWithPayload(
226 StatusCode::kInvalidArgument,
227 "Missing mandatory `model_file_with_metadata` field",
228 TfLiteSupportStatus::kInvalidArgumentError);
229 }
230 if (options.max_results() == 0) {
231 return CreateStatusWithPayload(
232 StatusCode::kInvalidArgument,
233 "Invalid `max_results` option: value must be != 0",
234 TfLiteSupportStatus::kInvalidArgumentError);
235 }
236 if (options.class_name_whitelist_size() > 0 &&
237 options.class_name_blacklist_size() > 0) {
238 return CreateStatusWithPayload(
239 StatusCode::kInvalidArgument,
240 "`class_name_whitelist` and `class_name_blacklist` are mutually "
241 "exclusive options.",
242 TfLiteSupportStatus::kInvalidArgumentError);
243 }
244 if (options.num_threads() == 0 || options.num_threads() < -1) {
245 return CreateStatusWithPayload(
246 StatusCode::kInvalidArgument,
247 "`num_threads` must be greater than 0 or equal to -1.",
248 TfLiteSupportStatus::kInvalidArgumentError);
249 }
250 return absl::OkStatus();
251 }
252
253 /* static */
CreateFromOptions(const ObjectDetectorOptions & options,std::unique_ptr<tflite::OpResolver> resolver)254 StatusOr<std::unique_ptr<ObjectDetector>> ObjectDetector::CreateFromOptions(
255 const ObjectDetectorOptions& options,
256 std::unique_ptr<tflite::OpResolver> resolver) {
257 RETURN_IF_ERROR(SanityCheckOptions(options));
258
259 // Copy options to ensure the ExternalFile outlives the constructed object.
260 auto options_copy = absl::make_unique<ObjectDetectorOptions>(options);
261
262 ASSIGN_OR_RETURN(auto object_detector,
263 TaskAPIFactory::CreateFromExternalFileProto<ObjectDetector>(
264 &options_copy->model_file_with_metadata(),
265 std::move(resolver), options_copy->num_threads()));
266
267 RETURN_IF_ERROR(object_detector->Init(std::move(options_copy)));
268
269 return object_detector;
270 }
271
Init(std::unique_ptr<ObjectDetectorOptions> options)272 absl::Status ObjectDetector::Init(
273 std::unique_ptr<ObjectDetectorOptions> options) {
274 // Set options.
275 options_ = std::move(options);
276
277 // Perform pre-initialization actions (by default, sets the process engine for
278 // image pre-processing to kLibyuv as a sane default).
279 RETURN_IF_ERROR(PreInit());
280
281 // Sanity check and set inputs and outputs.
282 RETURN_IF_ERROR(CheckAndSetInputs());
283 RETURN_IF_ERROR(CheckAndSetOutputs());
284
285 // Initialize class whitelisting/blacklisting, if any.
286 RETURN_IF_ERROR(CheckAndSetClassIndexSet());
287
288 return absl::OkStatus();
289 }
290
PreInit()291 absl::Status ObjectDetector::PreInit() {
292 SetProcessEngine(FrameBufferUtils::ProcessEngine::kLibyuv);
293 return absl::OkStatus();
294 }
295
CheckAndSetOutputs()296 absl::Status ObjectDetector::CheckAndSetOutputs() {
297 // First, sanity checks on the model itself.
298 const TfLiteEngine::Interpreter* interpreter = engine_->interpreter();
299 // Check the number of output tensors.
300 if (TfLiteEngine::OutputCount(interpreter) != 4) {
301 return CreateStatusWithPayload(
302 StatusCode::kInvalidArgument,
303 absl::StrFormat("Mobile SSD models are expected to have exactly 4 "
304 "outputs, found %d",
305 TfLiteEngine::OutputCount(interpreter)),
306 TfLiteSupportStatus::kInvalidNumOutputTensorsError);
307 }
308 // Check tensor dimensions and batch size.
309 for (int i = 0; i < 4; ++i) {
310 const TfLiteTensor* tensor = TfLiteEngine::GetOutput(interpreter, i);
311 if (tensor->dims->size != kOutputTensorsExpectedDims[i]) {
312 return CreateStatusWithPayload(
313 StatusCode::kInvalidArgument,
314 absl::StrFormat("Output tensor at index %d is expected to "
315 "have %d dimensions, found %d.",
316 i, kOutputTensorsExpectedDims[i], tensor->dims->size),
317 TfLiteSupportStatus::kInvalidOutputTensorDimensionsError);
318 }
319 if (tensor->dims->data[0] != 1) {
320 return CreateStatusWithPayload(
321 StatusCode::kInvalidArgument,
322 absl::StrFormat("Expected batch size of 1, found %d.",
323 tensor->dims->data[0]),
324 TfLiteSupportStatus::kInvalidOutputTensorDimensionsError);
325 }
326 }
327
328 // Now, perform sanity checks and extract metadata.
329 const ModelMetadataExtractor* metadata_extractor =
330 engine_->metadata_extractor();
331 // Check that metadata is available.
332 if (metadata_extractor->GetModelMetadata() == nullptr ||
333 metadata_extractor->GetModelMetadata()->subgraph_metadata() == nullptr) {
334 return CreateStatusWithPayload(StatusCode::kInvalidArgument,
335 "Object detection models require TFLite "
336 "Model Metadata but none was found",
337 TfLiteSupportStatus::kMetadataNotFoundError);
338 }
339 // Check output tensor metadata is present and consistent with model.
340 auto output_tensors_metadata = metadata_extractor->GetOutputTensorMetadata();
341 if (output_tensors_metadata == nullptr ||
342 output_tensors_metadata->size() != 4) {
343 return CreateStatusWithPayload(
344 StatusCode::kInvalidArgument,
345 absl::StrFormat(
346 "Mismatch between number of output tensors (4) and output tensors "
347 "metadata (%d).",
348 output_tensors_metadata == nullptr
349 ? 0
350 : output_tensors_metadata->size()),
351 TfLiteSupportStatus::kMetadataInconsistencyError);
352 }
353
354 // Extract mandatory BoundingBoxProperties for easier access at
355 // post-processing time, performing sanity checks on the fly.
356 ASSIGN_OR_RETURN(const BoundingBoxProperties* bounding_box_properties,
357 GetBoundingBoxProperties(*output_tensors_metadata->Get(0)));
358 if (bounding_box_properties->index() == nullptr) {
359 bounding_box_corners_order_ = {0, 1, 2, 3};
360 } else {
361 auto bounding_box_index = bounding_box_properties->index();
362 bounding_box_corners_order_ = {
363 bounding_box_index->Get(0),
364 bounding_box_index->Get(1),
365 bounding_box_index->Get(2),
366 bounding_box_index->Get(3),
367 };
368 }
369
370 // Build label map (if available) from metadata.
371 ASSIGN_OR_RETURN(
372 label_map_,
373 GetLabelMapIfAny(*metadata_extractor, *output_tensors_metadata->Get(1),
374 options_->display_names_locale()));
375
376 // Set score threshold.
377 if (options_->has_score_threshold()) {
378 score_threshold_ = options_->score_threshold();
379 } else {
380 ASSIGN_OR_RETURN(score_threshold_,
381 GetScoreThreshold(*metadata_extractor,
382 *output_tensors_metadata->Get(2)));
383 }
384
385 return absl::OkStatus();
386 }
387
CheckAndSetClassIndexSet()388 absl::Status ObjectDetector::CheckAndSetClassIndexSet() {
389 // Exit early if no blacklist/whitelist.
390 if (options_->class_name_blacklist_size() == 0 &&
391 options_->class_name_whitelist_size() == 0) {
392 return absl::OkStatus();
393 }
394 // Label map is mandatory.
395 if (label_map_.empty()) {
396 return CreateStatusWithPayload(
397 StatusCode::kInvalidArgument,
398 "Using `class_name_whitelist` or `class_name_blacklist` requires "
399 "labels to be present in the TFLite Model Metadata but none was found.",
400 TfLiteSupportStatus::kMetadataMissingLabelsError);
401 }
402
403 class_index_set_.is_whitelist = options_->class_name_whitelist_size() > 0;
404 const auto& class_names = class_index_set_.is_whitelist
405 ? options_->class_name_whitelist()
406 : options_->class_name_blacklist();
407 class_index_set_.values.clear();
408 for (const auto& class_name : class_names) {
409 int index = -1;
410 for (int i = 0; i < label_map_.size(); ++i) {
411 if (label_map_[i].name == class_name) {
412 index = i;
413 break;
414 }
415 }
416 // Ignore duplicate or unknown classes.
417 if (index < 0 || class_index_set_.values.contains(index)) {
418 continue;
419 }
420 class_index_set_.values.insert(index);
421 }
422
423 if (class_index_set_.values.empty()) {
424 return CreateStatusWithPayload(
425 StatusCode::kInvalidArgument,
426 absl::StrFormat(
427 "Invalid class names specified via `class_name_%s`: none match "
428 "with model labels.",
429 class_index_set_.is_whitelist ? "whitelist" : "blacklist"),
430 TfLiteSupportStatus::kInvalidArgumentError);
431 }
432
433 return absl::OkStatus();
434 }
435
Detect(const FrameBuffer & frame_buffer)436 StatusOr<DetectionResult> ObjectDetector::Detect(
437 const FrameBuffer& frame_buffer) {
438 BoundingBox roi;
439 roi.set_width(frame_buffer.dimension().width);
440 roi.set_height(frame_buffer.dimension().height);
441 // Rely on `Infer` instead of `InferWithFallback` as DetectionPostprocessing
442 // op doesn't support hardware acceleration at the time.
443 return Infer(frame_buffer, roi);
444 }
445
Postprocess(const std::vector<const TfLiteTensor * > & output_tensors,const FrameBuffer & frame_buffer,const BoundingBox &)446 StatusOr<DetectionResult> ObjectDetector::Postprocess(
447 const std::vector<const TfLiteTensor*>& output_tensors,
448 const FrameBuffer& frame_buffer, const BoundingBox& /*roi*/) {
449 // Most of the checks here should never happen, as outputs have been validated
450 // at construction time. Checking nonetheless and returning internal errors if
451 // something bad happens.
452 RETURN_IF_ERROR(SanityCheckOutputTensors(output_tensors));
453
454 // Get number of available results.
455 const int num_results =
456 static_cast<int>(AssertAndReturnTypedTensor<float>(output_tensors[3])[0]);
457 // Compute number of max results to return.
458 const int max_results = options_->max_results() > 0
459 ? std::min(options_->max_results(), num_results)
460 : num_results;
461 // The dimensions of the upright (i.e. rotated according to its orientation)
462 // input frame.
463 FrameBuffer::Dimension upright_input_frame_dimensions =
464 frame_buffer.dimension();
465 if (RequireDimensionSwap(frame_buffer.orientation(),
466 FrameBuffer::Orientation::kTopLeft)) {
467 upright_input_frame_dimensions.Swap();
468 }
469
470 const float* locations = AssertAndReturnTypedTensor<float>(output_tensors[0]);
471 const float* classes = AssertAndReturnTypedTensor<float>(output_tensors[1]);
472 const float* scores = AssertAndReturnTypedTensor<float>(output_tensors[2]);
473 DetectionResult results;
474 for (int i = 0; i < num_results; ++i) {
475 const int class_index = static_cast<int>(classes[i]);
476 const float score = scores[i];
477 if (!IsClassIndexAllowed(class_index) || score < score_threshold_) {
478 continue;
479 }
480 Detection* detection = results.add_detections();
481 // Denormalize the bounding box cooordinates in the upright frame
482 // coordinates system, then rotate back from frame_buffer.orientation() to
483 // the unrotated frame of reference coordinates system (i.e. with
484 // orientation = kTopLeft).
485 *detection->mutable_bounding_box() = OrientAndDenormalizeBoundingBox(
486 /*from_left=*/locations[4 * i + bounding_box_corners_order_[0]],
487 /*from_top=*/locations[4 * i + bounding_box_corners_order_[1]],
488 /*from_right=*/locations[4 * i + bounding_box_corners_order_[2]],
489 /*from_bottom=*/locations[4 * i + bounding_box_corners_order_[3]],
490 /*from_orientation=*/frame_buffer.orientation(),
491 /*to_orientation=*/FrameBuffer::Orientation::kTopLeft,
492 /*from_dimension=*/upright_input_frame_dimensions);
493 Class* detection_class = detection->add_classes();
494 detection_class->set_index(class_index);
495 detection_class->set_score(score);
496 if (results.detections_size() == max_results) {
497 break;
498 }
499 }
500
501 if (!label_map_.empty()) {
502 RETURN_IF_ERROR(FillResultsFromLabelMap(&results));
503 }
504
505 return results;
506 }
507
IsClassIndexAllowed(int class_index)508 bool ObjectDetector::IsClassIndexAllowed(int class_index) {
509 if (class_index_set_.values.empty()) {
510 return true;
511 }
512 if (class_index_set_.is_whitelist) {
513 return class_index_set_.values.contains(class_index);
514 } else {
515 return !class_index_set_.values.contains(class_index);
516 }
517 }
518
FillResultsFromLabelMap(DetectionResult * result)519 absl::Status ObjectDetector::FillResultsFromLabelMap(DetectionResult* result) {
520 for (int i = 0; i < result->detections_size(); ++i) {
521 Detection* detection = result->mutable_detections(i);
522 for (int j = 0; j < detection->classes_size(); ++j) {
523 Class* detection_class = detection->mutable_classes(j);
524 const int index = detection_class->index();
525 if (index >= label_map_.size()) {
526 return CreateStatusWithPayload(
527 StatusCode::kInvalidArgument,
528 absl::StrFormat(
529 "Label map does not contain enough elements: model returned "
530 "class index %d but label map only contains %d elements.",
531 index, label_map_.size()),
532 TfLiteSupportStatus::kMetadataInconsistencyError);
533 }
534 std::string name = label_map_[index].name;
535 if (!name.empty()) {
536 detection_class->set_class_name(name);
537 }
538 std::string display_name = label_map_[index].display_name;
539 if (!display_name.empty()) {
540 detection_class->set_display_name(display_name);
541 }
542 }
543 }
544 return absl::OkStatus();
545 }
546
547 } // namespace vision
548 } // namespace task
549 } // namespace tflite
550