xref: /aosp_15_r20/external/federated-compute/fcp/tensorflow/dictionary_ops.cc (revision 14675a029014e728ec732f129a32e299b2da0601)
1 /*
2  * Copyright 2022 Google LLC
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include <cstdint>
17 #include <functional>
18 #include <memory>
19 #include <string>
20 #include <utility>
21 #include <vector>
22 
23 #include "fcp/base/monitoring.h"
24 #include "fcp/dictionary/dictionary.h"
25 #include "fcp/dictionary/dictionary.pb.h"
26 #include "absl/status/status.h"
27 #include "absl/status/statusor.h"
28 #include "tensorflow/core/framework/common_shape_fns.h"
29 #include "tensorflow/core/framework/op.h"
30 #include "tensorflow/core/framework/op_kernel.h"
31 #include "tensorflow/core/public/version.h"
32 
33 namespace tf = tensorflow;
34 
35 namespace fcp {
36 namespace tensorflow {
37 
38 using fcp::dictionary::Dictionary;
39 using fcp::dictionary::DictionaryDescription;
40 
41 namespace {
42 
43 // Base class for ops that work with a Dictionary.
44 //
45 // Subclasses need to provide Compute and register appropriately using
46 // REGISTER_OP.
47 class AbstractDictionaryOp : public tf::OpKernel {
48  public:
AbstractDictionaryOp(tf::OpKernelConstruction * context,int32_t num_expected_inputs)49   explicit AbstractDictionaryOp(tf::OpKernelConstruction* context,
50                                 int32_t num_expected_inputs)
51       : tf::OpKernel(context), num_expected_inputs_(num_expected_inputs) {
52     std::string dictionary_description_string;
53     OP_REQUIRES_OK(context,
54                    context->GetAttr("dictionary_description_proto",
55                                     &dictionary_description_string));
56 
57     DictionaryDescription parsed_dictionary_description;
58     OP_REQUIRES(context,
59                 parsed_dictionary_description.ParseFromString(
60                     dictionary_description_string),
61                 tf::errors::InvalidArgument(
62                     "Cannot parse provided DictionaryDescription."));
63 
64     if (parsed_dictionary_description.has_vocabulary()) {
65       // Fully specified dictionary.
66       absl::StatusOr<std::unique_ptr<Dictionary>> dictionary(
67           Dictionary::Create(parsed_dictionary_description));
68       OP_REQUIRES(context, dictionary.ok(),
69                   tf::errors::InvalidArgument(dictionary.status().ToString()));
70       dictionary_ = *std::move(dictionary);
71       parsed_dictionary_description.clear_vocabulary();  // Save space.
72     }
73     dictionary_description_ = parsed_dictionary_description;
74   }
75 
Compute(tf::OpKernelContext * context)76   void Compute(tf::OpKernelContext* context) override {
77     FCP_CHECK(num_expected_inputs_ == context->num_inputs());
78 
79     // Use the dictionary_ constructed at setup.
80     OP_REQUIRES(context, dictionary_ != nullptr,
81                 tf::errors::InvalidArgument(
82                 "DictionaryDescription does not contain a vocabulary. "));
83     absl::Status status = DoCompute(context, *dictionary_);
84     OP_REQUIRES(context, status.ok(),
85                 tf::errors::InvalidArgument(std::string(status.message())));
86   }
87 
88  protected:
89   // Computes using the given dictionary.
90   virtual absl::Status DoCompute(tf::OpKernelContext* context,
91                                  const Dictionary& dictionary) = 0;
92 
93  private:
94   DictionaryDescription dictionary_description_;
95   std::unique_ptr<Dictionary> dictionary_;
96   const int32_t num_expected_inputs_;
97 };
98 
99 }  // namespace
100 
101 class DictionarySize : public AbstractDictionaryOp {
102  public:
DictionarySize(tf::OpKernelConstruction * context)103   explicit DictionarySize(tf::OpKernelConstruction* context)
104       : AbstractDictionaryOp(context, 0 /* num_expected_inputs */) {}
105 
106  protected:
DoCompute(tf::OpKernelContext * context,const Dictionary & dictionary)107   absl::Status DoCompute(tf::OpKernelContext* context,
108                          const Dictionary& dictionary) override {
109     tf::Tensor* size_tensor;
110     auto status =
111         context->allocate_output(0, tf::TensorShape({}), &size_tensor);
112     if (!status.ok()) {
113 #if TF_GRAPH_DEF_VERSION < 1467
114       return absl::InternalError(status.error_message());
115 #else
116       return absl::InternalError(status.message());
117 #endif
118     }
119     size_tensor->flat<int64_t>()(0) = dictionary.Size();
120     return absl::OkStatus();
121   }
122 };
123 
124 REGISTER_KERNEL_BUILDER(Name("DictionarySize").Device(tf::DEVICE_CPU),
125                         DictionarySize);
126 REGISTER_OP("DictionarySize")
127     .Output("size: int64")
128     .Attr("dictionary_description_proto: string = ''")
129     .SetShapeFn(::tensorflow::shape_inference::ScalarShape)
130     .Doc(R"doc(
131 Returns the number of ids in the given dictionary.
132 
133 The dictionary should be fully specified at construction time via the
134 dictionary_description_proto.
135 
136 dictionary_description_proto: A `DictionaryDescription` as a string.
137 )doc");
138 
139 class DictionaryLookup : public AbstractDictionaryOp {
140  public:
DictionaryLookup(tf::OpKernelConstruction * context)141   explicit DictionaryLookup(tf::OpKernelConstruction* context)
142       : AbstractDictionaryOp(context, 1 /* num_expected_inputs */) {}
143 
144  protected:
DoCompute(tf::OpKernelContext * context,const Dictionary & dictionary)145   absl::Status DoCompute(tf::OpKernelContext* context,
146                          const Dictionary& dictionary) override {
147     const tf::Tensor& token_tensor = context->input(0);
148     tf::Tensor* ids_tensor;
149     auto status =
150         context->allocate_output(0, token_tensor.shape(), &ids_tensor);
151     if (!status.ok()) {
152 #if TF_GRAPH_DEF_VERSION < 1467
153       return absl::InternalError(status.error_message());
154 #else
155       return absl::InternalError(status.message());
156 #endif
157     }
158 
159     if (token_tensor.dtype() != tf::DataType::DT_STRING) {
160       return absl::InvalidArgumentError("Expected input of 'tokens'.");
161     }
162     if (ids_tensor->dtype() != tf::DataType::DT_INT64) {
163       return absl::InvalidArgumentError("Expected output of 'ids'.");
164     }
165     if (token_tensor.shape() != ids_tensor->shape()) {
166       return absl::InvalidArgumentError("Wrong shape for ids_tensor");
167     }
168     const auto tokens_flat = token_tensor.flat<tf::tstring>();
169     auto ids_flat = ids_tensor->flat<int64_t>();
170     const int64_t num_tokens = tokens_flat.size();
171     for (int i = 0; i < num_tokens; ++i) {
172       ids_flat(i) = dictionary.TokenToId(tokens_flat(i));
173     }
174     return absl::OkStatus();
175   }
176 };
177 
178 REGISTER_KERNEL_BUILDER(Name("DictionaryLookup").Device(tf::DEVICE_CPU),
179                         DictionaryLookup);
180 REGISTER_OP("DictionaryLookup")
181     .Input("tokens: string")
182     .Output("token_ids: int64")
183     .Attr("dictionary_description_proto: string = ''")
184     .SetShapeFn(::tensorflow::shape_inference::UnchangedShape)
185     .Doc(R"doc(
186 Maps each string to an id by lookup in the dictionary.
187 
188 The dictionary should be fully specified at construction time via the
189 dictionary_description_proto. Output has the same shape as input.
190 
191 tokens: A `Tensor` of strings to lookup in the dictionary.
192 dictionary_description_proto: A `DictionaryDescription` as a string.
193 )doc");
194 
195 class DictionaryReverseLookup : public AbstractDictionaryOp {
196  public:
DictionaryReverseLookup(tf::OpKernelConstruction * context)197   explicit DictionaryReverseLookup(tf::OpKernelConstruction* context)
198       : AbstractDictionaryOp(context, 1 /* num_expected_inputs */) {}
199 
200  protected:
DoCompute(tf::OpKernelContext * context,const Dictionary & dictionary)201   absl::Status DoCompute(tf::OpKernelContext* context,
202                          const Dictionary& dictionary) override {
203     const tf::Tensor& ids_tensor = context->input(0);
204     tf::Tensor* token_tensor;
205     auto status =
206         context->allocate_output(0, ids_tensor.shape(), &token_tensor);
207     if (!status.ok()) {
208 #if TF_GRAPH_DEF_VERSION < 1467
209       return absl::InternalError(status.error_message());
210 #else
211       return absl::InternalError(status.message());
212 #endif
213     }
214 
215     if (token_tensor->dtype() != tf::DataType::DT_STRING) {
216       return absl::InvalidArgumentError("Expected input of 'tokens'.");
217     }
218     if (ids_tensor.dtype() != tf::DataType::DT_INT64) {
219       return absl::InvalidArgumentError("Expected output of 'ids'.");
220     }
221     if (ids_tensor.shape() != token_tensor->shape()) {
222       return absl::InvalidArgumentError("Wrong shape for token_tensor");
223     }
224 
225     const auto ids_flat = ids_tensor.flat<int64_t>();
226     auto tokens_flat = token_tensor->flat<tf::tstring>();
227     const int64_t num_tokens = ids_flat.size();
228     for (int i = 0; i < num_tokens; ++i) {
229       tokens_flat(i) = dictionary.IdToToken(ids_flat(i));
230     }
231     return absl::OkStatus();
232   }
233 };
234 
235 REGISTER_KERNEL_BUILDER(Name("DictionaryReverseLookup").Device(tf::DEVICE_CPU),
236                         DictionaryReverseLookup);
237 REGISTER_OP("DictionaryReverseLookup")
238     .Input("token_ids: int64")
239     .Output("tokens: string")
240     .Attr("dictionary_description_proto: string = ''")
241     .SetShapeFn(::tensorflow::shape_inference::UnchangedShape)
242     .Doc(R"doc(
243 Maps each id to its string by performing a reverse lookup in the dictionary.
244 
245 The dictionary should be fully specified at construction time via the
246 dictionary_description_proto. Output has the same shape as input.
247 
248 token_ids: A `Tensor` of int64 ids to lookup in the dictionary.
249 dictionary_description_proto: A `DictionaryDescription` as a string.
250 )doc");
251 }  // namespace tensorflow
252 }  // namespace fcp
253