xref: /aosp_15_r20/external/tflite-support/tensorflow_lite_support/cc/task/vision/core/classification_head.h (revision b16991f985baa50654c05c5adbb3c8bbcfb40082)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_CORE_CLASSIFICATION_HEAD_ITEM_H_
16 #define TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_CORE_CLASSIFICATION_HEAD_ITEM_H_
17 
18 #include <string>
19 #include <vector>
20 
21 #include "absl/memory/memory.h"
22 #include "absl/strings/string_view.h"
23 #include "tensorflow_lite_support/cc/port/statusor.h"
24 #include "tensorflow_lite_support/cc/task/vision/core/label_map_item.h"
25 #include "tensorflow_lite_support/cc/task/vision/utils/score_calibration.h"
26 #include "tensorflow_lite_support/metadata/cc/metadata_extractor.h"
27 #include "tensorflow_lite_support/metadata/metadata_schema_generated.h"
28 
29 namespace tflite {
30 namespace task {
31 namespace vision {
32 
33 // A single classifier head for an image classifier model, associated with a
34 // corresponding output tensor.
35 struct ClassificationHead {
ClassificationHeadClassificationHead36   ClassificationHead() : score_threshold(0) {}
37 
ClassificationHeadClassificationHead38   explicit ClassificationHead(
39       const std::vector<tflite::task::vision::LabelMapItem>&& label_map_items)
40       : label_map_items(label_map_items), score_threshold(0) {}
41 
42   // An optional name that usually indicates what this set of classes represent,
43   // e.g. "flowers".
44   std::string name;
45   // The label map representing the list of supported classes, aka labels.
46   //
47   // This must be in direct correspondence with the associated output tensor,
48   // i.e.:
49   //
50   // - The number of classes must match with the dimension of the corresponding
51   // output tensor,
52   // - The i-th item in the label map is assumed to correspond to the i-th
53   // output value in the output tensor.
54   //
55   // This requires to put in place dedicated sanity checks before running
56   // inference.
57   std::vector<tflite::task::vision::LabelMapItem> label_map_items;
58   // Recommended score threshold typically in [0,1[. Classification results with
59   // a score below this value are considered low-confidence and should be
60   // rejected from returned results.
61   float score_threshold;
62   // Optional score calibration parameters (one set of parameters per class in
63   // the label map). This is primarily meant for multi-label classifiers made of
64   // independent sigmoids.
65   //
66   // Such parameters are usually tuned so that calibrated scores can be compared
67   // to a default threshold common to all classes to achieve a given amount of
68   // precision.
69   //
70   // Example: 60% precision for threshold = 0.5.
71   absl::optional<tflite::task::vision::SigmoidCalibrationParameters>
72       calibration_params;
73 };
74 
75 // Builds a classification head using the provided metadata extractor, for the
76 // given output tensor metadata. Returns an error in case the head cannot be
77 // built (e.g. missing associated file for score calibration parameters).
78 //
79 // Optionally it is possible to specify which locale should be used (e.g. "en")
80 // to fill the label map display names, if any, and provided the corresponding
81 // associated file is present in the metadata. If no locale is specified, or if
82 // there is no associated file for the provided locale, display names are just
83 // left empty and no error is returned.
84 //
85 // E.g. (metatada displayed in JSON format below):
86 //
87 // ...
88 // "associated_files": [
89 //  {
90 //    "name": "labels.txt",
91 //    "type": "TENSOR_AXIS_LABELS"
92 //  },
93 //  {
94 //    "name": "labels-en.txt",
95 //    "type": "TENSOR_AXIS_LABELS",
96 //    "locale": "en"
97 //  },
98 // ...
99 //
100 // See metadata schema TENSOR_AXIS_LABELS for more details.
101 tflite::support::StatusOr<ClassificationHead> BuildClassificationHead(
102     const tflite::metadata::ModelMetadataExtractor& metadata_extractor,
103     const tflite::TensorMetadata& output_tensor_metadata,
104     absl::string_view display_names_locale = absl::string_view());
105 
106 }  // namespace vision
107 }  // namespace task
108 }  // namespace tflite
109 
110 #endif  // TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_CORE_CLASSIFICATION_HEAD_ITEM_H_
111