xref: /aosp_15_r20/external/libtextclassifier/native/annotator/translate/translate.cc (revision 993b0882672172b81d12fad7a7ac0c3e5c824a12)
1*993b0882SAndroid Build Coastguard Worker /*
2*993b0882SAndroid Build Coastguard Worker  * Copyright (C) 2018 The Android Open Source Project
3*993b0882SAndroid Build Coastguard Worker  *
4*993b0882SAndroid Build Coastguard Worker  * Licensed under the Apache License, Version 2.0 (the "License");
5*993b0882SAndroid Build Coastguard Worker  * you may not use this file except in compliance with the License.
6*993b0882SAndroid Build Coastguard Worker  * You may obtain a copy of the License at
7*993b0882SAndroid Build Coastguard Worker  *
8*993b0882SAndroid Build Coastguard Worker  *      http://www.apache.org/licenses/LICENSE-2.0
9*993b0882SAndroid Build Coastguard Worker  *
10*993b0882SAndroid Build Coastguard Worker  * Unless required by applicable law or agreed to in writing, software
11*993b0882SAndroid Build Coastguard Worker  * distributed under the License is distributed on an "AS IS" BASIS,
12*993b0882SAndroid Build Coastguard Worker  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13*993b0882SAndroid Build Coastguard Worker  * See the License for the specific language governing permissions and
14*993b0882SAndroid Build Coastguard Worker  * limitations under the License.
15*993b0882SAndroid Build Coastguard Worker  */
16*993b0882SAndroid Build Coastguard Worker 
17*993b0882SAndroid Build Coastguard Worker #include "annotator/translate/translate.h"
18*993b0882SAndroid Build Coastguard Worker 
19*993b0882SAndroid Build Coastguard Worker #include <algorithm>
20*993b0882SAndroid Build Coastguard Worker #include <memory>
21*993b0882SAndroid Build Coastguard Worker 
22*993b0882SAndroid Build Coastguard Worker #include "annotator/collections.h"
23*993b0882SAndroid Build Coastguard Worker #include "annotator/entity-data_generated.h"
24*993b0882SAndroid Build Coastguard Worker #include "annotator/model_generated.h"
25*993b0882SAndroid Build Coastguard Worker #include "annotator/types.h"
26*993b0882SAndroid Build Coastguard Worker #include "lang_id/lang-id-wrapper.h"
27*993b0882SAndroid Build Coastguard Worker #include "utils/base/logging.h"
28*993b0882SAndroid Build Coastguard Worker #include "utils/i18n/locale.h"
29*993b0882SAndroid Build Coastguard Worker #include "utils/utf8/unicodetext.h"
30*993b0882SAndroid Build Coastguard Worker #include "lang_id/lang-id.h"
31*993b0882SAndroid Build Coastguard Worker 
32*993b0882SAndroid Build Coastguard Worker namespace libtextclassifier3 {
33*993b0882SAndroid Build Coastguard Worker 
ClassifyText(const UnicodeText & context,CodepointSpan selection_indices,const std::string & user_familiar_language_tags,ClassificationResult * classification_result) const34*993b0882SAndroid Build Coastguard Worker bool TranslateAnnotator::ClassifyText(
35*993b0882SAndroid Build Coastguard Worker     const UnicodeText& context, CodepointSpan selection_indices,
36*993b0882SAndroid Build Coastguard Worker     const std::string& user_familiar_language_tags,
37*993b0882SAndroid Build Coastguard Worker     ClassificationResult* classification_result) const {
38*993b0882SAndroid Build Coastguard Worker   if (!(options_->enabled_modes() & ModeFlag_CLASSIFICATION)) {
39*993b0882SAndroid Build Coastguard Worker     return false;
40*993b0882SAndroid Build Coastguard Worker   }
41*993b0882SAndroid Build Coastguard Worker 
42*993b0882SAndroid Build Coastguard Worker   std::vector<TranslateAnnotator::LanguageConfidence> confidences;
43*993b0882SAndroid Build Coastguard Worker   if (options_->algorithm() ==
44*993b0882SAndroid Build Coastguard Worker       TranslateAnnotatorOptions_::Algorithm::Algorithm_BACKOFF) {
45*993b0882SAndroid Build Coastguard Worker     if (options_->backoff_options() == nullptr) {
46*993b0882SAndroid Build Coastguard Worker       TC3_LOG(WARNING) << "No backoff options specified. Returning.";
47*993b0882SAndroid Build Coastguard Worker       return false;
48*993b0882SAndroid Build Coastguard Worker     }
49*993b0882SAndroid Build Coastguard Worker     confidences = BackoffDetectLanguages(context, selection_indices);
50*993b0882SAndroid Build Coastguard Worker   }
51*993b0882SAndroid Build Coastguard Worker 
52*993b0882SAndroid Build Coastguard Worker   if (confidences.empty()) {
53*993b0882SAndroid Build Coastguard Worker     return false;
54*993b0882SAndroid Build Coastguard Worker   }
55*993b0882SAndroid Build Coastguard Worker 
56*993b0882SAndroid Build Coastguard Worker   std::vector<Locale> user_familiar_languages;
57*993b0882SAndroid Build Coastguard Worker   if (!ParseLocales(user_familiar_language_tags, &user_familiar_languages)) {
58*993b0882SAndroid Build Coastguard Worker     TC3_LOG(WARNING) << "Couldn't parse the user-understood languages.";
59*993b0882SAndroid Build Coastguard Worker     return false;
60*993b0882SAndroid Build Coastguard Worker   }
61*993b0882SAndroid Build Coastguard Worker   if (user_familiar_languages.empty()) {
62*993b0882SAndroid Build Coastguard Worker     TC3_VLOG(INFO) << "user_familiar_languages is not set, not suggesting "
63*993b0882SAndroid Build Coastguard Worker                       "translate action.";
64*993b0882SAndroid Build Coastguard Worker     return false;
65*993b0882SAndroid Build Coastguard Worker   }
66*993b0882SAndroid Build Coastguard Worker   bool user_can_understand_language_of_text = false;
67*993b0882SAndroid Build Coastguard Worker   for (const Locale& locale : user_familiar_languages) {
68*993b0882SAndroid Build Coastguard Worker     if (locale.Language() == confidences[0].language) {
69*993b0882SAndroid Build Coastguard Worker       user_can_understand_language_of_text = true;
70*993b0882SAndroid Build Coastguard Worker       break;
71*993b0882SAndroid Build Coastguard Worker     }
72*993b0882SAndroid Build Coastguard Worker   }
73*993b0882SAndroid Build Coastguard Worker 
74*993b0882SAndroid Build Coastguard Worker   if (!user_can_understand_language_of_text) {
75*993b0882SAndroid Build Coastguard Worker     classification_result->collection = Collections::Translate();
76*993b0882SAndroid Build Coastguard Worker     classification_result->score = options_->score();
77*993b0882SAndroid Build Coastguard Worker     classification_result->priority_score = options_->priority_score();
78*993b0882SAndroid Build Coastguard Worker     classification_result->serialized_entity_data =
79*993b0882SAndroid Build Coastguard Worker         CreateSerializedEntityData(confidences);
80*993b0882SAndroid Build Coastguard Worker     return true;
81*993b0882SAndroid Build Coastguard Worker   }
82*993b0882SAndroid Build Coastguard Worker 
83*993b0882SAndroid Build Coastguard Worker   return false;
84*993b0882SAndroid Build Coastguard Worker }
85*993b0882SAndroid Build Coastguard Worker 
CreateSerializedEntityData(const std::vector<TranslateAnnotator::LanguageConfidence> & confidences) const86*993b0882SAndroid Build Coastguard Worker std::string TranslateAnnotator::CreateSerializedEntityData(
87*993b0882SAndroid Build Coastguard Worker     const std::vector<TranslateAnnotator::LanguageConfidence>& confidences)
88*993b0882SAndroid Build Coastguard Worker     const {
89*993b0882SAndroid Build Coastguard Worker   EntityDataT entity_data;
90*993b0882SAndroid Build Coastguard Worker   entity_data.translate.reset(new EntityData_::TranslateT());
91*993b0882SAndroid Build Coastguard Worker 
92*993b0882SAndroid Build Coastguard Worker   for (const LanguageConfidence& confidence : confidences) {
93*993b0882SAndroid Build Coastguard Worker     EntityData_::Translate_::LanguagePredictionResultT*
94*993b0882SAndroid Build Coastguard Worker         language_prediction_result =
95*993b0882SAndroid Build Coastguard Worker             new EntityData_::Translate_::LanguagePredictionResultT();
96*993b0882SAndroid Build Coastguard Worker     language_prediction_result->language_tag = confidence.language;
97*993b0882SAndroid Build Coastguard Worker     language_prediction_result->confidence_score = confidence.confidence;
98*993b0882SAndroid Build Coastguard Worker     entity_data.translate->language_prediction_results.emplace_back(
99*993b0882SAndroid Build Coastguard Worker         language_prediction_result);
100*993b0882SAndroid Build Coastguard Worker   }
101*993b0882SAndroid Build Coastguard Worker   flatbuffers::FlatBufferBuilder builder;
102*993b0882SAndroid Build Coastguard Worker   FinishEntityDataBuffer(builder, EntityData::Pack(builder, &entity_data));
103*993b0882SAndroid Build Coastguard Worker   return std::string(reinterpret_cast<const char*>(builder.GetBufferPointer()),
104*993b0882SAndroid Build Coastguard Worker                      builder.GetSize());
105*993b0882SAndroid Build Coastguard Worker }
106*993b0882SAndroid Build Coastguard Worker 
107*993b0882SAndroid Build Coastguard Worker std::vector<TranslateAnnotator::LanguageConfidence>
BackoffDetectLanguages(const UnicodeText & context,CodepointSpan selection_indices) const108*993b0882SAndroid Build Coastguard Worker TranslateAnnotator::BackoffDetectLanguages(
109*993b0882SAndroid Build Coastguard Worker     const UnicodeText& context, CodepointSpan selection_indices) const {
110*993b0882SAndroid Build Coastguard Worker   const float penalize_ratio = options_->backoff_options()->penalize_ratio();
111*993b0882SAndroid Build Coastguard Worker   const int min_text_size = options_->backoff_options()->min_text_size();
112*993b0882SAndroid Build Coastguard Worker   if (selection_indices.second - selection_indices.first < min_text_size &&
113*993b0882SAndroid Build Coastguard Worker       penalize_ratio <= 0) {
114*993b0882SAndroid Build Coastguard Worker     return {};
115*993b0882SAndroid Build Coastguard Worker   }
116*993b0882SAndroid Build Coastguard Worker 
117*993b0882SAndroid Build Coastguard Worker   const UnicodeText entity =
118*993b0882SAndroid Build Coastguard Worker       UnicodeText::Substring(context, selection_indices.first,
119*993b0882SAndroid Build Coastguard Worker                              selection_indices.second, /*do_copy=*/false);
120*993b0882SAndroid Build Coastguard Worker   const std::vector<std::pair<std::string, float>> lang_id_result =
121*993b0882SAndroid Build Coastguard Worker       langid::GetPredictions(langid_model_, entity.data(), entity.size_bytes());
122*993b0882SAndroid Build Coastguard Worker 
123*993b0882SAndroid Build Coastguard Worker   const float more_text_score_ratio =
124*993b0882SAndroid Build Coastguard Worker       1.0f - options_->backoff_options()->subject_text_score_ratio();
125*993b0882SAndroid Build Coastguard Worker   std::vector<std::pair<std::string, float>> more_lang_id_results;
126*993b0882SAndroid Build Coastguard Worker   if (more_text_score_ratio >= 0) {
127*993b0882SAndroid Build Coastguard Worker     const UnicodeText entity_with_context = TokenAlignedSubstringAroundSpan(
128*993b0882SAndroid Build Coastguard Worker         context, selection_indices, min_text_size);
129*993b0882SAndroid Build Coastguard Worker     more_lang_id_results =
130*993b0882SAndroid Build Coastguard Worker         langid::GetPredictions(langid_model_, entity_with_context.data(),
131*993b0882SAndroid Build Coastguard Worker                                entity_with_context.size_bytes());
132*993b0882SAndroid Build Coastguard Worker   }
133*993b0882SAndroid Build Coastguard Worker 
134*993b0882SAndroid Build Coastguard Worker   const float subject_text_score_ratio =
135*993b0882SAndroid Build Coastguard Worker       options_->backoff_options()->subject_text_score_ratio();
136*993b0882SAndroid Build Coastguard Worker 
137*993b0882SAndroid Build Coastguard Worker   std::map<std::string, float> result_map;
138*993b0882SAndroid Build Coastguard Worker   for (const auto& [language, score] : lang_id_result) {
139*993b0882SAndroid Build Coastguard Worker     result_map[language] = subject_text_score_ratio * score;
140*993b0882SAndroid Build Coastguard Worker   }
141*993b0882SAndroid Build Coastguard Worker   for (const auto& [language, score] : more_lang_id_results) {
142*993b0882SAndroid Build Coastguard Worker     result_map[language] += more_text_score_ratio * score * penalize_ratio;
143*993b0882SAndroid Build Coastguard Worker   }
144*993b0882SAndroid Build Coastguard Worker 
145*993b0882SAndroid Build Coastguard Worker   std::vector<TranslateAnnotator::LanguageConfidence> result;
146*993b0882SAndroid Build Coastguard Worker   result.reserve(result_map.size());
147*993b0882SAndroid Build Coastguard Worker   for (const auto& [key, value] : result_map) {
148*993b0882SAndroid Build Coastguard Worker     result.push_back({key, value});
149*993b0882SAndroid Build Coastguard Worker   }
150*993b0882SAndroid Build Coastguard Worker 
151*993b0882SAndroid Build Coastguard Worker   std::stable_sort(result.begin(), result.end(),
152*993b0882SAndroid Build Coastguard Worker                    [](const TranslateAnnotator::LanguageConfidence& a,
153*993b0882SAndroid Build Coastguard Worker                       const TranslateAnnotator::LanguageConfidence& b) {
154*993b0882SAndroid Build Coastguard Worker                      return a.confidence > b.confidence;
155*993b0882SAndroid Build Coastguard Worker                    });
156*993b0882SAndroid Build Coastguard Worker   return result;
157*993b0882SAndroid Build Coastguard Worker }
158*993b0882SAndroid Build Coastguard Worker 
159*993b0882SAndroid Build Coastguard Worker UnicodeText::const_iterator
FindIndexOfNextWhitespaceOrPunctuation(const UnicodeText & text,int start_index,int direction) const160*993b0882SAndroid Build Coastguard Worker TranslateAnnotator::FindIndexOfNextWhitespaceOrPunctuation(
161*993b0882SAndroid Build Coastguard Worker     const UnicodeText& text, int start_index, int direction) const {
162*993b0882SAndroid Build Coastguard Worker   TC3_CHECK(direction == 1 || direction == -1);
163*993b0882SAndroid Build Coastguard Worker   auto it = text.begin();
164*993b0882SAndroid Build Coastguard Worker   std::advance(it, start_index);
165*993b0882SAndroid Build Coastguard Worker   while (it > text.begin() && it < text.end()) {
166*993b0882SAndroid Build Coastguard Worker     if (unilib_->IsWhitespace(*it) || unilib_->IsPunctuation(*it)) {
167*993b0882SAndroid Build Coastguard Worker       break;
168*993b0882SAndroid Build Coastguard Worker     }
169*993b0882SAndroid Build Coastguard Worker     std::advance(it, direction);
170*993b0882SAndroid Build Coastguard Worker   }
171*993b0882SAndroid Build Coastguard Worker   return it;
172*993b0882SAndroid Build Coastguard Worker }
173*993b0882SAndroid Build Coastguard Worker 
TokenAlignedSubstringAroundSpan(const UnicodeText & text,CodepointSpan indices,int minimum_length) const174*993b0882SAndroid Build Coastguard Worker UnicodeText TranslateAnnotator::TokenAlignedSubstringAroundSpan(
175*993b0882SAndroid Build Coastguard Worker     const UnicodeText& text, CodepointSpan indices, int minimum_length) const {
176*993b0882SAndroid Build Coastguard Worker   const int text_size_codepoints = text.size_codepoints();
177*993b0882SAndroid Build Coastguard Worker   if (text_size_codepoints < minimum_length) {
178*993b0882SAndroid Build Coastguard Worker     return UnicodeText(text, /*do_copy=*/false);
179*993b0882SAndroid Build Coastguard Worker   }
180*993b0882SAndroid Build Coastguard Worker 
181*993b0882SAndroid Build Coastguard Worker   const int start = indices.first;
182*993b0882SAndroid Build Coastguard Worker   const int end = indices.second;
183*993b0882SAndroid Build Coastguard Worker   const int length = end - start;
184*993b0882SAndroid Build Coastguard Worker   if (length >= minimum_length) {
185*993b0882SAndroid Build Coastguard Worker     return UnicodeText::Substring(text, start, end, /*do_copy=*/false);
186*993b0882SAndroid Build Coastguard Worker   }
187*993b0882SAndroid Build Coastguard Worker 
188*993b0882SAndroid Build Coastguard Worker   const int offset = (minimum_length - length) / 2;
189*993b0882SAndroid Build Coastguard Worker   const int iter_start = std::max(
190*993b0882SAndroid Build Coastguard Worker       0, std::min(start - offset, text_size_codepoints - minimum_length));
191*993b0882SAndroid Build Coastguard Worker   const int iter_end =
192*993b0882SAndroid Build Coastguard Worker       std::min(text_size_codepoints, iter_start + minimum_length);
193*993b0882SAndroid Build Coastguard Worker 
194*993b0882SAndroid Build Coastguard Worker   auto it_start = FindIndexOfNextWhitespaceOrPunctuation(text, iter_start, -1);
195*993b0882SAndroid Build Coastguard Worker   const auto it_end = FindIndexOfNextWhitespaceOrPunctuation(text, iter_end, 1);
196*993b0882SAndroid Build Coastguard Worker 
197*993b0882SAndroid Build Coastguard Worker   // The it_start now points to whitespace/punctuation (unless it reached the
198*993b0882SAndroid Build Coastguard Worker   // beginning of the string). So we'll move it one position forward to point to
199*993b0882SAndroid Build Coastguard Worker   // the actual text.
200*993b0882SAndroid Build Coastguard Worker   if (it_start != it_end && unilib_->IsWhitespace(*it_start)) {
201*993b0882SAndroid Build Coastguard Worker     std::advance(it_start, 1);
202*993b0882SAndroid Build Coastguard Worker   }
203*993b0882SAndroid Build Coastguard Worker 
204*993b0882SAndroid Build Coastguard Worker   return UnicodeText::Substring(it_start, it_end, /*do_copy=*/false);
205*993b0882SAndroid Build Coastguard Worker }
206*993b0882SAndroid Build Coastguard Worker 
207*993b0882SAndroid Build Coastguard Worker }  // namespace libtextclassifier3
208