1*993b0882SAndroid Build Coastguard Worker /*
2*993b0882SAndroid Build Coastguard Worker * Copyright (C) 2018 The Android Open Source Project
3*993b0882SAndroid Build Coastguard Worker *
4*993b0882SAndroid Build Coastguard Worker * Licensed under the Apache License, Version 2.0 (the "License");
5*993b0882SAndroid Build Coastguard Worker * you may not use this file except in compliance with the License.
6*993b0882SAndroid Build Coastguard Worker * You may obtain a copy of the License at
7*993b0882SAndroid Build Coastguard Worker *
8*993b0882SAndroid Build Coastguard Worker * http://www.apache.org/licenses/LICENSE-2.0
9*993b0882SAndroid Build Coastguard Worker *
10*993b0882SAndroid Build Coastguard Worker * Unless required by applicable law or agreed to in writing, software
11*993b0882SAndroid Build Coastguard Worker * distributed under the License is distributed on an "AS IS" BASIS,
12*993b0882SAndroid Build Coastguard Worker * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13*993b0882SAndroid Build Coastguard Worker * See the License for the specific language governing permissions and
14*993b0882SAndroid Build Coastguard Worker * limitations under the License.
15*993b0882SAndroid Build Coastguard Worker */
16*993b0882SAndroid Build Coastguard Worker
17*993b0882SAndroid Build Coastguard Worker #include "lang_id/lang-id.h"
18*993b0882SAndroid Build Coastguard Worker
19*993b0882SAndroid Build Coastguard Worker #include <stdio.h>
20*993b0882SAndroid Build Coastguard Worker
21*993b0882SAndroid Build Coastguard Worker #include <memory>
22*993b0882SAndroid Build Coastguard Worker #include <string>
23*993b0882SAndroid Build Coastguard Worker #include <unordered_map>
24*993b0882SAndroid Build Coastguard Worker #include <utility>
25*993b0882SAndroid Build Coastguard Worker #include <vector>
26*993b0882SAndroid Build Coastguard Worker
27*993b0882SAndroid Build Coastguard Worker #include "lang_id/common/embedding-feature-interface.h"
28*993b0882SAndroid Build Coastguard Worker #include "lang_id/common/embedding-network-params.h"
29*993b0882SAndroid Build Coastguard Worker #include "lang_id/common/embedding-network.h"
30*993b0882SAndroid Build Coastguard Worker #include "lang_id/common/fel/feature-extractor.h"
31*993b0882SAndroid Build Coastguard Worker #include "lang_id/common/lite_base/logging.h"
32*993b0882SAndroid Build Coastguard Worker #include "lang_id/common/lite_strings/numbers.h"
33*993b0882SAndroid Build Coastguard Worker #include "lang_id/common/lite_strings/str-split.h"
34*993b0882SAndroid Build Coastguard Worker #include "lang_id/common/lite_strings/stringpiece.h"
35*993b0882SAndroid Build Coastguard Worker #include "lang_id/common/math/algorithm.h"
36*993b0882SAndroid Build Coastguard Worker #include "lang_id/common/math/softmax.h"
37*993b0882SAndroid Build Coastguard Worker #include "lang_id/custom-tokenizer.h"
38*993b0882SAndroid Build Coastguard Worker #include "lang_id/features/light-sentence-features.h"
39*993b0882SAndroid Build Coastguard Worker // The two features/ headers below are needed only for RegisterClass().
40*993b0882SAndroid Build Coastguard Worker #include "lang_id/features/char-ngram-feature.h"
41*993b0882SAndroid Build Coastguard Worker #include "lang_id/features/relevant-script-feature.h"
42*993b0882SAndroid Build Coastguard Worker #include "lang_id/light-sentence.h"
43*993b0882SAndroid Build Coastguard Worker // The two script/ headers below are needed only for RegisterClass().
44*993b0882SAndroid Build Coastguard Worker #include "lang_id/script/approx-script.h"
45*993b0882SAndroid Build Coastguard Worker #include "lang_id/script/tiny-script-detector.h"
46*993b0882SAndroid Build Coastguard Worker
47*993b0882SAndroid Build Coastguard Worker namespace libtextclassifier3 {
48*993b0882SAndroid Build Coastguard Worker namespace mobile {
49*993b0882SAndroid Build Coastguard Worker namespace lang_id {
50*993b0882SAndroid Build Coastguard Worker
51*993b0882SAndroid Build Coastguard Worker namespace {
52*993b0882SAndroid Build Coastguard Worker // Default value for the confidence threshold. If the confidence of the top
53*993b0882SAndroid Build Coastguard Worker // prediction is below this threshold, then FindLanguage() returns
54*993b0882SAndroid Build Coastguard Worker // LangId::kUnknownLanguageCode. Note: this is just a default value; if the
55*993b0882SAndroid Build Coastguard Worker // TaskSpec from the model specifies a "reliability_thresh" parameter, then we
56*993b0882SAndroid Build Coastguard Worker // use that value instead. Note: for legacy reasons, our code and comments use
57*993b0882SAndroid Build Coastguard Worker // the terms "confidence", "probability" and "reliability" equivalently.
58*993b0882SAndroid Build Coastguard Worker static const float kDefaultConfidenceThreshold = 0.50f;
59*993b0882SAndroid Build Coastguard Worker } // namespace
60*993b0882SAndroid Build Coastguard Worker
61*993b0882SAndroid Build Coastguard Worker // Class that performs all work behind LangId.
62*993b0882SAndroid Build Coastguard Worker class LangIdImpl {
63*993b0882SAndroid Build Coastguard Worker public:
LangIdImpl(std::unique_ptr<ModelProvider> model_provider)64*993b0882SAndroid Build Coastguard Worker explicit LangIdImpl(std::unique_ptr<ModelProvider> model_provider)
65*993b0882SAndroid Build Coastguard Worker : model_provider_(std::move(model_provider)),
66*993b0882SAndroid Build Coastguard Worker lang_id_brain_interface_("language_identifier") {
67*993b0882SAndroid Build Coastguard Worker // Note: in the code below, we set valid_ to true only if all initialization
68*993b0882SAndroid Build Coastguard Worker // steps completed successfully. Otherwise, we return early, leaving valid_
69*993b0882SAndroid Build Coastguard Worker // to its default value false.
70*993b0882SAndroid Build Coastguard Worker if (!model_provider_ || !model_provider_->is_valid()) {
71*993b0882SAndroid Build Coastguard Worker SAFTM_LOG(ERROR) << "Invalid model provider";
72*993b0882SAndroid Build Coastguard Worker return;
73*993b0882SAndroid Build Coastguard Worker }
74*993b0882SAndroid Build Coastguard Worker
75*993b0882SAndroid Build Coastguard Worker auto *nn_params = model_provider_->GetNnParams();
76*993b0882SAndroid Build Coastguard Worker if (!nn_params) {
77*993b0882SAndroid Build Coastguard Worker SAFTM_LOG(ERROR) << "No NN params";
78*993b0882SAndroid Build Coastguard Worker return;
79*993b0882SAndroid Build Coastguard Worker }
80*993b0882SAndroid Build Coastguard Worker network_.reset(new EmbeddingNetwork(nn_params));
81*993b0882SAndroid Build Coastguard Worker
82*993b0882SAndroid Build Coastguard Worker languages_ = model_provider_->GetLanguages();
83*993b0882SAndroid Build Coastguard Worker if (languages_.empty()) {
84*993b0882SAndroid Build Coastguard Worker SAFTM_LOG(ERROR) << "No known languages";
85*993b0882SAndroid Build Coastguard Worker return;
86*993b0882SAndroid Build Coastguard Worker }
87*993b0882SAndroid Build Coastguard Worker
88*993b0882SAndroid Build Coastguard Worker TaskContext context = *model_provider_->GetTaskContext();
89*993b0882SAndroid Build Coastguard Worker if (!Setup(&context)) {
90*993b0882SAndroid Build Coastguard Worker SAFTM_LOG(ERROR) << "Unable to Setup() LangId";
91*993b0882SAndroid Build Coastguard Worker return;
92*993b0882SAndroid Build Coastguard Worker }
93*993b0882SAndroid Build Coastguard Worker if (!Init(&context)) {
94*993b0882SAndroid Build Coastguard Worker SAFTM_LOG(ERROR) << "Unable to Init() LangId";
95*993b0882SAndroid Build Coastguard Worker return;
96*993b0882SAndroid Build Coastguard Worker }
97*993b0882SAndroid Build Coastguard Worker valid_ = true;
98*993b0882SAndroid Build Coastguard Worker }
99*993b0882SAndroid Build Coastguard Worker
FindLanguage(StringPiece text) const100*993b0882SAndroid Build Coastguard Worker std::string FindLanguage(StringPiece text) const {
101*993b0882SAndroid Build Coastguard Worker LangIdResult lang_id_result;
102*993b0882SAndroid Build Coastguard Worker FindLanguages(text, &lang_id_result, /* max_results = */ 1);
103*993b0882SAndroid Build Coastguard Worker if (lang_id_result.predictions.empty()) {
104*993b0882SAndroid Build Coastguard Worker return LangId::kUnknownLanguageCode;
105*993b0882SAndroid Build Coastguard Worker }
106*993b0882SAndroid Build Coastguard Worker
107*993b0882SAndroid Build Coastguard Worker const std::string &language = lang_id_result.predictions[0].first;
108*993b0882SAndroid Build Coastguard Worker const float probability = lang_id_result.predictions[0].second;
109*993b0882SAndroid Build Coastguard Worker SAFTM_DLOG(INFO) << "Predicted " << language
110*993b0882SAndroid Build Coastguard Worker << " with prob: " << probability << " for \"" << text
111*993b0882SAndroid Build Coastguard Worker << "\"";
112*993b0882SAndroid Build Coastguard Worker
113*993b0882SAndroid Build Coastguard Worker // Find confidence threshold for language.
114*993b0882SAndroid Build Coastguard Worker float threshold = default_threshold_;
115*993b0882SAndroid Build Coastguard Worker auto it = per_lang_thresholds_.find(language);
116*993b0882SAndroid Build Coastguard Worker if (it != per_lang_thresholds_.end()) {
117*993b0882SAndroid Build Coastguard Worker threshold = it->second;
118*993b0882SAndroid Build Coastguard Worker }
119*993b0882SAndroid Build Coastguard Worker if (probability < threshold) {
120*993b0882SAndroid Build Coastguard Worker SAFTM_DLOG(INFO) << " below threshold => "
121*993b0882SAndroid Build Coastguard Worker << LangId::kUnknownLanguageCode;
122*993b0882SAndroid Build Coastguard Worker return LangId::kUnknownLanguageCode;
123*993b0882SAndroid Build Coastguard Worker }
124*993b0882SAndroid Build Coastguard Worker return language;
125*993b0882SAndroid Build Coastguard Worker }
126*993b0882SAndroid Build Coastguard Worker
FindLanguages(StringPiece text,LangIdResult * result,int max_results) const127*993b0882SAndroid Build Coastguard Worker void FindLanguages(StringPiece text, LangIdResult *result,
128*993b0882SAndroid Build Coastguard Worker int max_results) const {
129*993b0882SAndroid Build Coastguard Worker if (result == nullptr) return;
130*993b0882SAndroid Build Coastguard Worker
131*993b0882SAndroid Build Coastguard Worker if (max_results <= 0) {
132*993b0882SAndroid Build Coastguard Worker max_results = languages_.size();
133*993b0882SAndroid Build Coastguard Worker }
134*993b0882SAndroid Build Coastguard Worker result->predictions.clear();
135*993b0882SAndroid Build Coastguard Worker if (!is_valid() || (max_results == 0)) {
136*993b0882SAndroid Build Coastguard Worker result->predictions.emplace_back(LangId::kUnknownLanguageCode, 1);
137*993b0882SAndroid Build Coastguard Worker return;
138*993b0882SAndroid Build Coastguard Worker }
139*993b0882SAndroid Build Coastguard Worker
140*993b0882SAndroid Build Coastguard Worker // Tokenize the input text (this also does some pre-processing, like
141*993b0882SAndroid Build Coastguard Worker // removing ASCII digits, punctuation, etc).
142*993b0882SAndroid Build Coastguard Worker LightSentence sentence;
143*993b0882SAndroid Build Coastguard Worker tokenizer_.Tokenize(text, &sentence);
144*993b0882SAndroid Build Coastguard Worker
145*993b0882SAndroid Build Coastguard Worker // Test input size here, after pre-processing removed irrelevant chars.
146*993b0882SAndroid Build Coastguard Worker if (IsTooShort(sentence)) {
147*993b0882SAndroid Build Coastguard Worker result->predictions.emplace_back(LangId::kUnknownLanguageCode, 1);
148*993b0882SAndroid Build Coastguard Worker return;
149*993b0882SAndroid Build Coastguard Worker }
150*993b0882SAndroid Build Coastguard Worker
151*993b0882SAndroid Build Coastguard Worker // Extract features from the tokenized text.
152*993b0882SAndroid Build Coastguard Worker std::vector<FeatureVector> features =
153*993b0882SAndroid Build Coastguard Worker lang_id_brain_interface_.GetFeaturesNoCaching(&sentence);
154*993b0882SAndroid Build Coastguard Worker
155*993b0882SAndroid Build Coastguard Worker // Run feed-forward neural network to compute scores (softmax logits).
156*993b0882SAndroid Build Coastguard Worker std::vector<float> scores;
157*993b0882SAndroid Build Coastguard Worker network_->ComputeFinalScores(features, &scores);
158*993b0882SAndroid Build Coastguard Worker
159*993b0882SAndroid Build Coastguard Worker if (max_results == 1) {
160*993b0882SAndroid Build Coastguard Worker // Optimization for the case when the user wants only the top result.
161*993b0882SAndroid Build Coastguard Worker // Computing argmax is faster than the general top-k code.
162*993b0882SAndroid Build Coastguard Worker int prediction_id = GetArgMax(scores);
163*993b0882SAndroid Build Coastguard Worker const std::string language = GetLanguageForSoftmaxLabel(prediction_id);
164*993b0882SAndroid Build Coastguard Worker float probability = ComputeSoftmaxProbability(scores, prediction_id);
165*993b0882SAndroid Build Coastguard Worker result->predictions.emplace_back(language, probability);
166*993b0882SAndroid Build Coastguard Worker } else {
167*993b0882SAndroid Build Coastguard Worker // Compute and sort softmax in descending order by probability and convert
168*993b0882SAndroid Build Coastguard Worker // IDs to language code strings. When probabilities are equal, we sort by
169*993b0882SAndroid Build Coastguard Worker // language code string in ascending order.
170*993b0882SAndroid Build Coastguard Worker const std::vector<float> softmax = ComputeSoftmax(scores);
171*993b0882SAndroid Build Coastguard Worker const std::vector<int> indices = GetTopKIndices(max_results, softmax);
172*993b0882SAndroid Build Coastguard Worker for (const int index : indices) {
173*993b0882SAndroid Build Coastguard Worker result->predictions.emplace_back(GetLanguageForSoftmaxLabel(index),
174*993b0882SAndroid Build Coastguard Worker softmax[index]);
175*993b0882SAndroid Build Coastguard Worker }
176*993b0882SAndroid Build Coastguard Worker }
177*993b0882SAndroid Build Coastguard Worker }
178*993b0882SAndroid Build Coastguard Worker
is_valid() const179*993b0882SAndroid Build Coastguard Worker bool is_valid() const { return valid_; }
180*993b0882SAndroid Build Coastguard Worker
GetModelVersion() const181*993b0882SAndroid Build Coastguard Worker int GetModelVersion() const { return model_version_; }
182*993b0882SAndroid Build Coastguard Worker
183*993b0882SAndroid Build Coastguard Worker // Returns a property stored in the model file.
184*993b0882SAndroid Build Coastguard Worker template <typename T, typename R>
GetProperty(const std::string & property,T default_value) const185*993b0882SAndroid Build Coastguard Worker R GetProperty(const std::string &property, T default_value) const {
186*993b0882SAndroid Build Coastguard Worker return model_provider_->GetTaskContext()->Get(property, default_value);
187*993b0882SAndroid Build Coastguard Worker }
188*993b0882SAndroid Build Coastguard Worker
189*993b0882SAndroid Build Coastguard Worker // Perform any necessary static initialization.
190*993b0882SAndroid Build Coastguard Worker // This function is thread-safe.
191*993b0882SAndroid Build Coastguard Worker // It's also safe to call this function multiple times.
192*993b0882SAndroid Build Coastguard Worker //
193*993b0882SAndroid Build Coastguard Worker // We explicitly call RegisterClass() rather than relying on alwayslink=1 in
194*993b0882SAndroid Build Coastguard Worker // the BUILD file, because the build process for some users of this code
195*993b0882SAndroid Build Coastguard Worker // doesn't support any equivalent to alwayslink=1 (in particular the
196*993b0882SAndroid Build Coastguard Worker // Firebase C++ SDK build uses a Kokoro-based CMake build). While it might
197*993b0882SAndroid Build Coastguard Worker // be possible to add such support, avoiding the need for an equivalent to
198*993b0882SAndroid Build Coastguard Worker // alwayslink=1 is preferable because it avoids unnecessarily bloating code
199*993b0882SAndroid Build Coastguard Worker // size in apps that link against this code but don't use it.
RegisterClasses()200*993b0882SAndroid Build Coastguard Worker static void RegisterClasses() {
201*993b0882SAndroid Build Coastguard Worker static bool initialized = []() -> bool {
202*993b0882SAndroid Build Coastguard Worker libtextclassifier3::mobile::ApproxScriptDetector::RegisterClass();
203*993b0882SAndroid Build Coastguard Worker libtextclassifier3::mobile::lang_id::ContinuousBagOfNgramsFunction::RegisterClass();
204*993b0882SAndroid Build Coastguard Worker libtextclassifier3::mobile::lang_id::TinyScriptDetector::RegisterClass();
205*993b0882SAndroid Build Coastguard Worker libtextclassifier3::mobile::lang_id::RelevantScriptFeature::RegisterClass();
206*993b0882SAndroid Build Coastguard Worker return true;
207*993b0882SAndroid Build Coastguard Worker }();
208*993b0882SAndroid Build Coastguard Worker (void)initialized; // Variable used only for initializer's side effects.
209*993b0882SAndroid Build Coastguard Worker }
210*993b0882SAndroid Build Coastguard Worker
211*993b0882SAndroid Build Coastguard Worker private:
Setup(TaskContext * context)212*993b0882SAndroid Build Coastguard Worker bool Setup(TaskContext *context) {
213*993b0882SAndroid Build Coastguard Worker tokenizer_.Setup(context);
214*993b0882SAndroid Build Coastguard Worker if (!lang_id_brain_interface_.SetupForProcessing(context)) return false;
215*993b0882SAndroid Build Coastguard Worker
216*993b0882SAndroid Build Coastguard Worker min_text_size_in_bytes_ = context->Get("min_text_size_in_bytes", 0);
217*993b0882SAndroid Build Coastguard Worker default_threshold_ =
218*993b0882SAndroid Build Coastguard Worker context->Get("reliability_thresh", kDefaultConfidenceThreshold);
219*993b0882SAndroid Build Coastguard Worker
220*993b0882SAndroid Build Coastguard Worker // Parse task parameter "per_lang_reliability_thresholds", fill
221*993b0882SAndroid Build Coastguard Worker // per_lang_thresholds_.
222*993b0882SAndroid Build Coastguard Worker const std::string thresholds_str =
223*993b0882SAndroid Build Coastguard Worker context->Get("per_lang_reliability_thresholds", "");
224*993b0882SAndroid Build Coastguard Worker std::vector<StringPiece> tokens = LiteStrSplit(thresholds_str, ',');
225*993b0882SAndroid Build Coastguard Worker for (const auto &token : tokens) {
226*993b0882SAndroid Build Coastguard Worker if (token.empty()) continue;
227*993b0882SAndroid Build Coastguard Worker std::vector<StringPiece> parts = LiteStrSplit(token, '=');
228*993b0882SAndroid Build Coastguard Worker float threshold = 0.0f;
229*993b0882SAndroid Build Coastguard Worker if ((parts.size() == 2) && LiteAtof(parts[1], &threshold)) {
230*993b0882SAndroid Build Coastguard Worker per_lang_thresholds_[std::string(parts[0])] = threshold;
231*993b0882SAndroid Build Coastguard Worker } else {
232*993b0882SAndroid Build Coastguard Worker SAFTM_LOG(ERROR) << "Broken token: \"" << token << "\"";
233*993b0882SAndroid Build Coastguard Worker }
234*993b0882SAndroid Build Coastguard Worker }
235*993b0882SAndroid Build Coastguard Worker model_version_ = context->Get("model_version", model_version_);
236*993b0882SAndroid Build Coastguard Worker return true;
237*993b0882SAndroid Build Coastguard Worker }
238*993b0882SAndroid Build Coastguard Worker
Init(TaskContext * context)239*993b0882SAndroid Build Coastguard Worker bool Init(TaskContext *context) {
240*993b0882SAndroid Build Coastguard Worker return lang_id_brain_interface_.InitForProcessing(context);
241*993b0882SAndroid Build Coastguard Worker }
242*993b0882SAndroid Build Coastguard Worker
243*993b0882SAndroid Build Coastguard Worker // Returns language code for a softmax label. See comments for languages_
244*993b0882SAndroid Build Coastguard Worker // field. If label is out of range, returns LangId::kUnknownLanguageCode.
GetLanguageForSoftmaxLabel(int label) const245*993b0882SAndroid Build Coastguard Worker std::string GetLanguageForSoftmaxLabel(int label) const {
246*993b0882SAndroid Build Coastguard Worker if ((label >= 0) && (static_cast<size_t>(label) < languages_.size())) {
247*993b0882SAndroid Build Coastguard Worker return languages_[label];
248*993b0882SAndroid Build Coastguard Worker } else {
249*993b0882SAndroid Build Coastguard Worker SAFTM_LOG(ERROR) << "Softmax label " << label << " outside range [0, "
250*993b0882SAndroid Build Coastguard Worker << languages_.size() << ")";
251*993b0882SAndroid Build Coastguard Worker return LangId::kUnknownLanguageCode;
252*993b0882SAndroid Build Coastguard Worker }
253*993b0882SAndroid Build Coastguard Worker }
254*993b0882SAndroid Build Coastguard Worker
IsTooShort(const LightSentence & sentence) const255*993b0882SAndroid Build Coastguard Worker bool IsTooShort(const LightSentence &sentence) const {
256*993b0882SAndroid Build Coastguard Worker int text_size = 0;
257*993b0882SAndroid Build Coastguard Worker for (const std::string &token : sentence) {
258*993b0882SAndroid Build Coastguard Worker // Each token has the form ^...$: we subtract 2 because we want to count
259*993b0882SAndroid Build Coastguard Worker // only the real text, not the chars added by us.
260*993b0882SAndroid Build Coastguard Worker text_size += token.size() - 2;
261*993b0882SAndroid Build Coastguard Worker }
262*993b0882SAndroid Build Coastguard Worker return text_size < min_text_size_in_bytes_;
263*993b0882SAndroid Build Coastguard Worker }
264*993b0882SAndroid Build Coastguard Worker
265*993b0882SAndroid Build Coastguard Worker std::unique_ptr<ModelProvider> model_provider_;
266*993b0882SAndroid Build Coastguard Worker
267*993b0882SAndroid Build Coastguard Worker TokenizerForLangId tokenizer_;
268*993b0882SAndroid Build Coastguard Worker
269*993b0882SAndroid Build Coastguard Worker EmbeddingFeatureInterface<LightSentenceExtractor, LightSentence>
270*993b0882SAndroid Build Coastguard Worker lang_id_brain_interface_;
271*993b0882SAndroid Build Coastguard Worker
272*993b0882SAndroid Build Coastguard Worker // Neural network to use for scoring.
273*993b0882SAndroid Build Coastguard Worker std::unique_ptr<EmbeddingNetwork> network_;
274*993b0882SAndroid Build Coastguard Worker
275*993b0882SAndroid Build Coastguard Worker // True if this object is ready to perform language predictions.
276*993b0882SAndroid Build Coastguard Worker bool valid_ = false;
277*993b0882SAndroid Build Coastguard Worker
278*993b0882SAndroid Build Coastguard Worker // The model returns LangId::kUnknownLanguageCode for input text that has
279*993b0882SAndroid Build Coastguard Worker // fewer than min_text_size_in_bytes_ bytes (excluding ASCII whitespaces,
280*993b0882SAndroid Build Coastguard Worker // digits, and punctuation).
281*993b0882SAndroid Build Coastguard Worker int min_text_size_in_bytes_ = 0;
282*993b0882SAndroid Build Coastguard Worker
283*993b0882SAndroid Build Coastguard Worker // Only predictions with a probability (confidence) above this threshold are
284*993b0882SAndroid Build Coastguard Worker // reported. Otherwise, we report LangId::kUnknownLanguageCode.
285*993b0882SAndroid Build Coastguard Worker float default_threshold_ = kDefaultConfidenceThreshold;
286*993b0882SAndroid Build Coastguard Worker
287*993b0882SAndroid Build Coastguard Worker std::unordered_map<std::string, float> per_lang_thresholds_;
288*993b0882SAndroid Build Coastguard Worker
289*993b0882SAndroid Build Coastguard Worker // Recognized languages: softmax label i means languages_[i] (something like
290*993b0882SAndroid Build Coastguard Worker // "en", "fr", "ru", etc).
291*993b0882SAndroid Build Coastguard Worker std::vector<std::string> languages_;
292*993b0882SAndroid Build Coastguard Worker
293*993b0882SAndroid Build Coastguard Worker // Version of the model used by this LangIdImpl object. Zero means that the
294*993b0882SAndroid Build Coastguard Worker // model version could not be determined.
295*993b0882SAndroid Build Coastguard Worker int model_version_ = 0;
296*993b0882SAndroid Build Coastguard Worker };
297*993b0882SAndroid Build Coastguard Worker
298*993b0882SAndroid Build Coastguard Worker const char LangId::kUnknownLanguageCode[] = "und";
299*993b0882SAndroid Build Coastguard Worker
LangId(std::unique_ptr<ModelProvider> model_provider)300*993b0882SAndroid Build Coastguard Worker LangId::LangId(std::unique_ptr<ModelProvider> model_provider)
301*993b0882SAndroid Build Coastguard Worker : pimpl_(new LangIdImpl(std::move(model_provider))) {
302*993b0882SAndroid Build Coastguard Worker LangIdImpl::RegisterClasses();
303*993b0882SAndroid Build Coastguard Worker }
304*993b0882SAndroid Build Coastguard Worker
305*993b0882SAndroid Build Coastguard Worker LangId::~LangId() = default;
306*993b0882SAndroid Build Coastguard Worker
FindLanguage(const char * data,size_t num_bytes) const307*993b0882SAndroid Build Coastguard Worker std::string LangId::FindLanguage(const char *data, size_t num_bytes) const {
308*993b0882SAndroid Build Coastguard Worker StringPiece text(data, num_bytes);
309*993b0882SAndroid Build Coastguard Worker return pimpl_->FindLanguage(text);
310*993b0882SAndroid Build Coastguard Worker }
311*993b0882SAndroid Build Coastguard Worker
FindLanguages(const char * data,size_t num_bytes,LangIdResult * result,int max_results) const312*993b0882SAndroid Build Coastguard Worker void LangId::FindLanguages(const char *data, size_t num_bytes,
313*993b0882SAndroid Build Coastguard Worker LangIdResult *result, int max_results) const {
314*993b0882SAndroid Build Coastguard Worker SAFTM_DCHECK(result) << "LangIdResult must not be null.";
315*993b0882SAndroid Build Coastguard Worker StringPiece text(data, num_bytes);
316*993b0882SAndroid Build Coastguard Worker pimpl_->FindLanguages(text, result, max_results);
317*993b0882SAndroid Build Coastguard Worker }
318*993b0882SAndroid Build Coastguard Worker
is_valid() const319*993b0882SAndroid Build Coastguard Worker bool LangId::is_valid() const { return pimpl_->is_valid(); }
320*993b0882SAndroid Build Coastguard Worker
GetModelVersion() const321*993b0882SAndroid Build Coastguard Worker int LangId::GetModelVersion() const { return pimpl_->GetModelVersion(); }
322*993b0882SAndroid Build Coastguard Worker
GetFloatProperty(const std::string & property,float default_value) const323*993b0882SAndroid Build Coastguard Worker float LangId::GetFloatProperty(const std::string &property,
324*993b0882SAndroid Build Coastguard Worker float default_value) const {
325*993b0882SAndroid Build Coastguard Worker return pimpl_->GetProperty<float, float>(property, default_value);
326*993b0882SAndroid Build Coastguard Worker }
327*993b0882SAndroid Build Coastguard Worker
328*993b0882SAndroid Build Coastguard Worker } // namespace lang_id
329*993b0882SAndroid Build Coastguard Worker } // namespace mobile
330*993b0882SAndroid Build Coastguard Worker } // namespace nlp_saft
331