1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow_lite_support/cc/task/vision/utils/score_calibration.h"
16
17 #include <cmath>
18 #include <memory>
19 #include <utility>
20 #include <vector>
21
22 #include "absl/status/status.h"
23 #include "absl/strings/str_format.h"
24 #include "absl/strings/str_split.h"
25 #include "absl/strings/string_view.h"
26 #include "absl/types/optional.h"
27 #include "tensorflow_lite_support/cc/common.h"
28 #include "tensorflow_lite_support/cc/port/status_macros.h"
29
30 namespace tflite {
31 namespace task {
32 namespace vision {
33 namespace {
34
35 using ::absl::StatusCode;
36 using ::tflite::support::CreateStatusWithPayload;
37 using ::tflite::support::StatusOr;
38 using ::tflite::support::TfLiteSupportStatus;
39
40 // Used to prevent log(<=0.0) in ClampedLog() calls.
41 constexpr float kLogScoreMinimum = 1e-16;
42
43 // Returns the following, depending on x:
44 // x => threshold: log(x)
45 // x < threshold: 2 * log(thresh) - log(2 * thresh - x)
46 // This form (a) is anti-symmetric about the threshold and (b) has continuous
47 // value and first derivative. This is done to prevent taking the log of values
48 // close to 0 which can lead to floating point errors and is better than simple
49 // clamping since it preserves order for scores less than the threshold.
ClampedLog(float x,float threshold)50 float ClampedLog(float x, float threshold) {
51 if (x < threshold) {
52 return 2.0 * std::log(static_cast<double>(threshold)) -
53 log(2.0 * threshold - x);
54 }
55 return std::log(static_cast<double>(x));
56 }
57
58 // Applies the specified score transformation to the provided score.
59 // Currently supports the following,
60 // IDENTITY : f(x) = x
61 // LOG : f(x) = log(x)
62 // INVERSE_LOGISTIC : f(x) = log(x) - log(1-x)
ApplyScoreTransformation(float score,const ScoreTransformation & type)63 float ApplyScoreTransformation(float score, const ScoreTransformation& type) {
64 switch (type) {
65 case ScoreTransformation::kIDENTITY:
66 return score;
67 case ScoreTransformation::kINVERSE_LOGISTIC:
68 return (ClampedLog(score, kLogScoreMinimum) -
69 ClampedLog(1.0 - score, kLogScoreMinimum));
70 case ScoreTransformation::kLOG:
71 return ClampedLog(score, kLogScoreMinimum);
72 }
73 }
74
75 // Builds a single Sigmoid from the label name and associated CSV file line.
SigmoidFromLabelAndLine(absl::string_view label,absl::string_view line)76 StatusOr<Sigmoid> SigmoidFromLabelAndLine(absl::string_view label,
77 absl::string_view line) {
78 std::vector<absl::string_view> str_params = absl::StrSplit(line, ',');
79 if (str_params.size() != 3 && str_params.size() != 4) {
80 return CreateStatusWithPayload(
81 StatusCode::kInvalidArgument,
82 absl::StrFormat("Expected 3 or 4 parameters per line in score "
83 "calibration file, got %d.",
84 str_params.size()),
85 TfLiteSupportStatus::kMetadataMalformedScoreCalibrationError);
86 }
87 std::vector<float> float_params(4);
88 for (int i = 0; i < str_params.size(); ++i) {
89 if (!absl::SimpleAtof(str_params[i], &float_params[i])) {
90 return CreateStatusWithPayload(
91 StatusCode::kInvalidArgument,
92 absl::StrFormat(
93 "Could not parse score calibration parameter as float: %s.",
94 str_params[i]),
95 TfLiteSupportStatus::kMetadataMalformedScoreCalibrationError);
96 }
97 }
98 Sigmoid sigmoid;
99 sigmoid.label = std::string(label);
100 sigmoid.scale = float_params[0];
101 sigmoid.slope = float_params[1];
102 sigmoid.offset = float_params[2];
103 if (str_params.size() == 4) {
104 sigmoid.min_uncalibrated_score = float_params[3];
105 }
106 return sigmoid;
107 }
108
109 // Converts a tflite::ScoreTransformationType to its
110 // tflite::task::vision::ScoreTransformation equivalent.
ConvertScoreTransformationType(tflite::ScoreTransformationType type)111 ScoreTransformation ConvertScoreTransformationType(
112 tflite::ScoreTransformationType type) {
113 switch (type) {
114 case tflite::ScoreTransformationType_IDENTITY:
115 return ScoreTransformation::kIDENTITY;
116 case tflite::ScoreTransformationType_LOG:
117 return ScoreTransformation::kLOG;
118 case tflite::ScoreTransformationType_INVERSE_LOGISTIC:
119 return ScoreTransformation::kINVERSE_LOGISTIC;
120 }
121 }
122
123 } // namespace
124
operator <<(std::ostream & os,const Sigmoid & s)125 std::ostream& operator<<(std::ostream& os, const Sigmoid& s) {
126 os << s.label << "," << s.slope << "," << s.offset << "," << s.scale;
127 if (s.min_uncalibrated_score.has_value()) {
128 os << "," << s.min_uncalibrated_score.value();
129 }
130 return os;
131 }
132
ScoreCalibration()133 ScoreCalibration::ScoreCalibration() {}
~ScoreCalibration()134 ScoreCalibration::~ScoreCalibration() {}
135
InitializeFromParameters(const SigmoidCalibrationParameters & params)136 absl::Status ScoreCalibration::InitializeFromParameters(
137 const SigmoidCalibrationParameters& params) {
138 sigmoid_parameters_ = std::move(params);
139 // Fill in the map from label -> sigmoid.
140 sigmoid_parameters_map_.clear();
141 for (const auto& sigmoid : sigmoid_parameters_.sigmoid) {
142 sigmoid_parameters_map_.insert_or_assign(sigmoid.label, sigmoid);
143 }
144 return absl::OkStatus();
145 }
146
ComputeCalibratedScore(const std::string & label,float uncalibrated_score) const147 float ScoreCalibration::ComputeCalibratedScore(const std::string& label,
148 float uncalibrated_score) const {
149 absl::optional<Sigmoid> sigmoid = FindSigmoidParameters(label);
150 if (!sigmoid.has_value() ||
151 (sigmoid.value().min_uncalibrated_score.has_value() &&
152 uncalibrated_score < sigmoid.value().min_uncalibrated_score.value())) {
153 return sigmoid_parameters_.default_score;
154 }
155
156 float transformed_score = ApplyScoreTransformation(
157 uncalibrated_score, sigmoid_parameters_.score_transformation);
158 float scale_shifted_score =
159 transformed_score * sigmoid.value().slope + sigmoid.value().offset;
160
161 // For numerical stability use 1 / (1+exp(-x)) when scale_shifted_score >= 0
162 // and exp(x) / (1+exp(x)) when scale_shifted_score < 0.
163 if (scale_shifted_score >= 0.0) {
164 return sigmoid.value().scale /
165 (1.0 + std::exp(static_cast<double>(-scale_shifted_score)));
166 } else {
167 float score_exp = std::exp(static_cast<double>(scale_shifted_score));
168 return sigmoid.value().scale * score_exp / (1.0 + score_exp);
169 }
170 }
171
FindSigmoidParameters(const std::string & label) const172 absl::optional<Sigmoid> ScoreCalibration::FindSigmoidParameters(
173 const std::string& label) const {
174 auto it = sigmoid_parameters_map_.find(label);
175 if (it != sigmoid_parameters_map_.end()) {
176 return it->second;
177 } else if (sigmoid_parameters_.default_sigmoid.has_value()) {
178 return sigmoid_parameters_.default_sigmoid.value();
179 }
180 return absl::nullopt;
181 }
182
BuildSigmoidCalibrationParams(const tflite::ScoreCalibrationOptions & score_calibration_options,absl::string_view score_calibration_file,const std::vector<LabelMapItem> & label_map_items)183 StatusOr<SigmoidCalibrationParameters> BuildSigmoidCalibrationParams(
184 const tflite::ScoreCalibrationOptions& score_calibration_options,
185 absl::string_view score_calibration_file,
186 const std::vector<LabelMapItem>& label_map_items) {
187 // Split file lines and perform sanity checks.
188 if (score_calibration_file.empty()) {
189 return CreateStatusWithPayload(
190 StatusCode::kInvalidArgument,
191 "Expected non-empty score calibration file.");
192 }
193 std::vector<absl::string_view> lines =
194 absl::StrSplit(score_calibration_file, '\n');
195 if (label_map_items.size() != lines.size()) {
196 return CreateStatusWithPayload(
197 StatusCode::kInvalidArgument,
198 absl::StrFormat("Mismatch between number of labels (%d) and score "
199 "calibration parameters (%d).",
200 label_map_items.size(), lines.size()),
201 TfLiteSupportStatus::kMetadataNumLabelsMismatchError);
202 }
203 // Initialize SigmoidCalibrationParameters with its class-agnostic parameters.
204 SigmoidCalibrationParameters sigmoid_params = {};
205 sigmoid_params.score_transformation = ConvertScoreTransformationType(
206 score_calibration_options.score_transformation());
207 sigmoid_params.default_score = score_calibration_options.default_score();
208 std::vector<Sigmoid> sigmoid_vector;
209 // Fill sigmoids for each class with parameters in the file.
210 for (int i = 0; i < label_map_items.size(); ++i) {
211 if (lines[i].empty()) {
212 continue;
213 }
214 ASSIGN_OR_RETURN(Sigmoid sigmoid, SigmoidFromLabelAndLine(
215 label_map_items[i].name, lines[i]));
216 sigmoid_vector.emplace_back(std::move(sigmoid));
217 }
218 sigmoid_params.sigmoid = std::move(sigmoid_vector);
219
220 return sigmoid_params;
221 }
222
223 } // namespace vision
224 } // namespace task
225 } // namespace tflite
226