xref: /aosp_15_r20/external/federated-compute/fcp/dictionary/dictionary_test.cc (revision 14675a029014e728ec732f129a32e299b2da0601)
1*14675a02SAndroid Build Coastguard Worker /*
2*14675a02SAndroid Build Coastguard Worker  * Copyright 2022 Google LLC
3*14675a02SAndroid Build Coastguard Worker  *
4*14675a02SAndroid Build Coastguard Worker  * Licensed under the Apache License, Version 2.0 (the "License");
5*14675a02SAndroid Build Coastguard Worker  * you may not use this file except in compliance with the License.
6*14675a02SAndroid Build Coastguard Worker  * You may obtain a copy of the License at
7*14675a02SAndroid Build Coastguard Worker  *
8*14675a02SAndroid Build Coastguard Worker  *      http://www.apache.org/licenses/LICENSE-2.0
9*14675a02SAndroid Build Coastguard Worker  *
10*14675a02SAndroid Build Coastguard Worker  * Unless required by applicable law or agreed to in writing, software
11*14675a02SAndroid Build Coastguard Worker  * distributed under the License is distributed on an "AS IS" BASIS,
12*14675a02SAndroid Build Coastguard Worker  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13*14675a02SAndroid Build Coastguard Worker  * See the License for the specific language governing permissions and
14*14675a02SAndroid Build Coastguard Worker  * limitations under the License.
15*14675a02SAndroid Build Coastguard Worker  */
16*14675a02SAndroid Build Coastguard Worker #include "fcp/dictionary/dictionary.h"
17*14675a02SAndroid Build Coastguard Worker 
18*14675a02SAndroid Build Coastguard Worker #include <memory>
19*14675a02SAndroid Build Coastguard Worker #include <string>
20*14675a02SAndroid Build Coastguard Worker #include <vector>
21*14675a02SAndroid Build Coastguard Worker 
22*14675a02SAndroid Build Coastguard Worker #include "fcp/base/monitoring.h"
23*14675a02SAndroid Build Coastguard Worker #include "fcp/dictionary/dictionary.pb.h"
24*14675a02SAndroid Build Coastguard Worker #include "fcp/testing/parse_text_proto.h"
25*14675a02SAndroid Build Coastguard Worker #include "gmock/gmock.h"
26*14675a02SAndroid Build Coastguard Worker #include "gtest/gtest.h"
27*14675a02SAndroid Build Coastguard Worker #include "absl/status/status.h"
28*14675a02SAndroid Build Coastguard Worker 
29*14675a02SAndroid Build Coastguard Worker namespace fcp {
30*14675a02SAndroid Build Coastguard Worker namespace dictionary {
31*14675a02SAndroid Build Coastguard Worker 
32*14675a02SAndroid Build Coastguard Worker using ::testing::ElementsAre;
33*14675a02SAndroid Build Coastguard Worker 
34*14675a02SAndroid Build Coastguard Worker class DictionaryTest : public ::testing::Test {};
35*14675a02SAndroid Build Coastguard Worker 
TEST_F(DictionaryTest,TestMapDictionaryLookup)36*14675a02SAndroid Build Coastguard Worker TEST_F(DictionaryTest, TestMapDictionaryLookup) {
37*14675a02SAndroid Build Coastguard Worker   std::unique_ptr<Dictionary> dictionary = *Dictionary::Create(PARSE_TEXT_PROTO(
38*14675a02SAndroid Build Coastguard Worker       "vocabulary: < index: < token: 'a' token: 'b' token: 'c' > >"));
39*14675a02SAndroid Build Coastguard Worker 
40*14675a02SAndroid Build Coastguard Worker   EXPECT_EQ(0, dictionary->TokenToId("a"));
41*14675a02SAndroid Build Coastguard Worker   EXPECT_EQ(1, dictionary->TokenToId("b"));
42*14675a02SAndroid Build Coastguard Worker   EXPECT_EQ(2, dictionary->TokenToId("c"));
43*14675a02SAndroid Build Coastguard Worker   EXPECT_EQ(Dictionary::kNotFound, dictionary->TokenToId("d"));
44*14675a02SAndroid Build Coastguard Worker }
45*14675a02SAndroid Build Coastguard Worker 
TEST_F(DictionaryTest,TestMapDictionaryLookupWithUnk)46*14675a02SAndroid Build Coastguard Worker TEST_F(DictionaryTest, TestMapDictionaryLookupWithUnk) {
47*14675a02SAndroid Build Coastguard Worker   std::unique_ptr<Dictionary> dictionary = *Dictionary::Create(
48*14675a02SAndroid Build Coastguard Worker       PARSE_TEXT_PROTO("special_ids: < unk: 0 bos: 1 > "
49*14675a02SAndroid Build Coastguard Worker                        "vocabulary: < index: <"
50*14675a02SAndroid Build Coastguard Worker                        "  token: 'a' token: 'b' token: 'c' > >"));
51*14675a02SAndroid Build Coastguard Worker   EXPECT_EQ(2, dictionary->TokenToId("a"));
52*14675a02SAndroid Build Coastguard Worker   EXPECT_EQ(3, dictionary->TokenToId("b"));
53*14675a02SAndroid Build Coastguard Worker   EXPECT_EQ(4, dictionary->TokenToId("c"));
54*14675a02SAndroid Build Coastguard Worker   EXPECT_EQ(0, dictionary->TokenToId("d"));
55*14675a02SAndroid Build Coastguard Worker   EXPECT_EQ(0, dictionary->TokenToId("e"));
56*14675a02SAndroid Build Coastguard Worker   EXPECT_EQ(0, dictionary->TokenToId("<UNK>"));
57*14675a02SAndroid Build Coastguard Worker   EXPECT_EQ(0, dictionary->TokenToId("<BOS>"));
58*14675a02SAndroid Build Coastguard Worker }
59*14675a02SAndroid Build Coastguard Worker 
TEST_F(DictionaryTest,TestMapDictionaryLookupWithSpecialTokenHoles)60*14675a02SAndroid Build Coastguard Worker TEST_F(DictionaryTest, TestMapDictionaryLookupWithSpecialTokenHoles) {
61*14675a02SAndroid Build Coastguard Worker   std::unique_ptr<Dictionary> dictionary = *Dictionary::Create(
62*14675a02SAndroid Build Coastguard Worker       PARSE_TEXT_PROTO("special_ids: < unk: 1 bos: 4 > "
63*14675a02SAndroid Build Coastguard Worker                        "vocabulary: < index: <"
64*14675a02SAndroid Build Coastguard Worker                        "  token: 'a' token: 'b' token: 'c' > >"));
65*14675a02SAndroid Build Coastguard Worker 
66*14675a02SAndroid Build Coastguard Worker   // Make sure dictionary doesn't use the "holes" in IDs - 0, 2 and 3 - for
67*14675a02SAndroid Build Coastguard Worker   // tokens, but starts numbering tokens with max(special_ids) + 1.
68*14675a02SAndroid Build Coastguard Worker   EXPECT_EQ(5, dictionary->TokenToId("a"));
69*14675a02SAndroid Build Coastguard Worker   EXPECT_EQ(6, dictionary->TokenToId("b"));
70*14675a02SAndroid Build Coastguard Worker   EXPECT_EQ(7, dictionary->TokenToId("c"));
71*14675a02SAndroid Build Coastguard Worker   EXPECT_EQ(1, dictionary->TokenToId("d"));
72*14675a02SAndroid Build Coastguard Worker   EXPECT_EQ(1, dictionary->TokenToId("e"));
73*14675a02SAndroid Build Coastguard Worker   EXPECT_EQ(1, dictionary->TokenToId("<UNK>"));
74*14675a02SAndroid Build Coastguard Worker   EXPECT_EQ(1, dictionary->TokenToId("<BOS>"));
75*14675a02SAndroid Build Coastguard Worker }
76*14675a02SAndroid Build Coastguard Worker 
TEST_F(DictionaryTest,TestMapDictionaryReverseLookup)77*14675a02SAndroid Build Coastguard Worker TEST_F(DictionaryTest, TestMapDictionaryReverseLookup) {
78*14675a02SAndroid Build Coastguard Worker   std::unique_ptr<Dictionary> dictionary = *Dictionary::Create(PARSE_TEXT_PROTO(
79*14675a02SAndroid Build Coastguard Worker       "vocabulary: < index: < token: 'a' token: 'b' token: 'c' > >"));
80*14675a02SAndroid Build Coastguard Worker   EXPECT_EQ(dictionary->IdToToken(dictionary->TokenToId("a")), "a");
81*14675a02SAndroid Build Coastguard Worker   EXPECT_EQ(dictionary->IdToToken(dictionary->TokenToId("b")), "b");
82*14675a02SAndroid Build Coastguard Worker   EXPECT_EQ(dictionary->IdToToken(dictionary->TokenToId("c")), "c");
83*14675a02SAndroid Build Coastguard Worker   EXPECT_EQ(dictionary->IdToToken(dictionary->TokenToId("d")), "");
84*14675a02SAndroid Build Coastguard Worker   EXPECT_EQ(dictionary->IdToToken(0xDEADBEEF), "");
85*14675a02SAndroid Build Coastguard Worker   EXPECT_EQ(dictionary->IdToToken(1337), "");
86*14675a02SAndroid Build Coastguard Worker }
87*14675a02SAndroid Build Coastguard Worker 
TEST_F(DictionaryTest,TestMapDictionaryReverseLookupWithUnk)88*14675a02SAndroid Build Coastguard Worker TEST_F(DictionaryTest, TestMapDictionaryReverseLookupWithUnk) {
89*14675a02SAndroid Build Coastguard Worker   std::unique_ptr<Dictionary> dictionary = *Dictionary::Create(
90*14675a02SAndroid Build Coastguard Worker       PARSE_TEXT_PROTO("special_ids: < unk: 0 bos: 1 > "
91*14675a02SAndroid Build Coastguard Worker                        "vocabulary: < index: <"
92*14675a02SAndroid Build Coastguard Worker                        "  token: 'a' token: 'b' token: 'c' > >"));
93*14675a02SAndroid Build Coastguard Worker   EXPECT_EQ(dictionary->IdToToken(dictionary->TokenToId("a")), "a");
94*14675a02SAndroid Build Coastguard Worker   EXPECT_EQ(dictionary->IdToToken(dictionary->TokenToId("b")), "b");
95*14675a02SAndroid Build Coastguard Worker   EXPECT_EQ(dictionary->IdToToken(dictionary->TokenToId("c")), "c");
96*14675a02SAndroid Build Coastguard Worker   EXPECT_EQ(dictionary->IdToToken(dictionary->TokenToId("d")), "");
97*14675a02SAndroid Build Coastguard Worker   EXPECT_EQ(dictionary->IdToToken(0xDEADBEEF), "");
98*14675a02SAndroid Build Coastguard Worker   EXPECT_EQ(dictionary->IdToToken(1337), "");
99*14675a02SAndroid Build Coastguard Worker }
100*14675a02SAndroid Build Coastguard Worker 
TEST_F(DictionaryTest,TestMapDictionaryReverseLookupWithSpecialTokenHoles)101*14675a02SAndroid Build Coastguard Worker TEST_F(DictionaryTest, TestMapDictionaryReverseLookupWithSpecialTokenHoles) {
102*14675a02SAndroid Build Coastguard Worker   std::unique_ptr<Dictionary> dictionary = *Dictionary::Create(
103*14675a02SAndroid Build Coastguard Worker       PARSE_TEXT_PROTO("special_ids: < unk: 1 bos: 4 > "
104*14675a02SAndroid Build Coastguard Worker                        "vocabulary: < index: <"
105*14675a02SAndroid Build Coastguard Worker                        "  token: 'a' token: 'b' token: 'c' > >"));
106*14675a02SAndroid Build Coastguard Worker   EXPECT_EQ(dictionary->IdToToken(dictionary->TokenToId("a")), "a");
107*14675a02SAndroid Build Coastguard Worker   EXPECT_EQ(dictionary->IdToToken(dictionary->TokenToId("b")), "b");
108*14675a02SAndroid Build Coastguard Worker   EXPECT_EQ(dictionary->IdToToken(dictionary->TokenToId("c")), "c");
109*14675a02SAndroid Build Coastguard Worker   EXPECT_EQ(dictionary->IdToToken(dictionary->TokenToId("d")), "");
110*14675a02SAndroid Build Coastguard Worker   EXPECT_EQ(dictionary->IdToToken(0xDEADBEEF), "");
111*14675a02SAndroid Build Coastguard Worker   EXPECT_EQ(dictionary->IdToToken(1337), "");
112*14675a02SAndroid Build Coastguard Worker }
113*14675a02SAndroid Build Coastguard Worker }  // namespace dictionary
114*14675a02SAndroid Build Coastguard Worker }  // namespace fcp
115