1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_CORE_CLASSIFICATION_HEAD_ITEM_H_ 16 #define TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_CORE_CLASSIFICATION_HEAD_ITEM_H_ 17 18 #include <string> 19 #include <vector> 20 21 #include "absl/memory/memory.h" 22 #include "absl/strings/string_view.h" 23 #include "tensorflow_lite_support/cc/port/statusor.h" 24 #include "tensorflow_lite_support/cc/task/vision/core/label_map_item.h" 25 #include "tensorflow_lite_support/cc/task/vision/utils/score_calibration.h" 26 #include "tensorflow_lite_support/metadata/cc/metadata_extractor.h" 27 #include "tensorflow_lite_support/metadata/metadata_schema_generated.h" 28 29 namespace tflite { 30 namespace task { 31 namespace vision { 32 33 // A single classifier head for an image classifier model, associated with a 34 // corresponding output tensor. 35 struct ClassificationHead { ClassificationHeadClassificationHead36 ClassificationHead() : score_threshold(0) {} 37 ClassificationHeadClassificationHead38 explicit ClassificationHead( 39 const std::vector<tflite::task::vision::LabelMapItem>&& label_map_items) 40 : label_map_items(label_map_items), score_threshold(0) {} 41 42 // An optional name that usually indicates what this set of classes represent, 43 // e.g. "flowers". 44 std::string name; 45 // The label map representing the list of supported classes, aka labels. 46 // 47 // This must be in direct correspondence with the associated output tensor, 48 // i.e.: 49 // 50 // - The number of classes must match with the dimension of the corresponding 51 // output tensor, 52 // - The i-th item in the label map is assumed to correspond to the i-th 53 // output value in the output tensor. 54 // 55 // This requires to put in place dedicated sanity checks before running 56 // inference. 57 std::vector<tflite::task::vision::LabelMapItem> label_map_items; 58 // Recommended score threshold typically in [0,1[. Classification results with 59 // a score below this value are considered low-confidence and should be 60 // rejected from returned results. 61 float score_threshold; 62 // Optional score calibration parameters (one set of parameters per class in 63 // the label map). This is primarily meant for multi-label classifiers made of 64 // independent sigmoids. 65 // 66 // Such parameters are usually tuned so that calibrated scores can be compared 67 // to a default threshold common to all classes to achieve a given amount of 68 // precision. 69 // 70 // Example: 60% precision for threshold = 0.5. 71 absl::optional<tflite::task::vision::SigmoidCalibrationParameters> 72 calibration_params; 73 }; 74 75 // Builds a classification head using the provided metadata extractor, for the 76 // given output tensor metadata. Returns an error in case the head cannot be 77 // built (e.g. missing associated file for score calibration parameters). 78 // 79 // Optionally it is possible to specify which locale should be used (e.g. "en") 80 // to fill the label map display names, if any, and provided the corresponding 81 // associated file is present in the metadata. If no locale is specified, or if 82 // there is no associated file for the provided locale, display names are just 83 // left empty and no error is returned. 84 // 85 // E.g. (metatada displayed in JSON format below): 86 // 87 // ... 88 // "associated_files": [ 89 // { 90 // "name": "labels.txt", 91 // "type": "TENSOR_AXIS_LABELS" 92 // }, 93 // { 94 // "name": "labels-en.txt", 95 // "type": "TENSOR_AXIS_LABELS", 96 // "locale": "en" 97 // }, 98 // ... 99 // 100 // See metadata schema TENSOR_AXIS_LABELS for more details. 101 tflite::support::StatusOr<ClassificationHead> BuildClassificationHead( 102 const tflite::metadata::ModelMetadataExtractor& metadata_extractor, 103 const tflite::TensorMetadata& output_tensor_metadata, 104 absl::string_view display_names_locale = absl::string_view()); 105 106 } // namespace vision 107 } // namespace task 108 } // namespace tflite 109 110 #endif // TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_CORE_CLASSIFICATION_HEAD_ITEM_H_ 111