1*993b0882SAndroid Build Coastguard Worker /*
2*993b0882SAndroid Build Coastguard Worker * Copyright (C) 2018 The Android Open Source Project
3*993b0882SAndroid Build Coastguard Worker *
4*993b0882SAndroid Build Coastguard Worker * Licensed under the Apache License, Version 2.0 (the "License");
5*993b0882SAndroid Build Coastguard Worker * you may not use this file except in compliance with the License.
6*993b0882SAndroid Build Coastguard Worker * You may obtain a copy of the License at
7*993b0882SAndroid Build Coastguard Worker *
8*993b0882SAndroid Build Coastguard Worker * http://www.apache.org/licenses/LICENSE-2.0
9*993b0882SAndroid Build Coastguard Worker *
10*993b0882SAndroid Build Coastguard Worker * Unless required by applicable law or agreed to in writing, software
11*993b0882SAndroid Build Coastguard Worker * distributed under the License is distributed on an "AS IS" BASIS,
12*993b0882SAndroid Build Coastguard Worker * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13*993b0882SAndroid Build Coastguard Worker * See the License for the specific language governing permissions and
14*993b0882SAndroid Build Coastguard Worker * limitations under the License.
15*993b0882SAndroid Build Coastguard Worker */
16*993b0882SAndroid Build Coastguard Worker
17*993b0882SAndroid Build Coastguard Worker #include "annotator/translate/translate.h"
18*993b0882SAndroid Build Coastguard Worker
19*993b0882SAndroid Build Coastguard Worker #include <algorithm>
20*993b0882SAndroid Build Coastguard Worker #include <memory>
21*993b0882SAndroid Build Coastguard Worker
22*993b0882SAndroid Build Coastguard Worker #include "annotator/collections.h"
23*993b0882SAndroid Build Coastguard Worker #include "annotator/entity-data_generated.h"
24*993b0882SAndroid Build Coastguard Worker #include "annotator/model_generated.h"
25*993b0882SAndroid Build Coastguard Worker #include "annotator/types.h"
26*993b0882SAndroid Build Coastguard Worker #include "lang_id/lang-id-wrapper.h"
27*993b0882SAndroid Build Coastguard Worker #include "utils/base/logging.h"
28*993b0882SAndroid Build Coastguard Worker #include "utils/i18n/locale.h"
29*993b0882SAndroid Build Coastguard Worker #include "utils/utf8/unicodetext.h"
30*993b0882SAndroid Build Coastguard Worker #include "lang_id/lang-id.h"
31*993b0882SAndroid Build Coastguard Worker
32*993b0882SAndroid Build Coastguard Worker namespace libtextclassifier3 {
33*993b0882SAndroid Build Coastguard Worker
ClassifyText(const UnicodeText & context,CodepointSpan selection_indices,const std::string & user_familiar_language_tags,ClassificationResult * classification_result) const34*993b0882SAndroid Build Coastguard Worker bool TranslateAnnotator::ClassifyText(
35*993b0882SAndroid Build Coastguard Worker const UnicodeText& context, CodepointSpan selection_indices,
36*993b0882SAndroid Build Coastguard Worker const std::string& user_familiar_language_tags,
37*993b0882SAndroid Build Coastguard Worker ClassificationResult* classification_result) const {
38*993b0882SAndroid Build Coastguard Worker if (!(options_->enabled_modes() & ModeFlag_CLASSIFICATION)) {
39*993b0882SAndroid Build Coastguard Worker return false;
40*993b0882SAndroid Build Coastguard Worker }
41*993b0882SAndroid Build Coastguard Worker
42*993b0882SAndroid Build Coastguard Worker std::vector<TranslateAnnotator::LanguageConfidence> confidences;
43*993b0882SAndroid Build Coastguard Worker if (options_->algorithm() ==
44*993b0882SAndroid Build Coastguard Worker TranslateAnnotatorOptions_::Algorithm::Algorithm_BACKOFF) {
45*993b0882SAndroid Build Coastguard Worker if (options_->backoff_options() == nullptr) {
46*993b0882SAndroid Build Coastguard Worker TC3_LOG(WARNING) << "No backoff options specified. Returning.";
47*993b0882SAndroid Build Coastguard Worker return false;
48*993b0882SAndroid Build Coastguard Worker }
49*993b0882SAndroid Build Coastguard Worker confidences = BackoffDetectLanguages(context, selection_indices);
50*993b0882SAndroid Build Coastguard Worker }
51*993b0882SAndroid Build Coastguard Worker
52*993b0882SAndroid Build Coastguard Worker if (confidences.empty()) {
53*993b0882SAndroid Build Coastguard Worker return false;
54*993b0882SAndroid Build Coastguard Worker }
55*993b0882SAndroid Build Coastguard Worker
56*993b0882SAndroid Build Coastguard Worker std::vector<Locale> user_familiar_languages;
57*993b0882SAndroid Build Coastguard Worker if (!ParseLocales(user_familiar_language_tags, &user_familiar_languages)) {
58*993b0882SAndroid Build Coastguard Worker TC3_LOG(WARNING) << "Couldn't parse the user-understood languages.";
59*993b0882SAndroid Build Coastguard Worker return false;
60*993b0882SAndroid Build Coastguard Worker }
61*993b0882SAndroid Build Coastguard Worker if (user_familiar_languages.empty()) {
62*993b0882SAndroid Build Coastguard Worker TC3_VLOG(INFO) << "user_familiar_languages is not set, not suggesting "
63*993b0882SAndroid Build Coastguard Worker "translate action.";
64*993b0882SAndroid Build Coastguard Worker return false;
65*993b0882SAndroid Build Coastguard Worker }
66*993b0882SAndroid Build Coastguard Worker bool user_can_understand_language_of_text = false;
67*993b0882SAndroid Build Coastguard Worker for (const Locale& locale : user_familiar_languages) {
68*993b0882SAndroid Build Coastguard Worker if (locale.Language() == confidences[0].language) {
69*993b0882SAndroid Build Coastguard Worker user_can_understand_language_of_text = true;
70*993b0882SAndroid Build Coastguard Worker break;
71*993b0882SAndroid Build Coastguard Worker }
72*993b0882SAndroid Build Coastguard Worker }
73*993b0882SAndroid Build Coastguard Worker
74*993b0882SAndroid Build Coastguard Worker if (!user_can_understand_language_of_text) {
75*993b0882SAndroid Build Coastguard Worker classification_result->collection = Collections::Translate();
76*993b0882SAndroid Build Coastguard Worker classification_result->score = options_->score();
77*993b0882SAndroid Build Coastguard Worker classification_result->priority_score = options_->priority_score();
78*993b0882SAndroid Build Coastguard Worker classification_result->serialized_entity_data =
79*993b0882SAndroid Build Coastguard Worker CreateSerializedEntityData(confidences);
80*993b0882SAndroid Build Coastguard Worker return true;
81*993b0882SAndroid Build Coastguard Worker }
82*993b0882SAndroid Build Coastguard Worker
83*993b0882SAndroid Build Coastguard Worker return false;
84*993b0882SAndroid Build Coastguard Worker }
85*993b0882SAndroid Build Coastguard Worker
CreateSerializedEntityData(const std::vector<TranslateAnnotator::LanguageConfidence> & confidences) const86*993b0882SAndroid Build Coastguard Worker std::string TranslateAnnotator::CreateSerializedEntityData(
87*993b0882SAndroid Build Coastguard Worker const std::vector<TranslateAnnotator::LanguageConfidence>& confidences)
88*993b0882SAndroid Build Coastguard Worker const {
89*993b0882SAndroid Build Coastguard Worker EntityDataT entity_data;
90*993b0882SAndroid Build Coastguard Worker entity_data.translate.reset(new EntityData_::TranslateT());
91*993b0882SAndroid Build Coastguard Worker
92*993b0882SAndroid Build Coastguard Worker for (const LanguageConfidence& confidence : confidences) {
93*993b0882SAndroid Build Coastguard Worker EntityData_::Translate_::LanguagePredictionResultT*
94*993b0882SAndroid Build Coastguard Worker language_prediction_result =
95*993b0882SAndroid Build Coastguard Worker new EntityData_::Translate_::LanguagePredictionResultT();
96*993b0882SAndroid Build Coastguard Worker language_prediction_result->language_tag = confidence.language;
97*993b0882SAndroid Build Coastguard Worker language_prediction_result->confidence_score = confidence.confidence;
98*993b0882SAndroid Build Coastguard Worker entity_data.translate->language_prediction_results.emplace_back(
99*993b0882SAndroid Build Coastguard Worker language_prediction_result);
100*993b0882SAndroid Build Coastguard Worker }
101*993b0882SAndroid Build Coastguard Worker flatbuffers::FlatBufferBuilder builder;
102*993b0882SAndroid Build Coastguard Worker FinishEntityDataBuffer(builder, EntityData::Pack(builder, &entity_data));
103*993b0882SAndroid Build Coastguard Worker return std::string(reinterpret_cast<const char*>(builder.GetBufferPointer()),
104*993b0882SAndroid Build Coastguard Worker builder.GetSize());
105*993b0882SAndroid Build Coastguard Worker }
106*993b0882SAndroid Build Coastguard Worker
107*993b0882SAndroid Build Coastguard Worker std::vector<TranslateAnnotator::LanguageConfidence>
BackoffDetectLanguages(const UnicodeText & context,CodepointSpan selection_indices) const108*993b0882SAndroid Build Coastguard Worker TranslateAnnotator::BackoffDetectLanguages(
109*993b0882SAndroid Build Coastguard Worker const UnicodeText& context, CodepointSpan selection_indices) const {
110*993b0882SAndroid Build Coastguard Worker const float penalize_ratio = options_->backoff_options()->penalize_ratio();
111*993b0882SAndroid Build Coastguard Worker const int min_text_size = options_->backoff_options()->min_text_size();
112*993b0882SAndroid Build Coastguard Worker if (selection_indices.second - selection_indices.first < min_text_size &&
113*993b0882SAndroid Build Coastguard Worker penalize_ratio <= 0) {
114*993b0882SAndroid Build Coastguard Worker return {};
115*993b0882SAndroid Build Coastguard Worker }
116*993b0882SAndroid Build Coastguard Worker
117*993b0882SAndroid Build Coastguard Worker const UnicodeText entity =
118*993b0882SAndroid Build Coastguard Worker UnicodeText::Substring(context, selection_indices.first,
119*993b0882SAndroid Build Coastguard Worker selection_indices.second, /*do_copy=*/false);
120*993b0882SAndroid Build Coastguard Worker const std::vector<std::pair<std::string, float>> lang_id_result =
121*993b0882SAndroid Build Coastguard Worker langid::GetPredictions(langid_model_, entity.data(), entity.size_bytes());
122*993b0882SAndroid Build Coastguard Worker
123*993b0882SAndroid Build Coastguard Worker const float more_text_score_ratio =
124*993b0882SAndroid Build Coastguard Worker 1.0f - options_->backoff_options()->subject_text_score_ratio();
125*993b0882SAndroid Build Coastguard Worker std::vector<std::pair<std::string, float>> more_lang_id_results;
126*993b0882SAndroid Build Coastguard Worker if (more_text_score_ratio >= 0) {
127*993b0882SAndroid Build Coastguard Worker const UnicodeText entity_with_context = TokenAlignedSubstringAroundSpan(
128*993b0882SAndroid Build Coastguard Worker context, selection_indices, min_text_size);
129*993b0882SAndroid Build Coastguard Worker more_lang_id_results =
130*993b0882SAndroid Build Coastguard Worker langid::GetPredictions(langid_model_, entity_with_context.data(),
131*993b0882SAndroid Build Coastguard Worker entity_with_context.size_bytes());
132*993b0882SAndroid Build Coastguard Worker }
133*993b0882SAndroid Build Coastguard Worker
134*993b0882SAndroid Build Coastguard Worker const float subject_text_score_ratio =
135*993b0882SAndroid Build Coastguard Worker options_->backoff_options()->subject_text_score_ratio();
136*993b0882SAndroid Build Coastguard Worker
137*993b0882SAndroid Build Coastguard Worker std::map<std::string, float> result_map;
138*993b0882SAndroid Build Coastguard Worker for (const auto& [language, score] : lang_id_result) {
139*993b0882SAndroid Build Coastguard Worker result_map[language] = subject_text_score_ratio * score;
140*993b0882SAndroid Build Coastguard Worker }
141*993b0882SAndroid Build Coastguard Worker for (const auto& [language, score] : more_lang_id_results) {
142*993b0882SAndroid Build Coastguard Worker result_map[language] += more_text_score_ratio * score * penalize_ratio;
143*993b0882SAndroid Build Coastguard Worker }
144*993b0882SAndroid Build Coastguard Worker
145*993b0882SAndroid Build Coastguard Worker std::vector<TranslateAnnotator::LanguageConfidence> result;
146*993b0882SAndroid Build Coastguard Worker result.reserve(result_map.size());
147*993b0882SAndroid Build Coastguard Worker for (const auto& [key, value] : result_map) {
148*993b0882SAndroid Build Coastguard Worker result.push_back({key, value});
149*993b0882SAndroid Build Coastguard Worker }
150*993b0882SAndroid Build Coastguard Worker
151*993b0882SAndroid Build Coastguard Worker std::stable_sort(result.begin(), result.end(),
152*993b0882SAndroid Build Coastguard Worker [](const TranslateAnnotator::LanguageConfidence& a,
153*993b0882SAndroid Build Coastguard Worker const TranslateAnnotator::LanguageConfidence& b) {
154*993b0882SAndroid Build Coastguard Worker return a.confidence > b.confidence;
155*993b0882SAndroid Build Coastguard Worker });
156*993b0882SAndroid Build Coastguard Worker return result;
157*993b0882SAndroid Build Coastguard Worker }
158*993b0882SAndroid Build Coastguard Worker
159*993b0882SAndroid Build Coastguard Worker UnicodeText::const_iterator
FindIndexOfNextWhitespaceOrPunctuation(const UnicodeText & text,int start_index,int direction) const160*993b0882SAndroid Build Coastguard Worker TranslateAnnotator::FindIndexOfNextWhitespaceOrPunctuation(
161*993b0882SAndroid Build Coastguard Worker const UnicodeText& text, int start_index, int direction) const {
162*993b0882SAndroid Build Coastguard Worker TC3_CHECK(direction == 1 || direction == -1);
163*993b0882SAndroid Build Coastguard Worker auto it = text.begin();
164*993b0882SAndroid Build Coastguard Worker std::advance(it, start_index);
165*993b0882SAndroid Build Coastguard Worker while (it > text.begin() && it < text.end()) {
166*993b0882SAndroid Build Coastguard Worker if (unilib_->IsWhitespace(*it) || unilib_->IsPunctuation(*it)) {
167*993b0882SAndroid Build Coastguard Worker break;
168*993b0882SAndroid Build Coastguard Worker }
169*993b0882SAndroid Build Coastguard Worker std::advance(it, direction);
170*993b0882SAndroid Build Coastguard Worker }
171*993b0882SAndroid Build Coastguard Worker return it;
172*993b0882SAndroid Build Coastguard Worker }
173*993b0882SAndroid Build Coastguard Worker
TokenAlignedSubstringAroundSpan(const UnicodeText & text,CodepointSpan indices,int minimum_length) const174*993b0882SAndroid Build Coastguard Worker UnicodeText TranslateAnnotator::TokenAlignedSubstringAroundSpan(
175*993b0882SAndroid Build Coastguard Worker const UnicodeText& text, CodepointSpan indices, int minimum_length) const {
176*993b0882SAndroid Build Coastguard Worker const int text_size_codepoints = text.size_codepoints();
177*993b0882SAndroid Build Coastguard Worker if (text_size_codepoints < minimum_length) {
178*993b0882SAndroid Build Coastguard Worker return UnicodeText(text, /*do_copy=*/false);
179*993b0882SAndroid Build Coastguard Worker }
180*993b0882SAndroid Build Coastguard Worker
181*993b0882SAndroid Build Coastguard Worker const int start = indices.first;
182*993b0882SAndroid Build Coastguard Worker const int end = indices.second;
183*993b0882SAndroid Build Coastguard Worker const int length = end - start;
184*993b0882SAndroid Build Coastguard Worker if (length >= minimum_length) {
185*993b0882SAndroid Build Coastguard Worker return UnicodeText::Substring(text, start, end, /*do_copy=*/false);
186*993b0882SAndroid Build Coastguard Worker }
187*993b0882SAndroid Build Coastguard Worker
188*993b0882SAndroid Build Coastguard Worker const int offset = (minimum_length - length) / 2;
189*993b0882SAndroid Build Coastguard Worker const int iter_start = std::max(
190*993b0882SAndroid Build Coastguard Worker 0, std::min(start - offset, text_size_codepoints - minimum_length));
191*993b0882SAndroid Build Coastguard Worker const int iter_end =
192*993b0882SAndroid Build Coastguard Worker std::min(text_size_codepoints, iter_start + minimum_length);
193*993b0882SAndroid Build Coastguard Worker
194*993b0882SAndroid Build Coastguard Worker auto it_start = FindIndexOfNextWhitespaceOrPunctuation(text, iter_start, -1);
195*993b0882SAndroid Build Coastguard Worker const auto it_end = FindIndexOfNextWhitespaceOrPunctuation(text, iter_end, 1);
196*993b0882SAndroid Build Coastguard Worker
197*993b0882SAndroid Build Coastguard Worker // The it_start now points to whitespace/punctuation (unless it reached the
198*993b0882SAndroid Build Coastguard Worker // beginning of the string). So we'll move it one position forward to point to
199*993b0882SAndroid Build Coastguard Worker // the actual text.
200*993b0882SAndroid Build Coastguard Worker if (it_start != it_end && unilib_->IsWhitespace(*it_start)) {
201*993b0882SAndroid Build Coastguard Worker std::advance(it_start, 1);
202*993b0882SAndroid Build Coastguard Worker }
203*993b0882SAndroid Build Coastguard Worker
204*993b0882SAndroid Build Coastguard Worker return UnicodeText::Substring(it_start, it_end, /*do_copy=*/false);
205*993b0882SAndroid Build Coastguard Worker }
206*993b0882SAndroid Build Coastguard Worker
207*993b0882SAndroid Build Coastguard Worker } // namespace libtextclassifier3
208