xref: /aosp_15_r20/external/cronet/third_party/icu/source/common/mlbe.cpp (revision 6777b5387eb2ff775bb5750e3f5d96f37fb7352b)
1 // © 2022 and later: Unicode, Inc. and others.
2 // License & terms of use: http://www.unicode.org/copyright.html
3 
4 #include "unicode/utypes.h"
5 
6 #if !UCONFIG_NO_BREAK_ITERATION
7 
8 #include "cmemory.h"
9 #include "mlbe.h"
10 #include "uassert.h"
11 #include "ubrkimpl.h"
12 #include "unicode/resbund.h"
13 #include "unicode/udata.h"
14 #include "unicode/utf16.h"
15 #include "uresimp.h"
16 #include "util.h"
17 #include "uvectr32.h"
18 
19 U_NAMESPACE_BEGIN
20 
21 enum class ModelIndex { kUWStart = 0, kBWStart = 6, kTWStart = 9 };
22 
MlBreakEngine(const UnicodeSet & digitOrOpenPunctuationOrAlphabetSet,const UnicodeSet & closePunctuationSet,UErrorCode & status)23 MlBreakEngine::MlBreakEngine(const UnicodeSet &digitOrOpenPunctuationOrAlphabetSet,
24                              const UnicodeSet &closePunctuationSet, UErrorCode &status)
25     : fDigitOrOpenPunctuationOrAlphabetSet(digitOrOpenPunctuationOrAlphabetSet),
26       fClosePunctuationSet(closePunctuationSet),
27       fNegativeSum(0) {
28     if (U_FAILURE(status)) {
29         return;
30     }
31     loadMLModel(status);
32 }
33 
~MlBreakEngine()34 MlBreakEngine::~MlBreakEngine() {}
35 
divideUpRange(UText * inText,int32_t rangeStart,int32_t rangeEnd,UVector32 & foundBreaks,const UnicodeString & inString,const LocalPointer<UVector32> & inputMap,UErrorCode & status) const36 int32_t MlBreakEngine::divideUpRange(UText *inText, int32_t rangeStart, int32_t rangeEnd,
37                                      UVector32 &foundBreaks, const UnicodeString &inString,
38                                      const LocalPointer<UVector32> &inputMap,
39                                      UErrorCode &status) const {
40     if (U_FAILURE(status)) {
41         return 0;
42     }
43     if (rangeStart >= rangeEnd) {
44         status = U_ILLEGAL_ARGUMENT_ERROR;
45         return 0;
46     }
47 
48     UVector32 boundary(inString.countChar32() + 1, status);
49     if (U_FAILURE(status)) {
50         return 0;
51     }
52     int32_t numBreaks = 0;
53     int32_t codePointLength = inString.countChar32();
54     // The ML algorithm groups six char and evaluates whether the 4th char is a breakpoint.
55     // In each iteration, it evaluates the 4th char and then moves forward one char like a sliding
56     // window. Initially, the first six values in the indexList are [-1, -1, 0, 1, 2, 3]. After
57     // moving forward, finally the last six values in the indexList are
58     // [length-4, length-3, length-2, length-1, -1, -1]. The "+4" here means four extra "-1".
59     int32_t indexSize = codePointLength + 4;
60     int32_t *indexList = (int32_t *)uprv_malloc(indexSize * sizeof(int32_t));
61     if (indexList == nullptr) {
62         status = U_MEMORY_ALLOCATION_ERROR;
63         return 0;
64     }
65     int32_t numCodeUnits = initIndexList(inString, indexList, status);
66 
67     // Add a break for the start.
68     boundary.addElement(0, status);
69     numBreaks++;
70     if (U_FAILURE(status)) return 0;
71 
72     for (int32_t idx = 0; idx + 1 < codePointLength && U_SUCCESS(status); idx++) {
73         numBreaks =
74             evaluateBreakpoint(inString, indexList, idx, numCodeUnits, numBreaks, boundary, status);
75         if (idx + 4 < codePointLength) {
76             indexList[idx + 6] = numCodeUnits;
77             numCodeUnits += U16_LENGTH(inString.char32At(indexList[idx + 6]));
78         }
79     }
80     uprv_free(indexList);
81 
82     if (U_FAILURE(status)) return 0;
83 
84     // Add a break for the end if there is not one there already.
85     if (boundary.lastElementi() != inString.countChar32()) {
86         boundary.addElement(inString.countChar32(), status);
87         numBreaks++;
88     }
89 
90     int32_t prevCPPos = -1;
91     int32_t prevUTextPos = -1;
92     int32_t correctedNumBreaks = 0;
93     for (int32_t i = 0; i < numBreaks; i++) {
94         int32_t cpPos = boundary.elementAti(i);
95         int32_t utextPos = inputMap.isValid() ? inputMap->elementAti(cpPos) : cpPos + rangeStart;
96         U_ASSERT(cpPos > prevCPPos);
97         U_ASSERT(utextPos >= prevUTextPos);
98 
99         if (utextPos > prevUTextPos) {
100             if (utextPos != rangeStart ||
101                 (utextPos > 0 &&
102                  fClosePunctuationSet.contains(utext_char32At(inText, utextPos - 1)))) {
103                 foundBreaks.push(utextPos, status);
104                 correctedNumBreaks++;
105             }
106         } else {
107             // Normalization expanded the input text, the dictionary found a boundary
108             // within the expansion, giving two boundaries with the same index in the
109             // original text. Ignore the second. See ticket #12918.
110             --numBreaks;
111         }
112         prevCPPos = cpPos;
113         prevUTextPos = utextPos;
114     }
115     (void)prevCPPos;  // suppress compiler warnings about unused variable
116 
117     UChar32 nextChar = utext_char32At(inText, rangeEnd);
118     if (!foundBreaks.isEmpty() && foundBreaks.peeki() == rangeEnd) {
119         // In phrase breaking, there has to be a breakpoint between Cj character and
120         // the number/open punctuation.
121         // E.g. る文字「そうだ、京都」->る▁文字▁「そうだ、▁京都」-> breakpoint between 字 and「
122         // E.g. 乗車率90%程度だろうか -> 乗車▁率▁90%▁程度だろうか -> breakpoint between 率 and 9
123         // E.g. しかもロゴがUnicode! -> しかも▁ロゴが▁Unicode!-> breakpoint between が and U
124         if (!fDigitOrOpenPunctuationOrAlphabetSet.contains(nextChar)) {
125             foundBreaks.popi();
126             correctedNumBreaks--;
127         }
128     }
129 
130     return correctedNumBreaks;
131 }
132 
evaluateBreakpoint(const UnicodeString & inString,int32_t * indexList,int32_t startIdx,int32_t numCodeUnits,int32_t numBreaks,UVector32 & boundary,UErrorCode & status) const133 int32_t MlBreakEngine::evaluateBreakpoint(const UnicodeString &inString, int32_t *indexList,
134                                           int32_t startIdx, int32_t numCodeUnits, int32_t numBreaks,
135                                           UVector32 &boundary, UErrorCode &status) const {
136     if (U_FAILURE(status)) {
137         return numBreaks;
138     }
139     int32_t start = 0, end = 0;
140     int32_t score = fNegativeSum;
141 
142     for (int i = 0; i < 6; i++) {
143         // UW1 ~ UW6
144         start = startIdx + i;
145         if (indexList[start] != -1) {
146             end = (indexList[start + 1] != -1) ? indexList[start + 1] : numCodeUnits;
147             score += fModel[static_cast<int32_t>(ModelIndex::kUWStart) + i].geti(
148                 inString.tempSubString(indexList[start], end - indexList[start]));
149         }
150     }
151     for (int i = 0; i < 3; i++) {
152         // BW1 ~ BW3
153         start = startIdx + i + 1;
154         if (indexList[start] != -1 && indexList[start + 1] != -1) {
155             end = (indexList[start + 2] != -1) ? indexList[start + 2] : numCodeUnits;
156             score += fModel[static_cast<int32_t>(ModelIndex::kBWStart) + i].geti(
157                 inString.tempSubString(indexList[start], end - indexList[start]));
158         }
159     }
160     for (int i = 0; i < 4; i++) {
161         // TW1 ~ TW4
162         start = startIdx + i;
163         if (indexList[start] != -1 && indexList[start + 1] != -1 && indexList[start + 2] != -1) {
164             end = (indexList[start + 3] != -1) ? indexList[start + 3] : numCodeUnits;
165             score += fModel[static_cast<int32_t>(ModelIndex::kTWStart) + i].geti(
166                 inString.tempSubString(indexList[start], end - indexList[start]));
167         }
168     }
169 
170     if (score > 0) {
171         boundary.addElement(startIdx + 1, status);
172         numBreaks++;
173     }
174     return numBreaks;
175 }
176 
initIndexList(const UnicodeString & inString,int32_t * indexList,UErrorCode & status) const177 int32_t MlBreakEngine::initIndexList(const UnicodeString &inString, int32_t *indexList,
178                                      UErrorCode &status) const {
179     if (U_FAILURE(status)) {
180         return 0;
181     }
182     int32_t index = 0;
183     int32_t length = inString.countChar32();
184     // Set all (lenght+4) items inside indexLength to -1 presuming -1 is 4 bytes of 0xff.
185     uprv_memset(indexList, 0xff, (length + 4) * sizeof(int32_t));
186     if (length > 0) {
187         indexList[2] = 0;
188         index = U16_LENGTH(inString.char32At(0));
189         if (length > 1) {
190             indexList[3] = index;
191             index += U16_LENGTH(inString.char32At(index));
192             if (length > 2) {
193                 indexList[4] = index;
194                 index += U16_LENGTH(inString.char32At(index));
195                 if (length > 3) {
196                     indexList[5] = index;
197                     index += U16_LENGTH(inString.char32At(index));
198                 }
199             }
200         }
201     }
202     return index;
203 }
204 
loadMLModel(UErrorCode & error)205 void MlBreakEngine::loadMLModel(UErrorCode &error) {
206     // BudouX's model consists of thirteen categories, each of which is make up of pairs of the
207     // feature and its score. As integrating it into jaml.txt, we define thirteen kinds of key and
208     // value to represent the feature and the corresponding score respectively.
209 
210     if (U_FAILURE(error)) return;
211 
212     UnicodeString key;
213     StackUResourceBundle stackTempBundle;
214     ResourceDataValue modelKey;
215 
216     LocalUResourceBundlePointer rbp(ures_openDirect(U_ICUDATA_BRKITR, "jaml", &error));
217     UResourceBundle *rb = rbp.getAlias();
218     if (U_FAILURE(error)) return;
219 
220     int32_t index = 0;
221     initKeyValue(rb, "UW1Keys", "UW1Values", fModel[index++], error);
222     initKeyValue(rb, "UW2Keys", "UW2Values", fModel[index++], error);
223     initKeyValue(rb, "UW3Keys", "UW3Values", fModel[index++], error);
224     initKeyValue(rb, "UW4Keys", "UW4Values", fModel[index++], error);
225     initKeyValue(rb, "UW5Keys", "UW5Values", fModel[index++], error);
226     initKeyValue(rb, "UW6Keys", "UW6Values", fModel[index++], error);
227     initKeyValue(rb, "BW1Keys", "BW1Values", fModel[index++], error);
228     initKeyValue(rb, "BW2Keys", "BW2Values", fModel[index++], error);
229     initKeyValue(rb, "BW3Keys", "BW3Values", fModel[index++], error);
230     initKeyValue(rb, "TW1Keys", "TW1Values", fModel[index++], error);
231     initKeyValue(rb, "TW2Keys", "TW2Values", fModel[index++], error);
232     initKeyValue(rb, "TW3Keys", "TW3Values", fModel[index++], error);
233     initKeyValue(rb, "TW4Keys", "TW4Values", fModel[index++], error);
234     fNegativeSum /= 2;
235 }
236 
initKeyValue(UResourceBundle * rb,const char * keyName,const char * valueName,Hashtable & model,UErrorCode & error)237 void MlBreakEngine::initKeyValue(UResourceBundle *rb, const char *keyName, const char *valueName,
238                                  Hashtable &model, UErrorCode &error) {
239     int32_t keySize = 0;
240     int32_t valueSize = 0;
241     int32_t stringLength = 0;
242     UnicodeString key;
243     StackUResourceBundle stackTempBundle;
244     ResourceDataValue modelKey;
245 
246     // get modelValues
247     LocalUResourceBundlePointer modelValue(ures_getByKey(rb, valueName, nullptr, &error));
248     const int32_t *value = ures_getIntVector(modelValue.getAlias(), &valueSize, &error);
249     if (U_FAILURE(error)) return;
250 
251     // get modelKeys
252     ures_getValueWithFallback(rb, keyName, stackTempBundle.getAlias(), modelKey, error);
253     ResourceArray stringArray = modelKey.getArray(error);
254     keySize = stringArray.getSize();
255     if (U_FAILURE(error)) return;
256 
257     for (int32_t idx = 0; idx < keySize; idx++) {
258         stringArray.getValue(idx, modelKey);
259         key = UnicodeString(modelKey.getString(stringLength, error));
260         if (U_SUCCESS(error)) {
261             U_ASSERT(idx < valueSize);
262             fNegativeSum -= value[idx];
263             model.puti(key, value[idx], error);
264         }
265     }
266 }
267 
268 U_NAMESPACE_END
269 
270 #endif /* #if !UCONFIG_NO_BREAK_ITERATION */
271