xref: /aosp_15_r20/external/tflite-support/tensorflow_lite_support/cc/task/vision/utils/score_calibration.cc (revision b16991f985baa50654c05c5adbb3c8bbcfb40082)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow_lite_support/cc/task/vision/utils/score_calibration.h"
16 
17 #include <cmath>
18 #include <memory>
19 #include <utility>
20 #include <vector>
21 
22 #include "absl/status/status.h"
23 #include "absl/strings/str_format.h"
24 #include "absl/strings/str_split.h"
25 #include "absl/strings/string_view.h"
26 #include "absl/types/optional.h"
27 #include "tensorflow_lite_support/cc/common.h"
28 #include "tensorflow_lite_support/cc/port/status_macros.h"
29 
30 namespace tflite {
31 namespace task {
32 namespace vision {
33 namespace {
34 
35 using ::absl::StatusCode;
36 using ::tflite::support::CreateStatusWithPayload;
37 using ::tflite::support::StatusOr;
38 using ::tflite::support::TfLiteSupportStatus;
39 
40 // Used to prevent log(<=0.0) in ClampedLog() calls.
41 constexpr float kLogScoreMinimum = 1e-16;
42 
43 // Returns the following, depending on x:
44 //   x => threshold: log(x)
45 //   x < threshold: 2 * log(thresh) - log(2 * thresh - x)
46 // This form (a) is anti-symmetric about the threshold and (b) has continuous
47 // value and first derivative. This is done to prevent taking the log of values
48 // close to 0 which can lead to floating point errors and is better than simple
49 // clamping since it preserves order for scores less than the threshold.
ClampedLog(float x,float threshold)50 float ClampedLog(float x, float threshold) {
51   if (x < threshold) {
52     return 2.0 * std::log(static_cast<double>(threshold)) -
53            log(2.0 * threshold - x);
54   }
55   return std::log(static_cast<double>(x));
56 }
57 
58 // Applies the specified score transformation to the provided score.
59 // Currently supports the following,
60 //   IDENTITY         : f(x) = x
61 //   LOG              : f(x) = log(x)
62 //   INVERSE_LOGISTIC : f(x) = log(x) - log(1-x)
ApplyScoreTransformation(float score,const ScoreTransformation & type)63 float ApplyScoreTransformation(float score, const ScoreTransformation& type) {
64   switch (type) {
65     case ScoreTransformation::kIDENTITY:
66       return score;
67     case ScoreTransformation::kINVERSE_LOGISTIC:
68       return (ClampedLog(score, kLogScoreMinimum) -
69               ClampedLog(1.0 - score, kLogScoreMinimum));
70     case ScoreTransformation::kLOG:
71       return ClampedLog(score, kLogScoreMinimum);
72   }
73 }
74 
75 // Builds a single Sigmoid from the label name and associated CSV file line.
SigmoidFromLabelAndLine(absl::string_view label,absl::string_view line)76 StatusOr<Sigmoid> SigmoidFromLabelAndLine(absl::string_view label,
77                                           absl::string_view line) {
78   std::vector<absl::string_view> str_params = absl::StrSplit(line, ',');
79   if (str_params.size() != 3 && str_params.size() != 4) {
80     return CreateStatusWithPayload(
81         StatusCode::kInvalidArgument,
82         absl::StrFormat("Expected 3 or 4 parameters per line in score "
83                         "calibration file, got %d.",
84                         str_params.size()),
85         TfLiteSupportStatus::kMetadataMalformedScoreCalibrationError);
86   }
87   std::vector<float> float_params(4);
88   for (int i = 0; i < str_params.size(); ++i) {
89     if (!absl::SimpleAtof(str_params[i], &float_params[i])) {
90       return CreateStatusWithPayload(
91           StatusCode::kInvalidArgument,
92           absl::StrFormat(
93               "Could not parse score calibration parameter as float: %s.",
94               str_params[i]),
95           TfLiteSupportStatus::kMetadataMalformedScoreCalibrationError);
96     }
97   }
98   Sigmoid sigmoid;
99   sigmoid.label = std::string(label);
100   sigmoid.scale = float_params[0];
101   sigmoid.slope = float_params[1];
102   sigmoid.offset = float_params[2];
103   if (str_params.size() == 4) {
104     sigmoid.min_uncalibrated_score = float_params[3];
105   }
106   return sigmoid;
107 }
108 
109 // Converts a tflite::ScoreTransformationType to its
110 // tflite::task::vision::ScoreTransformation equivalent.
ConvertScoreTransformationType(tflite::ScoreTransformationType type)111 ScoreTransformation ConvertScoreTransformationType(
112     tflite::ScoreTransformationType type) {
113   switch (type) {
114     case tflite::ScoreTransformationType_IDENTITY:
115       return ScoreTransformation::kIDENTITY;
116     case tflite::ScoreTransformationType_LOG:
117       return ScoreTransformation::kLOG;
118     case tflite::ScoreTransformationType_INVERSE_LOGISTIC:
119       return ScoreTransformation::kINVERSE_LOGISTIC;
120   }
121 }
122 
123 }  // namespace
124 
operator <<(std::ostream & os,const Sigmoid & s)125 std::ostream& operator<<(std::ostream& os, const Sigmoid& s) {
126   os << s.label << "," << s.slope << "," << s.offset << "," << s.scale;
127   if (s.min_uncalibrated_score.has_value()) {
128     os << "," << s.min_uncalibrated_score.value();
129   }
130   return os;
131 }
132 
ScoreCalibration()133 ScoreCalibration::ScoreCalibration() {}
~ScoreCalibration()134 ScoreCalibration::~ScoreCalibration() {}
135 
InitializeFromParameters(const SigmoidCalibrationParameters & params)136 absl::Status ScoreCalibration::InitializeFromParameters(
137     const SigmoidCalibrationParameters& params) {
138   sigmoid_parameters_ = std::move(params);
139   // Fill in the map from label -> sigmoid.
140   sigmoid_parameters_map_.clear();
141   for (const auto& sigmoid : sigmoid_parameters_.sigmoid) {
142     sigmoid_parameters_map_.insert_or_assign(sigmoid.label, sigmoid);
143   }
144   return absl::OkStatus();
145 }
146 
ComputeCalibratedScore(const std::string & label,float uncalibrated_score) const147 float ScoreCalibration::ComputeCalibratedScore(const std::string& label,
148                                                float uncalibrated_score) const {
149   absl::optional<Sigmoid> sigmoid = FindSigmoidParameters(label);
150   if (!sigmoid.has_value() ||
151       (sigmoid.value().min_uncalibrated_score.has_value() &&
152        uncalibrated_score < sigmoid.value().min_uncalibrated_score.value())) {
153     return sigmoid_parameters_.default_score;
154   }
155 
156   float transformed_score = ApplyScoreTransformation(
157       uncalibrated_score, sigmoid_parameters_.score_transformation);
158   float scale_shifted_score =
159       transformed_score * sigmoid.value().slope + sigmoid.value().offset;
160 
161   // For numerical stability use 1 / (1+exp(-x)) when scale_shifted_score >= 0
162   // and exp(x) / (1+exp(x)) when scale_shifted_score < 0.
163   if (scale_shifted_score >= 0.0) {
164     return sigmoid.value().scale /
165            (1.0 + std::exp(static_cast<double>(-scale_shifted_score)));
166   } else {
167     float score_exp = std::exp(static_cast<double>(scale_shifted_score));
168     return sigmoid.value().scale * score_exp / (1.0 + score_exp);
169   }
170 }
171 
FindSigmoidParameters(const std::string & label) const172 absl::optional<Sigmoid> ScoreCalibration::FindSigmoidParameters(
173     const std::string& label) const {
174   auto it = sigmoid_parameters_map_.find(label);
175   if (it != sigmoid_parameters_map_.end()) {
176     return it->second;
177   } else if (sigmoid_parameters_.default_sigmoid.has_value()) {
178     return sigmoid_parameters_.default_sigmoid.value();
179   }
180   return absl::nullopt;
181 }
182 
BuildSigmoidCalibrationParams(const tflite::ScoreCalibrationOptions & score_calibration_options,absl::string_view score_calibration_file,const std::vector<LabelMapItem> & label_map_items)183 StatusOr<SigmoidCalibrationParameters> BuildSigmoidCalibrationParams(
184     const tflite::ScoreCalibrationOptions& score_calibration_options,
185     absl::string_view score_calibration_file,
186     const std::vector<LabelMapItem>& label_map_items) {
187   // Split file lines and perform sanity checks.
188   if (score_calibration_file.empty()) {
189     return CreateStatusWithPayload(
190         StatusCode::kInvalidArgument,
191         "Expected non-empty score calibration file.");
192   }
193   std::vector<absl::string_view> lines =
194       absl::StrSplit(score_calibration_file, '\n');
195   if (label_map_items.size() != lines.size()) {
196     return CreateStatusWithPayload(
197         StatusCode::kInvalidArgument,
198         absl::StrFormat("Mismatch between number of labels (%d) and score "
199                         "calibration parameters (%d).",
200                         label_map_items.size(), lines.size()),
201         TfLiteSupportStatus::kMetadataNumLabelsMismatchError);
202   }
203   // Initialize SigmoidCalibrationParameters with its class-agnostic parameters.
204   SigmoidCalibrationParameters sigmoid_params = {};
205   sigmoid_params.score_transformation = ConvertScoreTransformationType(
206       score_calibration_options.score_transformation());
207   sigmoid_params.default_score = score_calibration_options.default_score();
208   std::vector<Sigmoid> sigmoid_vector;
209   // Fill sigmoids for each class with parameters in the file.
210   for (int i = 0; i < label_map_items.size(); ++i) {
211     if (lines[i].empty()) {
212       continue;
213     }
214     ASSIGN_OR_RETURN(Sigmoid sigmoid, SigmoidFromLabelAndLine(
215                                           label_map_items[i].name, lines[i]));
216     sigmoid_vector.emplace_back(std::move(sigmoid));
217   }
218   sigmoid_params.sigmoid = std::move(sigmoid_vector);
219 
220   return sigmoid_params;
221 }
222 
223 }  // namespace vision
224 }  // namespace task
225 }  // namespace tflite
226