xref: /aosp_15_r20/external/tflite-support/tensorflow_lite_support/cc/task/vision/object_detector.cc (revision b16991f985baa50654c05c5adbb3c8bbcfb40082)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow_lite_support/cc/task/vision/object_detector.h"
17 
18 #include <algorithm>
19 #include <limits>
20 #include <vector>
21 
22 #include "absl/memory/memory.h"
23 #include "absl/strings/str_format.h"
24 #include "absl/strings/string_view.h"
25 #include "tensorflow/lite/c/common.h"
26 #include "tensorflow/lite/interpreter.h"
27 #include "tensorflow_lite_support/cc/common.h"
28 #include "tensorflow_lite_support/cc/port/status_macros.h"
29 #include "tensorflow_lite_support/cc/task/core/task_api_factory.h"
30 #include "tensorflow_lite_support/cc/task/core/task_utils.h"
31 #include "tensorflow_lite_support/cc/task/core/tflite_engine.h"
32 #include "tensorflow_lite_support/cc/task/vision/core/label_map_item.h"
33 #include "tensorflow_lite_support/cc/task/vision/proto/bounding_box_proto_inc.h"
34 #include "tensorflow_lite_support/cc/task/vision/proto/class_proto_inc.h"
35 #include "tensorflow_lite_support/cc/task/vision/utils/frame_buffer_utils.h"
36 #include "tensorflow_lite_support/metadata/cc/metadata_extractor.h"
37 #include "tensorflow_lite_support/metadata/metadata_schema_generated.h"
38 
39 namespace tflite {
40 namespace task {
41 namespace vision {
42 
43 namespace {
44 
45 using ::absl::StatusCode;
46 using ::tflite::BoundingBoxProperties;
47 using ::tflite::ContentProperties;
48 using ::tflite::ContentProperties_BoundingBoxProperties;
49 using ::tflite::EnumNameContentProperties;
50 using ::tflite::ProcessUnit;
51 using ::tflite::ProcessUnitOptions_ScoreThresholdingOptions;
52 using ::tflite::TensorMetadata;
53 using ::tflite::metadata::ModelMetadataExtractor;
54 using ::tflite::support::CreateStatusWithPayload;
55 using ::tflite::support::StatusOr;
56 using ::tflite::support::TfLiteSupportStatus;
57 using ::tflite::task::core::AssertAndReturnTypedTensor;
58 using ::tflite::task::core::TaskAPIFactory;
59 using ::tflite::task::core::TfLiteEngine;
60 
61 // The expected number of dimensions of the 4 output tensors, representing in
62 // that order: locations, classes, scores, num_results.
63 static constexpr int kOutputTensorsExpectedDims[4] = {3, 2, 2, 1};
64 
GetBoundingBoxProperties(const TensorMetadata & tensor_metadata)65 StatusOr<const BoundingBoxProperties*> GetBoundingBoxProperties(
66     const TensorMetadata& tensor_metadata) {
67   if (tensor_metadata.content() == nullptr ||
68       tensor_metadata.content()->content_properties() == nullptr) {
69     return CreateStatusWithPayload(
70         StatusCode::kInvalidArgument,
71         absl::StrFormat(
72             "Expected BoundingBoxProperties for tensor %s, found none.",
73             tensor_metadata.name() ? tensor_metadata.name()->str() : "#0"),
74         TfLiteSupportStatus::kMetadataInvalidContentPropertiesError);
75   }
76 
77   ContentProperties type = tensor_metadata.content()->content_properties_type();
78   if (type != ContentProperties_BoundingBoxProperties) {
79     return CreateStatusWithPayload(
80         StatusCode::kInvalidArgument,
81         absl::StrFormat(
82             "Expected BoundingBoxProperties for tensor %s, found %s.",
83             tensor_metadata.name() ? tensor_metadata.name()->str() : "#0",
84             EnumNameContentProperties(type)),
85         TfLiteSupportStatus::kMetadataInvalidContentPropertiesError);
86   }
87 
88   const BoundingBoxProperties* properties =
89       tensor_metadata.content()->content_properties_as_BoundingBoxProperties();
90 
91   // Mobile SSD only supports "BOUNDARIES" bounding box type.
92   if (properties->type() != tflite::BoundingBoxType_BOUNDARIES) {
93     return CreateStatusWithPayload(
94         StatusCode::kInvalidArgument,
95         absl::StrFormat(
96             "Mobile SSD only supports BoundingBoxType BOUNDARIES, found %s",
97             tflite::EnumNameBoundingBoxType(properties->type())),
98         TfLiteSupportStatus::kMetadataInvalidContentPropertiesError);
99   }
100 
101   // Mobile SSD only supports "RATIO" coordinates type.
102   if (properties->coordinate_type() != tflite::CoordinateType_RATIO) {
103     return CreateStatusWithPayload(
104         StatusCode::kInvalidArgument,
105         absl::StrFormat(
106             "Mobile SSD only supports CoordinateType RATIO, found %s",
107             tflite::EnumNameCoordinateType(properties->coordinate_type())),
108         TfLiteSupportStatus::kMetadataInvalidContentPropertiesError);
109   }
110 
111   // Index is optional, but must contain 4 values if present.
112   if (properties->index() != nullptr && properties->index()->size() != 4) {
113     return CreateStatusWithPayload(
114         StatusCode::kInvalidArgument,
115         absl::StrFormat(
116             "Expected BoundingBoxProperties index to contain 4 values, found "
117             "%d",
118             properties->index()->size()),
119         TfLiteSupportStatus::kMetadataInvalidContentPropertiesError);
120   }
121 
122   return properties;
123 }
124 
GetLabelMapIfAny(const ModelMetadataExtractor & metadata_extractor,const TensorMetadata & tensor_metadata,absl::string_view locale)125 StatusOr<std::vector<LabelMapItem>> GetLabelMapIfAny(
126     const ModelMetadataExtractor& metadata_extractor,
127     const TensorMetadata& tensor_metadata, absl::string_view locale) {
128   const std::string labels_filename =
129       ModelMetadataExtractor::FindFirstAssociatedFileName(
130           tensor_metadata, tflite::AssociatedFileType_TENSOR_VALUE_LABELS);
131   if (labels_filename.empty()) {
132     return std::vector<LabelMapItem>();
133   }
134   ASSIGN_OR_RETURN(absl::string_view labels_file,
135                    metadata_extractor.GetAssociatedFile(labels_filename));
136   const std::string display_names_filename =
137       ModelMetadataExtractor::FindFirstAssociatedFileName(
138           tensor_metadata, tflite::AssociatedFileType_TENSOR_VALUE_LABELS,
139           locale);
140   absl::string_view display_names_file = nullptr;
141   if (!display_names_filename.empty()) {
142     ASSIGN_OR_RETURN(display_names_file, metadata_extractor.GetAssociatedFile(
143                                              display_names_filename));
144   }
145   return BuildLabelMapFromFiles(labels_file, display_names_file);
146 }
147 
GetScoreThreshold(const ModelMetadataExtractor & metadata_extractor,const TensorMetadata & tensor_metadata)148 StatusOr<float> GetScoreThreshold(
149     const ModelMetadataExtractor& metadata_extractor,
150     const TensorMetadata& tensor_metadata) {
151   ASSIGN_OR_RETURN(
152       const ProcessUnit* score_thresholding_process_unit,
153       metadata_extractor.FindFirstProcessUnit(
154           tensor_metadata, ProcessUnitOptions_ScoreThresholdingOptions));
155   if (score_thresholding_process_unit == nullptr) {
156     return std::numeric_limits<float>::lowest();
157   }
158   return score_thresholding_process_unit->options_as_ScoreThresholdingOptions()
159       ->global_score_threshold();
160 }
161 
SanityCheckOutputTensors(const std::vector<const TfLiteTensor * > & output_tensors)162 absl::Status SanityCheckOutputTensors(
163     const std::vector<const TfLiteTensor*>& output_tensors) {
164   if (output_tensors.size() != 4) {
165     return CreateStatusWithPayload(
166         StatusCode::kInternal,
167         absl::StrFormat("Expected 4 output tensors, found %d",
168                         output_tensors.size()));
169   }
170 
171   // Get number of results.
172   if (output_tensors[3]->dims->data[0] != 1) {
173     return CreateStatusWithPayload(
174         StatusCode::kInternal,
175         absl::StrFormat(
176             "Expected tensor with dimensions [1] at index 3, found [%d]",
177             output_tensors[3]->dims->data[0]));
178   }
179   int num_results =
180       static_cast<int>(AssertAndReturnTypedTensor<float>(output_tensors[3])[0]);
181 
182   // Check dimensions for the other tensors are correct.
183   if (output_tensors[0]->dims->data[0] != 1 ||
184       output_tensors[0]->dims->data[1] != num_results ||
185       output_tensors[0]->dims->data[2] != 4) {
186     return CreateStatusWithPayload(
187         StatusCode::kInternal,
188         absl::StrFormat(
189             "Expected locations tensor with dimensions [1,%d,4] at index 0, "
190             "found [%d,%d,%d].",
191             num_results, output_tensors[0]->dims->data[0],
192             output_tensors[0]->dims->data[1],
193             output_tensors[0]->dims->data[2]));
194   }
195   if (output_tensors[1]->dims->data[0] != 1 ||
196       output_tensors[1]->dims->data[1] != num_results) {
197     return CreateStatusWithPayload(
198         StatusCode::kInternal,
199         absl::StrFormat(
200             "Expected classes tensor with dimensions [1,%d] at index 1, "
201             "found [%d,%d].",
202             num_results, output_tensors[1]->dims->data[0],
203             output_tensors[1]->dims->data[1]));
204   }
205   if (output_tensors[2]->dims->data[0] != 1 ||
206       output_tensors[2]->dims->data[1] != num_results) {
207     return CreateStatusWithPayload(
208         StatusCode::kInternal,
209         absl::StrFormat(
210             "Expected scores tensor with dimensions [1,%d] at index 2, "
211             "found [%d,%d].",
212             num_results, output_tensors[2]->dims->data[0],
213             output_tensors[2]->dims->data[1]));
214   }
215 
216   return absl::OkStatus();
217 }
218 
219 }  // namespace
220 
221 /* static */
SanityCheckOptions(const ObjectDetectorOptions & options)222 absl::Status ObjectDetector::SanityCheckOptions(
223     const ObjectDetectorOptions& options) {
224   if (!options.has_model_file_with_metadata()) {
225     return CreateStatusWithPayload(
226         StatusCode::kInvalidArgument,
227         "Missing mandatory `model_file_with_metadata` field",
228         TfLiteSupportStatus::kInvalidArgumentError);
229   }
230   if (options.max_results() == 0) {
231     return CreateStatusWithPayload(
232         StatusCode::kInvalidArgument,
233         "Invalid `max_results` option: value must be != 0",
234         TfLiteSupportStatus::kInvalidArgumentError);
235   }
236   if (options.class_name_whitelist_size() > 0 &&
237       options.class_name_blacklist_size() > 0) {
238     return CreateStatusWithPayload(
239         StatusCode::kInvalidArgument,
240         "`class_name_whitelist` and `class_name_blacklist` are mutually "
241         "exclusive options.",
242         TfLiteSupportStatus::kInvalidArgumentError);
243   }
244   if (options.num_threads() == 0 || options.num_threads() < -1) {
245     return CreateStatusWithPayload(
246         StatusCode::kInvalidArgument,
247         "`num_threads` must be greater than 0 or equal to -1.",
248         TfLiteSupportStatus::kInvalidArgumentError);
249   }
250   return absl::OkStatus();
251 }
252 
253 /* static */
CreateFromOptions(const ObjectDetectorOptions & options,std::unique_ptr<tflite::OpResolver> resolver)254 StatusOr<std::unique_ptr<ObjectDetector>> ObjectDetector::CreateFromOptions(
255     const ObjectDetectorOptions& options,
256     std::unique_ptr<tflite::OpResolver> resolver) {
257   RETURN_IF_ERROR(SanityCheckOptions(options));
258 
259   // Copy options to ensure the ExternalFile outlives the constructed object.
260   auto options_copy = absl::make_unique<ObjectDetectorOptions>(options);
261 
262   ASSIGN_OR_RETURN(auto object_detector,
263                    TaskAPIFactory::CreateFromExternalFileProto<ObjectDetector>(
264                        &options_copy->model_file_with_metadata(),
265                        std::move(resolver), options_copy->num_threads()));
266 
267   RETURN_IF_ERROR(object_detector->Init(std::move(options_copy)));
268 
269   return object_detector;
270 }
271 
Init(std::unique_ptr<ObjectDetectorOptions> options)272 absl::Status ObjectDetector::Init(
273     std::unique_ptr<ObjectDetectorOptions> options) {
274   // Set options.
275   options_ = std::move(options);
276 
277   // Perform pre-initialization actions (by default, sets the process engine for
278   // image pre-processing to kLibyuv as a sane default).
279   RETURN_IF_ERROR(PreInit());
280 
281   // Sanity check and set inputs and outputs.
282   RETURN_IF_ERROR(CheckAndSetInputs());
283   RETURN_IF_ERROR(CheckAndSetOutputs());
284 
285   // Initialize class whitelisting/blacklisting, if any.
286   RETURN_IF_ERROR(CheckAndSetClassIndexSet());
287 
288   return absl::OkStatus();
289 }
290 
PreInit()291 absl::Status ObjectDetector::PreInit() {
292   SetProcessEngine(FrameBufferUtils::ProcessEngine::kLibyuv);
293   return absl::OkStatus();
294 }
295 
CheckAndSetOutputs()296 absl::Status ObjectDetector::CheckAndSetOutputs() {
297   // First, sanity checks on the model itself.
298   const TfLiteEngine::Interpreter* interpreter = engine_->interpreter();
299   // Check the number of output tensors.
300   if (TfLiteEngine::OutputCount(interpreter) != 4) {
301     return CreateStatusWithPayload(
302         StatusCode::kInvalidArgument,
303         absl::StrFormat("Mobile SSD models are expected to have exactly 4 "
304                         "outputs, found %d",
305                         TfLiteEngine::OutputCount(interpreter)),
306         TfLiteSupportStatus::kInvalidNumOutputTensorsError);
307   }
308   // Check tensor dimensions and batch size.
309   for (int i = 0; i < 4; ++i) {
310     const TfLiteTensor* tensor = TfLiteEngine::GetOutput(interpreter, i);
311     if (tensor->dims->size != kOutputTensorsExpectedDims[i]) {
312       return CreateStatusWithPayload(
313           StatusCode::kInvalidArgument,
314           absl::StrFormat("Output tensor at index %d is expected to "
315                           "have %d dimensions, found %d.",
316                           i, kOutputTensorsExpectedDims[i], tensor->dims->size),
317           TfLiteSupportStatus::kInvalidOutputTensorDimensionsError);
318     }
319     if (tensor->dims->data[0] != 1) {
320       return CreateStatusWithPayload(
321           StatusCode::kInvalidArgument,
322           absl::StrFormat("Expected batch size of 1, found %d.",
323                           tensor->dims->data[0]),
324           TfLiteSupportStatus::kInvalidOutputTensorDimensionsError);
325     }
326   }
327 
328   // Now, perform sanity checks and extract metadata.
329   const ModelMetadataExtractor* metadata_extractor =
330       engine_->metadata_extractor();
331   // Check that metadata is available.
332   if (metadata_extractor->GetModelMetadata() == nullptr ||
333       metadata_extractor->GetModelMetadata()->subgraph_metadata() == nullptr) {
334     return CreateStatusWithPayload(StatusCode::kInvalidArgument,
335                                    "Object detection models require TFLite "
336                                    "Model Metadata but none was found",
337                                    TfLiteSupportStatus::kMetadataNotFoundError);
338   }
339   // Check output tensor metadata is present and consistent with model.
340   auto output_tensors_metadata = metadata_extractor->GetOutputTensorMetadata();
341   if (output_tensors_metadata == nullptr ||
342       output_tensors_metadata->size() != 4) {
343     return CreateStatusWithPayload(
344         StatusCode::kInvalidArgument,
345         absl::StrFormat(
346             "Mismatch between number of output tensors (4) and output tensors "
347             "metadata (%d).",
348             output_tensors_metadata == nullptr
349                 ? 0
350                 : output_tensors_metadata->size()),
351         TfLiteSupportStatus::kMetadataInconsistencyError);
352   }
353 
354   // Extract mandatory BoundingBoxProperties for easier access at
355   // post-processing time, performing sanity checks on the fly.
356   ASSIGN_OR_RETURN(const BoundingBoxProperties* bounding_box_properties,
357                    GetBoundingBoxProperties(*output_tensors_metadata->Get(0)));
358   if (bounding_box_properties->index() == nullptr) {
359     bounding_box_corners_order_ = {0, 1, 2, 3};
360   } else {
361     auto bounding_box_index = bounding_box_properties->index();
362     bounding_box_corners_order_ = {
363         bounding_box_index->Get(0),
364         bounding_box_index->Get(1),
365         bounding_box_index->Get(2),
366         bounding_box_index->Get(3),
367     };
368   }
369 
370   // Build label map (if available) from metadata.
371   ASSIGN_OR_RETURN(
372       label_map_,
373       GetLabelMapIfAny(*metadata_extractor, *output_tensors_metadata->Get(1),
374                        options_->display_names_locale()));
375 
376   // Set score threshold.
377   if (options_->has_score_threshold()) {
378     score_threshold_ = options_->score_threshold();
379   } else {
380     ASSIGN_OR_RETURN(score_threshold_,
381                      GetScoreThreshold(*metadata_extractor,
382                                        *output_tensors_metadata->Get(2)));
383   }
384 
385   return absl::OkStatus();
386 }
387 
CheckAndSetClassIndexSet()388 absl::Status ObjectDetector::CheckAndSetClassIndexSet() {
389   // Exit early if no blacklist/whitelist.
390   if (options_->class_name_blacklist_size() == 0 &&
391       options_->class_name_whitelist_size() == 0) {
392     return absl::OkStatus();
393   }
394   // Label map is mandatory.
395   if (label_map_.empty()) {
396     return CreateStatusWithPayload(
397         StatusCode::kInvalidArgument,
398         "Using `class_name_whitelist` or `class_name_blacklist` requires "
399         "labels to be present in the TFLite Model Metadata but none was found.",
400         TfLiteSupportStatus::kMetadataMissingLabelsError);
401   }
402 
403   class_index_set_.is_whitelist = options_->class_name_whitelist_size() > 0;
404   const auto& class_names = class_index_set_.is_whitelist
405                                 ? options_->class_name_whitelist()
406                                 : options_->class_name_blacklist();
407   class_index_set_.values.clear();
408   for (const auto& class_name : class_names) {
409     int index = -1;
410     for (int i = 0; i < label_map_.size(); ++i) {
411       if (label_map_[i].name == class_name) {
412         index = i;
413         break;
414       }
415     }
416     // Ignore duplicate or unknown classes.
417     if (index < 0 || class_index_set_.values.contains(index)) {
418       continue;
419     }
420     class_index_set_.values.insert(index);
421   }
422 
423   if (class_index_set_.values.empty()) {
424     return CreateStatusWithPayload(
425         StatusCode::kInvalidArgument,
426         absl::StrFormat(
427             "Invalid class names specified via `class_name_%s`: none match "
428             "with model labels.",
429             class_index_set_.is_whitelist ? "whitelist" : "blacklist"),
430         TfLiteSupportStatus::kInvalidArgumentError);
431   }
432 
433   return absl::OkStatus();
434 }
435 
Detect(const FrameBuffer & frame_buffer)436 StatusOr<DetectionResult> ObjectDetector::Detect(
437     const FrameBuffer& frame_buffer) {
438   BoundingBox roi;
439   roi.set_width(frame_buffer.dimension().width);
440   roi.set_height(frame_buffer.dimension().height);
441   // Rely on `Infer` instead of `InferWithFallback` as DetectionPostprocessing
442   // op doesn't support hardware acceleration at the time.
443   return Infer(frame_buffer, roi);
444 }
445 
Postprocess(const std::vector<const TfLiteTensor * > & output_tensors,const FrameBuffer & frame_buffer,const BoundingBox &)446 StatusOr<DetectionResult> ObjectDetector::Postprocess(
447     const std::vector<const TfLiteTensor*>& output_tensors,
448     const FrameBuffer& frame_buffer, const BoundingBox& /*roi*/) {
449   // Most of the checks here should never happen, as outputs have been validated
450   // at construction time. Checking nonetheless and returning internal errors if
451   // something bad happens.
452   RETURN_IF_ERROR(SanityCheckOutputTensors(output_tensors));
453 
454   // Get number of available results.
455   const int num_results =
456       static_cast<int>(AssertAndReturnTypedTensor<float>(output_tensors[3])[0]);
457   // Compute number of max results to return.
458   const int max_results = options_->max_results() > 0
459                               ? std::min(options_->max_results(), num_results)
460                               : num_results;
461   // The dimensions of the upright (i.e. rotated according to its orientation)
462   // input frame.
463   FrameBuffer::Dimension upright_input_frame_dimensions =
464       frame_buffer.dimension();
465   if (RequireDimensionSwap(frame_buffer.orientation(),
466                            FrameBuffer::Orientation::kTopLeft)) {
467     upright_input_frame_dimensions.Swap();
468   }
469 
470   const float* locations = AssertAndReturnTypedTensor<float>(output_tensors[0]);
471   const float* classes = AssertAndReturnTypedTensor<float>(output_tensors[1]);
472   const float* scores = AssertAndReturnTypedTensor<float>(output_tensors[2]);
473   DetectionResult results;
474   for (int i = 0; i < num_results; ++i) {
475     const int class_index = static_cast<int>(classes[i]);
476     const float score = scores[i];
477     if (!IsClassIndexAllowed(class_index) || score < score_threshold_) {
478       continue;
479     }
480     Detection* detection = results.add_detections();
481     // Denormalize the bounding box cooordinates in the upright frame
482     // coordinates system, then rotate back from frame_buffer.orientation() to
483     // the unrotated frame of reference coordinates system (i.e. with
484     // orientation = kTopLeft).
485     *detection->mutable_bounding_box() = OrientAndDenormalizeBoundingBox(
486         /*from_left=*/locations[4 * i + bounding_box_corners_order_[0]],
487         /*from_top=*/locations[4 * i + bounding_box_corners_order_[1]],
488         /*from_right=*/locations[4 * i + bounding_box_corners_order_[2]],
489         /*from_bottom=*/locations[4 * i + bounding_box_corners_order_[3]],
490         /*from_orientation=*/frame_buffer.orientation(),
491         /*to_orientation=*/FrameBuffer::Orientation::kTopLeft,
492         /*from_dimension=*/upright_input_frame_dimensions);
493     Class* detection_class = detection->add_classes();
494     detection_class->set_index(class_index);
495     detection_class->set_score(score);
496     if (results.detections_size() == max_results) {
497       break;
498     }
499   }
500 
501   if (!label_map_.empty()) {
502     RETURN_IF_ERROR(FillResultsFromLabelMap(&results));
503   }
504 
505   return results;
506 }
507 
IsClassIndexAllowed(int class_index)508 bool ObjectDetector::IsClassIndexAllowed(int class_index) {
509   if (class_index_set_.values.empty()) {
510     return true;
511   }
512   if (class_index_set_.is_whitelist) {
513     return class_index_set_.values.contains(class_index);
514   } else {
515     return !class_index_set_.values.contains(class_index);
516   }
517 }
518 
FillResultsFromLabelMap(DetectionResult * result)519 absl::Status ObjectDetector::FillResultsFromLabelMap(DetectionResult* result) {
520   for (int i = 0; i < result->detections_size(); ++i) {
521     Detection* detection = result->mutable_detections(i);
522     for (int j = 0; j < detection->classes_size(); ++j) {
523       Class* detection_class = detection->mutable_classes(j);
524       const int index = detection_class->index();
525       if (index >= label_map_.size()) {
526         return CreateStatusWithPayload(
527             StatusCode::kInvalidArgument,
528             absl::StrFormat(
529                 "Label map does not contain enough elements: model returned "
530                 "class index %d but label map only contains %d elements.",
531                 index, label_map_.size()),
532             TfLiteSupportStatus::kMetadataInconsistencyError);
533       }
534       std::string name = label_map_[index].name;
535       if (!name.empty()) {
536         detection_class->set_class_name(name);
537       }
538       std::string display_name = label_map_[index].display_name;
539       if (!display_name.empty()) {
540         detection_class->set_display_name(display_name);
541       }
542     }
543   }
544   return absl::OkStatus();
545 }
546 
547 }  // namespace vision
548 }  // namespace task
549 }  // namespace tflite
550