xref: /aosp_15_r20/external/libtextclassifier/native/utils/grammar/semantics/composer.cc (revision 993b0882672172b81d12fad7a7ac0c3e5c824a12)
1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "utils/grammar/semantics/composer.h"
18 
19 #include "utils/base/status_macros.h"
20 #include "utils/grammar/semantics/evaluators/arithmetic-eval.h"
21 #include "utils/grammar/semantics/evaluators/compose-eval.h"
22 #include "utils/grammar/semantics/evaluators/const-eval.h"
23 #include "utils/grammar/semantics/evaluators/constituent-eval.h"
24 #include "utils/grammar/semantics/evaluators/merge-values-eval.h"
25 #include "utils/grammar/semantics/evaluators/parse-number-eval.h"
26 #include "utils/grammar/semantics/evaluators/span-eval.h"
27 
28 namespace libtextclassifier3::grammar {
29 namespace {
30 
31 // Gathers all constituents of a rule and index them.
32 // The constituents are numbered in the rule construction. But consituents could
33 // be in optional parts of the rule and might not be present in a match.
34 // This finds all constituents that are present in a match and allows to
35 // retrieve them by their index.
GatherConstituents(const ParseTree * root)36 std::unordered_map<int, const ParseTree*> GatherConstituents(
37     const ParseTree* root) {
38   std::unordered_map<int, const ParseTree*> constituents;
39   Traverse(root, [root, &constituents](const ParseTree* node) {
40     switch (node->type) {
41       case ParseTree::Type::kMapping:
42         TC3_CHECK(node->IsUnaryRule());
43         constituents[static_cast<const MappingNode*>(node)->id] =
44             node->unary_rule_rhs();
45         return false;
46       case ParseTree::Type::kDefault:
47         // Continue traversal.
48         return true;
49       default:
50         // Don't continue the traversal if we are not at the root node.
51         // This could e.g. be an assertion node.
52         return (node == root);
53     }
54   });
55   return constituents;
56 }
57 
58 }  // namespace
59 
SemanticComposer(const reflection::Schema * semantic_values_schema)60 SemanticComposer::SemanticComposer(
61     const reflection::Schema* semantic_values_schema) {
62   evaluators_.emplace(SemanticExpression_::Expression_ArithmeticExpression,
63                       std::make_unique<ArithmeticExpressionEvaluator>(this));
64   evaluators_.emplace(SemanticExpression_::Expression_ConstituentExpression,
65                       std::make_unique<ConstituentEvaluator>());
66   evaluators_.emplace(SemanticExpression_::Expression_ParseNumberExpression,
67                       std::make_unique<ParseNumberEvaluator>(this));
68   evaluators_.emplace(SemanticExpression_::Expression_SpanAsStringExpression,
69                       std::make_unique<SpanAsStringEvaluator>());
70   if (semantic_values_schema != nullptr) {
71     // Register semantic functions.
72     evaluators_.emplace(
73         SemanticExpression_::Expression_ComposeExpression,
74         std::make_unique<ComposeEvaluator>(this, semantic_values_schema));
75     evaluators_.emplace(
76         SemanticExpression_::Expression_ConstValueExpression,
77         std::make_unique<ConstEvaluator>(semantic_values_schema));
78     evaluators_.emplace(
79         SemanticExpression_::Expression_MergeValueExpression,
80         std::make_unique<MergeValuesEvaluator>(this, semantic_values_schema));
81   }
82 }
83 
Eval(const TextContext & text_context,const Derivation & derivation,UnsafeArena * arena) const84 StatusOr<const SemanticValue*> SemanticComposer::Eval(
85     const TextContext& text_context, const Derivation& derivation,
86     UnsafeArena* arena) const {
87   if (!derivation.parse_tree->IsUnaryRule() ||
88       derivation.parse_tree->unary_rule_rhs()->type !=
89           ParseTree::Type::kExpression) {
90     return nullptr;
91   }
92   return Eval(text_context,
93               static_cast<const SemanticExpressionNode*>(
94                   derivation.parse_tree->unary_rule_rhs()),
95               arena);
96 }
97 
Eval(const TextContext & text_context,const SemanticExpressionNode * derivation,UnsafeArena * arena) const98 StatusOr<const SemanticValue*> SemanticComposer::Eval(
99     const TextContext& text_context, const SemanticExpressionNode* derivation,
100     UnsafeArena* arena) const {
101   // Evaluate constituents.
102   EvalContext context{&text_context, derivation};
103   for (const auto& [constituent_index, constituent] :
104        GatherConstituents(derivation)) {
105     if (constituent->type == ParseTree::Type::kExpression) {
106       TC3_ASSIGN_OR_RETURN(
107           context.rule_constituents[constituent_index],
108           Eval(text_context,
109                static_cast<const SemanticExpressionNode*>(constituent), arena));
110     } else {
111       // Just use the text of the constituent if no semantic expression was
112       // defined.
113       context.rule_constituents[constituent_index] = SemanticValue::Create(
114           text_context.Span(constituent->codepoint_span), arena);
115     }
116   }
117   return Apply(context, derivation->expression, arena);
118 }
119 
Apply(const EvalContext & context,const SemanticExpression * expression,UnsafeArena * arena) const120 StatusOr<const SemanticValue*> SemanticComposer::Apply(
121     const EvalContext& context, const SemanticExpression* expression,
122     UnsafeArena* arena) const {
123   const auto handler_it = evaluators_.find(expression->expression_type());
124   if (handler_it == evaluators_.end()) {
125     return Status(StatusCode::INVALID_ARGUMENT,
126                   std::string("Unhandled expression type: ") +
127                       EnumNameExpression(expression->expression_type()));
128   }
129   return handler_it->second->Apply(context, expression, arena);
130 }
131 
132 }  // namespace libtextclassifier3::grammar
133