xref: /aosp_15_r20/external/federated-compute/fcp/dictionary/dictionary_test.cc (revision 14675a029014e728ec732f129a32e299b2da0601)
1 /*
2  * Copyright 2022 Google LLC
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "fcp/dictionary/dictionary.h"
17 
18 #include <memory>
19 #include <string>
20 #include <vector>
21 
22 #include "fcp/base/monitoring.h"
23 #include "fcp/dictionary/dictionary.pb.h"
24 #include "fcp/testing/parse_text_proto.h"
25 #include "gmock/gmock.h"
26 #include "gtest/gtest.h"
27 #include "absl/status/status.h"
28 
29 namespace fcp {
30 namespace dictionary {
31 
32 using ::testing::ElementsAre;
33 
34 class DictionaryTest : public ::testing::Test {};
35 
TEST_F(DictionaryTest,TestMapDictionaryLookup)36 TEST_F(DictionaryTest, TestMapDictionaryLookup) {
37   std::unique_ptr<Dictionary> dictionary = *Dictionary::Create(PARSE_TEXT_PROTO(
38       "vocabulary: < index: < token: 'a' token: 'b' token: 'c' > >"));
39 
40   EXPECT_EQ(0, dictionary->TokenToId("a"));
41   EXPECT_EQ(1, dictionary->TokenToId("b"));
42   EXPECT_EQ(2, dictionary->TokenToId("c"));
43   EXPECT_EQ(Dictionary::kNotFound, dictionary->TokenToId("d"));
44 }
45 
TEST_F(DictionaryTest,TestMapDictionaryLookupWithUnk)46 TEST_F(DictionaryTest, TestMapDictionaryLookupWithUnk) {
47   std::unique_ptr<Dictionary> dictionary = *Dictionary::Create(
48       PARSE_TEXT_PROTO("special_ids: < unk: 0 bos: 1 > "
49                        "vocabulary: < index: <"
50                        "  token: 'a' token: 'b' token: 'c' > >"));
51   EXPECT_EQ(2, dictionary->TokenToId("a"));
52   EXPECT_EQ(3, dictionary->TokenToId("b"));
53   EXPECT_EQ(4, dictionary->TokenToId("c"));
54   EXPECT_EQ(0, dictionary->TokenToId("d"));
55   EXPECT_EQ(0, dictionary->TokenToId("e"));
56   EXPECT_EQ(0, dictionary->TokenToId("<UNK>"));
57   EXPECT_EQ(0, dictionary->TokenToId("<BOS>"));
58 }
59 
TEST_F(DictionaryTest,TestMapDictionaryLookupWithSpecialTokenHoles)60 TEST_F(DictionaryTest, TestMapDictionaryLookupWithSpecialTokenHoles) {
61   std::unique_ptr<Dictionary> dictionary = *Dictionary::Create(
62       PARSE_TEXT_PROTO("special_ids: < unk: 1 bos: 4 > "
63                        "vocabulary: < index: <"
64                        "  token: 'a' token: 'b' token: 'c' > >"));
65 
66   // Make sure dictionary doesn't use the "holes" in IDs - 0, 2 and 3 - for
67   // tokens, but starts numbering tokens with max(special_ids) + 1.
68   EXPECT_EQ(5, dictionary->TokenToId("a"));
69   EXPECT_EQ(6, dictionary->TokenToId("b"));
70   EXPECT_EQ(7, dictionary->TokenToId("c"));
71   EXPECT_EQ(1, dictionary->TokenToId("d"));
72   EXPECT_EQ(1, dictionary->TokenToId("e"));
73   EXPECT_EQ(1, dictionary->TokenToId("<UNK>"));
74   EXPECT_EQ(1, dictionary->TokenToId("<BOS>"));
75 }
76 
TEST_F(DictionaryTest,TestMapDictionaryReverseLookup)77 TEST_F(DictionaryTest, TestMapDictionaryReverseLookup) {
78   std::unique_ptr<Dictionary> dictionary = *Dictionary::Create(PARSE_TEXT_PROTO(
79       "vocabulary: < index: < token: 'a' token: 'b' token: 'c' > >"));
80   EXPECT_EQ(dictionary->IdToToken(dictionary->TokenToId("a")), "a");
81   EXPECT_EQ(dictionary->IdToToken(dictionary->TokenToId("b")), "b");
82   EXPECT_EQ(dictionary->IdToToken(dictionary->TokenToId("c")), "c");
83   EXPECT_EQ(dictionary->IdToToken(dictionary->TokenToId("d")), "");
84   EXPECT_EQ(dictionary->IdToToken(0xDEADBEEF), "");
85   EXPECT_EQ(dictionary->IdToToken(1337), "");
86 }
87 
TEST_F(DictionaryTest,TestMapDictionaryReverseLookupWithUnk)88 TEST_F(DictionaryTest, TestMapDictionaryReverseLookupWithUnk) {
89   std::unique_ptr<Dictionary> dictionary = *Dictionary::Create(
90       PARSE_TEXT_PROTO("special_ids: < unk: 0 bos: 1 > "
91                        "vocabulary: < index: <"
92                        "  token: 'a' token: 'b' token: 'c' > >"));
93   EXPECT_EQ(dictionary->IdToToken(dictionary->TokenToId("a")), "a");
94   EXPECT_EQ(dictionary->IdToToken(dictionary->TokenToId("b")), "b");
95   EXPECT_EQ(dictionary->IdToToken(dictionary->TokenToId("c")), "c");
96   EXPECT_EQ(dictionary->IdToToken(dictionary->TokenToId("d")), "");
97   EXPECT_EQ(dictionary->IdToToken(0xDEADBEEF), "");
98   EXPECT_EQ(dictionary->IdToToken(1337), "");
99 }
100 
TEST_F(DictionaryTest,TestMapDictionaryReverseLookupWithSpecialTokenHoles)101 TEST_F(DictionaryTest, TestMapDictionaryReverseLookupWithSpecialTokenHoles) {
102   std::unique_ptr<Dictionary> dictionary = *Dictionary::Create(
103       PARSE_TEXT_PROTO("special_ids: < unk: 1 bos: 4 > "
104                        "vocabulary: < index: <"
105                        "  token: 'a' token: 'b' token: 'c' > >"));
106   EXPECT_EQ(dictionary->IdToToken(dictionary->TokenToId("a")), "a");
107   EXPECT_EQ(dictionary->IdToToken(dictionary->TokenToId("b")), "b");
108   EXPECT_EQ(dictionary->IdToToken(dictionary->TokenToId("c")), "c");
109   EXPECT_EQ(dictionary->IdToToken(dictionary->TokenToId("d")), "");
110   EXPECT_EQ(dictionary->IdToToken(0xDEADBEEF), "");
111   EXPECT_EQ(dictionary->IdToToken(1337), "");
112 }
113 }  // namespace dictionary
114 }  // namespace fcp
115