1 /*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "lang_id/lang-id.h"
18
19 #include <stdio.h>
20
21 #include <memory>
22 #include <string>
23 #include <unordered_map>
24 #include <utility>
25 #include <vector>
26
27 #include "lang_id/common/embedding-feature-interface.h"
28 #include "lang_id/common/embedding-network-params.h"
29 #include "lang_id/common/embedding-network.h"
30 #include "lang_id/common/fel/feature-extractor.h"
31 #include "lang_id/common/lite_base/logging.h"
32 #include "lang_id/common/lite_strings/numbers.h"
33 #include "lang_id/common/lite_strings/str-split.h"
34 #include "lang_id/common/lite_strings/stringpiece.h"
35 #include "lang_id/common/math/algorithm.h"
36 #include "lang_id/common/math/softmax.h"
37 #include "lang_id/custom-tokenizer.h"
38 #include "lang_id/features/light-sentence-features.h"
39 // The two features/ headers below are needed only for RegisterClass().
40 #include "lang_id/features/char-ngram-feature.h"
41 #include "lang_id/features/relevant-script-feature.h"
42 #include "lang_id/light-sentence.h"
43 // The two script/ headers below are needed only for RegisterClass().
44 #include "lang_id/script/approx-script.h"
45 #include "lang_id/script/tiny-script-detector.h"
46
47 namespace libtextclassifier3 {
48 namespace mobile {
49 namespace lang_id {
50
51 namespace {
52 // Default value for the confidence threshold. If the confidence of the top
53 // prediction is below this threshold, then FindLanguage() returns
54 // LangId::kUnknownLanguageCode. Note: this is just a default value; if the
55 // TaskSpec from the model specifies a "reliability_thresh" parameter, then we
56 // use that value instead. Note: for legacy reasons, our code and comments use
57 // the terms "confidence", "probability" and "reliability" equivalently.
58 static const float kDefaultConfidenceThreshold = 0.50f;
59 } // namespace
60
61 // Class that performs all work behind LangId.
62 class LangIdImpl {
63 public:
LangIdImpl(std::unique_ptr<ModelProvider> model_provider)64 explicit LangIdImpl(std::unique_ptr<ModelProvider> model_provider)
65 : model_provider_(std::move(model_provider)),
66 lang_id_brain_interface_("language_identifier") {
67 // Note: in the code below, we set valid_ to true only if all initialization
68 // steps completed successfully. Otherwise, we return early, leaving valid_
69 // to its default value false.
70 if (!model_provider_ || !model_provider_->is_valid()) {
71 SAFTM_LOG(ERROR) << "Invalid model provider";
72 return;
73 }
74
75 auto *nn_params = model_provider_->GetNnParams();
76 if (!nn_params) {
77 SAFTM_LOG(ERROR) << "No NN params";
78 return;
79 }
80 network_.reset(new EmbeddingNetwork(nn_params));
81
82 languages_ = model_provider_->GetLanguages();
83 if (languages_.empty()) {
84 SAFTM_LOG(ERROR) << "No known languages";
85 return;
86 }
87
88 TaskContext context = *model_provider_->GetTaskContext();
89 if (!Setup(&context)) {
90 SAFTM_LOG(ERROR) << "Unable to Setup() LangId";
91 return;
92 }
93 if (!Init(&context)) {
94 SAFTM_LOG(ERROR) << "Unable to Init() LangId";
95 return;
96 }
97 valid_ = true;
98 }
99
FindLanguage(StringPiece text) const100 std::string FindLanguage(StringPiece text) const {
101 LangIdResult lang_id_result;
102 FindLanguages(text, &lang_id_result, /* max_results = */ 1);
103 if (lang_id_result.predictions.empty()) {
104 return LangId::kUnknownLanguageCode;
105 }
106
107 const std::string &language = lang_id_result.predictions[0].first;
108 const float probability = lang_id_result.predictions[0].second;
109 SAFTM_DLOG(INFO) << "Predicted " << language
110 << " with prob: " << probability << " for \"" << text
111 << "\"";
112
113 // Find confidence threshold for language.
114 float threshold = default_threshold_;
115 auto it = per_lang_thresholds_.find(language);
116 if (it != per_lang_thresholds_.end()) {
117 threshold = it->second;
118 }
119 if (probability < threshold) {
120 SAFTM_DLOG(INFO) << " below threshold => "
121 << LangId::kUnknownLanguageCode;
122 return LangId::kUnknownLanguageCode;
123 }
124 return language;
125 }
126
FindLanguages(StringPiece text,LangIdResult * result,int max_results) const127 void FindLanguages(StringPiece text, LangIdResult *result,
128 int max_results) const {
129 if (result == nullptr) return;
130
131 if (max_results <= 0) {
132 max_results = languages_.size();
133 }
134 result->predictions.clear();
135 if (!is_valid() || (max_results == 0)) {
136 result->predictions.emplace_back(LangId::kUnknownLanguageCode, 1);
137 return;
138 }
139
140 // Tokenize the input text (this also does some pre-processing, like
141 // removing ASCII digits, punctuation, etc).
142 LightSentence sentence;
143 tokenizer_.Tokenize(text, &sentence);
144
145 // Test input size here, after pre-processing removed irrelevant chars.
146 if (IsTooShort(sentence)) {
147 result->predictions.emplace_back(LangId::kUnknownLanguageCode, 1);
148 return;
149 }
150
151 // Extract features from the tokenized text.
152 std::vector<FeatureVector> features =
153 lang_id_brain_interface_.GetFeaturesNoCaching(&sentence);
154
155 // Run feed-forward neural network to compute scores (softmax logits).
156 std::vector<float> scores;
157 network_->ComputeFinalScores(features, &scores);
158
159 if (max_results == 1) {
160 // Optimization for the case when the user wants only the top result.
161 // Computing argmax is faster than the general top-k code.
162 int prediction_id = GetArgMax(scores);
163 const std::string language = GetLanguageForSoftmaxLabel(prediction_id);
164 float probability = ComputeSoftmaxProbability(scores, prediction_id);
165 result->predictions.emplace_back(language, probability);
166 } else {
167 // Compute and sort softmax in descending order by probability and convert
168 // IDs to language code strings. When probabilities are equal, we sort by
169 // language code string in ascending order.
170 const std::vector<float> softmax = ComputeSoftmax(scores);
171 const std::vector<int> indices = GetTopKIndices(max_results, softmax);
172 for (const int index : indices) {
173 result->predictions.emplace_back(GetLanguageForSoftmaxLabel(index),
174 softmax[index]);
175 }
176 }
177 }
178
is_valid() const179 bool is_valid() const { return valid_; }
180
GetModelVersion() const181 int GetModelVersion() const { return model_version_; }
182
183 // Returns a property stored in the model file.
184 template <typename T, typename R>
GetProperty(const std::string & property,T default_value) const185 R GetProperty(const std::string &property, T default_value) const {
186 return model_provider_->GetTaskContext()->Get(property, default_value);
187 }
188
189 // Perform any necessary static initialization.
190 // This function is thread-safe.
191 // It's also safe to call this function multiple times.
192 //
193 // We explicitly call RegisterClass() rather than relying on alwayslink=1 in
194 // the BUILD file, because the build process for some users of this code
195 // doesn't support any equivalent to alwayslink=1 (in particular the
196 // Firebase C++ SDK build uses a Kokoro-based CMake build). While it might
197 // be possible to add such support, avoiding the need for an equivalent to
198 // alwayslink=1 is preferable because it avoids unnecessarily bloating code
199 // size in apps that link against this code but don't use it.
RegisterClasses()200 static void RegisterClasses() {
201 static bool initialized = []() -> bool {
202 libtextclassifier3::mobile::ApproxScriptDetector::RegisterClass();
203 libtextclassifier3::mobile::lang_id::ContinuousBagOfNgramsFunction::RegisterClass();
204 libtextclassifier3::mobile::lang_id::TinyScriptDetector::RegisterClass();
205 libtextclassifier3::mobile::lang_id::RelevantScriptFeature::RegisterClass();
206 return true;
207 }();
208 (void)initialized; // Variable used only for initializer's side effects.
209 }
210
211 private:
Setup(TaskContext * context)212 bool Setup(TaskContext *context) {
213 tokenizer_.Setup(context);
214 if (!lang_id_brain_interface_.SetupForProcessing(context)) return false;
215
216 min_text_size_in_bytes_ = context->Get("min_text_size_in_bytes", 0);
217 default_threshold_ =
218 context->Get("reliability_thresh", kDefaultConfidenceThreshold);
219
220 // Parse task parameter "per_lang_reliability_thresholds", fill
221 // per_lang_thresholds_.
222 const std::string thresholds_str =
223 context->Get("per_lang_reliability_thresholds", "");
224 std::vector<StringPiece> tokens = LiteStrSplit(thresholds_str, ',');
225 for (const auto &token : tokens) {
226 if (token.empty()) continue;
227 std::vector<StringPiece> parts = LiteStrSplit(token, '=');
228 float threshold = 0.0f;
229 if ((parts.size() == 2) && LiteAtof(parts[1], &threshold)) {
230 per_lang_thresholds_[std::string(parts[0])] = threshold;
231 } else {
232 SAFTM_LOG(ERROR) << "Broken token: \"" << token << "\"";
233 }
234 }
235 model_version_ = context->Get("model_version", model_version_);
236 return true;
237 }
238
Init(TaskContext * context)239 bool Init(TaskContext *context) {
240 return lang_id_brain_interface_.InitForProcessing(context);
241 }
242
243 // Returns language code for a softmax label. See comments for languages_
244 // field. If label is out of range, returns LangId::kUnknownLanguageCode.
GetLanguageForSoftmaxLabel(int label) const245 std::string GetLanguageForSoftmaxLabel(int label) const {
246 if ((label >= 0) && (static_cast<size_t>(label) < languages_.size())) {
247 return languages_[label];
248 } else {
249 SAFTM_LOG(ERROR) << "Softmax label " << label << " outside range [0, "
250 << languages_.size() << ")";
251 return LangId::kUnknownLanguageCode;
252 }
253 }
254
IsTooShort(const LightSentence & sentence) const255 bool IsTooShort(const LightSentence &sentence) const {
256 int text_size = 0;
257 for (const std::string &token : sentence) {
258 // Each token has the form ^...$: we subtract 2 because we want to count
259 // only the real text, not the chars added by us.
260 text_size += token.size() - 2;
261 }
262 return text_size < min_text_size_in_bytes_;
263 }
264
265 std::unique_ptr<ModelProvider> model_provider_;
266
267 TokenizerForLangId tokenizer_;
268
269 EmbeddingFeatureInterface<LightSentenceExtractor, LightSentence>
270 lang_id_brain_interface_;
271
272 // Neural network to use for scoring.
273 std::unique_ptr<EmbeddingNetwork> network_;
274
275 // True if this object is ready to perform language predictions.
276 bool valid_ = false;
277
278 // The model returns LangId::kUnknownLanguageCode for input text that has
279 // fewer than min_text_size_in_bytes_ bytes (excluding ASCII whitespaces,
280 // digits, and punctuation).
281 int min_text_size_in_bytes_ = 0;
282
283 // Only predictions with a probability (confidence) above this threshold are
284 // reported. Otherwise, we report LangId::kUnknownLanguageCode.
285 float default_threshold_ = kDefaultConfidenceThreshold;
286
287 std::unordered_map<std::string, float> per_lang_thresholds_;
288
289 // Recognized languages: softmax label i means languages_[i] (something like
290 // "en", "fr", "ru", etc).
291 std::vector<std::string> languages_;
292
293 // Version of the model used by this LangIdImpl object. Zero means that the
294 // model version could not be determined.
295 int model_version_ = 0;
296 };
297
298 const char LangId::kUnknownLanguageCode[] = "und";
299
LangId(std::unique_ptr<ModelProvider> model_provider)300 LangId::LangId(std::unique_ptr<ModelProvider> model_provider)
301 : pimpl_(new LangIdImpl(std::move(model_provider))) {
302 LangIdImpl::RegisterClasses();
303 }
304
305 LangId::~LangId() = default;
306
FindLanguage(const char * data,size_t num_bytes) const307 std::string LangId::FindLanguage(const char *data, size_t num_bytes) const {
308 StringPiece text(data, num_bytes);
309 return pimpl_->FindLanguage(text);
310 }
311
FindLanguages(const char * data,size_t num_bytes,LangIdResult * result,int max_results) const312 void LangId::FindLanguages(const char *data, size_t num_bytes,
313 LangIdResult *result, int max_results) const {
314 SAFTM_DCHECK(result) << "LangIdResult must not be null.";
315 StringPiece text(data, num_bytes);
316 pimpl_->FindLanguages(text, result, max_results);
317 }
318
is_valid() const319 bool LangId::is_valid() const { return pimpl_->is_valid(); }
320
GetModelVersion() const321 int LangId::GetModelVersion() const { return pimpl_->GetModelVersion(); }
322
GetFloatProperty(const std::string & property,float default_value) const323 float LangId::GetFloatProperty(const std::string &property,
324 float default_value) const {
325 return pimpl_->GetProperty<float, float>(property, default_value);
326 }
327
328 } // namespace lang_id
329 } // namespace mobile
330 } // namespace nlp_saft
331