xref: /aosp_15_r20/external/libtextclassifier/native/utils/sentencepiece/encoder.h (revision 993b0882672172b81d12fad7a7ac0c3e5c824a12)
1*993b0882SAndroid Build Coastguard Worker /*
2*993b0882SAndroid Build Coastguard Worker  * Copyright (C) 2018 The Android Open Source Project
3*993b0882SAndroid Build Coastguard Worker  *
4*993b0882SAndroid Build Coastguard Worker  * Licensed under the Apache License, Version 2.0 (the "License");
5*993b0882SAndroid Build Coastguard Worker  * you may not use this file except in compliance with the License.
6*993b0882SAndroid Build Coastguard Worker  * You may obtain a copy of the License at
7*993b0882SAndroid Build Coastguard Worker  *
8*993b0882SAndroid Build Coastguard Worker  *      http://www.apache.org/licenses/LICENSE-2.0
9*993b0882SAndroid Build Coastguard Worker  *
10*993b0882SAndroid Build Coastguard Worker  * Unless required by applicable law or agreed to in writing, software
11*993b0882SAndroid Build Coastguard Worker  * distributed under the License is distributed on an "AS IS" BASIS,
12*993b0882SAndroid Build Coastguard Worker  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13*993b0882SAndroid Build Coastguard Worker  * See the License for the specific language governing permissions and
14*993b0882SAndroid Build Coastguard Worker  * limitations under the License.
15*993b0882SAndroid Build Coastguard Worker  */
16*993b0882SAndroid Build Coastguard Worker 
17*993b0882SAndroid Build Coastguard Worker #ifndef LIBTEXTCLASSIFIER_UTILS_SENTENCEPIECE_ENCODER_H_
18*993b0882SAndroid Build Coastguard Worker #define LIBTEXTCLASSIFIER_UTILS_SENTENCEPIECE_ENCODER_H_
19*993b0882SAndroid Build Coastguard Worker 
20*993b0882SAndroid Build Coastguard Worker #include <vector>
21*993b0882SAndroid Build Coastguard Worker 
22*993b0882SAndroid Build Coastguard Worker #include "utils/base/logging.h"
23*993b0882SAndroid Build Coastguard Worker #include "utils/container/string-set.h"
24*993b0882SAndroid Build Coastguard Worker #include "utils/strings/stringpiece.h"
25*993b0882SAndroid Build Coastguard Worker 
26*993b0882SAndroid Build Coastguard Worker namespace libtextclassifier3 {
27*993b0882SAndroid Build Coastguard Worker 
28*993b0882SAndroid Build Coastguard Worker // Encoder to segment/tokenize strings into pieces such that the sum of the
29*993b0882SAndroid Build Coastguard Worker // scores of the pieces used is maximized.
30*993b0882SAndroid Build Coastguard Worker class Encoder {
31*993b0882SAndroid Build Coastguard Worker  public:
32*993b0882SAndroid Build Coastguard Worker   // pieces: the list of valid sentence pieces represented as a string set, e.g.
33*993b0882SAndroid Build Coastguard Worker   //     a trie.
34*993b0882SAndroid Build Coastguard Worker   // num_pieces: the number of pieces in the trie.
35*993b0882SAndroid Build Coastguard Worker   // pieces_scores: the scores of the individual pieces.
36*993b0882SAndroid Build Coastguard Worker   // start_code: code that is used as encoding of the start of input.
37*993b0882SAndroid Build Coastguard Worker   // end_code: code that is used as encoding of the end of input.
38*993b0882SAndroid Build Coastguard Worker   // encoding_offset: value added to the sentence piece ids to make them
39*993b0882SAndroid Build Coastguard Worker   //     not interesecting with start_code and end_code.
40*993b0882SAndroid Build Coastguard Worker   // unknown_code: code that is used for out-of-dictionary characters.
41*993b0882SAndroid Build Coastguard Worker   // unknown_score: the penality score associated with the unknown code.
42*993b0882SAndroid Build Coastguard Worker   Encoder(const StringSet* pieces, const int num_pieces,
43*993b0882SAndroid Build Coastguard Worker           const float* pieces_scores, int start_code = 0, int end_code = 1,
44*993b0882SAndroid Build Coastguard Worker           int encoding_offset = 2, int unknown_code = -1,
45*993b0882SAndroid Build Coastguard Worker           float unknown_score = 0.f)
num_pieces_(num_pieces)46*993b0882SAndroid Build Coastguard Worker       : num_pieces_(num_pieces),
47*993b0882SAndroid Build Coastguard Worker         scores_(pieces_scores),
48*993b0882SAndroid Build Coastguard Worker         pieces_(pieces),
49*993b0882SAndroid Build Coastguard Worker         start_code_(start_code),
50*993b0882SAndroid Build Coastguard Worker         end_code_(end_code),
51*993b0882SAndroid Build Coastguard Worker         encoding_offset_(encoding_offset),
52*993b0882SAndroid Build Coastguard Worker         unknown_code_(unknown_code),
53*993b0882SAndroid Build Coastguard Worker         unknown_score_(unknown_score) {}
54*993b0882SAndroid Build Coastguard Worker 
55*993b0882SAndroid Build Coastguard Worker   // Segment the input so that the total score of the pieces used is maximized.
56*993b0882SAndroid Build Coastguard Worker   // This is a simplified implementation of the general Viterbi algorithm,
57*993b0882SAndroid Build Coastguard Worker   // assuming independence between individual pieces.
58*993b0882SAndroid Build Coastguard Worker   bool Encode(StringPiece normalized_text,
59*993b0882SAndroid Build Coastguard Worker               std::vector<int>* encoded_text) const;
60*993b0882SAndroid Build Coastguard Worker 
61*993b0882SAndroid Build Coastguard Worker  private:
62*993b0882SAndroid Build Coastguard Worker   // State in the dynamic programming algorithm.
63*993b0882SAndroid Build Coastguard Worker   struct SegmentationEntry {
64*993b0882SAndroid Build Coastguard Worker     // Accumulated score.
65*993b0882SAndroid Build Coastguard Worker     float score;
66*993b0882SAndroid Build Coastguard Worker 
67*993b0882SAndroid Build Coastguard Worker     // Position before last piece.
68*993b0882SAndroid Build Coastguard Worker     int previous_pos;
69*993b0882SAndroid Build Coastguard Worker 
70*993b0882SAndroid Build Coastguard Worker     // Last piece used.
71*993b0882SAndroid Build Coastguard Worker     int piece_id;
72*993b0882SAndroid Build Coastguard Worker 
73*993b0882SAndroid Build Coastguard Worker     // Total number of pieces used.
74*993b0882SAndroid Build Coastguard Worker     int num_pieces;
75*993b0882SAndroid Build Coastguard Worker   };
76*993b0882SAndroid Build Coastguard Worker 
77*993b0882SAndroid Build Coastguard Worker   const int num_pieces_;
78*993b0882SAndroid Build Coastguard Worker   const float* scores_;
79*993b0882SAndroid Build Coastguard Worker   const StringSet* pieces_;
80*993b0882SAndroid Build Coastguard Worker   const int start_code_;
81*993b0882SAndroid Build Coastguard Worker   const int end_code_;
82*993b0882SAndroid Build Coastguard Worker   const int encoding_offset_;
83*993b0882SAndroid Build Coastguard Worker   const int unknown_code_;
84*993b0882SAndroid Build Coastguard Worker   const int unknown_score_;
85*993b0882SAndroid Build Coastguard Worker };
86*993b0882SAndroid Build Coastguard Worker 
87*993b0882SAndroid Build Coastguard Worker }  // namespace libtextclassifier3
88*993b0882SAndroid Build Coastguard Worker 
89*993b0882SAndroid Build Coastguard Worker #endif  // LIBTEXTCLASSIFIER_UTILS_SENTENCEPIECE_ENCODER_H_
90