1 /* 2 * Copyright 2022 Google LLC 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #include <cstdint> 17 #include <functional> 18 #include <memory> 19 #include <string> 20 #include <utility> 21 #include <vector> 22 23 #include "fcp/base/monitoring.h" 24 #include "fcp/dictionary/dictionary.h" 25 #include "fcp/dictionary/dictionary.pb.h" 26 #include "absl/status/status.h" 27 #include "absl/status/statusor.h" 28 #include "tensorflow/core/framework/common_shape_fns.h" 29 #include "tensorflow/core/framework/op.h" 30 #include "tensorflow/core/framework/op_kernel.h" 31 #include "tensorflow/core/public/version.h" 32 33 namespace tf = tensorflow; 34 35 namespace fcp { 36 namespace tensorflow { 37 38 using fcp::dictionary::Dictionary; 39 using fcp::dictionary::DictionaryDescription; 40 41 namespace { 42 43 // Base class for ops that work with a Dictionary. 44 // 45 // Subclasses need to provide Compute and register appropriately using 46 // REGISTER_OP. 47 class AbstractDictionaryOp : public tf::OpKernel { 48 public: AbstractDictionaryOp(tf::OpKernelConstruction * context,int32_t num_expected_inputs)49 explicit AbstractDictionaryOp(tf::OpKernelConstruction* context, 50 int32_t num_expected_inputs) 51 : tf::OpKernel(context), num_expected_inputs_(num_expected_inputs) { 52 std::string dictionary_description_string; 53 OP_REQUIRES_OK(context, 54 context->GetAttr("dictionary_description_proto", 55 &dictionary_description_string)); 56 57 DictionaryDescription parsed_dictionary_description; 58 OP_REQUIRES(context, 59 parsed_dictionary_description.ParseFromString( 60 dictionary_description_string), 61 tf::errors::InvalidArgument( 62 "Cannot parse provided DictionaryDescription.")); 63 64 if (parsed_dictionary_description.has_vocabulary()) { 65 // Fully specified dictionary. 66 absl::StatusOr<std::unique_ptr<Dictionary>> dictionary( 67 Dictionary::Create(parsed_dictionary_description)); 68 OP_REQUIRES(context, dictionary.ok(), 69 tf::errors::InvalidArgument(dictionary.status().ToString())); 70 dictionary_ = *std::move(dictionary); 71 parsed_dictionary_description.clear_vocabulary(); // Save space. 72 } 73 dictionary_description_ = parsed_dictionary_description; 74 } 75 Compute(tf::OpKernelContext * context)76 void Compute(tf::OpKernelContext* context) override { 77 FCP_CHECK(num_expected_inputs_ == context->num_inputs()); 78 79 // Use the dictionary_ constructed at setup. 80 OP_REQUIRES(context, dictionary_ != nullptr, 81 tf::errors::InvalidArgument( 82 "DictionaryDescription does not contain a vocabulary. ")); 83 absl::Status status = DoCompute(context, *dictionary_); 84 OP_REQUIRES(context, status.ok(), 85 tf::errors::InvalidArgument(std::string(status.message()))); 86 } 87 88 protected: 89 // Computes using the given dictionary. 90 virtual absl::Status DoCompute(tf::OpKernelContext* context, 91 const Dictionary& dictionary) = 0; 92 93 private: 94 DictionaryDescription dictionary_description_; 95 std::unique_ptr<Dictionary> dictionary_; 96 const int32_t num_expected_inputs_; 97 }; 98 99 } // namespace 100 101 class DictionarySize : public AbstractDictionaryOp { 102 public: DictionarySize(tf::OpKernelConstruction * context)103 explicit DictionarySize(tf::OpKernelConstruction* context) 104 : AbstractDictionaryOp(context, 0 /* num_expected_inputs */) {} 105 106 protected: DoCompute(tf::OpKernelContext * context,const Dictionary & dictionary)107 absl::Status DoCompute(tf::OpKernelContext* context, 108 const Dictionary& dictionary) override { 109 tf::Tensor* size_tensor; 110 auto status = 111 context->allocate_output(0, tf::TensorShape({}), &size_tensor); 112 if (!status.ok()) { 113 #if TF_GRAPH_DEF_VERSION < 1467 114 return absl::InternalError(status.error_message()); 115 #else 116 return absl::InternalError(status.message()); 117 #endif 118 } 119 size_tensor->flat<int64_t>()(0) = dictionary.Size(); 120 return absl::OkStatus(); 121 } 122 }; 123 124 REGISTER_KERNEL_BUILDER(Name("DictionarySize").Device(tf::DEVICE_CPU), 125 DictionarySize); 126 REGISTER_OP("DictionarySize") 127 .Output("size: int64") 128 .Attr("dictionary_description_proto: string = ''") 129 .SetShapeFn(::tensorflow::shape_inference::ScalarShape) 130 .Doc(R"doc( 131 Returns the number of ids in the given dictionary. 132 133 The dictionary should be fully specified at construction time via the 134 dictionary_description_proto. 135 136 dictionary_description_proto: A `DictionaryDescription` as a string. 137 )doc"); 138 139 class DictionaryLookup : public AbstractDictionaryOp { 140 public: DictionaryLookup(tf::OpKernelConstruction * context)141 explicit DictionaryLookup(tf::OpKernelConstruction* context) 142 : AbstractDictionaryOp(context, 1 /* num_expected_inputs */) {} 143 144 protected: DoCompute(tf::OpKernelContext * context,const Dictionary & dictionary)145 absl::Status DoCompute(tf::OpKernelContext* context, 146 const Dictionary& dictionary) override { 147 const tf::Tensor& token_tensor = context->input(0); 148 tf::Tensor* ids_tensor; 149 auto status = 150 context->allocate_output(0, token_tensor.shape(), &ids_tensor); 151 if (!status.ok()) { 152 #if TF_GRAPH_DEF_VERSION < 1467 153 return absl::InternalError(status.error_message()); 154 #else 155 return absl::InternalError(status.message()); 156 #endif 157 } 158 159 if (token_tensor.dtype() != tf::DataType::DT_STRING) { 160 return absl::InvalidArgumentError("Expected input of 'tokens'."); 161 } 162 if (ids_tensor->dtype() != tf::DataType::DT_INT64) { 163 return absl::InvalidArgumentError("Expected output of 'ids'."); 164 } 165 if (token_tensor.shape() != ids_tensor->shape()) { 166 return absl::InvalidArgumentError("Wrong shape for ids_tensor"); 167 } 168 const auto tokens_flat = token_tensor.flat<tf::tstring>(); 169 auto ids_flat = ids_tensor->flat<int64_t>(); 170 const int64_t num_tokens = tokens_flat.size(); 171 for (int i = 0; i < num_tokens; ++i) { 172 ids_flat(i) = dictionary.TokenToId(tokens_flat(i)); 173 } 174 return absl::OkStatus(); 175 } 176 }; 177 178 REGISTER_KERNEL_BUILDER(Name("DictionaryLookup").Device(tf::DEVICE_CPU), 179 DictionaryLookup); 180 REGISTER_OP("DictionaryLookup") 181 .Input("tokens: string") 182 .Output("token_ids: int64") 183 .Attr("dictionary_description_proto: string = ''") 184 .SetShapeFn(::tensorflow::shape_inference::UnchangedShape) 185 .Doc(R"doc( 186 Maps each string to an id by lookup in the dictionary. 187 188 The dictionary should be fully specified at construction time via the 189 dictionary_description_proto. Output has the same shape as input. 190 191 tokens: A `Tensor` of strings to lookup in the dictionary. 192 dictionary_description_proto: A `DictionaryDescription` as a string. 193 )doc"); 194 195 class DictionaryReverseLookup : public AbstractDictionaryOp { 196 public: DictionaryReverseLookup(tf::OpKernelConstruction * context)197 explicit DictionaryReverseLookup(tf::OpKernelConstruction* context) 198 : AbstractDictionaryOp(context, 1 /* num_expected_inputs */) {} 199 200 protected: DoCompute(tf::OpKernelContext * context,const Dictionary & dictionary)201 absl::Status DoCompute(tf::OpKernelContext* context, 202 const Dictionary& dictionary) override { 203 const tf::Tensor& ids_tensor = context->input(0); 204 tf::Tensor* token_tensor; 205 auto status = 206 context->allocate_output(0, ids_tensor.shape(), &token_tensor); 207 if (!status.ok()) { 208 #if TF_GRAPH_DEF_VERSION < 1467 209 return absl::InternalError(status.error_message()); 210 #else 211 return absl::InternalError(status.message()); 212 #endif 213 } 214 215 if (token_tensor->dtype() != tf::DataType::DT_STRING) { 216 return absl::InvalidArgumentError("Expected input of 'tokens'."); 217 } 218 if (ids_tensor.dtype() != tf::DataType::DT_INT64) { 219 return absl::InvalidArgumentError("Expected output of 'ids'."); 220 } 221 if (ids_tensor.shape() != token_tensor->shape()) { 222 return absl::InvalidArgumentError("Wrong shape for token_tensor"); 223 } 224 225 const auto ids_flat = ids_tensor.flat<int64_t>(); 226 auto tokens_flat = token_tensor->flat<tf::tstring>(); 227 const int64_t num_tokens = ids_flat.size(); 228 for (int i = 0; i < num_tokens; ++i) { 229 tokens_flat(i) = dictionary.IdToToken(ids_flat(i)); 230 } 231 return absl::OkStatus(); 232 } 233 }; 234 235 REGISTER_KERNEL_BUILDER(Name("DictionaryReverseLookup").Device(tf::DEVICE_CPU), 236 DictionaryReverseLookup); 237 REGISTER_OP("DictionaryReverseLookup") 238 .Input("token_ids: int64") 239 .Output("tokens: string") 240 .Attr("dictionary_description_proto: string = ''") 241 .SetShapeFn(::tensorflow::shape_inference::UnchangedShape) 242 .Doc(R"doc( 243 Maps each id to its string by performing a reverse lookup in the dictionary. 244 245 The dictionary should be fully specified at construction time via the 246 dictionary_description_proto. Output has the same shape as input. 247 248 token_ids: A `Tensor` of int64 ids to lookup in the dictionary. 249 dictionary_description_proto: A `DictionaryDescription` as a string. 250 )doc"); 251 } // namespace tensorflow 252 } // namespace fcp 253