xref: /aosp_15_r20/external/libtextclassifier/native/lang_id/lang-id.cc (revision 993b0882672172b81d12fad7a7ac0c3e5c824a12)
1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "lang_id/lang-id.h"
18 
19 #include <stdio.h>
20 
21 #include <memory>
22 #include <string>
23 #include <unordered_map>
24 #include <utility>
25 #include <vector>
26 
27 #include "lang_id/common/embedding-feature-interface.h"
28 #include "lang_id/common/embedding-network-params.h"
29 #include "lang_id/common/embedding-network.h"
30 #include "lang_id/common/fel/feature-extractor.h"
31 #include "lang_id/common/lite_base/logging.h"
32 #include "lang_id/common/lite_strings/numbers.h"
33 #include "lang_id/common/lite_strings/str-split.h"
34 #include "lang_id/common/lite_strings/stringpiece.h"
35 #include "lang_id/common/math/algorithm.h"
36 #include "lang_id/common/math/softmax.h"
37 #include "lang_id/custom-tokenizer.h"
38 #include "lang_id/features/light-sentence-features.h"
39 // The two features/ headers below are needed only for RegisterClass().
40 #include "lang_id/features/char-ngram-feature.h"
41 #include "lang_id/features/relevant-script-feature.h"
42 #include "lang_id/light-sentence.h"
43 // The two script/ headers below are needed only for RegisterClass().
44 #include "lang_id/script/approx-script.h"
45 #include "lang_id/script/tiny-script-detector.h"
46 
47 namespace libtextclassifier3 {
48 namespace mobile {
49 namespace lang_id {
50 
51 namespace {
52 // Default value for the confidence threshold.  If the confidence of the top
53 // prediction is below this threshold, then FindLanguage() returns
54 // LangId::kUnknownLanguageCode.  Note: this is just a default value; if the
55 // TaskSpec from the model specifies a "reliability_thresh" parameter, then we
56 // use that value instead.  Note: for legacy reasons, our code and comments use
57 // the terms "confidence", "probability" and "reliability" equivalently.
58 static const float kDefaultConfidenceThreshold = 0.50f;
59 }  // namespace
60 
61 // Class that performs all work behind LangId.
62 class LangIdImpl {
63  public:
LangIdImpl(std::unique_ptr<ModelProvider> model_provider)64   explicit LangIdImpl(std::unique_ptr<ModelProvider> model_provider)
65       : model_provider_(std::move(model_provider)),
66         lang_id_brain_interface_("language_identifier") {
67     // Note: in the code below, we set valid_ to true only if all initialization
68     // steps completed successfully.  Otherwise, we return early, leaving valid_
69     // to its default value false.
70     if (!model_provider_ || !model_provider_->is_valid()) {
71       SAFTM_LOG(ERROR) << "Invalid model provider";
72       return;
73     }
74 
75     auto *nn_params = model_provider_->GetNnParams();
76     if (!nn_params) {
77       SAFTM_LOG(ERROR) << "No NN params";
78       return;
79     }
80     network_.reset(new EmbeddingNetwork(nn_params));
81 
82     languages_ = model_provider_->GetLanguages();
83     if (languages_.empty()) {
84       SAFTM_LOG(ERROR) << "No known languages";
85       return;
86     }
87 
88     TaskContext context = *model_provider_->GetTaskContext();
89     if (!Setup(&context)) {
90       SAFTM_LOG(ERROR) << "Unable to Setup() LangId";
91       return;
92     }
93     if (!Init(&context)) {
94       SAFTM_LOG(ERROR) << "Unable to Init() LangId";
95       return;
96     }
97     valid_ = true;
98   }
99 
FindLanguage(StringPiece text) const100   std::string FindLanguage(StringPiece text) const {
101     LangIdResult lang_id_result;
102     FindLanguages(text, &lang_id_result, /* max_results = */ 1);
103     if (lang_id_result.predictions.empty()) {
104       return LangId::kUnknownLanguageCode;
105     }
106 
107     const std::string &language = lang_id_result.predictions[0].first;
108     const float probability = lang_id_result.predictions[0].second;
109     SAFTM_DLOG(INFO) << "Predicted " << language
110                      << " with prob: " << probability << " for \"" << text
111                      << "\"";
112 
113     // Find confidence threshold for language.
114     float threshold = default_threshold_;
115     auto it = per_lang_thresholds_.find(language);
116     if (it != per_lang_thresholds_.end()) {
117       threshold = it->second;
118     }
119     if (probability < threshold) {
120       SAFTM_DLOG(INFO) << "  below threshold => "
121                        << LangId::kUnknownLanguageCode;
122       return LangId::kUnknownLanguageCode;
123     }
124     return language;
125   }
126 
FindLanguages(StringPiece text,LangIdResult * result,int max_results) const127   void FindLanguages(StringPiece text, LangIdResult *result,
128                      int max_results) const {
129     if (result == nullptr) return;
130 
131     if (max_results <= 0) {
132       max_results = languages_.size();
133     }
134     result->predictions.clear();
135     if (!is_valid() || (max_results == 0)) {
136       result->predictions.emplace_back(LangId::kUnknownLanguageCode, 1);
137       return;
138     }
139 
140     // Tokenize the input text (this also does some pre-processing, like
141     // removing ASCII digits, punctuation, etc).
142     LightSentence sentence;
143     tokenizer_.Tokenize(text, &sentence);
144 
145     // Test input size here, after pre-processing removed irrelevant chars.
146     if (IsTooShort(sentence)) {
147       result->predictions.emplace_back(LangId::kUnknownLanguageCode, 1);
148       return;
149     }
150 
151     // Extract features from the tokenized text.
152     std::vector<FeatureVector> features =
153         lang_id_brain_interface_.GetFeaturesNoCaching(&sentence);
154 
155     // Run feed-forward neural network to compute scores (softmax logits).
156     std::vector<float> scores;
157     network_->ComputeFinalScores(features, &scores);
158 
159     if (max_results == 1) {
160       // Optimization for the case when the user wants only the top result.
161       // Computing argmax is faster than the general top-k code.
162       int prediction_id = GetArgMax(scores);
163       const std::string language = GetLanguageForSoftmaxLabel(prediction_id);
164       float probability = ComputeSoftmaxProbability(scores, prediction_id);
165       result->predictions.emplace_back(language, probability);
166     } else {
167       // Compute and sort softmax in descending order by probability and convert
168       // IDs to language code strings.  When probabilities are equal, we sort by
169       // language code string in ascending order.
170       const std::vector<float> softmax = ComputeSoftmax(scores);
171       const std::vector<int> indices = GetTopKIndices(max_results, softmax);
172       for (const int index : indices) {
173         result->predictions.emplace_back(GetLanguageForSoftmaxLabel(index),
174                                          softmax[index]);
175       }
176     }
177   }
178 
is_valid() const179   bool is_valid() const { return valid_; }
180 
GetModelVersion() const181   int GetModelVersion() const { return model_version_; }
182 
183   // Returns a property stored in the model file.
184   template <typename T, typename R>
GetProperty(const std::string & property,T default_value) const185   R GetProperty(const std::string &property, T default_value) const {
186     return model_provider_->GetTaskContext()->Get(property, default_value);
187   }
188 
189   // Perform any necessary static initialization.
190   // This function is thread-safe.
191   // It's also safe to call this function multiple times.
192   //
193   // We explicitly call RegisterClass() rather than relying on alwayslink=1 in
194   // the BUILD file, because the build process for some users of this code
195   // doesn't support any equivalent to alwayslink=1 (in particular the
196   // Firebase C++ SDK build uses a Kokoro-based CMake build).  While it might
197   // be possible to add such support, avoiding the need for an equivalent to
198   // alwayslink=1 is preferable because it avoids unnecessarily bloating code
199   // size in apps that link against this code but don't use it.
RegisterClasses()200   static void RegisterClasses() {
201     static bool initialized = []() -> bool {
202       libtextclassifier3::mobile::ApproxScriptDetector::RegisterClass();
203       libtextclassifier3::mobile::lang_id::ContinuousBagOfNgramsFunction::RegisterClass();
204       libtextclassifier3::mobile::lang_id::TinyScriptDetector::RegisterClass();
205       libtextclassifier3::mobile::lang_id::RelevantScriptFeature::RegisterClass();
206       return true;
207     }();
208     (void)initialized;  // Variable used only for initializer's side effects.
209   }
210 
211  private:
Setup(TaskContext * context)212   bool Setup(TaskContext *context) {
213     tokenizer_.Setup(context);
214     if (!lang_id_brain_interface_.SetupForProcessing(context)) return false;
215 
216     min_text_size_in_bytes_ = context->Get("min_text_size_in_bytes", 0);
217     default_threshold_ =
218         context->Get("reliability_thresh", kDefaultConfidenceThreshold);
219 
220     // Parse task parameter "per_lang_reliability_thresholds", fill
221     // per_lang_thresholds_.
222     const std::string thresholds_str =
223         context->Get("per_lang_reliability_thresholds", "");
224     std::vector<StringPiece> tokens = LiteStrSplit(thresholds_str, ',');
225     for (const auto &token : tokens) {
226       if (token.empty()) continue;
227       std::vector<StringPiece> parts = LiteStrSplit(token, '=');
228       float threshold = 0.0f;
229       if ((parts.size() == 2) && LiteAtof(parts[1], &threshold)) {
230         per_lang_thresholds_[std::string(parts[0])] = threshold;
231       } else {
232         SAFTM_LOG(ERROR) << "Broken token: \"" << token << "\"";
233       }
234     }
235     model_version_ = context->Get("model_version", model_version_);
236     return true;
237   }
238 
Init(TaskContext * context)239   bool Init(TaskContext *context) {
240     return lang_id_brain_interface_.InitForProcessing(context);
241   }
242 
243   // Returns language code for a softmax label.  See comments for languages_
244   // field.  If label is out of range, returns LangId::kUnknownLanguageCode.
GetLanguageForSoftmaxLabel(int label) const245   std::string GetLanguageForSoftmaxLabel(int label) const {
246     if ((label >= 0) && (static_cast<size_t>(label) < languages_.size())) {
247       return languages_[label];
248     } else {
249       SAFTM_LOG(ERROR) << "Softmax label " << label << " outside range [0, "
250                        << languages_.size() << ")";
251       return LangId::kUnknownLanguageCode;
252     }
253   }
254 
IsTooShort(const LightSentence & sentence) const255   bool IsTooShort(const LightSentence &sentence) const {
256     int text_size = 0;
257     for (const std::string &token : sentence) {
258       // Each token has the form ^...$: we subtract 2 because we want to count
259       // only the real text, not the chars added by us.
260       text_size += token.size() - 2;
261     }
262     return text_size < min_text_size_in_bytes_;
263   }
264 
265   std::unique_ptr<ModelProvider> model_provider_;
266 
267   TokenizerForLangId tokenizer_;
268 
269   EmbeddingFeatureInterface<LightSentenceExtractor, LightSentence>
270       lang_id_brain_interface_;
271 
272   // Neural network to use for scoring.
273   std::unique_ptr<EmbeddingNetwork> network_;
274 
275   // True if this object is ready to perform language predictions.
276   bool valid_ = false;
277 
278   // The model returns LangId::kUnknownLanguageCode for input text that has
279   // fewer than min_text_size_in_bytes_ bytes (excluding ASCII whitespaces,
280   // digits, and punctuation).
281   int min_text_size_in_bytes_ = 0;
282 
283   // Only predictions with a probability (confidence) above this threshold are
284   // reported.  Otherwise, we report LangId::kUnknownLanguageCode.
285   float default_threshold_ = kDefaultConfidenceThreshold;
286 
287   std::unordered_map<std::string, float> per_lang_thresholds_;
288 
289   // Recognized languages: softmax label i means languages_[i] (something like
290   // "en", "fr", "ru", etc).
291   std::vector<std::string> languages_;
292 
293   // Version of the model used by this LangIdImpl object.  Zero means that the
294   // model version could not be determined.
295   int model_version_ = 0;
296 };
297 
298 const char LangId::kUnknownLanguageCode[] = "und";
299 
LangId(std::unique_ptr<ModelProvider> model_provider)300 LangId::LangId(std::unique_ptr<ModelProvider> model_provider)
301     : pimpl_(new LangIdImpl(std::move(model_provider))) {
302   LangIdImpl::RegisterClasses();
303 }
304 
305 LangId::~LangId() = default;
306 
FindLanguage(const char * data,size_t num_bytes) const307 std::string LangId::FindLanguage(const char *data, size_t num_bytes) const {
308   StringPiece text(data, num_bytes);
309   return pimpl_->FindLanguage(text);
310 }
311 
FindLanguages(const char * data,size_t num_bytes,LangIdResult * result,int max_results) const312 void LangId::FindLanguages(const char *data, size_t num_bytes,
313                            LangIdResult *result, int max_results) const {
314   SAFTM_DCHECK(result) << "LangIdResult must not be null.";
315   StringPiece text(data, num_bytes);
316   pimpl_->FindLanguages(text, result, max_results);
317 }
318 
is_valid() const319 bool LangId::is_valid() const { return pimpl_->is_valid(); }
320 
GetModelVersion() const321 int LangId::GetModelVersion() const { return pimpl_->GetModelVersion(); }
322 
GetFloatProperty(const std::string & property,float default_value) const323 float LangId::GetFloatProperty(const std::string &property,
324                                float default_value) const {
325   return pimpl_->GetProperty<float, float>(property, default_value);
326 }
327 
328 }  // namespace lang_id
329 }  // namespace mobile
330 }  // namespace nlp_saft
331