1 /*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "utils/grammar/semantics/composer.h"
18
19 #include "utils/base/status_macros.h"
20 #include "utils/grammar/semantics/evaluators/arithmetic-eval.h"
21 #include "utils/grammar/semantics/evaluators/compose-eval.h"
22 #include "utils/grammar/semantics/evaluators/const-eval.h"
23 #include "utils/grammar/semantics/evaluators/constituent-eval.h"
24 #include "utils/grammar/semantics/evaluators/merge-values-eval.h"
25 #include "utils/grammar/semantics/evaluators/parse-number-eval.h"
26 #include "utils/grammar/semantics/evaluators/span-eval.h"
27
28 namespace libtextclassifier3::grammar {
29 namespace {
30
31 // Gathers all constituents of a rule and index them.
32 // The constituents are numbered in the rule construction. But consituents could
33 // be in optional parts of the rule and might not be present in a match.
34 // This finds all constituents that are present in a match and allows to
35 // retrieve them by their index.
GatherConstituents(const ParseTree * root)36 std::unordered_map<int, const ParseTree*> GatherConstituents(
37 const ParseTree* root) {
38 std::unordered_map<int, const ParseTree*> constituents;
39 Traverse(root, [root, &constituents](const ParseTree* node) {
40 switch (node->type) {
41 case ParseTree::Type::kMapping:
42 TC3_CHECK(node->IsUnaryRule());
43 constituents[static_cast<const MappingNode*>(node)->id] =
44 node->unary_rule_rhs();
45 return false;
46 case ParseTree::Type::kDefault:
47 // Continue traversal.
48 return true;
49 default:
50 // Don't continue the traversal if we are not at the root node.
51 // This could e.g. be an assertion node.
52 return (node == root);
53 }
54 });
55 return constituents;
56 }
57
58 } // namespace
59
SemanticComposer(const reflection::Schema * semantic_values_schema)60 SemanticComposer::SemanticComposer(
61 const reflection::Schema* semantic_values_schema) {
62 evaluators_.emplace(SemanticExpression_::Expression_ArithmeticExpression,
63 std::make_unique<ArithmeticExpressionEvaluator>(this));
64 evaluators_.emplace(SemanticExpression_::Expression_ConstituentExpression,
65 std::make_unique<ConstituentEvaluator>());
66 evaluators_.emplace(SemanticExpression_::Expression_ParseNumberExpression,
67 std::make_unique<ParseNumberEvaluator>(this));
68 evaluators_.emplace(SemanticExpression_::Expression_SpanAsStringExpression,
69 std::make_unique<SpanAsStringEvaluator>());
70 if (semantic_values_schema != nullptr) {
71 // Register semantic functions.
72 evaluators_.emplace(
73 SemanticExpression_::Expression_ComposeExpression,
74 std::make_unique<ComposeEvaluator>(this, semantic_values_schema));
75 evaluators_.emplace(
76 SemanticExpression_::Expression_ConstValueExpression,
77 std::make_unique<ConstEvaluator>(semantic_values_schema));
78 evaluators_.emplace(
79 SemanticExpression_::Expression_MergeValueExpression,
80 std::make_unique<MergeValuesEvaluator>(this, semantic_values_schema));
81 }
82 }
83
Eval(const TextContext & text_context,const Derivation & derivation,UnsafeArena * arena) const84 StatusOr<const SemanticValue*> SemanticComposer::Eval(
85 const TextContext& text_context, const Derivation& derivation,
86 UnsafeArena* arena) const {
87 if (!derivation.parse_tree->IsUnaryRule() ||
88 derivation.parse_tree->unary_rule_rhs()->type !=
89 ParseTree::Type::kExpression) {
90 return nullptr;
91 }
92 return Eval(text_context,
93 static_cast<const SemanticExpressionNode*>(
94 derivation.parse_tree->unary_rule_rhs()),
95 arena);
96 }
97
Eval(const TextContext & text_context,const SemanticExpressionNode * derivation,UnsafeArena * arena) const98 StatusOr<const SemanticValue*> SemanticComposer::Eval(
99 const TextContext& text_context, const SemanticExpressionNode* derivation,
100 UnsafeArena* arena) const {
101 // Evaluate constituents.
102 EvalContext context{&text_context, derivation};
103 for (const auto& [constituent_index, constituent] :
104 GatherConstituents(derivation)) {
105 if (constituent->type == ParseTree::Type::kExpression) {
106 TC3_ASSIGN_OR_RETURN(
107 context.rule_constituents[constituent_index],
108 Eval(text_context,
109 static_cast<const SemanticExpressionNode*>(constituent), arena));
110 } else {
111 // Just use the text of the constituent if no semantic expression was
112 // defined.
113 context.rule_constituents[constituent_index] = SemanticValue::Create(
114 text_context.Span(constituent->codepoint_span), arena);
115 }
116 }
117 return Apply(context, derivation->expression, arena);
118 }
119
Apply(const EvalContext & context,const SemanticExpression * expression,UnsafeArena * arena) const120 StatusOr<const SemanticValue*> SemanticComposer::Apply(
121 const EvalContext& context, const SemanticExpression* expression,
122 UnsafeArena* arena) const {
123 const auto handler_it = evaluators_.find(expression->expression_type());
124 if (handler_it == evaluators_.end()) {
125 return Status(StatusCode::INVALID_ARGUMENT,
126 std::string("Unhandled expression type: ") +
127 EnumNameExpression(expression->expression_type()));
128 }
129 return handler_it->second->Apply(context, expression, arena);
130 }
131
132 } // namespace libtextclassifier3::grammar
133