xref: /aosp_15_r20/external/tflite-support/tensorflow_lite_support/cc/task/vision/utils/score_calibration.h (revision b16991f985baa50654c05c5adbb3c8bbcfb40082)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_UTILS_SCORE_CALIBRATION_H_
16 #define TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_UTILS_SCORE_CALIBRATION_H_
17 
18 #include <iostream>
19 #include <map>
20 #include <memory>
21 #include <string>
22 #include <utility>
23 #include <vector>
24 
25 #include "absl/container/flat_hash_map.h"
26 #include "absl/status/status.h"
27 #include "absl/strings/string_view.h"
28 #include "absl/types/optional.h"
29 #include "tensorflow_lite_support/cc/port/statusor.h"
30 #include "tensorflow_lite_support/cc/task/vision/core/label_map_item.h"
31 #include "tensorflow_lite_support/metadata/metadata_schema_generated.h"
32 
33 namespace tflite {
34 namespace task {
35 namespace vision {
36 
37 // Sigmoid structure.
38 struct Sigmoid {
SigmoidSigmoid39   Sigmoid() : scale(1.0) {}
40   Sigmoid(std::string label, float slope, float offset, float scale = 1.0,
41           absl::optional<float> min_uncalibrated_score = absl::nullopt)
labelSigmoid42       : label(label),
43         slope(slope),
44         offset(offset),
45         scale(scale),
46         min_uncalibrated_score(min_uncalibrated_score) {}
47 
48   bool operator==(const Sigmoid& other) const {
49     return label == other.label && slope == other.slope &&
50            offset == other.offset && scale == other.scale &&
51            min_uncalibrated_score == other.min_uncalibrated_score;
52   }
53 
54   // Unique label corresponding to the sigmoid parameters.
55   std::string label;
56   float slope;
57   float offset;
58   float scale;
59   absl::optional<float> min_uncalibrated_score;
60 };
61 
62 std::ostream& operator<<(std::ostream& os, const Sigmoid& s);
63 
64 // Transformation function to use for computing transformation scores.
65 enum class ScoreTransformation {
66   kIDENTITY,         // f(x) = x
67   kLOG,              // f(x) = log(x)
68   kINVERSE_LOGISTIC  // f(x) = log(x) - log(1 - x)
69 };
70 
71 // Sigmoid calibration parameters.
72 struct SigmoidCalibrationParameters {
SigmoidCalibrationParametersSigmoidCalibrationParameters73   SigmoidCalibrationParameters()
74       : default_score(0.0),
75         score_transformation(ScoreTransformation::kIDENTITY) {}
76   explicit SigmoidCalibrationParameters(
77       std::vector<Sigmoid> sigmoid,
78       ScoreTransformation score_transformation = ScoreTransformation::kIDENTITY,
79       absl::optional<Sigmoid> default_sigmoid = absl::nullopt,
80       float default_score = 0.0)
sigmoidSigmoidCalibrationParameters81       : sigmoid(sigmoid),
82         default_sigmoid(default_sigmoid),
83         default_score(default_score),
84         score_transformation(score_transformation) {}
85   // A vector of Sigmoid associated to the ScoreCalibration instance.
86   std::vector<Sigmoid> sigmoid;
87   // If set, this sigmoid will be applied to any non-matching labels.
88   absl::optional<Sigmoid> default_sigmoid;
89   // The default score for non-matching labels. Only used if default_sigmoid
90   // isn't set.
91   float default_score;
92   // Function for computing a transformation score prior to sigmoid fitting.
93   ScoreTransformation score_transformation;
94 };
95 
96 // This class is used to calibrate predicted scores so that scores are
97 // comparable across labels. Depending on the particular calibration parameters
98 // being used, the calibrated scores can also be approximately interpreted as a
99 // likelihood of being correct. For a given TF Lite model, such parameters are
100 // typically obtained from TF Lite Metadata (see ScoreCalibrationOptions).
101 class ScoreCalibration {
102  public:
103   ScoreCalibration();
104   ~ScoreCalibration();
105 
106   // Transfers input parameters and construct a label to sigmoid map.
107   absl::Status InitializeFromParameters(
108       const SigmoidCalibrationParameters& params);
109 
110   // Returns a calibrated score given a label string and uncalibrated score. The
111   // calibrated score will be in the range [0.0, 1.0] and can loosely be
112   // interpreted as a likelihood of the label being correct.
113   float ComputeCalibratedScore(const std::string& label,
114                                float uncalibrated_score) const;
115 
116  private:
117   // Finds the sigmoid parameters corresponding to the provided label.
118   absl::optional<Sigmoid> FindSigmoidParameters(const std::string& label) const;
119 
120   // Parameters for internal states.
121   SigmoidCalibrationParameters sigmoid_parameters_;
122 
123   // Maps label strings to the particular sigmoid stored in sigmoid_parameters_.
124   absl::flat_hash_map<std::string, Sigmoid> sigmoid_parameters_map_;
125 };
126 
127 // Builds SigmoidCalibrationParameters using data obtained from TF Lite Metadata
128 // (see ScoreCalibrationOptions in metadata schema).
129 //
130 // The provided `score_calibration_file` represents the contents of the score
131 // calibration associated file (TENSOR_AXIS_SCORE_CALIBRATION), i.e. one set of
132 // parameters (scale, slope, etc) per line. Each line must be in 1:1
133 // correspondence with `label_map_items`, so as to associate each sigmoid to its
134 // corresponding label name. Returns an error if no valid parameters could be
135 // built (e.g. malformed parameters).
136 tflite::support::StatusOr<SigmoidCalibrationParameters>
137 BuildSigmoidCalibrationParams(
138     const tflite::ScoreCalibrationOptions& score_calibration_options,
139     absl::string_view score_calibration_file,
140     const std::vector<LabelMapItem>& label_map_items);
141 
142 }  // namespace vision
143 }  // namespace task
144 }  // namespace tflite
145 
146 #endif  // TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_UTILS_SCORE_CALIBRATION_H_
147