xref: /aosp_15_r20/external/libtextclassifier/native/annotator/vocab/vocab-annotator-impl.cc (revision 993b0882672172b81d12fad7a7ac0c3e5c824a12)
1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "annotator/vocab/vocab-annotator-impl.h"
18 
19 #include "annotator/feature-processor.h"
20 #include "annotator/model_generated.h"
21 #include "utils/base/logging.h"
22 #include "utils/optional.h"
23 #include "utils/strings/numbers.h"
24 
25 namespace libtextclassifier3 {
26 
VocabAnnotator(std::unique_ptr<VocabLevelTable> vocab_level_table,const std::vector<Locale> & triggering_locales,const FeatureProcessor & feature_processor,const UniLib & unilib,const VocabModel * model)27 VocabAnnotator::VocabAnnotator(
28     std::unique_ptr<VocabLevelTable> vocab_level_table,
29     const std::vector<Locale>& triggering_locales,
30     const FeatureProcessor& feature_processor, const UniLib& unilib,
31     const VocabModel* model)
32     : vocab_level_table_(std::move(vocab_level_table)),
33       triggering_locales_(triggering_locales),
34       feature_processor_(feature_processor),
35       unilib_(unilib),
36       model_(model) {}
37 
Create(const VocabModel * model,const FeatureProcessor & feature_processor,const UniLib & unilib)38 std::unique_ptr<VocabAnnotator> VocabAnnotator::Create(
39     const VocabModel* model, const FeatureProcessor& feature_processor,
40     const UniLib& unilib) {
41   std::unique_ptr<VocabLevelTable> vocab_lebel_table =
42       VocabLevelTable::Create(model);
43   if (vocab_lebel_table == nullptr) {
44     TC3_LOG(ERROR) << "Failed to create vocab level table.";
45     return nullptr;
46   }
47   std::vector<Locale> triggering_locales;
48   if (model->triggering_locales() &&
49       !ParseLocales(model->triggering_locales()->c_str(),
50                     &triggering_locales)) {
51     TC3_LOG(ERROR) << "Could not parse model supported locales.";
52     return nullptr;
53   }
54 
55   return std::unique_ptr<VocabAnnotator>(
56       new VocabAnnotator(std::move(vocab_lebel_table), triggering_locales,
57                          feature_processor, unilib, model));
58 }
59 
Annotate(const UnicodeText & context,const std::vector<Locale> detected_text_language_tags,bool trigger_on_beginner_words,std::vector<AnnotatedSpan> * results) const60 bool VocabAnnotator::Annotate(
61     const UnicodeText& context,
62     const std::vector<Locale> detected_text_language_tags,
63     bool trigger_on_beginner_words, std::vector<AnnotatedSpan>* results) const {
64   if (!(model_->enabled_modes() & ModeFlag_ANNOTATION)) {
65     return true;
66   }
67   std::vector<Token> tokens = feature_processor_.Tokenize(context);
68   for (const Token& token : tokens) {
69     ClassificationResult classification_result;
70     CodepointSpan stripped_span;
71     bool found = ClassifyTextInternal(
72         context, {token.start, token.end}, detected_text_language_tags,
73         trigger_on_beginner_words, &classification_result, &stripped_span);
74     if (found) {
75       results->push_back(AnnotatedSpan{stripped_span, {classification_result}});
76     }
77   }
78   return true;
79 }
80 
ClassifyText(const UnicodeText & context,CodepointSpan click,const std::vector<Locale> detected_text_language_tags,bool trigger_on_beginner_words,ClassificationResult * result) const81 bool VocabAnnotator::ClassifyText(
82     const UnicodeText& context, CodepointSpan click,
83     const std::vector<Locale> detected_text_language_tags,
84     bool trigger_on_beginner_words, ClassificationResult* result) const {
85   CodepointSpan stripped_span;
86   return ClassifyTextInternal(context, click, detected_text_language_tags,
87                               trigger_on_beginner_words, result,
88                               &stripped_span);
89 }
90 
ClassifyTextInternal(const UnicodeText & context,const CodepointSpan click,const std::vector<Locale> detected_text_language_tags,bool trigger_on_beginner_words,ClassificationResult * classification_result,CodepointSpan * classified_span) const91 bool VocabAnnotator::ClassifyTextInternal(
92     const UnicodeText& context, const CodepointSpan click,
93     const std::vector<Locale> detected_text_language_tags,
94     bool trigger_on_beginner_words, ClassificationResult* classification_result,
95     CodepointSpan* classified_span) const {
96   if (!(model_->enabled_modes() & ModeFlag_CLASSIFICATION)) {
97     return false;
98   }
99   if (vocab_level_table_ == nullptr) {
100     return false;
101   }
102 
103   if (!Locale::IsAnyLocaleSupported(detected_text_language_tags,
104                                     triggering_locales_,
105                                     /*default_value=*/false)) {
106     return false;
107   }
108   const CodepointSpan stripped_span =
109       feature_processor_.StripBoundaryCodepoints(context,
110                                                  {click.first, click.second});
111   const UnicodeText stripped_token = UnicodeText::Substring(
112       context, stripped_span.first, stripped_span.second, /*do_copy=*/false);
113   const std::string lower_token =
114       unilib_.ToLowerText(stripped_token).ToUTF8String();
115 
116   const Optional<LookupResult> result = vocab_level_table_->Lookup(lower_token);
117   if (!result.has_value()) {
118     return false;
119   }
120   if (result.value().do_not_trigger_in_upper_case &&
121       unilib_.IsUpper(*stripped_token.begin())) {
122     TC3_VLOG(INFO) << "Not trigger define: proper noun in upper case.";
123     return false;
124   }
125   if (result.value().beginner_level && !trigger_on_beginner_words) {
126     TC3_VLOG(INFO) << "Not trigger define: for beginner only.";
127     return false;
128   }
129   *classification_result =
130       ClassificationResult("dictionary", model_->target_classification_score(),
131                            model_->priority_score());
132   *classified_span = stripped_span;
133 
134   return true;
135 }
136 }  // namespace libtextclassifier3
137