1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow_lite_support/cc/task/vision/image_classifier.h"
17
18 #include "absl/algorithm/container.h"
19 #include "absl/strings/str_format.h"
20 #include "absl/strings/string_view.h"
21 #include "flatbuffers/flatbuffers.h" // from @flatbuffers
22 #include "tensorflow/lite/interpreter.h"
23 #include "tensorflow_lite_support/cc/common.h"
24 #include "tensorflow_lite_support/cc/port/integral_types.h"
25 #include "tensorflow_lite_support/cc/port/status_macros.h"
26 #include "tensorflow_lite_support/cc/task/core/task_api_factory.h"
27 #include "tensorflow_lite_support/cc/task/core/task_utils.h"
28 #include "tensorflow_lite_support/cc/task/core/tflite_engine.h"
29 #include "tensorflow_lite_support/cc/task/vision/core/label_map_item.h"
30 #include "tensorflow_lite_support/cc/task/vision/proto/class_proto_inc.h"
31 #include "tensorflow_lite_support/cc/task/vision/utils/frame_buffer_utils.h"
32 #include "tensorflow_lite_support/metadata/cc/metadata_extractor.h"
33 #include "tensorflow_lite_support/metadata/metadata_schema_generated.h"
34
35 namespace tflite {
36 namespace task {
37 namespace vision {
38
39 namespace {
40
41 using ::absl::StatusCode;
42 using ::tflite::metadata::ModelMetadataExtractor;
43 using ::tflite::support::CreateStatusWithPayload;
44 using ::tflite::support::StatusOr;
45 using ::tflite::support::TfLiteSupportStatus;
46 using ::tflite::task::core::AssertAndReturnTypedTensor;
47 using ::tflite::task::core::TaskAPIFactory;
48 using ::tflite::task::core::TfLiteEngine;
49
50 // Default score value used as a fallback for classes that (1) have no score
51 // calibration data or (2) have a very low confident uncalibrated score, i.e.
52 // lower than the `min_uncalibrated_score` threshold.
53 //
54 // (1) This happens when the ScoreCalibration does not cover all the classes
55 // listed in the label map. This can be used to enforce the blacklisting of
56 // given classes so that they are never returned.
57 //
58 // (2) This is an optional threshold provided part of the calibration data. It
59 // is used to mitigate false alarms on some classes.
60 //
61 // In both cases, a class that gets assigned a score of -1 is never returned as
62 // it gets discarded by the `score_threshold` check (see post-processing logic).
63 constexpr float kDefaultCalibratedScore = -1.0f;
64
65 // Calibrated scores should be in the [0, 1] range, otherwise an error is
66 // returned at post-processing time.
67 constexpr float kMinCalibratedScore = 0.0f;
68 constexpr float kMaxCalibratedScore = 1.0f;
69
70 } // namespace
71
72 /* static */
CreateFromOptions(const ImageClassifierOptions & options,std::unique_ptr<tflite::OpResolver> resolver)73 StatusOr<std::unique_ptr<ImageClassifier>> ImageClassifier::CreateFromOptions(
74 const ImageClassifierOptions& options,
75 std::unique_ptr<tflite::OpResolver> resolver) {
76 RETURN_IF_ERROR(SanityCheckOptions(options));
77
78 // Copy options to ensure the ExternalFile outlives the constructed object.
79 auto options_copy = absl::make_unique<ImageClassifierOptions>(options);
80
81 ASSIGN_OR_RETURN(auto image_classifier,
82 TaskAPIFactory::CreateFromExternalFileProto<ImageClassifier>(
83 &options_copy->model_file_with_metadata(),
84 std::move(resolver), options_copy->num_threads()));
85
86 RETURN_IF_ERROR(image_classifier->Init(std::move(options_copy)));
87
88 return image_classifier;
89 }
90
91 /* static */
SanityCheckOptions(const ImageClassifierOptions & options)92 absl::Status ImageClassifier::SanityCheckOptions(
93 const ImageClassifierOptions& options) {
94 if (!options.has_model_file_with_metadata()) {
95 return CreateStatusWithPayload(
96 StatusCode::kInvalidArgument,
97 "Missing mandatory `model_file_with_metadata` field",
98 TfLiteSupportStatus::kInvalidArgumentError);
99 }
100 if (options.max_results() == 0) {
101 return CreateStatusWithPayload(
102 StatusCode::kInvalidArgument,
103 "Invalid `max_results` option: value must be != 0",
104 TfLiteSupportStatus::kInvalidArgumentError);
105 }
106 if (options.score_threshold() < 0 || options.score_threshold() >= 1) {
107 return CreateStatusWithPayload(
108 StatusCode::kInvalidArgument,
109 absl::StrFormat(
110 "`score_threshold` out of range: %f. Valid range is [0,1[.",
111 options.score_threshold()),
112 TfLiteSupportStatus::kInvalidArgumentError);
113 }
114 if (options.class_name_whitelist_size() > 0 &&
115 options.class_name_blacklist_size() > 0) {
116 return CreateStatusWithPayload(
117 StatusCode::kInvalidArgument,
118 "`class_name_whitelist` and `class_name_blacklist` are mutually "
119 "exclusive options.",
120 TfLiteSupportStatus::kInvalidArgumentError);
121 }
122 if (options.num_threads() == 0 || options.num_threads() < -1) {
123 return CreateStatusWithPayload(
124 StatusCode::kInvalidArgument,
125 "`num_threads` must be greater than 0 or equal to -1.",
126 TfLiteSupportStatus::kInvalidArgumentError);
127 }
128 return absl::OkStatus();
129 }
130
Init(std::unique_ptr<ImageClassifierOptions> options)131 absl::Status ImageClassifier::Init(
132 std::unique_ptr<ImageClassifierOptions> options) {
133 // Set options.
134 options_ = std::move(options);
135
136 // Perform pre-initialization actions (by default, sets the process engine for
137 // image pre-processing to kLibyuv as a sane default).
138 RETURN_IF_ERROR(PreInit());
139
140 // Sanity check and set inputs and outputs.
141 RETURN_IF_ERROR(CheckAndSetInputs());
142 RETURN_IF_ERROR(CheckAndSetOutputs());
143
144 // Initialize class whitelisting/blacklisting, if any.
145 RETURN_IF_ERROR(CheckAndSetClassNameSet());
146
147 // Perform final initialization (by default, initialize score calibration
148 // parameters, if any).
149 RETURN_IF_ERROR(PostInit());
150
151 return absl::OkStatus();
152 }
153
PreInit()154 absl::Status ImageClassifier::PreInit() {
155 SetProcessEngine(FrameBufferUtils::ProcessEngine::kLibyuv);
156 return absl::OkStatus();
157 }
158
PostInit()159 absl::Status ImageClassifier::PostInit() { return InitScoreCalibrations(); }
160
CheckAndSetOutputs()161 absl::Status ImageClassifier::CheckAndSetOutputs() {
162 num_outputs_ = TfLiteEngine::OutputCount(engine_->interpreter());
163
164 // Perform sanity checks and extract metadata.
165 const ModelMetadataExtractor* metadata_extractor =
166 engine_->metadata_extractor();
167
168 const flatbuffers::Vector<flatbuffers::Offset<tflite::TensorMetadata>>*
169 output_tensor_metadata = metadata_extractor->GetOutputTensorMetadata();
170
171 // Loop over output tensors metadata, if any.
172 // Note: models with no output tensor metadata at all are supported.
173 if (output_tensor_metadata != nullptr) {
174 int num_output_tensors = output_tensor_metadata->size();
175
176 if (num_outputs_ != num_output_tensors) {
177 return CreateStatusWithPayload(
178 StatusCode::kInvalidArgument,
179 absl::StrFormat("Mismatch between number of output tensors (%d) and "
180 "output tensors "
181 "metadata (%d).",
182 num_outputs_, num_output_tensors),
183 TfLiteSupportStatus::kMetadataInconsistencyError);
184 }
185
186 for (int i = 0; i < num_output_tensors; ++i) {
187 const tflite::TensorMetadata* output_tensor =
188 output_tensor_metadata->Get(i);
189
190 ASSIGN_OR_RETURN(
191 ClassificationHead head,
192 BuildClassificationHead(*metadata_extractor, *output_tensor,
193 options_->display_names_locale()));
194
195 classification_heads_.emplace_back(std::move(head));
196 }
197 }
198
199 // If classifier heads are not set, build default ones based on model
200 // introspection. This happens if a model with partial or no metadata was
201 // provided through the `model_file_with_metadata` options field.
202 if (classification_heads_.empty()) {
203 classification_heads_.reserve(num_outputs_);
204 for (int output_index = 0; output_index < num_outputs_; ++output_index) {
205 classification_heads_.emplace_back(ClassificationHead{});
206 }
207 }
208
209 if (num_outputs_ != classification_heads_.size()) {
210 return CreateStatusWithPayload(
211 StatusCode::kInvalidArgument,
212 absl::StrFormat("Got %d classifier head(s), expected %d according to "
213 "the label map.",
214 num_outputs_, classification_heads_.size()),
215 TfLiteSupportStatus::kMetadataInconsistencyError);
216 }
217
218 int num_quantized_outputs = 0;
219 for (int i = 0; i < num_outputs_; ++i) {
220 const TfLiteTensor* output_tensor =
221 TfLiteEngine::GetOutput(engine_->interpreter(), i);
222 const int num_dimensions = output_tensor->dims->size;
223 if (num_dimensions == 4) {
224 if (output_tensor->dims->data[1] != 1 ||
225 output_tensor->dims->data[2] != 1) {
226 return CreateStatusWithPayload(
227 StatusCode::kInvalidArgument,
228 absl::StrFormat("Unexpected WxH sizes for output index %d: got "
229 "%dx%d, expected 1x1.",
230 i, output_tensor->dims->data[2],
231 output_tensor->dims->data[1]),
232 TfLiteSupportStatus::kInvalidOutputTensorDimensionsError);
233 }
234 } else if (num_dimensions != 2) {
235 return CreateStatusWithPayload(
236 StatusCode::kInvalidArgument,
237 absl::StrFormat(
238 "Unexpected number of dimensions for output index %d: got %dD, "
239 "expected either 2D (BxN with B=1) or 4D (BxHxWxN with B=1, W=1, "
240 "H=1).",
241 i, num_dimensions),
242 TfLiteSupportStatus::kInvalidOutputTensorDimensionsError);
243 }
244 if (output_tensor->dims->data[0] != 1) {
245 return CreateStatusWithPayload(
246 StatusCode::kInvalidArgument,
247 absl::StrFormat("The output array is expected to have a batch size "
248 "of 1. Got %d for output index %d.",
249 output_tensor->dims->data[0], i),
250 TfLiteSupportStatus::kInvalidOutputTensorDimensionsError);
251 }
252 int num_classes = output_tensor->dims->data[num_dimensions - 1];
253 // If label map is not set, build a default one based on model
254 // introspection. This happens if a model with partial or no metadata was
255 // provided through the `model_file_with_metadata` options field.
256 if (classification_heads_[i].label_map_items.empty()) {
257 classification_heads_[i].label_map_items.reserve(num_classes);
258 for (int class_index = 0; class_index < num_classes; ++class_index) {
259 classification_heads_[i].label_map_items.emplace_back(LabelMapItem{});
260 }
261 }
262 int num_label_map_items = classification_heads_[i].label_map_items.size();
263 if (num_classes != num_label_map_items) {
264 return CreateStatusWithPayload(
265 StatusCode::kInvalidArgument,
266 absl::StrFormat("Got %d class(es) for output index %d, expected %d "
267 "according to the label map.",
268 output_tensor->dims->data[num_dimensions - 1], i,
269 num_label_map_items),
270 TfLiteSupportStatus::kMetadataInconsistencyError);
271 }
272 if (output_tensor->type == kTfLiteUInt8) {
273 num_quantized_outputs++;
274 } else if (output_tensor->type != kTfLiteFloat32) {
275 return CreateStatusWithPayload(
276 StatusCode::kInvalidArgument,
277 absl::StrFormat("Type mismatch for output tensor %s. Requested one "
278 "of these types: "
279 "kTfLiteUint8/kTfLiteFloat32, got %s.",
280 output_tensor->name,
281 TfLiteTypeGetName(output_tensor->type)),
282 TfLiteSupportStatus::kInvalidOutputTensorTypeError);
283 }
284 }
285
286 if (num_quantized_outputs > 0 && num_quantized_outputs != num_outputs_) {
287 return CreateStatusWithPayload(
288 StatusCode::kInvalidArgument,
289 absl::StrFormat("Got %d quantized output(s), expected %d (i.e. all "
290 "provided outputs must be quantized).",
291 num_quantized_outputs, num_outputs_),
292 TfLiteSupportStatus::kInvalidOutputTensorTypeError);
293 }
294 has_uint8_outputs_ = (num_quantized_outputs > 0);
295
296 return absl::OkStatus();
297 }
298
CheckAndSetClassNameSet()299 absl::Status ImageClassifier::CheckAndSetClassNameSet() {
300 // Exit early if no blacklist/whitelist.
301 if (options_->class_name_blacklist_size() == 0 &&
302 options_->class_name_whitelist_size() == 0) {
303 return absl::OkStatus();
304 }
305
306 // Before processing class names whitelist or blacklist from the input options
307 // create a set with _all_ known class names from the label map(s).
308 absl::flat_hash_set<std::string> all_class_names;
309 int head_index = 0;
310 for (const auto& head : classification_heads_) {
311 absl::flat_hash_set<std::string> head_class_names;
312 for (const auto& item : head.label_map_items) {
313 if (!item.name.empty()) {
314 head_class_names.insert(item.name);
315 }
316 }
317 if (head_class_names.empty()) {
318 std::string name = head.name;
319 if (name.empty()) {
320 name = absl::StrFormat("#%d", head_index);
321 }
322 return CreateStatusWithPayload(
323 StatusCode::kInvalidArgument,
324 absl::StrFormat(
325 "Using `class_name_whitelist` or `class_name_blacklist` "
326 "requires labels to be present but none was found for "
327 "classification head: %s",
328 name),
329 TfLiteSupportStatus::kMetadataMissingLabelsError);
330 }
331 all_class_names.insert(head_class_names.begin(), head_class_names.end());
332 head_index++;
333 }
334
335 class_name_set_.is_whitelist = options_->class_name_whitelist_size() > 0;
336 const auto& class_names = class_name_set_.is_whitelist
337 ? options_->class_name_whitelist()
338 : options_->class_name_blacklist();
339
340 // Note: duplicate or unknown classes are just ignored.
341 class_name_set_.values.clear();
342 for (const auto& class_name : class_names) {
343 if (!all_class_names.contains(class_name)) {
344 continue;
345 }
346 class_name_set_.values.insert(class_name);
347 }
348
349 if (class_name_set_.values.empty()) {
350 return CreateStatusWithPayload(
351 StatusCode::kInvalidArgument,
352 absl::StrFormat(
353 "Invalid class names specified via `class_name_%s`: none match "
354 "with model labels.",
355 class_name_set_.is_whitelist ? "whitelist" : "blacklist"),
356 TfLiteSupportStatus::kInvalidArgumentError);
357 }
358
359 return absl::OkStatus();
360 }
361
InitScoreCalibrations()362 absl::Status ImageClassifier::InitScoreCalibrations() {
363 score_calibrations_.clear();
364 score_calibrations_.resize(classification_heads_.size());
365
366 for (int i = 0; i < classification_heads_.size(); ++i) {
367 if (!classification_heads_[i].calibration_params.has_value()) {
368 continue;
369 }
370
371 // Use a specific default score instead of the one specified by default in
372 // cc/task/vision/utils/score_calibration.h. See `kDefaultCalibratedScore`
373 // documentation for more details.
374 classification_heads_[i].calibration_params->default_score =
375 kDefaultCalibratedScore;
376
377 score_calibrations_[i] = absl::make_unique<ScoreCalibration>();
378 if (score_calibrations_[i] == nullptr) {
379 return CreateStatusWithPayload(
380 StatusCode::kInternal, "Could not create score calibration object.");
381 }
382
383 RETURN_IF_ERROR(score_calibrations_[i]->InitializeFromParameters(
384 classification_heads_[i].calibration_params.value()));
385 }
386
387 return absl::OkStatus();
388 }
389
Classify(const FrameBuffer & frame_buffer)390 StatusOr<ClassificationResult> ImageClassifier::Classify(
391 const FrameBuffer& frame_buffer) {
392 BoundingBox roi;
393 roi.set_width(frame_buffer.dimension().width);
394 roi.set_height(frame_buffer.dimension().height);
395 return Classify(frame_buffer, roi);
396 }
397
Classify(const FrameBuffer & frame_buffer,const BoundingBox & roi)398 StatusOr<ClassificationResult> ImageClassifier::Classify(
399 const FrameBuffer& frame_buffer, const BoundingBox& roi) {
400 return InferWithFallback(frame_buffer, roi);
401 }
402
Postprocess(const std::vector<const TfLiteTensor * > & output_tensors,const FrameBuffer &,const BoundingBox &)403 StatusOr<ClassificationResult> ImageClassifier::Postprocess(
404 const std::vector<const TfLiteTensor*>& output_tensors,
405 const FrameBuffer& /*frame_buffer*/, const BoundingBox& /*roi*/) {
406 if (output_tensors.size() != num_outputs_) {
407 return CreateStatusWithPayload(
408 StatusCode::kInternal,
409 absl::StrFormat("Expected %d output tensors, found %d", num_outputs_,
410 output_tensors.size()));
411 }
412
413 ClassificationResult result;
414 std::vector<std::pair<int, float>> score_pairs;
415
416 for (int i = 0; i < num_outputs_; ++i) {
417 auto* classifications = result.add_classifications();
418 classifications->set_head_index(i);
419
420 const auto& head = classification_heads_[i];
421 score_pairs.clear();
422 score_pairs.reserve(head.label_map_items.size());
423
424 const TfLiteTensor* output_tensor = output_tensors[i];
425 if (has_uint8_outputs_) {
426 const uint8* output_data =
427 AssertAndReturnTypedTensor<uint8>(output_tensor);
428 for (int j = 0; j < head.label_map_items.size(); ++j) {
429 score_pairs.emplace_back(j, output_tensor->params.scale *
430 (static_cast<int>(output_data[j]) -
431 output_tensor->params.zero_point));
432 }
433 } else {
434 const float* output_data =
435 AssertAndReturnTypedTensor<float>(output_tensor);
436 for (int j = 0; j < head.label_map_items.size(); ++j) {
437 score_pairs.emplace_back(j, output_data[j]);
438 }
439 }
440
441 // Optional score calibration.
442 if (score_calibrations_[i] != nullptr) {
443 for (auto& score_pair : score_pairs) {
444 const std::string& class_name =
445 head.label_map_items[score_pair.first].name;
446 score_pair.second = score_calibrations_[i]->ComputeCalibratedScore(
447 class_name, score_pair.second);
448 if (score_pair.second > kMaxCalibratedScore) {
449 return CreateStatusWithPayload(
450 StatusCode::kInternal,
451 absl::StrFormat("calibrated score is too high: got %f, expected "
452 "%f as maximum.",
453 score_pair.second, kMaxCalibratedScore));
454 }
455 if (score_pair.second != kDefaultCalibratedScore &&
456 score_pair.second < kMinCalibratedScore) {
457 return CreateStatusWithPayload(
458 StatusCode::kInternal,
459 absl::StrFormat("calibrated score is too low: got %f, expected "
460 "%f as minimum.",
461 score_pair.second, kMinCalibratedScore));
462 }
463 }
464 }
465
466 int num_results =
467 options_->max_results() >= 0
468 ? std::min(static_cast<int>(head.label_map_items.size()),
469 options_->max_results())
470 : head.label_map_items.size();
471 float score_threshold = options_->has_score_threshold()
472 ? options_->score_threshold()
473 : head.score_threshold;
474
475 if (class_name_set_.values.empty()) {
476 // Partially sort in descending order (higher score is better).
477 absl::c_partial_sort(
478 score_pairs, score_pairs.begin() + num_results,
479 [](const std::pair<int, float>& a, const std::pair<int, float>& b) {
480 return a.second > b.second;
481 });
482
483 for (int j = 0; j < num_results; ++j) {
484 float score = score_pairs[j].second;
485 if (score < score_threshold) {
486 break;
487 }
488 auto* cl = classifications->add_classes();
489 cl->set_index(score_pairs[j].first);
490 cl->set_score(score);
491 }
492 } else {
493 // Sort in descending order (higher score is better).
494 absl::c_sort(score_pairs, [](const std::pair<int, float>& a,
495 const std::pair<int, float>& b) {
496 return a.second > b.second;
497 });
498
499 for (int j = 0; j < head.label_map_items.size(); ++j) {
500 float score = score_pairs[j].second;
501 if (score < score_threshold ||
502 classifications->classes_size() >= num_results) {
503 break;
504 }
505
506 const int class_index = score_pairs[j].first;
507 const std::string& class_name = head.label_map_items[class_index].name;
508
509 bool class_name_found = class_name_set_.values.contains(class_name);
510
511 if ((!class_name_found && class_name_set_.is_whitelist) ||
512 (class_name_found && !class_name_set_.is_whitelist)) {
513 continue;
514 }
515
516 auto* cl = classifications->add_classes();
517 cl->set_index(class_index);
518 cl->set_score(score);
519 }
520 }
521 }
522
523 RETURN_IF_ERROR(FillResultsFromLabelMaps(&result));
524
525 return result;
526 }
527
FillResultsFromLabelMaps(ClassificationResult * result)528 absl::Status ImageClassifier::FillResultsFromLabelMaps(
529 ClassificationResult* result) {
530 for (int i = 0; i < result->classifications_size(); ++i) {
531 Classifications* classifications = result->mutable_classifications(i);
532 int head_index = classifications->head_index();
533 if (head_index < 0 || head_index >= classification_heads_.size()) {
534 return CreateStatusWithPayload(
535 StatusCode::kInvalidArgument,
536 absl::StrFormat("Invalid head index (%d) with respect to total "
537 "number of classification heads (%d).",
538 head_index, classification_heads_.size()),
539 TfLiteSupportStatus::kMetadataInconsistencyError);
540 }
541 const std::vector<LabelMapItem>& label_map_items =
542 classification_heads_[head_index].label_map_items;
543 for (int j = 0; j < classifications->classes_size(); ++j) {
544 Class* current_class = classifications->mutable_classes(j);
545 int current_class_index = current_class->index();
546 if (current_class_index < 0 ||
547 current_class_index >= label_map_items.size()) {
548 return CreateStatusWithPayload(
549 StatusCode::kInvalidArgument,
550 absl::StrFormat("Invalid class index (%d) with respect to label "
551 "map size (%d) for head #%d.",
552 current_class_index, label_map_items.size(),
553 head_index),
554 TfLiteSupportStatus::kMetadataInconsistencyError);
555 }
556 const std::string& name = label_map_items[current_class_index].name;
557 if (!name.empty()) {
558 current_class->set_class_name(name);
559 }
560 const std::string& display_name =
561 label_map_items[current_class_index].display_name;
562 if (!display_name.empty()) {
563 current_class->set_display_name(display_name);
564 }
565 }
566 }
567 return absl::OkStatus();
568 }
569
570 } // namespace vision
571 } // namespace task
572 } // namespace tflite
573