1 /* 2 * Copyright 2021 Google LLC 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #include <string> 18 19 #include "fcp/protos/federated_api.pb.h" 20 #include "tensorflow/core/framework/common_shape_fns.h" 21 #include "tensorflow/core/framework/dataset.h" 22 #include "tensorflow/core/framework/op.h" 23 #include "tensorflow/core/framework/op_kernel.h" 24 #include "tensorflow/core/framework/shape_inference.h" 25 #include "tensorflow/core/lib/core/errors.h" 26 #include "tensorflow/core/public/version.h" 27 28 namespace fcp { 29 30 using ::google::internal::federatedml::v2::TaskEligibilityInfo; 31 using ::google::internal::federatedml::v2::TaskWeight; 32 33 /** 34 * CreateTaskEligibilityInfo op-kernel. Converts a set of input tensors into a 35 * `TaskEligibilityInfo` proto serialized into a string tensor. 36 * 37 * This op is used to generate `TaskEligibilityInfo` protos from a model at 38 * runtime, since TF Mobile does not support the standard TensorFlow ops for 39 * encoding/decoding protos. 40 */ 41 class CreateTaskEligibilityInfoOp : public tensorflow::OpKernel { 42 public: CreateTaskEligibilityInfoOp(tensorflow::OpKernelConstruction * context)43 explicit CreateTaskEligibilityInfoOp( 44 tensorflow::OpKernelConstruction* context) 45 : OpKernel(context) {} 46 Compute(tensorflow::OpKernelContext * ctx)47 void Compute(tensorflow::OpKernelContext* ctx) override { 48 // Note: We use the tensorflow::data::ParseScalar/VectorArgument helpers 49 // here, even though this op isn't strictly related to our tf.Dataset 50 // integration. The helpers are public though, and we already use them in 51 // our ExternalDataset implementation, so we might as well use them here 52 // too. 53 54 // Parse/validate the input arguments. 55 tensorflow::int64 version; 56 OP_REQUIRES_OK( 57 ctx, tensorflow::data::ParseScalarArgument(ctx, "version", &version)); 58 std::vector<tensorflow::tstring> task_names; 59 OP_REQUIRES_OK(ctx, tensorflow::data::ParseVectorArgument(ctx, "task_names", 60 &task_names)); 61 std::vector<float> task_weights; 62 OP_REQUIRES_OK(ctx, tensorflow::data::ParseVectorArgument( 63 ctx, "task_weights", &task_weights)); 64 OP_REQUIRES(ctx, task_names.size() == task_weights.size(), 65 tensorflow::errors::InvalidArgument(absl::StrCat( 66 "task_names length must match task_weights length: ", 67 task_names.size(), " vs. ", task_weights.size()))); 68 69 // Create the output proto, based on the inputs. 70 TaskEligibilityInfo eligibility_info; 71 eligibility_info.set_version(version); 72 // Create a `TaskWeight` message for each pair of `task_names` and 73 // `task_weights` elements. 74 auto task_weight_it = task_weights.cbegin(); 75 for (const tensorflow::tstring& task_name : task_names) { 76 float task_weight = *task_weight_it++; 77 TaskWeight* task_weight_proto = eligibility_info.add_task_weights(); 78 task_weight_proto->set_task_name(std::string(task_name)); 79 task_weight_proto->set_weight(task_weight); 80 } 81 82 // Place the serialized output proto into the output tensor. 83 tensorflow::Tensor* output_tensor; 84 OP_REQUIRES_OK(ctx, 85 ctx->allocate_output("output", tensorflow::TensorShape({}), 86 &output_tensor)); 87 output_tensor->scalar<tensorflow::tstring>()() = 88 eligibility_info.SerializeAsString(); 89 } 90 }; 91 92 REGISTER_OP("CreateTaskEligibilityInfo") 93 .Input("version: int64") 94 .Input("task_names: string") 95 .Input("task_weights: float32") 96 .Output("output: string") 97 .SetShapeFn(tensorflow::shape_inference::ScalarShape); 98 99 REGISTER_KERNEL_BUILDER( 100 Name("CreateTaskEligibilityInfo").Device(tensorflow::DEVICE_CPU), 101 CreateTaskEligibilityInfoOp); 102 103 } // namespace fcp 104