xref: /aosp_15_r20/external/tflite-support/tensorflow_lite_support/cc/task/vision/image_classifier.cc (revision b16991f985baa50654c05c5adbb3c8bbcfb40082)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow_lite_support/cc/task/vision/image_classifier.h"
17 
18 #include "absl/algorithm/container.h"
19 #include "absl/strings/str_format.h"
20 #include "absl/strings/string_view.h"
21 #include "flatbuffers/flatbuffers.h"  // from @flatbuffers
22 #include "tensorflow/lite/interpreter.h"
23 #include "tensorflow_lite_support/cc/common.h"
24 #include "tensorflow_lite_support/cc/port/integral_types.h"
25 #include "tensorflow_lite_support/cc/port/status_macros.h"
26 #include "tensorflow_lite_support/cc/task/core/task_api_factory.h"
27 #include "tensorflow_lite_support/cc/task/core/task_utils.h"
28 #include "tensorflow_lite_support/cc/task/core/tflite_engine.h"
29 #include "tensorflow_lite_support/cc/task/vision/core/label_map_item.h"
30 #include "tensorflow_lite_support/cc/task/vision/proto/class_proto_inc.h"
31 #include "tensorflow_lite_support/cc/task/vision/utils/frame_buffer_utils.h"
32 #include "tensorflow_lite_support/metadata/cc/metadata_extractor.h"
33 #include "tensorflow_lite_support/metadata/metadata_schema_generated.h"
34 
35 namespace tflite {
36 namespace task {
37 namespace vision {
38 
39 namespace {
40 
41 using ::absl::StatusCode;
42 using ::tflite::metadata::ModelMetadataExtractor;
43 using ::tflite::support::CreateStatusWithPayload;
44 using ::tflite::support::StatusOr;
45 using ::tflite::support::TfLiteSupportStatus;
46 using ::tflite::task::core::AssertAndReturnTypedTensor;
47 using ::tflite::task::core::TaskAPIFactory;
48 using ::tflite::task::core::TfLiteEngine;
49 
50 // Default score value used as a fallback for classes that (1) have no score
51 // calibration data or (2) have a very low confident uncalibrated score, i.e.
52 // lower than the `min_uncalibrated_score` threshold.
53 //
54 // (1) This happens when the ScoreCalibration does not cover all the classes
55 // listed in the label map. This can be used to enforce the blacklisting of
56 // given classes so that they are never returned.
57 //
58 // (2) This is an optional threshold provided part of the calibration data. It
59 // is used to mitigate false alarms on some classes.
60 //
61 // In both cases, a class that gets assigned a score of -1 is never returned as
62 // it gets discarded by the `score_threshold` check (see post-processing logic).
63 constexpr float kDefaultCalibratedScore = -1.0f;
64 
65 // Calibrated scores should be in the [0, 1] range, otherwise an error is
66 // returned at post-processing time.
67 constexpr float kMinCalibratedScore = 0.0f;
68 constexpr float kMaxCalibratedScore = 1.0f;
69 
70 }  // namespace
71 
72 /* static */
CreateFromOptions(const ImageClassifierOptions & options,std::unique_ptr<tflite::OpResolver> resolver)73 StatusOr<std::unique_ptr<ImageClassifier>> ImageClassifier::CreateFromOptions(
74     const ImageClassifierOptions& options,
75     std::unique_ptr<tflite::OpResolver> resolver) {
76   RETURN_IF_ERROR(SanityCheckOptions(options));
77 
78   // Copy options to ensure the ExternalFile outlives the constructed object.
79   auto options_copy = absl::make_unique<ImageClassifierOptions>(options);
80 
81   ASSIGN_OR_RETURN(auto image_classifier,
82                    TaskAPIFactory::CreateFromExternalFileProto<ImageClassifier>(
83                        &options_copy->model_file_with_metadata(),
84                        std::move(resolver), options_copy->num_threads()));
85 
86   RETURN_IF_ERROR(image_classifier->Init(std::move(options_copy)));
87 
88   return image_classifier;
89 }
90 
91 /* static */
SanityCheckOptions(const ImageClassifierOptions & options)92 absl::Status ImageClassifier::SanityCheckOptions(
93     const ImageClassifierOptions& options) {
94   if (!options.has_model_file_with_metadata()) {
95     return CreateStatusWithPayload(
96         StatusCode::kInvalidArgument,
97         "Missing mandatory `model_file_with_metadata` field",
98         TfLiteSupportStatus::kInvalidArgumentError);
99   }
100   if (options.max_results() == 0) {
101     return CreateStatusWithPayload(
102         StatusCode::kInvalidArgument,
103         "Invalid `max_results` option: value must be != 0",
104         TfLiteSupportStatus::kInvalidArgumentError);
105   }
106   if (options.score_threshold() < 0 || options.score_threshold() >= 1) {
107     return CreateStatusWithPayload(
108         StatusCode::kInvalidArgument,
109         absl::StrFormat(
110             "`score_threshold` out of range: %f. Valid range is [0,1[.",
111             options.score_threshold()),
112         TfLiteSupportStatus::kInvalidArgumentError);
113   }
114   if (options.class_name_whitelist_size() > 0 &&
115       options.class_name_blacklist_size() > 0) {
116     return CreateStatusWithPayload(
117         StatusCode::kInvalidArgument,
118         "`class_name_whitelist` and `class_name_blacklist` are mutually "
119         "exclusive options.",
120         TfLiteSupportStatus::kInvalidArgumentError);
121   }
122   if (options.num_threads() == 0 || options.num_threads() < -1) {
123     return CreateStatusWithPayload(
124         StatusCode::kInvalidArgument,
125         "`num_threads` must be greater than 0 or equal to -1.",
126         TfLiteSupportStatus::kInvalidArgumentError);
127   }
128   return absl::OkStatus();
129 }
130 
Init(std::unique_ptr<ImageClassifierOptions> options)131 absl::Status ImageClassifier::Init(
132     std::unique_ptr<ImageClassifierOptions> options) {
133   // Set options.
134   options_ = std::move(options);
135 
136   // Perform pre-initialization actions (by default, sets the process engine for
137   // image pre-processing to kLibyuv as a sane default).
138   RETURN_IF_ERROR(PreInit());
139 
140   // Sanity check and set inputs and outputs.
141   RETURN_IF_ERROR(CheckAndSetInputs());
142   RETURN_IF_ERROR(CheckAndSetOutputs());
143 
144   // Initialize class whitelisting/blacklisting, if any.
145   RETURN_IF_ERROR(CheckAndSetClassNameSet());
146 
147   // Perform final initialization (by default, initialize score calibration
148   // parameters, if any).
149   RETURN_IF_ERROR(PostInit());
150 
151   return absl::OkStatus();
152 }
153 
PreInit()154 absl::Status ImageClassifier::PreInit() {
155   SetProcessEngine(FrameBufferUtils::ProcessEngine::kLibyuv);
156   return absl::OkStatus();
157 }
158 
PostInit()159 absl::Status ImageClassifier::PostInit() { return InitScoreCalibrations(); }
160 
CheckAndSetOutputs()161 absl::Status ImageClassifier::CheckAndSetOutputs() {
162   num_outputs_ = TfLiteEngine::OutputCount(engine_->interpreter());
163 
164   // Perform sanity checks and extract metadata.
165   const ModelMetadataExtractor* metadata_extractor =
166       engine_->metadata_extractor();
167 
168   const flatbuffers::Vector<flatbuffers::Offset<tflite::TensorMetadata>>*
169       output_tensor_metadata = metadata_extractor->GetOutputTensorMetadata();
170 
171   // Loop over output tensors metadata, if any.
172   // Note: models with no output tensor metadata at all are supported.
173   if (output_tensor_metadata != nullptr) {
174     int num_output_tensors = output_tensor_metadata->size();
175 
176     if (num_outputs_ != num_output_tensors) {
177       return CreateStatusWithPayload(
178           StatusCode::kInvalidArgument,
179           absl::StrFormat("Mismatch between number of output tensors (%d) and "
180                           "output tensors "
181                           "metadata (%d).",
182                           num_outputs_, num_output_tensors),
183           TfLiteSupportStatus::kMetadataInconsistencyError);
184     }
185 
186     for (int i = 0; i < num_output_tensors; ++i) {
187       const tflite::TensorMetadata* output_tensor =
188           output_tensor_metadata->Get(i);
189 
190       ASSIGN_OR_RETURN(
191           ClassificationHead head,
192           BuildClassificationHead(*metadata_extractor, *output_tensor,
193                                   options_->display_names_locale()));
194 
195       classification_heads_.emplace_back(std::move(head));
196     }
197   }
198 
199   // If classifier heads are not set, build default ones based on model
200   // introspection. This happens if a model with partial or no metadata was
201   // provided through the `model_file_with_metadata` options field.
202   if (classification_heads_.empty()) {
203     classification_heads_.reserve(num_outputs_);
204     for (int output_index = 0; output_index < num_outputs_; ++output_index) {
205       classification_heads_.emplace_back(ClassificationHead{});
206     }
207   }
208 
209   if (num_outputs_ != classification_heads_.size()) {
210     return CreateStatusWithPayload(
211         StatusCode::kInvalidArgument,
212         absl::StrFormat("Got %d classifier head(s), expected %d according to "
213                         "the label map.",
214                         num_outputs_, classification_heads_.size()),
215         TfLiteSupportStatus::kMetadataInconsistencyError);
216   }
217 
218   int num_quantized_outputs = 0;
219   for (int i = 0; i < num_outputs_; ++i) {
220     const TfLiteTensor* output_tensor =
221         TfLiteEngine::GetOutput(engine_->interpreter(), i);
222     const int num_dimensions = output_tensor->dims->size;
223     if (num_dimensions == 4) {
224       if (output_tensor->dims->data[1] != 1 ||
225           output_tensor->dims->data[2] != 1) {
226         return CreateStatusWithPayload(
227             StatusCode::kInvalidArgument,
228             absl::StrFormat("Unexpected WxH sizes for output index %d: got "
229                             "%dx%d, expected 1x1.",
230                             i, output_tensor->dims->data[2],
231                             output_tensor->dims->data[1]),
232             TfLiteSupportStatus::kInvalidOutputTensorDimensionsError);
233       }
234     } else if (num_dimensions != 2) {
235       return CreateStatusWithPayload(
236           StatusCode::kInvalidArgument,
237           absl::StrFormat(
238               "Unexpected number of dimensions for output index %d: got %dD, "
239               "expected either 2D (BxN with B=1) or 4D (BxHxWxN with B=1, W=1, "
240               "H=1).",
241               i, num_dimensions),
242           TfLiteSupportStatus::kInvalidOutputTensorDimensionsError);
243     }
244     if (output_tensor->dims->data[0] != 1) {
245       return CreateStatusWithPayload(
246           StatusCode::kInvalidArgument,
247           absl::StrFormat("The output array is expected to have a batch size "
248                           "of 1. Got %d for output index %d.",
249                           output_tensor->dims->data[0], i),
250           TfLiteSupportStatus::kInvalidOutputTensorDimensionsError);
251     }
252     int num_classes = output_tensor->dims->data[num_dimensions - 1];
253     // If label map is not set, build a default one based on model
254     // introspection. This happens if a model with partial or no metadata was
255     // provided through the `model_file_with_metadata` options field.
256     if (classification_heads_[i].label_map_items.empty()) {
257       classification_heads_[i].label_map_items.reserve(num_classes);
258       for (int class_index = 0; class_index < num_classes; ++class_index) {
259         classification_heads_[i].label_map_items.emplace_back(LabelMapItem{});
260       }
261     }
262     int num_label_map_items = classification_heads_[i].label_map_items.size();
263     if (num_classes != num_label_map_items) {
264       return CreateStatusWithPayload(
265           StatusCode::kInvalidArgument,
266           absl::StrFormat("Got %d class(es) for output index %d, expected %d "
267                           "according to the label map.",
268                           output_tensor->dims->data[num_dimensions - 1], i,
269                           num_label_map_items),
270           TfLiteSupportStatus::kMetadataInconsistencyError);
271     }
272     if (output_tensor->type == kTfLiteUInt8) {
273       num_quantized_outputs++;
274     } else if (output_tensor->type != kTfLiteFloat32) {
275       return CreateStatusWithPayload(
276           StatusCode::kInvalidArgument,
277           absl::StrFormat("Type mismatch for output tensor %s. Requested one "
278                           "of these types: "
279                           "kTfLiteUint8/kTfLiteFloat32, got %s.",
280                           output_tensor->name,
281                           TfLiteTypeGetName(output_tensor->type)),
282           TfLiteSupportStatus::kInvalidOutputTensorTypeError);
283     }
284   }
285 
286   if (num_quantized_outputs > 0 && num_quantized_outputs != num_outputs_) {
287     return CreateStatusWithPayload(
288         StatusCode::kInvalidArgument,
289         absl::StrFormat("Got %d quantized output(s), expected %d (i.e. all "
290                         "provided outputs must be quantized).",
291                         num_quantized_outputs, num_outputs_),
292         TfLiteSupportStatus::kInvalidOutputTensorTypeError);
293   }
294   has_uint8_outputs_ = (num_quantized_outputs > 0);
295 
296   return absl::OkStatus();
297 }
298 
CheckAndSetClassNameSet()299 absl::Status ImageClassifier::CheckAndSetClassNameSet() {
300   // Exit early if no blacklist/whitelist.
301   if (options_->class_name_blacklist_size() == 0 &&
302       options_->class_name_whitelist_size() == 0) {
303     return absl::OkStatus();
304   }
305 
306   // Before processing class names whitelist or blacklist from the input options
307   // create a set with _all_ known class names from the label map(s).
308   absl::flat_hash_set<std::string> all_class_names;
309   int head_index = 0;
310   for (const auto& head : classification_heads_) {
311     absl::flat_hash_set<std::string> head_class_names;
312     for (const auto& item : head.label_map_items) {
313       if (!item.name.empty()) {
314         head_class_names.insert(item.name);
315       }
316     }
317     if (head_class_names.empty()) {
318       std::string name = head.name;
319       if (name.empty()) {
320         name = absl::StrFormat("#%d", head_index);
321       }
322       return CreateStatusWithPayload(
323           StatusCode::kInvalidArgument,
324           absl::StrFormat(
325               "Using `class_name_whitelist` or `class_name_blacklist` "
326               "requires labels to be present but none was found for "
327               "classification head: %s",
328               name),
329           TfLiteSupportStatus::kMetadataMissingLabelsError);
330     }
331     all_class_names.insert(head_class_names.begin(), head_class_names.end());
332     head_index++;
333   }
334 
335   class_name_set_.is_whitelist = options_->class_name_whitelist_size() > 0;
336   const auto& class_names = class_name_set_.is_whitelist
337                                 ? options_->class_name_whitelist()
338                                 : options_->class_name_blacklist();
339 
340   // Note: duplicate or unknown classes are just ignored.
341   class_name_set_.values.clear();
342   for (const auto& class_name : class_names) {
343     if (!all_class_names.contains(class_name)) {
344       continue;
345     }
346     class_name_set_.values.insert(class_name);
347   }
348 
349   if (class_name_set_.values.empty()) {
350     return CreateStatusWithPayload(
351         StatusCode::kInvalidArgument,
352         absl::StrFormat(
353             "Invalid class names specified via `class_name_%s`: none match "
354             "with model labels.",
355             class_name_set_.is_whitelist ? "whitelist" : "blacklist"),
356         TfLiteSupportStatus::kInvalidArgumentError);
357   }
358 
359   return absl::OkStatus();
360 }
361 
InitScoreCalibrations()362 absl::Status ImageClassifier::InitScoreCalibrations() {
363   score_calibrations_.clear();
364   score_calibrations_.resize(classification_heads_.size());
365 
366   for (int i = 0; i < classification_heads_.size(); ++i) {
367     if (!classification_heads_[i].calibration_params.has_value()) {
368       continue;
369     }
370 
371     // Use a specific default score instead of the one specified by default in
372     // cc/task/vision/utils/score_calibration.h. See `kDefaultCalibratedScore`
373     // documentation for more details.
374     classification_heads_[i].calibration_params->default_score =
375         kDefaultCalibratedScore;
376 
377     score_calibrations_[i] = absl::make_unique<ScoreCalibration>();
378     if (score_calibrations_[i] == nullptr) {
379       return CreateStatusWithPayload(
380           StatusCode::kInternal, "Could not create score calibration object.");
381     }
382 
383     RETURN_IF_ERROR(score_calibrations_[i]->InitializeFromParameters(
384         classification_heads_[i].calibration_params.value()));
385   }
386 
387   return absl::OkStatus();
388 }
389 
Classify(const FrameBuffer & frame_buffer)390 StatusOr<ClassificationResult> ImageClassifier::Classify(
391     const FrameBuffer& frame_buffer) {
392   BoundingBox roi;
393   roi.set_width(frame_buffer.dimension().width);
394   roi.set_height(frame_buffer.dimension().height);
395   return Classify(frame_buffer, roi);
396 }
397 
Classify(const FrameBuffer & frame_buffer,const BoundingBox & roi)398 StatusOr<ClassificationResult> ImageClassifier::Classify(
399     const FrameBuffer& frame_buffer, const BoundingBox& roi) {
400   return InferWithFallback(frame_buffer, roi);
401 }
402 
Postprocess(const std::vector<const TfLiteTensor * > & output_tensors,const FrameBuffer &,const BoundingBox &)403 StatusOr<ClassificationResult> ImageClassifier::Postprocess(
404     const std::vector<const TfLiteTensor*>& output_tensors,
405     const FrameBuffer& /*frame_buffer*/, const BoundingBox& /*roi*/) {
406   if (output_tensors.size() != num_outputs_) {
407     return CreateStatusWithPayload(
408         StatusCode::kInternal,
409         absl::StrFormat("Expected %d output tensors, found %d", num_outputs_,
410                         output_tensors.size()));
411   }
412 
413   ClassificationResult result;
414   std::vector<std::pair<int, float>> score_pairs;
415 
416   for (int i = 0; i < num_outputs_; ++i) {
417     auto* classifications = result.add_classifications();
418     classifications->set_head_index(i);
419 
420     const auto& head = classification_heads_[i];
421     score_pairs.clear();
422     score_pairs.reserve(head.label_map_items.size());
423 
424     const TfLiteTensor* output_tensor = output_tensors[i];
425     if (has_uint8_outputs_) {
426       const uint8* output_data =
427           AssertAndReturnTypedTensor<uint8>(output_tensor);
428       for (int j = 0; j < head.label_map_items.size(); ++j) {
429         score_pairs.emplace_back(j, output_tensor->params.scale *
430                                         (static_cast<int>(output_data[j]) -
431                                          output_tensor->params.zero_point));
432       }
433     } else {
434       const float* output_data =
435           AssertAndReturnTypedTensor<float>(output_tensor);
436       for (int j = 0; j < head.label_map_items.size(); ++j) {
437         score_pairs.emplace_back(j, output_data[j]);
438       }
439     }
440 
441     // Optional score calibration.
442     if (score_calibrations_[i] != nullptr) {
443       for (auto& score_pair : score_pairs) {
444         const std::string& class_name =
445             head.label_map_items[score_pair.first].name;
446         score_pair.second = score_calibrations_[i]->ComputeCalibratedScore(
447             class_name, score_pair.second);
448         if (score_pair.second > kMaxCalibratedScore) {
449           return CreateStatusWithPayload(
450               StatusCode::kInternal,
451               absl::StrFormat("calibrated score is too high: got %f, expected "
452                               "%f as maximum.",
453                               score_pair.second, kMaxCalibratedScore));
454         }
455         if (score_pair.second != kDefaultCalibratedScore &&
456             score_pair.second < kMinCalibratedScore) {
457           return CreateStatusWithPayload(
458               StatusCode::kInternal,
459               absl::StrFormat("calibrated score is too low: got %f, expected "
460                               "%f as minimum.",
461                               score_pair.second, kMinCalibratedScore));
462         }
463       }
464     }
465 
466     int num_results =
467         options_->max_results() >= 0
468             ? std::min(static_cast<int>(head.label_map_items.size()),
469                        options_->max_results())
470             : head.label_map_items.size();
471     float score_threshold = options_->has_score_threshold()
472                                 ? options_->score_threshold()
473                                 : head.score_threshold;
474 
475     if (class_name_set_.values.empty()) {
476       // Partially sort in descending order (higher score is better).
477       absl::c_partial_sort(
478           score_pairs, score_pairs.begin() + num_results,
479           [](const std::pair<int, float>& a, const std::pair<int, float>& b) {
480             return a.second > b.second;
481           });
482 
483       for (int j = 0; j < num_results; ++j) {
484         float score = score_pairs[j].second;
485         if (score < score_threshold) {
486           break;
487         }
488         auto* cl = classifications->add_classes();
489         cl->set_index(score_pairs[j].first);
490         cl->set_score(score);
491       }
492     } else {
493       // Sort in descending order (higher score is better).
494       absl::c_sort(score_pairs, [](const std::pair<int, float>& a,
495                                    const std::pair<int, float>& b) {
496         return a.second > b.second;
497       });
498 
499       for (int j = 0; j < head.label_map_items.size(); ++j) {
500         float score = score_pairs[j].second;
501         if (score < score_threshold ||
502             classifications->classes_size() >= num_results) {
503           break;
504         }
505 
506         const int class_index = score_pairs[j].first;
507         const std::string& class_name = head.label_map_items[class_index].name;
508 
509         bool class_name_found = class_name_set_.values.contains(class_name);
510 
511         if ((!class_name_found && class_name_set_.is_whitelist) ||
512             (class_name_found && !class_name_set_.is_whitelist)) {
513           continue;
514         }
515 
516         auto* cl = classifications->add_classes();
517         cl->set_index(class_index);
518         cl->set_score(score);
519       }
520     }
521   }
522 
523   RETURN_IF_ERROR(FillResultsFromLabelMaps(&result));
524 
525   return result;
526 }
527 
FillResultsFromLabelMaps(ClassificationResult * result)528 absl::Status ImageClassifier::FillResultsFromLabelMaps(
529     ClassificationResult* result) {
530   for (int i = 0; i < result->classifications_size(); ++i) {
531     Classifications* classifications = result->mutable_classifications(i);
532     int head_index = classifications->head_index();
533     if (head_index < 0 || head_index >= classification_heads_.size()) {
534       return CreateStatusWithPayload(
535           StatusCode::kInvalidArgument,
536           absl::StrFormat("Invalid head index (%d) with respect to total "
537                           "number of classification heads (%d).",
538                           head_index, classification_heads_.size()),
539           TfLiteSupportStatus::kMetadataInconsistencyError);
540     }
541     const std::vector<LabelMapItem>& label_map_items =
542         classification_heads_[head_index].label_map_items;
543     for (int j = 0; j < classifications->classes_size(); ++j) {
544       Class* current_class = classifications->mutable_classes(j);
545       int current_class_index = current_class->index();
546       if (current_class_index < 0 ||
547           current_class_index >= label_map_items.size()) {
548         return CreateStatusWithPayload(
549             StatusCode::kInvalidArgument,
550             absl::StrFormat("Invalid class index (%d) with respect to label "
551                             "map size (%d) for head #%d.",
552                             current_class_index, label_map_items.size(),
553                             head_index),
554             TfLiteSupportStatus::kMetadataInconsistencyError);
555       }
556       const std::string& name = label_map_items[current_class_index].name;
557       if (!name.empty()) {
558         current_class->set_class_name(name);
559       }
560       const std::string& display_name =
561           label_map_items[current_class_index].display_name;
562       if (!display_name.empty()) {
563         current_class->set_display_name(display_name);
564       }
565     }
566   }
567   return absl::OkStatus();
568 }
569 
570 }  // namespace vision
571 }  // namespace task
572 }  // namespace tflite
573