1*14675a02SAndroid Build Coastguard Worker /* 2*14675a02SAndroid Build Coastguard Worker * Copyright 2022 Google LLC 3*14675a02SAndroid Build Coastguard Worker * 4*14675a02SAndroid Build Coastguard Worker * Licensed under the Apache License, Version 2.0 (the "License"); 5*14675a02SAndroid Build Coastguard Worker * you may not use this file except in compliance with the License. 6*14675a02SAndroid Build Coastguard Worker * You may obtain a copy of the License at 7*14675a02SAndroid Build Coastguard Worker * 8*14675a02SAndroid Build Coastguard Worker * http://www.apache.org/licenses/LICENSE-2.0 9*14675a02SAndroid Build Coastguard Worker * 10*14675a02SAndroid Build Coastguard Worker * Unless required by applicable law or agreed to in writing, software 11*14675a02SAndroid Build Coastguard Worker * distributed under the License is distributed on an "AS IS" BASIS, 12*14675a02SAndroid Build Coastguard Worker * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13*14675a02SAndroid Build Coastguard Worker * See the License for the specific language governing permissions and 14*14675a02SAndroid Build Coastguard Worker * limitations under the License. 15*14675a02SAndroid Build Coastguard Worker */ 16*14675a02SAndroid Build Coastguard Worker #include <cstdint> 17*14675a02SAndroid Build Coastguard Worker #include <functional> 18*14675a02SAndroid Build Coastguard Worker #include <memory> 19*14675a02SAndroid Build Coastguard Worker #include <string> 20*14675a02SAndroid Build Coastguard Worker #include <utility> 21*14675a02SAndroid Build Coastguard Worker #include <vector> 22*14675a02SAndroid Build Coastguard Worker 23*14675a02SAndroid Build Coastguard Worker #include "fcp/base/monitoring.h" 24*14675a02SAndroid Build Coastguard Worker #include "fcp/dictionary/dictionary.h" 25*14675a02SAndroid Build Coastguard Worker #include "fcp/dictionary/dictionary.pb.h" 26*14675a02SAndroid Build Coastguard Worker #include "absl/status/status.h" 27*14675a02SAndroid Build Coastguard Worker #include "absl/status/statusor.h" 28*14675a02SAndroid Build Coastguard Worker #include "tensorflow/core/framework/common_shape_fns.h" 29*14675a02SAndroid Build Coastguard Worker #include "tensorflow/core/framework/op.h" 30*14675a02SAndroid Build Coastguard Worker #include "tensorflow/core/framework/op_kernel.h" 31*14675a02SAndroid Build Coastguard Worker #include "tensorflow/core/public/version.h" 32*14675a02SAndroid Build Coastguard Worker 33*14675a02SAndroid Build Coastguard Worker namespace tf = tensorflow; 34*14675a02SAndroid Build Coastguard Worker 35*14675a02SAndroid Build Coastguard Worker namespace fcp { 36*14675a02SAndroid Build Coastguard Worker namespace tensorflow { 37*14675a02SAndroid Build Coastguard Worker 38*14675a02SAndroid Build Coastguard Worker using fcp::dictionary::Dictionary; 39*14675a02SAndroid Build Coastguard Worker using fcp::dictionary::DictionaryDescription; 40*14675a02SAndroid Build Coastguard Worker 41*14675a02SAndroid Build Coastguard Worker namespace { 42*14675a02SAndroid Build Coastguard Worker 43*14675a02SAndroid Build Coastguard Worker // Base class for ops that work with a Dictionary. 44*14675a02SAndroid Build Coastguard Worker // 45*14675a02SAndroid Build Coastguard Worker // Subclasses need to provide Compute and register appropriately using 46*14675a02SAndroid Build Coastguard Worker // REGISTER_OP. 47*14675a02SAndroid Build Coastguard Worker class AbstractDictionaryOp : public tf::OpKernel { 48*14675a02SAndroid Build Coastguard Worker public: AbstractDictionaryOp(tf::OpKernelConstruction * context,int32_t num_expected_inputs)49*14675a02SAndroid Build Coastguard Worker explicit AbstractDictionaryOp(tf::OpKernelConstruction* context, 50*14675a02SAndroid Build Coastguard Worker int32_t num_expected_inputs) 51*14675a02SAndroid Build Coastguard Worker : tf::OpKernel(context), num_expected_inputs_(num_expected_inputs) { 52*14675a02SAndroid Build Coastguard Worker std::string dictionary_description_string; 53*14675a02SAndroid Build Coastguard Worker OP_REQUIRES_OK(context, 54*14675a02SAndroid Build Coastguard Worker context->GetAttr("dictionary_description_proto", 55*14675a02SAndroid Build Coastguard Worker &dictionary_description_string)); 56*14675a02SAndroid Build Coastguard Worker 57*14675a02SAndroid Build Coastguard Worker DictionaryDescription parsed_dictionary_description; 58*14675a02SAndroid Build Coastguard Worker OP_REQUIRES(context, 59*14675a02SAndroid Build Coastguard Worker parsed_dictionary_description.ParseFromString( 60*14675a02SAndroid Build Coastguard Worker dictionary_description_string), 61*14675a02SAndroid Build Coastguard Worker tf::errors::InvalidArgument( 62*14675a02SAndroid Build Coastguard Worker "Cannot parse provided DictionaryDescription.")); 63*14675a02SAndroid Build Coastguard Worker 64*14675a02SAndroid Build Coastguard Worker if (parsed_dictionary_description.has_vocabulary()) { 65*14675a02SAndroid Build Coastguard Worker // Fully specified dictionary. 66*14675a02SAndroid Build Coastguard Worker absl::StatusOr<std::unique_ptr<Dictionary>> dictionary( 67*14675a02SAndroid Build Coastguard Worker Dictionary::Create(parsed_dictionary_description)); 68*14675a02SAndroid Build Coastguard Worker OP_REQUIRES(context, dictionary.ok(), 69*14675a02SAndroid Build Coastguard Worker tf::errors::InvalidArgument(dictionary.status().ToString())); 70*14675a02SAndroid Build Coastguard Worker dictionary_ = *std::move(dictionary); 71*14675a02SAndroid Build Coastguard Worker parsed_dictionary_description.clear_vocabulary(); // Save space. 72*14675a02SAndroid Build Coastguard Worker } 73*14675a02SAndroid Build Coastguard Worker dictionary_description_ = parsed_dictionary_description; 74*14675a02SAndroid Build Coastguard Worker } 75*14675a02SAndroid Build Coastguard Worker Compute(tf::OpKernelContext * context)76*14675a02SAndroid Build Coastguard Worker void Compute(tf::OpKernelContext* context) override { 77*14675a02SAndroid Build Coastguard Worker FCP_CHECK(num_expected_inputs_ == context->num_inputs()); 78*14675a02SAndroid Build Coastguard Worker 79*14675a02SAndroid Build Coastguard Worker // Use the dictionary_ constructed at setup. 80*14675a02SAndroid Build Coastguard Worker OP_REQUIRES(context, dictionary_ != nullptr, 81*14675a02SAndroid Build Coastguard Worker tf::errors::InvalidArgument( 82*14675a02SAndroid Build Coastguard Worker "DictionaryDescription does not contain a vocabulary. ")); 83*14675a02SAndroid Build Coastguard Worker absl::Status status = DoCompute(context, *dictionary_); 84*14675a02SAndroid Build Coastguard Worker OP_REQUIRES(context, status.ok(), 85*14675a02SAndroid Build Coastguard Worker tf::errors::InvalidArgument(std::string(status.message()))); 86*14675a02SAndroid Build Coastguard Worker } 87*14675a02SAndroid Build Coastguard Worker 88*14675a02SAndroid Build Coastguard Worker protected: 89*14675a02SAndroid Build Coastguard Worker // Computes using the given dictionary. 90*14675a02SAndroid Build Coastguard Worker virtual absl::Status DoCompute(tf::OpKernelContext* context, 91*14675a02SAndroid Build Coastguard Worker const Dictionary& dictionary) = 0; 92*14675a02SAndroid Build Coastguard Worker 93*14675a02SAndroid Build Coastguard Worker private: 94*14675a02SAndroid Build Coastguard Worker DictionaryDescription dictionary_description_; 95*14675a02SAndroid Build Coastguard Worker std::unique_ptr<Dictionary> dictionary_; 96*14675a02SAndroid Build Coastguard Worker const int32_t num_expected_inputs_; 97*14675a02SAndroid Build Coastguard Worker }; 98*14675a02SAndroid Build Coastguard Worker 99*14675a02SAndroid Build Coastguard Worker } // namespace 100*14675a02SAndroid Build Coastguard Worker 101*14675a02SAndroid Build Coastguard Worker class DictionarySize : public AbstractDictionaryOp { 102*14675a02SAndroid Build Coastguard Worker public: DictionarySize(tf::OpKernelConstruction * context)103*14675a02SAndroid Build Coastguard Worker explicit DictionarySize(tf::OpKernelConstruction* context) 104*14675a02SAndroid Build Coastguard Worker : AbstractDictionaryOp(context, 0 /* num_expected_inputs */) {} 105*14675a02SAndroid Build Coastguard Worker 106*14675a02SAndroid Build Coastguard Worker protected: DoCompute(tf::OpKernelContext * context,const Dictionary & dictionary)107*14675a02SAndroid Build Coastguard Worker absl::Status DoCompute(tf::OpKernelContext* context, 108*14675a02SAndroid Build Coastguard Worker const Dictionary& dictionary) override { 109*14675a02SAndroid Build Coastguard Worker tf::Tensor* size_tensor; 110*14675a02SAndroid Build Coastguard Worker auto status = 111*14675a02SAndroid Build Coastguard Worker context->allocate_output(0, tf::TensorShape({}), &size_tensor); 112*14675a02SAndroid Build Coastguard Worker if (!status.ok()) { 113*14675a02SAndroid Build Coastguard Worker #if TF_GRAPH_DEF_VERSION < 1467 114*14675a02SAndroid Build Coastguard Worker return absl::InternalError(status.error_message()); 115*14675a02SAndroid Build Coastguard Worker #else 116*14675a02SAndroid Build Coastguard Worker return absl::InternalError(status.message()); 117*14675a02SAndroid Build Coastguard Worker #endif 118*14675a02SAndroid Build Coastguard Worker } 119*14675a02SAndroid Build Coastguard Worker size_tensor->flat<int64_t>()(0) = dictionary.Size(); 120*14675a02SAndroid Build Coastguard Worker return absl::OkStatus(); 121*14675a02SAndroid Build Coastguard Worker } 122*14675a02SAndroid Build Coastguard Worker }; 123*14675a02SAndroid Build Coastguard Worker 124*14675a02SAndroid Build Coastguard Worker REGISTER_KERNEL_BUILDER(Name("DictionarySize").Device(tf::DEVICE_CPU), 125*14675a02SAndroid Build Coastguard Worker DictionarySize); 126*14675a02SAndroid Build Coastguard Worker REGISTER_OP("DictionarySize") 127*14675a02SAndroid Build Coastguard Worker .Output("size: int64") 128*14675a02SAndroid Build Coastguard Worker .Attr("dictionary_description_proto: string = ''") 129*14675a02SAndroid Build Coastguard Worker .SetShapeFn(::tensorflow::shape_inference::ScalarShape) 130*14675a02SAndroid Build Coastguard Worker .Doc(R"doc( 131*14675a02SAndroid Build Coastguard Worker Returns the number of ids in the given dictionary. 132*14675a02SAndroid Build Coastguard Worker 133*14675a02SAndroid Build Coastguard Worker The dictionary should be fully specified at construction time via the 134*14675a02SAndroid Build Coastguard Worker dictionary_description_proto. 135*14675a02SAndroid Build Coastguard Worker 136*14675a02SAndroid Build Coastguard Worker dictionary_description_proto: A `DictionaryDescription` as a string. 137*14675a02SAndroid Build Coastguard Worker )doc"); 138*14675a02SAndroid Build Coastguard Worker 139*14675a02SAndroid Build Coastguard Worker class DictionaryLookup : public AbstractDictionaryOp { 140*14675a02SAndroid Build Coastguard Worker public: DictionaryLookup(tf::OpKernelConstruction * context)141*14675a02SAndroid Build Coastguard Worker explicit DictionaryLookup(tf::OpKernelConstruction* context) 142*14675a02SAndroid Build Coastguard Worker : AbstractDictionaryOp(context, 1 /* num_expected_inputs */) {} 143*14675a02SAndroid Build Coastguard Worker 144*14675a02SAndroid Build Coastguard Worker protected: DoCompute(tf::OpKernelContext * context,const Dictionary & dictionary)145*14675a02SAndroid Build Coastguard Worker absl::Status DoCompute(tf::OpKernelContext* context, 146*14675a02SAndroid Build Coastguard Worker const Dictionary& dictionary) override { 147*14675a02SAndroid Build Coastguard Worker const tf::Tensor& token_tensor = context->input(0); 148*14675a02SAndroid Build Coastguard Worker tf::Tensor* ids_tensor; 149*14675a02SAndroid Build Coastguard Worker auto status = 150*14675a02SAndroid Build Coastguard Worker context->allocate_output(0, token_tensor.shape(), &ids_tensor); 151*14675a02SAndroid Build Coastguard Worker if (!status.ok()) { 152*14675a02SAndroid Build Coastguard Worker #if TF_GRAPH_DEF_VERSION < 1467 153*14675a02SAndroid Build Coastguard Worker return absl::InternalError(status.error_message()); 154*14675a02SAndroid Build Coastguard Worker #else 155*14675a02SAndroid Build Coastguard Worker return absl::InternalError(status.message()); 156*14675a02SAndroid Build Coastguard Worker #endif 157*14675a02SAndroid Build Coastguard Worker } 158*14675a02SAndroid Build Coastguard Worker 159*14675a02SAndroid Build Coastguard Worker if (token_tensor.dtype() != tf::DataType::DT_STRING) { 160*14675a02SAndroid Build Coastguard Worker return absl::InvalidArgumentError("Expected input of 'tokens'."); 161*14675a02SAndroid Build Coastguard Worker } 162*14675a02SAndroid Build Coastguard Worker if (ids_tensor->dtype() != tf::DataType::DT_INT64) { 163*14675a02SAndroid Build Coastguard Worker return absl::InvalidArgumentError("Expected output of 'ids'."); 164*14675a02SAndroid Build Coastguard Worker } 165*14675a02SAndroid Build Coastguard Worker if (token_tensor.shape() != ids_tensor->shape()) { 166*14675a02SAndroid Build Coastguard Worker return absl::InvalidArgumentError("Wrong shape for ids_tensor"); 167*14675a02SAndroid Build Coastguard Worker } 168*14675a02SAndroid Build Coastguard Worker const auto tokens_flat = token_tensor.flat<tf::tstring>(); 169*14675a02SAndroid Build Coastguard Worker auto ids_flat = ids_tensor->flat<int64_t>(); 170*14675a02SAndroid Build Coastguard Worker const int64_t num_tokens = tokens_flat.size(); 171*14675a02SAndroid Build Coastguard Worker for (int i = 0; i < num_tokens; ++i) { 172*14675a02SAndroid Build Coastguard Worker ids_flat(i) = dictionary.TokenToId(tokens_flat(i)); 173*14675a02SAndroid Build Coastguard Worker } 174*14675a02SAndroid Build Coastguard Worker return absl::OkStatus(); 175*14675a02SAndroid Build Coastguard Worker } 176*14675a02SAndroid Build Coastguard Worker }; 177*14675a02SAndroid Build Coastguard Worker 178*14675a02SAndroid Build Coastguard Worker REGISTER_KERNEL_BUILDER(Name("DictionaryLookup").Device(tf::DEVICE_CPU), 179*14675a02SAndroid Build Coastguard Worker DictionaryLookup); 180*14675a02SAndroid Build Coastguard Worker REGISTER_OP("DictionaryLookup") 181*14675a02SAndroid Build Coastguard Worker .Input("tokens: string") 182*14675a02SAndroid Build Coastguard Worker .Output("token_ids: int64") 183*14675a02SAndroid Build Coastguard Worker .Attr("dictionary_description_proto: string = ''") 184*14675a02SAndroid Build Coastguard Worker .SetShapeFn(::tensorflow::shape_inference::UnchangedShape) 185*14675a02SAndroid Build Coastguard Worker .Doc(R"doc( 186*14675a02SAndroid Build Coastguard Worker Maps each string to an id by lookup in the dictionary. 187*14675a02SAndroid Build Coastguard Worker 188*14675a02SAndroid Build Coastguard Worker The dictionary should be fully specified at construction time via the 189*14675a02SAndroid Build Coastguard Worker dictionary_description_proto. Output has the same shape as input. 190*14675a02SAndroid Build Coastguard Worker 191*14675a02SAndroid Build Coastguard Worker tokens: A `Tensor` of strings to lookup in the dictionary. 192*14675a02SAndroid Build Coastguard Worker dictionary_description_proto: A `DictionaryDescription` as a string. 193*14675a02SAndroid Build Coastguard Worker )doc"); 194*14675a02SAndroid Build Coastguard Worker 195*14675a02SAndroid Build Coastguard Worker class DictionaryReverseLookup : public AbstractDictionaryOp { 196*14675a02SAndroid Build Coastguard Worker public: DictionaryReverseLookup(tf::OpKernelConstruction * context)197*14675a02SAndroid Build Coastguard Worker explicit DictionaryReverseLookup(tf::OpKernelConstruction* context) 198*14675a02SAndroid Build Coastguard Worker : AbstractDictionaryOp(context, 1 /* num_expected_inputs */) {} 199*14675a02SAndroid Build Coastguard Worker 200*14675a02SAndroid Build Coastguard Worker protected: DoCompute(tf::OpKernelContext * context,const Dictionary & dictionary)201*14675a02SAndroid Build Coastguard Worker absl::Status DoCompute(tf::OpKernelContext* context, 202*14675a02SAndroid Build Coastguard Worker const Dictionary& dictionary) override { 203*14675a02SAndroid Build Coastguard Worker const tf::Tensor& ids_tensor = context->input(0); 204*14675a02SAndroid Build Coastguard Worker tf::Tensor* token_tensor; 205*14675a02SAndroid Build Coastguard Worker auto status = 206*14675a02SAndroid Build Coastguard Worker context->allocate_output(0, ids_tensor.shape(), &token_tensor); 207*14675a02SAndroid Build Coastguard Worker if (!status.ok()) { 208*14675a02SAndroid Build Coastguard Worker #if TF_GRAPH_DEF_VERSION < 1467 209*14675a02SAndroid Build Coastguard Worker return absl::InternalError(status.error_message()); 210*14675a02SAndroid Build Coastguard Worker #else 211*14675a02SAndroid Build Coastguard Worker return absl::InternalError(status.message()); 212*14675a02SAndroid Build Coastguard Worker #endif 213*14675a02SAndroid Build Coastguard Worker } 214*14675a02SAndroid Build Coastguard Worker 215*14675a02SAndroid Build Coastguard Worker if (token_tensor->dtype() != tf::DataType::DT_STRING) { 216*14675a02SAndroid Build Coastguard Worker return absl::InvalidArgumentError("Expected input of 'tokens'."); 217*14675a02SAndroid Build Coastguard Worker } 218*14675a02SAndroid Build Coastguard Worker if (ids_tensor.dtype() != tf::DataType::DT_INT64) { 219*14675a02SAndroid Build Coastguard Worker return absl::InvalidArgumentError("Expected output of 'ids'."); 220*14675a02SAndroid Build Coastguard Worker } 221*14675a02SAndroid Build Coastguard Worker if (ids_tensor.shape() != token_tensor->shape()) { 222*14675a02SAndroid Build Coastguard Worker return absl::InvalidArgumentError("Wrong shape for token_tensor"); 223*14675a02SAndroid Build Coastguard Worker } 224*14675a02SAndroid Build Coastguard Worker 225*14675a02SAndroid Build Coastguard Worker const auto ids_flat = ids_tensor.flat<int64_t>(); 226*14675a02SAndroid Build Coastguard Worker auto tokens_flat = token_tensor->flat<tf::tstring>(); 227*14675a02SAndroid Build Coastguard Worker const int64_t num_tokens = ids_flat.size(); 228*14675a02SAndroid Build Coastguard Worker for (int i = 0; i < num_tokens; ++i) { 229*14675a02SAndroid Build Coastguard Worker tokens_flat(i) = dictionary.IdToToken(ids_flat(i)); 230*14675a02SAndroid Build Coastguard Worker } 231*14675a02SAndroid Build Coastguard Worker return absl::OkStatus(); 232*14675a02SAndroid Build Coastguard Worker } 233*14675a02SAndroid Build Coastguard Worker }; 234*14675a02SAndroid Build Coastguard Worker 235*14675a02SAndroid Build Coastguard Worker REGISTER_KERNEL_BUILDER(Name("DictionaryReverseLookup").Device(tf::DEVICE_CPU), 236*14675a02SAndroid Build Coastguard Worker DictionaryReverseLookup); 237*14675a02SAndroid Build Coastguard Worker REGISTER_OP("DictionaryReverseLookup") 238*14675a02SAndroid Build Coastguard Worker .Input("token_ids: int64") 239*14675a02SAndroid Build Coastguard Worker .Output("tokens: string") 240*14675a02SAndroid Build Coastguard Worker .Attr("dictionary_description_proto: string = ''") 241*14675a02SAndroid Build Coastguard Worker .SetShapeFn(::tensorflow::shape_inference::UnchangedShape) 242*14675a02SAndroid Build Coastguard Worker .Doc(R"doc( 243*14675a02SAndroid Build Coastguard Worker Maps each id to its string by performing a reverse lookup in the dictionary. 244*14675a02SAndroid Build Coastguard Worker 245*14675a02SAndroid Build Coastguard Worker The dictionary should be fully specified at construction time via the 246*14675a02SAndroid Build Coastguard Worker dictionary_description_proto. Output has the same shape as input. 247*14675a02SAndroid Build Coastguard Worker 248*14675a02SAndroid Build Coastguard Worker token_ids: A `Tensor` of int64 ids to lookup in the dictionary. 249*14675a02SAndroid Build Coastguard Worker dictionary_description_proto: A `DictionaryDescription` as a string. 250*14675a02SAndroid Build Coastguard Worker )doc"); 251*14675a02SAndroid Build Coastguard Worker } // namespace tensorflow 252*14675a02SAndroid Build Coastguard Worker } // namespace fcp 253