xref: /aosp_15_r20/external/federated-compute/fcp/tensorflow/dictionary_ops.cc (revision 14675a029014e728ec732f129a32e299b2da0601)
1*14675a02SAndroid Build Coastguard Worker /*
2*14675a02SAndroid Build Coastguard Worker  * Copyright 2022 Google LLC
3*14675a02SAndroid Build Coastguard Worker  *
4*14675a02SAndroid Build Coastguard Worker  * Licensed under the Apache License, Version 2.0 (the "License");
5*14675a02SAndroid Build Coastguard Worker  * you may not use this file except in compliance with the License.
6*14675a02SAndroid Build Coastguard Worker  * You may obtain a copy of the License at
7*14675a02SAndroid Build Coastguard Worker  *
8*14675a02SAndroid Build Coastguard Worker  *      http://www.apache.org/licenses/LICENSE-2.0
9*14675a02SAndroid Build Coastguard Worker  *
10*14675a02SAndroid Build Coastguard Worker  * Unless required by applicable law or agreed to in writing, software
11*14675a02SAndroid Build Coastguard Worker  * distributed under the License is distributed on an "AS IS" BASIS,
12*14675a02SAndroid Build Coastguard Worker  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13*14675a02SAndroid Build Coastguard Worker  * See the License for the specific language governing permissions and
14*14675a02SAndroid Build Coastguard Worker  * limitations under the License.
15*14675a02SAndroid Build Coastguard Worker  */
16*14675a02SAndroid Build Coastguard Worker #include <cstdint>
17*14675a02SAndroid Build Coastguard Worker #include <functional>
18*14675a02SAndroid Build Coastguard Worker #include <memory>
19*14675a02SAndroid Build Coastguard Worker #include <string>
20*14675a02SAndroid Build Coastguard Worker #include <utility>
21*14675a02SAndroid Build Coastguard Worker #include <vector>
22*14675a02SAndroid Build Coastguard Worker 
23*14675a02SAndroid Build Coastguard Worker #include "fcp/base/monitoring.h"
24*14675a02SAndroid Build Coastguard Worker #include "fcp/dictionary/dictionary.h"
25*14675a02SAndroid Build Coastguard Worker #include "fcp/dictionary/dictionary.pb.h"
26*14675a02SAndroid Build Coastguard Worker #include "absl/status/status.h"
27*14675a02SAndroid Build Coastguard Worker #include "absl/status/statusor.h"
28*14675a02SAndroid Build Coastguard Worker #include "tensorflow/core/framework/common_shape_fns.h"
29*14675a02SAndroid Build Coastguard Worker #include "tensorflow/core/framework/op.h"
30*14675a02SAndroid Build Coastguard Worker #include "tensorflow/core/framework/op_kernel.h"
31*14675a02SAndroid Build Coastguard Worker #include "tensorflow/core/public/version.h"
32*14675a02SAndroid Build Coastguard Worker 
33*14675a02SAndroid Build Coastguard Worker namespace tf = tensorflow;
34*14675a02SAndroid Build Coastguard Worker 
35*14675a02SAndroid Build Coastguard Worker namespace fcp {
36*14675a02SAndroid Build Coastguard Worker namespace tensorflow {
37*14675a02SAndroid Build Coastguard Worker 
38*14675a02SAndroid Build Coastguard Worker using fcp::dictionary::Dictionary;
39*14675a02SAndroid Build Coastguard Worker using fcp::dictionary::DictionaryDescription;
40*14675a02SAndroid Build Coastguard Worker 
41*14675a02SAndroid Build Coastguard Worker namespace {
42*14675a02SAndroid Build Coastguard Worker 
43*14675a02SAndroid Build Coastguard Worker // Base class for ops that work with a Dictionary.
44*14675a02SAndroid Build Coastguard Worker //
45*14675a02SAndroid Build Coastguard Worker // Subclasses need to provide Compute and register appropriately using
46*14675a02SAndroid Build Coastguard Worker // REGISTER_OP.
47*14675a02SAndroid Build Coastguard Worker class AbstractDictionaryOp : public tf::OpKernel {
48*14675a02SAndroid Build Coastguard Worker  public:
AbstractDictionaryOp(tf::OpKernelConstruction * context,int32_t num_expected_inputs)49*14675a02SAndroid Build Coastguard Worker   explicit AbstractDictionaryOp(tf::OpKernelConstruction* context,
50*14675a02SAndroid Build Coastguard Worker                                 int32_t num_expected_inputs)
51*14675a02SAndroid Build Coastguard Worker       : tf::OpKernel(context), num_expected_inputs_(num_expected_inputs) {
52*14675a02SAndroid Build Coastguard Worker     std::string dictionary_description_string;
53*14675a02SAndroid Build Coastguard Worker     OP_REQUIRES_OK(context,
54*14675a02SAndroid Build Coastguard Worker                    context->GetAttr("dictionary_description_proto",
55*14675a02SAndroid Build Coastguard Worker                                     &dictionary_description_string));
56*14675a02SAndroid Build Coastguard Worker 
57*14675a02SAndroid Build Coastguard Worker     DictionaryDescription parsed_dictionary_description;
58*14675a02SAndroid Build Coastguard Worker     OP_REQUIRES(context,
59*14675a02SAndroid Build Coastguard Worker                 parsed_dictionary_description.ParseFromString(
60*14675a02SAndroid Build Coastguard Worker                     dictionary_description_string),
61*14675a02SAndroid Build Coastguard Worker                 tf::errors::InvalidArgument(
62*14675a02SAndroid Build Coastguard Worker                     "Cannot parse provided DictionaryDescription."));
63*14675a02SAndroid Build Coastguard Worker 
64*14675a02SAndroid Build Coastguard Worker     if (parsed_dictionary_description.has_vocabulary()) {
65*14675a02SAndroid Build Coastguard Worker       // Fully specified dictionary.
66*14675a02SAndroid Build Coastguard Worker       absl::StatusOr<std::unique_ptr<Dictionary>> dictionary(
67*14675a02SAndroid Build Coastguard Worker           Dictionary::Create(parsed_dictionary_description));
68*14675a02SAndroid Build Coastguard Worker       OP_REQUIRES(context, dictionary.ok(),
69*14675a02SAndroid Build Coastguard Worker                   tf::errors::InvalidArgument(dictionary.status().ToString()));
70*14675a02SAndroid Build Coastguard Worker       dictionary_ = *std::move(dictionary);
71*14675a02SAndroid Build Coastguard Worker       parsed_dictionary_description.clear_vocabulary();  // Save space.
72*14675a02SAndroid Build Coastguard Worker     }
73*14675a02SAndroid Build Coastguard Worker     dictionary_description_ = parsed_dictionary_description;
74*14675a02SAndroid Build Coastguard Worker   }
75*14675a02SAndroid Build Coastguard Worker 
Compute(tf::OpKernelContext * context)76*14675a02SAndroid Build Coastguard Worker   void Compute(tf::OpKernelContext* context) override {
77*14675a02SAndroid Build Coastguard Worker     FCP_CHECK(num_expected_inputs_ == context->num_inputs());
78*14675a02SAndroid Build Coastguard Worker 
79*14675a02SAndroid Build Coastguard Worker     // Use the dictionary_ constructed at setup.
80*14675a02SAndroid Build Coastguard Worker     OP_REQUIRES(context, dictionary_ != nullptr,
81*14675a02SAndroid Build Coastguard Worker                 tf::errors::InvalidArgument(
82*14675a02SAndroid Build Coastguard Worker                 "DictionaryDescription does not contain a vocabulary. "));
83*14675a02SAndroid Build Coastguard Worker     absl::Status status = DoCompute(context, *dictionary_);
84*14675a02SAndroid Build Coastguard Worker     OP_REQUIRES(context, status.ok(),
85*14675a02SAndroid Build Coastguard Worker                 tf::errors::InvalidArgument(std::string(status.message())));
86*14675a02SAndroid Build Coastguard Worker   }
87*14675a02SAndroid Build Coastguard Worker 
88*14675a02SAndroid Build Coastguard Worker  protected:
89*14675a02SAndroid Build Coastguard Worker   // Computes using the given dictionary.
90*14675a02SAndroid Build Coastguard Worker   virtual absl::Status DoCompute(tf::OpKernelContext* context,
91*14675a02SAndroid Build Coastguard Worker                                  const Dictionary& dictionary) = 0;
92*14675a02SAndroid Build Coastguard Worker 
93*14675a02SAndroid Build Coastguard Worker  private:
94*14675a02SAndroid Build Coastguard Worker   DictionaryDescription dictionary_description_;
95*14675a02SAndroid Build Coastguard Worker   std::unique_ptr<Dictionary> dictionary_;
96*14675a02SAndroid Build Coastguard Worker   const int32_t num_expected_inputs_;
97*14675a02SAndroid Build Coastguard Worker };
98*14675a02SAndroid Build Coastguard Worker 
99*14675a02SAndroid Build Coastguard Worker }  // namespace
100*14675a02SAndroid Build Coastguard Worker 
101*14675a02SAndroid Build Coastguard Worker class DictionarySize : public AbstractDictionaryOp {
102*14675a02SAndroid Build Coastguard Worker  public:
DictionarySize(tf::OpKernelConstruction * context)103*14675a02SAndroid Build Coastguard Worker   explicit DictionarySize(tf::OpKernelConstruction* context)
104*14675a02SAndroid Build Coastguard Worker       : AbstractDictionaryOp(context, 0 /* num_expected_inputs */) {}
105*14675a02SAndroid Build Coastguard Worker 
106*14675a02SAndroid Build Coastguard Worker  protected:
DoCompute(tf::OpKernelContext * context,const Dictionary & dictionary)107*14675a02SAndroid Build Coastguard Worker   absl::Status DoCompute(tf::OpKernelContext* context,
108*14675a02SAndroid Build Coastguard Worker                          const Dictionary& dictionary) override {
109*14675a02SAndroid Build Coastguard Worker     tf::Tensor* size_tensor;
110*14675a02SAndroid Build Coastguard Worker     auto status =
111*14675a02SAndroid Build Coastguard Worker         context->allocate_output(0, tf::TensorShape({}), &size_tensor);
112*14675a02SAndroid Build Coastguard Worker     if (!status.ok()) {
113*14675a02SAndroid Build Coastguard Worker #if TF_GRAPH_DEF_VERSION < 1467
114*14675a02SAndroid Build Coastguard Worker       return absl::InternalError(status.error_message());
115*14675a02SAndroid Build Coastguard Worker #else
116*14675a02SAndroid Build Coastguard Worker       return absl::InternalError(status.message());
117*14675a02SAndroid Build Coastguard Worker #endif
118*14675a02SAndroid Build Coastguard Worker     }
119*14675a02SAndroid Build Coastguard Worker     size_tensor->flat<int64_t>()(0) = dictionary.Size();
120*14675a02SAndroid Build Coastguard Worker     return absl::OkStatus();
121*14675a02SAndroid Build Coastguard Worker   }
122*14675a02SAndroid Build Coastguard Worker };
123*14675a02SAndroid Build Coastguard Worker 
124*14675a02SAndroid Build Coastguard Worker REGISTER_KERNEL_BUILDER(Name("DictionarySize").Device(tf::DEVICE_CPU),
125*14675a02SAndroid Build Coastguard Worker                         DictionarySize);
126*14675a02SAndroid Build Coastguard Worker REGISTER_OP("DictionarySize")
127*14675a02SAndroid Build Coastguard Worker     .Output("size: int64")
128*14675a02SAndroid Build Coastguard Worker     .Attr("dictionary_description_proto: string = ''")
129*14675a02SAndroid Build Coastguard Worker     .SetShapeFn(::tensorflow::shape_inference::ScalarShape)
130*14675a02SAndroid Build Coastguard Worker     .Doc(R"doc(
131*14675a02SAndroid Build Coastguard Worker Returns the number of ids in the given dictionary.
132*14675a02SAndroid Build Coastguard Worker 
133*14675a02SAndroid Build Coastguard Worker The dictionary should be fully specified at construction time via the
134*14675a02SAndroid Build Coastguard Worker dictionary_description_proto.
135*14675a02SAndroid Build Coastguard Worker 
136*14675a02SAndroid Build Coastguard Worker dictionary_description_proto: A `DictionaryDescription` as a string.
137*14675a02SAndroid Build Coastguard Worker )doc");
138*14675a02SAndroid Build Coastguard Worker 
139*14675a02SAndroid Build Coastguard Worker class DictionaryLookup : public AbstractDictionaryOp {
140*14675a02SAndroid Build Coastguard Worker  public:
DictionaryLookup(tf::OpKernelConstruction * context)141*14675a02SAndroid Build Coastguard Worker   explicit DictionaryLookup(tf::OpKernelConstruction* context)
142*14675a02SAndroid Build Coastguard Worker       : AbstractDictionaryOp(context, 1 /* num_expected_inputs */) {}
143*14675a02SAndroid Build Coastguard Worker 
144*14675a02SAndroid Build Coastguard Worker  protected:
DoCompute(tf::OpKernelContext * context,const Dictionary & dictionary)145*14675a02SAndroid Build Coastguard Worker   absl::Status DoCompute(tf::OpKernelContext* context,
146*14675a02SAndroid Build Coastguard Worker                          const Dictionary& dictionary) override {
147*14675a02SAndroid Build Coastguard Worker     const tf::Tensor& token_tensor = context->input(0);
148*14675a02SAndroid Build Coastguard Worker     tf::Tensor* ids_tensor;
149*14675a02SAndroid Build Coastguard Worker     auto status =
150*14675a02SAndroid Build Coastguard Worker         context->allocate_output(0, token_tensor.shape(), &ids_tensor);
151*14675a02SAndroid Build Coastguard Worker     if (!status.ok()) {
152*14675a02SAndroid Build Coastguard Worker #if TF_GRAPH_DEF_VERSION < 1467
153*14675a02SAndroid Build Coastguard Worker       return absl::InternalError(status.error_message());
154*14675a02SAndroid Build Coastguard Worker #else
155*14675a02SAndroid Build Coastguard Worker       return absl::InternalError(status.message());
156*14675a02SAndroid Build Coastguard Worker #endif
157*14675a02SAndroid Build Coastguard Worker     }
158*14675a02SAndroid Build Coastguard Worker 
159*14675a02SAndroid Build Coastguard Worker     if (token_tensor.dtype() != tf::DataType::DT_STRING) {
160*14675a02SAndroid Build Coastguard Worker       return absl::InvalidArgumentError("Expected input of 'tokens'.");
161*14675a02SAndroid Build Coastguard Worker     }
162*14675a02SAndroid Build Coastguard Worker     if (ids_tensor->dtype() != tf::DataType::DT_INT64) {
163*14675a02SAndroid Build Coastguard Worker       return absl::InvalidArgumentError("Expected output of 'ids'.");
164*14675a02SAndroid Build Coastguard Worker     }
165*14675a02SAndroid Build Coastguard Worker     if (token_tensor.shape() != ids_tensor->shape()) {
166*14675a02SAndroid Build Coastguard Worker       return absl::InvalidArgumentError("Wrong shape for ids_tensor");
167*14675a02SAndroid Build Coastguard Worker     }
168*14675a02SAndroid Build Coastguard Worker     const auto tokens_flat = token_tensor.flat<tf::tstring>();
169*14675a02SAndroid Build Coastguard Worker     auto ids_flat = ids_tensor->flat<int64_t>();
170*14675a02SAndroid Build Coastguard Worker     const int64_t num_tokens = tokens_flat.size();
171*14675a02SAndroid Build Coastguard Worker     for (int i = 0; i < num_tokens; ++i) {
172*14675a02SAndroid Build Coastguard Worker       ids_flat(i) = dictionary.TokenToId(tokens_flat(i));
173*14675a02SAndroid Build Coastguard Worker     }
174*14675a02SAndroid Build Coastguard Worker     return absl::OkStatus();
175*14675a02SAndroid Build Coastguard Worker   }
176*14675a02SAndroid Build Coastguard Worker };
177*14675a02SAndroid Build Coastguard Worker 
178*14675a02SAndroid Build Coastguard Worker REGISTER_KERNEL_BUILDER(Name("DictionaryLookup").Device(tf::DEVICE_CPU),
179*14675a02SAndroid Build Coastguard Worker                         DictionaryLookup);
180*14675a02SAndroid Build Coastguard Worker REGISTER_OP("DictionaryLookup")
181*14675a02SAndroid Build Coastguard Worker     .Input("tokens: string")
182*14675a02SAndroid Build Coastguard Worker     .Output("token_ids: int64")
183*14675a02SAndroid Build Coastguard Worker     .Attr("dictionary_description_proto: string = ''")
184*14675a02SAndroid Build Coastguard Worker     .SetShapeFn(::tensorflow::shape_inference::UnchangedShape)
185*14675a02SAndroid Build Coastguard Worker     .Doc(R"doc(
186*14675a02SAndroid Build Coastguard Worker Maps each string to an id by lookup in the dictionary.
187*14675a02SAndroid Build Coastguard Worker 
188*14675a02SAndroid Build Coastguard Worker The dictionary should be fully specified at construction time via the
189*14675a02SAndroid Build Coastguard Worker dictionary_description_proto. Output has the same shape as input.
190*14675a02SAndroid Build Coastguard Worker 
191*14675a02SAndroid Build Coastguard Worker tokens: A `Tensor` of strings to lookup in the dictionary.
192*14675a02SAndroid Build Coastguard Worker dictionary_description_proto: A `DictionaryDescription` as a string.
193*14675a02SAndroid Build Coastguard Worker )doc");
194*14675a02SAndroid Build Coastguard Worker 
195*14675a02SAndroid Build Coastguard Worker class DictionaryReverseLookup : public AbstractDictionaryOp {
196*14675a02SAndroid Build Coastguard Worker  public:
DictionaryReverseLookup(tf::OpKernelConstruction * context)197*14675a02SAndroid Build Coastguard Worker   explicit DictionaryReverseLookup(tf::OpKernelConstruction* context)
198*14675a02SAndroid Build Coastguard Worker       : AbstractDictionaryOp(context, 1 /* num_expected_inputs */) {}
199*14675a02SAndroid Build Coastguard Worker 
200*14675a02SAndroid Build Coastguard Worker  protected:
DoCompute(tf::OpKernelContext * context,const Dictionary & dictionary)201*14675a02SAndroid Build Coastguard Worker   absl::Status DoCompute(tf::OpKernelContext* context,
202*14675a02SAndroid Build Coastguard Worker                          const Dictionary& dictionary) override {
203*14675a02SAndroid Build Coastguard Worker     const tf::Tensor& ids_tensor = context->input(0);
204*14675a02SAndroid Build Coastguard Worker     tf::Tensor* token_tensor;
205*14675a02SAndroid Build Coastguard Worker     auto status =
206*14675a02SAndroid Build Coastguard Worker         context->allocate_output(0, ids_tensor.shape(), &token_tensor);
207*14675a02SAndroid Build Coastguard Worker     if (!status.ok()) {
208*14675a02SAndroid Build Coastguard Worker #if TF_GRAPH_DEF_VERSION < 1467
209*14675a02SAndroid Build Coastguard Worker       return absl::InternalError(status.error_message());
210*14675a02SAndroid Build Coastguard Worker #else
211*14675a02SAndroid Build Coastguard Worker       return absl::InternalError(status.message());
212*14675a02SAndroid Build Coastguard Worker #endif
213*14675a02SAndroid Build Coastguard Worker     }
214*14675a02SAndroid Build Coastguard Worker 
215*14675a02SAndroid Build Coastguard Worker     if (token_tensor->dtype() != tf::DataType::DT_STRING) {
216*14675a02SAndroid Build Coastguard Worker       return absl::InvalidArgumentError("Expected input of 'tokens'.");
217*14675a02SAndroid Build Coastguard Worker     }
218*14675a02SAndroid Build Coastguard Worker     if (ids_tensor.dtype() != tf::DataType::DT_INT64) {
219*14675a02SAndroid Build Coastguard Worker       return absl::InvalidArgumentError("Expected output of 'ids'.");
220*14675a02SAndroid Build Coastguard Worker     }
221*14675a02SAndroid Build Coastguard Worker     if (ids_tensor.shape() != token_tensor->shape()) {
222*14675a02SAndroid Build Coastguard Worker       return absl::InvalidArgumentError("Wrong shape for token_tensor");
223*14675a02SAndroid Build Coastguard Worker     }
224*14675a02SAndroid Build Coastguard Worker 
225*14675a02SAndroid Build Coastguard Worker     const auto ids_flat = ids_tensor.flat<int64_t>();
226*14675a02SAndroid Build Coastguard Worker     auto tokens_flat = token_tensor->flat<tf::tstring>();
227*14675a02SAndroid Build Coastguard Worker     const int64_t num_tokens = ids_flat.size();
228*14675a02SAndroid Build Coastguard Worker     for (int i = 0; i < num_tokens; ++i) {
229*14675a02SAndroid Build Coastguard Worker       tokens_flat(i) = dictionary.IdToToken(ids_flat(i));
230*14675a02SAndroid Build Coastguard Worker     }
231*14675a02SAndroid Build Coastguard Worker     return absl::OkStatus();
232*14675a02SAndroid Build Coastguard Worker   }
233*14675a02SAndroid Build Coastguard Worker };
234*14675a02SAndroid Build Coastguard Worker 
235*14675a02SAndroid Build Coastguard Worker REGISTER_KERNEL_BUILDER(Name("DictionaryReverseLookup").Device(tf::DEVICE_CPU),
236*14675a02SAndroid Build Coastguard Worker                         DictionaryReverseLookup);
237*14675a02SAndroid Build Coastguard Worker REGISTER_OP("DictionaryReverseLookup")
238*14675a02SAndroid Build Coastguard Worker     .Input("token_ids: int64")
239*14675a02SAndroid Build Coastguard Worker     .Output("tokens: string")
240*14675a02SAndroid Build Coastguard Worker     .Attr("dictionary_description_proto: string = ''")
241*14675a02SAndroid Build Coastguard Worker     .SetShapeFn(::tensorflow::shape_inference::UnchangedShape)
242*14675a02SAndroid Build Coastguard Worker     .Doc(R"doc(
243*14675a02SAndroid Build Coastguard Worker Maps each id to its string by performing a reverse lookup in the dictionary.
244*14675a02SAndroid Build Coastguard Worker 
245*14675a02SAndroid Build Coastguard Worker The dictionary should be fully specified at construction time via the
246*14675a02SAndroid Build Coastguard Worker dictionary_description_proto. Output has the same shape as input.
247*14675a02SAndroid Build Coastguard Worker 
248*14675a02SAndroid Build Coastguard Worker token_ids: A `Tensor` of int64 ids to lookup in the dictionary.
249*14675a02SAndroid Build Coastguard Worker dictionary_description_proto: A `DictionaryDescription` as a string.
250*14675a02SAndroid Build Coastguard Worker )doc");
251*14675a02SAndroid Build Coastguard Worker }  // namespace tensorflow
252*14675a02SAndroid Build Coastguard Worker }  // namespace fcp
253