1 /* 2 * Copyright 2021 Google LLC 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #include <string> 17 18 #include "fcp/protos/plan.pb.h" 19 #include "google/protobuf/any.pb.h" 20 #include "tensorflow/core/framework/common_shape_fns.h" 21 #include "tensorflow/core/framework/dataset.h" 22 #include "tensorflow/core/framework/op.h" 23 #include "tensorflow/core/framework/op_kernel.h" 24 #include "tensorflow/core/framework/tensor.h" 25 #include "tensorflow/core/lib/core/errors.h" 26 27 namespace fcp { 28 29 using ::google::internal::federated::plan::ExampleSelector; 30 using ::tensorflow::DEVICE_CPU; 31 using ::tensorflow::OpKernel; 32 using ::tensorflow::OpKernelConstruction; 33 using ::tensorflow::OpKernelContext; 34 using ::tensorflow::Tensor; 35 using ::tensorflow::data::ParseScalarArgument; 36 37 /** 38 * ExampleSelectorFuserOp op-kernel. 39 * 40 * ExampleSelectorFuser fills the resumption token field for an existing 41 * ExampleSelector protobuf message. The resumption token field is an Any proto 42 * which can be any user defined protobuf message. The user needs to provide the 43 * type url and content for the resumption token. 44 * 45 * Inputs: 46 * example_selector: A string scalar encodes an ExampleSelector protobuf 47 * message. 48 * resumption_token_type_url: String scalar. The type_url for the resumption 49 * token. 50 * resumption_token_content: String scalar. The bytes for the resumption 51 * token message. 52 * 53 * Output: 54 * A string tensor contains the fused ExampleSelector message serialized to 55 * string. 56 */ 57 class ExampleSelectorFuserOp : public OpKernel { 58 public: ExampleSelectorFuserOp(OpKernelConstruction * context)59 explicit ExampleSelectorFuserOp(OpKernelConstruction* context) 60 : OpKernel(context) {} 61 Compute(OpKernelContext * ctx)62 void Compute(OpKernelContext* ctx) override { 63 tensorflow::tstring example_selector_str; 64 OP_REQUIRES_OK(ctx, ParseScalarArgument<tensorflow::tstring>( 65 ctx, "example_selector", &example_selector_str)); 66 tensorflow::tstring resumption_token_type_url_str; 67 OP_REQUIRES_OK(ctx, ParseScalarArgument<tensorflow::tstring>( 68 ctx, "resumption_token_type_url", 69 &resumption_token_type_url_str)); 70 tensorflow::tstring resumption_token_content_str; 71 OP_REQUIRES_OK(ctx, ParseScalarArgument<tensorflow::tstring>( 72 ctx, "resumption_token_content", 73 &resumption_token_content_str)); 74 ExampleSelector example_selector; 75 if (!example_selector.ParseFromString( 76 std::string(example_selector_str.data()))) { 77 ctx->SetStatus(tensorflow::Status( 78 // Remove the cast after TF 2.12 is released and used in FCP. 79 static_cast<tensorflow::errors::Code>( 80 absl::StatusCode::kInvalidArgument), 81 tensorflow::StringPiece("Cannot parse ExampleSelector"))); 82 return; 83 } 84 example_selector.mutable_resumption_token()->set_type_url( 85 std::string(resumption_token_type_url_str.data())); 86 example_selector.mutable_resumption_token()->set_value( 87 std::string(resumption_token_content_str.data())); 88 89 Tensor* output_tensor = nullptr; 90 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, {}, &output_tensor)); 91 output_tensor->flat<tensorflow::tstring>()(0) = 92 example_selector.SerializeAsString(); 93 } 94 }; 95 96 REGISTER_OP("ExampleSelectorFuser") 97 .Input("example_selector: string") 98 .Input("resumption_token_type_url: string") 99 .Input("resumption_token_content: string") 100 .Output("fused_example_selector: string") 101 .SetShapeFn(tensorflow::shape_inference::ScalarShape); 102 REGISTER_KERNEL_BUILDER(Name("ExampleSelectorFuser").Device(DEVICE_CPU), 103 ExampleSelectorFuserOp); 104 } // namespace fcp 105