xref: /aosp_15_r20/external/federated-compute/fcp/dictionary/dictionary.cc (revision 14675a029014e728ec732f129a32e299b2da0601)
1*14675a02SAndroid Build Coastguard Worker /*
2*14675a02SAndroid Build Coastguard Worker  * Copyright 2022 Google LLC
3*14675a02SAndroid Build Coastguard Worker  *
4*14675a02SAndroid Build Coastguard Worker  * Licensed under the Apache License, Version 2.0 (the "License");
5*14675a02SAndroid Build Coastguard Worker  * you may not use this file except in compliance with the License.
6*14675a02SAndroid Build Coastguard Worker  * You may obtain a copy of the License at
7*14675a02SAndroid Build Coastguard Worker  *
8*14675a02SAndroid Build Coastguard Worker  *      http://www.apache.org/licenses/LICENSE-2.0
9*14675a02SAndroid Build Coastguard Worker  *
10*14675a02SAndroid Build Coastguard Worker  * Unless required by applicable law or agreed to in writing, software
11*14675a02SAndroid Build Coastguard Worker  * distributed under the License is distributed on an "AS IS" BASIS,
12*14675a02SAndroid Build Coastguard Worker  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13*14675a02SAndroid Build Coastguard Worker  * See the License for the specific language governing permissions and
14*14675a02SAndroid Build Coastguard Worker  * limitations under the License.
15*14675a02SAndroid Build Coastguard Worker  */
16*14675a02SAndroid Build Coastguard Worker #include "fcp/dictionary/dictionary.h"
17*14675a02SAndroid Build Coastguard Worker 
18*14675a02SAndroid Build Coastguard Worker #include <algorithm>
19*14675a02SAndroid Build Coastguard Worker #include <cstdint>
20*14675a02SAndroid Build Coastguard Worker #include <memory>
21*14675a02SAndroid Build Coastguard Worker #include <string>
22*14675a02SAndroid Build Coastguard Worker #include <utility>
23*14675a02SAndroid Build Coastguard Worker #include <vector>
24*14675a02SAndroid Build Coastguard Worker 
25*14675a02SAndroid Build Coastguard Worker #include "fcp/base/monitoring.h"
26*14675a02SAndroid Build Coastguard Worker #include "fcp/dictionary/dictionary.pb.h"
27*14675a02SAndroid Build Coastguard Worker #include "absl/container/node_hash_map.h"
28*14675a02SAndroid Build Coastguard Worker #include "absl/status/status.h"
29*14675a02SAndroid Build Coastguard Worker #include "absl/status/statusor.h"
30*14675a02SAndroid Build Coastguard Worker #include "absl/strings/string_view.h"
31*14675a02SAndroid Build Coastguard Worker 
32*14675a02SAndroid Build Coastguard Worker namespace fcp {
33*14675a02SAndroid Build Coastguard Worker namespace dictionary {
34*14675a02SAndroid Build Coastguard Worker 
35*14675a02SAndroid Build Coastguard Worker // Bidirectional map defined as hash_map from strings to int32_t paired with
36*14675a02SAndroid Build Coastguard Worker // a vector of those keys for reverse lookup.
37*14675a02SAndroid Build Coastguard Worker typedef std::pair<absl::node_hash_map<std::string, int32_t>,
38*14675a02SAndroid Build Coastguard Worker                   std::vector<std::string>>
39*14675a02SAndroid Build Coastguard Worker     HashVectorBimap;
40*14675a02SAndroid Build Coastguard Worker 
41*14675a02SAndroid Build Coastguard Worker namespace {
42*14675a02SAndroid Build Coastguard Worker 
43*14675a02SAndroid Build Coastguard Worker // Map a string to an ID, using a bidirectional map (an std::pair containing
44*14675a02SAndroid Build Coastguard Worker // two data structures for string -> int and for int -> string lookups).
MapLookup(const HashVectorBimap & bimap,const std::string & tag)45*14675a02SAndroid Build Coastguard Worker int32_t MapLookup(const HashVectorBimap& bimap, const std::string& tag) {
46*14675a02SAndroid Build Coastguard Worker   auto map_idx = bimap.first.find(tag);
47*14675a02SAndroid Build Coastguard Worker   return map_idx == bimap.first.end() ? Dictionary::kNotFound : map_idx->second;
48*14675a02SAndroid Build Coastguard Worker }
49*14675a02SAndroid Build Coastguard Worker // Lookup a token given its ID.
MapReverseLookup(const HashVectorBimap & bimap,int32_t id)50*14675a02SAndroid Build Coastguard Worker std::string MapReverseLookup(const HashVectorBimap& bimap, int32_t id) {
51*14675a02SAndroid Build Coastguard Worker   if (id < 0 || id >= bimap.second.size()) {
52*14675a02SAndroid Build Coastguard Worker     return "";
53*14675a02SAndroid Build Coastguard Worker   }
54*14675a02SAndroid Build Coastguard Worker   return bimap.second[id];
55*14675a02SAndroid Build Coastguard Worker }
56*14675a02SAndroid Build Coastguard Worker 
57*14675a02SAndroid Build Coastguard Worker // Return the size of an stl-like data structure.
GetSize(const HashVectorBimap & bimap)58*14675a02SAndroid Build Coastguard Worker int32_t GetSize(const HashVectorBimap& bimap) {
59*14675a02SAndroid Build Coastguard Worker   return static_cast<int32_t>(bimap.first.size());
60*14675a02SAndroid Build Coastguard Worker }
61*14675a02SAndroid Build Coastguard Worker 
GetMaxSpecialId(const DictionaryDescription::SpecialIds & special_ids)62*14675a02SAndroid Build Coastguard Worker int32_t GetMaxSpecialId(const DictionaryDescription::SpecialIds& special_ids) {
63*14675a02SAndroid Build Coastguard Worker   int32_t max_special_id = -1;
64*14675a02SAndroid Build Coastguard Worker   max_special_id = std::max(max_special_id, special_ids.bos());
65*14675a02SAndroid Build Coastguard Worker   max_special_id = std::max(max_special_id, special_ids.eos());
66*14675a02SAndroid Build Coastguard Worker   max_special_id = std::max(max_special_id, special_ids.unk());
67*14675a02SAndroid Build Coastguard Worker   return max_special_id;
68*14675a02SAndroid Build Coastguard Worker }
69*14675a02SAndroid Build Coastguard Worker 
70*14675a02SAndroid Build Coastguard Worker // Dictionary implementation powered by templated utility functions above.
71*14675a02SAndroid Build Coastguard Worker template <typename Bimap>
72*14675a02SAndroid Build Coastguard Worker class DictionaryImpl : public Dictionary {
73*14675a02SAndroid Build Coastguard Worker  public:
DictionaryImpl(std::unique_ptr<Bimap> bimap,const DictionaryDescription::SpecialIds & special_ids,const DictionaryDescription::OutputBlocklistIds & output_blocklist_ids)74*14675a02SAndroid Build Coastguard Worker   DictionaryImpl(
75*14675a02SAndroid Build Coastguard Worker       std::unique_ptr<Bimap> bimap,
76*14675a02SAndroid Build Coastguard Worker       const DictionaryDescription::SpecialIds& special_ids,
77*14675a02SAndroid Build Coastguard Worker       const DictionaryDescription::OutputBlocklistIds& output_blocklist_ids)
78*14675a02SAndroid Build Coastguard Worker       : bimap_(std::move(bimap)),
79*14675a02SAndroid Build Coastguard Worker         special_ids_(special_ids),
80*14675a02SAndroid Build Coastguard Worker         max_special_id_(GetMaxSpecialId(special_ids)) {
81*14675a02SAndroid Build Coastguard Worker     // Validate special ids.
82*14675a02SAndroid Build Coastguard Worker     FCP_CHECK(special_ids.has_bos() == (special_ids.bos() >= 0));
83*14675a02SAndroid Build Coastguard Worker     FCP_CHECK(special_ids.has_eos() == (special_ids.eos() >= 0));
84*14675a02SAndroid Build Coastguard Worker     FCP_CHECK(special_ids.has_unk() == (special_ids.unk() >= 0));
85*14675a02SAndroid Build Coastguard Worker 
86*14675a02SAndroid Build Coastguard Worker     // Token numbering starts at max(special_ids) + 1.
87*14675a02SAndroid Build Coastguard Worker     output_blocklist_ids_.reserve(max_special_id_ + 1 +
88*14675a02SAndroid Build Coastguard Worker                                   output_blocklist_ids.id_size());
89*14675a02SAndroid Build Coastguard Worker     for (int32_t id = 0; id <= max_special_id_; ++id) {
90*14675a02SAndroid Build Coastguard Worker       output_blocklist_ids_.push_back(id);
91*14675a02SAndroid Build Coastguard Worker     }
92*14675a02SAndroid Build Coastguard Worker     for (int32_t id : output_blocklist_ids.id()) {
93*14675a02SAndroid Build Coastguard Worker       output_blocklist_ids_.push_back(id);
94*14675a02SAndroid Build Coastguard Worker     }
95*14675a02SAndroid Build Coastguard Worker   }
96*14675a02SAndroid Build Coastguard Worker 
Size() const97*14675a02SAndroid Build Coastguard Worker   int32_t Size() const override {
98*14675a02SAndroid Build Coastguard Worker     return GetSize(*bimap_) + max_special_id_ + 1;
99*14675a02SAndroid Build Coastguard Worker   }
100*14675a02SAndroid Build Coastguard Worker 
TokenToId(const std::string & tag) const101*14675a02SAndroid Build Coastguard Worker   int32_t TokenToId(const std::string& tag) const override {
102*14675a02SAndroid Build Coastguard Worker     int32_t id = MapLookup(*bimap_, tag);
103*14675a02SAndroid Build Coastguard Worker     if (id == kNotFound) {
104*14675a02SAndroid Build Coastguard Worker       return special_ids_.unk();
105*14675a02SAndroid Build Coastguard Worker     } else {
106*14675a02SAndroid Build Coastguard Worker       return id + max_special_id_ + 1;
107*14675a02SAndroid Build Coastguard Worker     }
108*14675a02SAndroid Build Coastguard Worker   }
109*14675a02SAndroid Build Coastguard Worker 
IdToToken(int32_t id) const110*14675a02SAndroid Build Coastguard Worker   std::string IdToToken(int32_t id) const override {
111*14675a02SAndroid Build Coastguard Worker     return MapReverseLookup(*bimap_, id - (max_special_id_ + 1));
112*14675a02SAndroid Build Coastguard Worker   }
113*14675a02SAndroid Build Coastguard Worker 
IsSpecialId(int32_t token_id) const114*14675a02SAndroid Build Coastguard Worker   bool IsSpecialId(int32_t token_id) const override {
115*14675a02SAndroid Build Coastguard Worker     return token_id <= max_special_id_;
116*14675a02SAndroid Build Coastguard Worker   }
117*14675a02SAndroid Build Coastguard Worker 
GetSortedOutputBlocklistIds() const118*14675a02SAndroid Build Coastguard Worker   const std::vector<int32_t>& GetSortedOutputBlocklistIds() const override {
119*14675a02SAndroid Build Coastguard Worker     return output_blocklist_ids_;
120*14675a02SAndroid Build Coastguard Worker   }
121*14675a02SAndroid Build Coastguard Worker 
GetSpecialIds() const122*14675a02SAndroid Build Coastguard Worker   const DictionaryDescription::SpecialIds& GetSpecialIds() const override {
123*14675a02SAndroid Build Coastguard Worker     return special_ids_;
124*14675a02SAndroid Build Coastguard Worker   }
125*14675a02SAndroid Build Coastguard Worker 
126*14675a02SAndroid Build Coastguard Worker  private:
127*14675a02SAndroid Build Coastguard Worker   const std::unique_ptr<Bimap> bimap_;
128*14675a02SAndroid Build Coastguard Worker   const DictionaryDescription::SpecialIds special_ids_;
129*14675a02SAndroid Build Coastguard Worker   int32_t max_special_id_;
130*14675a02SAndroid Build Coastguard Worker   std::vector<int32_t> output_blocklist_ids_;
131*14675a02SAndroid Build Coastguard Worker };
132*14675a02SAndroid Build Coastguard Worker 
IsOutputBlocklistIdsSortedAndUnique(const DictionaryDescription & description)133*14675a02SAndroid Build Coastguard Worker absl::Status IsOutputBlocklistIdsSortedAndUnique(
134*14675a02SAndroid Build Coastguard Worker     const DictionaryDescription& description) {
135*14675a02SAndroid Build Coastguard Worker   // All blocklist ids must be greater than max_special_id.
136*14675a02SAndroid Build Coastguard Worker   const int32_t max_special_id = GetMaxSpecialId(description.special_ids());
137*14675a02SAndroid Build Coastguard Worker 
138*14675a02SAndroid Build Coastguard Worker   // Make sure output blocklist IDs are sorted in ascending order and unique.
139*14675a02SAndroid Build Coastguard Worker   if (description.has_output_blocklist_ids()) {
140*14675a02SAndroid Build Coastguard Worker     for (int i = 0; i < description.output_blocklist_ids().id_size(); i++) {
141*14675a02SAndroid Build Coastguard Worker       if (description.output_blocklist_ids().id(i) <= max_special_id) {
142*14675a02SAndroid Build Coastguard Worker         return absl::InvalidArgumentError(
143*14675a02SAndroid Build Coastguard Worker             "output_blocklist_ids should not overlap with special ids");
144*14675a02SAndroid Build Coastguard Worker       }
145*14675a02SAndroid Build Coastguard Worker       if (!(i == 0 || description.output_blocklist_ids().id(i) >
146*14675a02SAndroid Build Coastguard Worker                           description.output_blocklist_ids().id(i - 1))) {
147*14675a02SAndroid Build Coastguard Worker         return absl::InvalidArgumentError(
148*14675a02SAndroid Build Coastguard Worker             "output_blocklist_ids not unique or sorted");
149*14675a02SAndroid Build Coastguard Worker       }
150*14675a02SAndroid Build Coastguard Worker     }
151*14675a02SAndroid Build Coastguard Worker   }
152*14675a02SAndroid Build Coastguard Worker   return absl::OkStatus();
153*14675a02SAndroid Build Coastguard Worker }
154*14675a02SAndroid Build Coastguard Worker 
155*14675a02SAndroid Build Coastguard Worker }  // anonymous namespace
156*14675a02SAndroid Build Coastguard Worker 
Create(const DictionaryDescription & description)157*14675a02SAndroid Build Coastguard Worker absl::StatusOr<std::unique_ptr<Dictionary>> Dictionary::Create(
158*14675a02SAndroid Build Coastguard Worker     const DictionaryDescription& description) {
159*14675a02SAndroid Build Coastguard Worker   if (!description.has_vocabulary()) {
160*14675a02SAndroid Build Coastguard Worker     return absl::InvalidArgumentError(
161*14675a02SAndroid Build Coastguard Worker         "Cannot create a dictionary that does not have vocabulary set");
162*14675a02SAndroid Build Coastguard Worker   }
163*14675a02SAndroid Build Coastguard Worker   // Make sure output blocklist IDs are sorted in ascending order and unique.
164*14675a02SAndroid Build Coastguard Worker   FCP_RETURN_IF_ERROR(IsOutputBlocklistIdsSortedAndUnique(description));
165*14675a02SAndroid Build Coastguard Worker 
166*14675a02SAndroid Build Coastguard Worker   if (description.vocabulary().has_index()) {
167*14675a02SAndroid Build Coastguard Worker     auto bimap = std::make_unique<HashVectorBimap>();
168*14675a02SAndroid Build Coastguard Worker     int i = 0;
169*14675a02SAndroid Build Coastguard Worker     bimap->second.reserve(description.vocabulary().index().token_size());
170*14675a02SAndroid Build Coastguard Worker     for (const std::string& token : description.vocabulary().index().token()) {
171*14675a02SAndroid Build Coastguard Worker       FCP_CHECK(!token.empty());
172*14675a02SAndroid Build Coastguard Worker       bimap->first[token] = i++;
173*14675a02SAndroid Build Coastguard Worker       bimap->second.push_back(token);
174*14675a02SAndroid Build Coastguard Worker     }
175*14675a02SAndroid Build Coastguard Worker     return std::unique_ptr<Dictionary>(new DictionaryImpl<HashVectorBimap>(
176*14675a02SAndroid Build Coastguard Worker         std::move(bimap), description.special_ids(),
177*14675a02SAndroid Build Coastguard Worker     description.output_blocklist_ids()));
178*14675a02SAndroid Build Coastguard Worker   } else {
179*14675a02SAndroid Build Coastguard Worker     return absl::InvalidArgumentError(
180*14675a02SAndroid Build Coastguard Worker         "Invalid DictionaryDescription: no vocabulary specified.");
181*14675a02SAndroid Build Coastguard Worker   }
182*14675a02SAndroid Build Coastguard Worker }
183*14675a02SAndroid Build Coastguard Worker }  // namespace dictionary
184*14675a02SAndroid Build Coastguard Worker }  // namespace fcp
185