xref: /aosp_15_r20/external/libtextclassifier/native/lang_id/common/math/softmax.cc (revision 993b0882672172b81d12fad7a7ac0c3e5c824a12)
1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "lang_id/common/math/softmax.h"
18 
19 #include <algorithm>
20 #include <vector>
21 
22 #include "lang_id/common/lite_base/logging.h"
23 #include "lang_id/common/math/fastexp.h"
24 
25 namespace libtextclassifier3 {
26 namespace mobile {
27 
ComputeSoftmaxProbability(const std::vector<float> & scores,int label)28 float ComputeSoftmaxProbability(const std::vector<float> &scores, int label) {
29   if ((label < 0) || (static_cast<size_t>(label) >= scores.size())) {
30     SAFTM_LOG(ERROR) << "label " << label << " outside range "
31                      << "[0, " << scores.size() << ")";
32     return 0.0f;
33   }
34 
35   // Standard softmax formula for label's probability is
36   //
37   //   exp(scores[label]) / sum_i exp(scores[i])
38   //
39   // We compute the mathematically equivalent
40   //
41   //   1 / (1 + sum_{i != label} exp(scores[i] - scores[label]))
42   //
43   // which saves two calls to exp().
44   const float label_score = scores[label];
45   float denominator = 1.0f;  // Contribution of i == label.
46   for (size_t i = 0; i < scores.size(); ++i) {
47     if (static_cast<int>(i) == label) continue;
48     const float delta_score = scores[i] - label_score;
49 
50     // TODO(salcianu): one can optimize the test below, to avoid any float
51     // operation: extract exponent (via bit mask + shift) and check it's >= 4.
52     if (fabs(delta_score) >= 16.0f) {
53       if (delta_score > 0.0f) {
54         // If delta_score >= 16, the denominator (e^delta_score + other positive
55         // terms) is very big and its inverse can be approximated with 0.
56         return 0.0f;
57       } else {
58         // If delta_score <= -16, then e^delta_score < 1.2e-7.  Even if we have
59         // 1000 such labels i, their sum is < 1.2e-4 (which gets summed with
60         // 1.0f for i == label).  Hence, we can approximate each such label with
61         // 0 and skip the call to VeryFastExp and the update to denominator.
62         continue;
63       }
64     }
65 
66     // At this point, delta_score is in (-16.0, 16.0).  For such values, vfexp
67     // works fine: no under/overflows (we have tests for that in fastexp_test).
68     // Also, even for 1000 labels, denominator will not overflow.
69     denominator += VeryFastExp(delta_score);
70   }
71   return 1.0f / denominator;
72 }
73 
ComputeSoftmax(const std::vector<float> & scores,float alpha)74 std::vector<float> ComputeSoftmax(const std::vector<float> &scores,
75                                   float alpha) {
76   std::vector<float> softmax;
77   softmax.reserve(scores.size());
78   if (scores.empty()) {
79     return softmax;
80   }
81 
82   std::vector<float> exp_scores;
83   exp_scores.reserve(scores.size());
84 
85   // Find max value in "scores" vector and rescale to avoid overflows.
86   const float max_score = *std::max_element(scores.begin(), scores.end());
87   float denominator = 0;
88   for (const float score : scores) {
89     // See comments above in ComputeSoftmaxProbability for the reasoning behind
90     // this approximation.
91     const float delta_score = alpha * (score - max_score);
92     const float exp_score = delta_score < -16.0f ? 0 : VeryFastExp(delta_score);
93     exp_scores.push_back(exp_score);
94     denominator += exp_score;
95   }
96 
97   for (size_t i = 0; i < scores.size(); ++i) {
98     softmax.push_back(exp_scores[i] / denominator);
99   }
100   return softmax;
101 }
102 
103 }  // namespace mobile
104 }  // namespace nlp_saft
105