xref: /aosp_15_r20/external/federated-compute/fcp/tensorflow/task_eligibility_info_ops.cc (revision 14675a029014e728ec732f129a32e299b2da0601)
1 /*
2  * Copyright 2021 Google LLC
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <string>
18 
19 #include "fcp/protos/federated_api.pb.h"
20 #include "tensorflow/core/framework/common_shape_fns.h"
21 #include "tensorflow/core/framework/dataset.h"
22 #include "tensorflow/core/framework/op.h"
23 #include "tensorflow/core/framework/op_kernel.h"
24 #include "tensorflow/core/framework/shape_inference.h"
25 #include "tensorflow/core/lib/core/errors.h"
26 #include "tensorflow/core/public/version.h"
27 
28 namespace fcp {
29 
30 using ::google::internal::federatedml::v2::TaskEligibilityInfo;
31 using ::google::internal::federatedml::v2::TaskWeight;
32 
33 /**
34  * CreateTaskEligibilityInfo op-kernel. Converts a set of input tensors into a
35  * `TaskEligibilityInfo` proto serialized into a string tensor.
36  *
37  * This op is used to generate `TaskEligibilityInfo` protos from a model at
38  * runtime, since TF Mobile does not support the standard TensorFlow ops for
39  * encoding/decoding protos.
40  */
41 class CreateTaskEligibilityInfoOp : public tensorflow::OpKernel {
42  public:
CreateTaskEligibilityInfoOp(tensorflow::OpKernelConstruction * context)43   explicit CreateTaskEligibilityInfoOp(
44       tensorflow::OpKernelConstruction* context)
45       : OpKernel(context) {}
46 
Compute(tensorflow::OpKernelContext * ctx)47   void Compute(tensorflow::OpKernelContext* ctx) override {
48     // Note: We use the tensorflow::data::ParseScalar/VectorArgument helpers
49     // here, even though this op isn't strictly related to our tf.Dataset
50     // integration. The helpers are public though, and we already use them in
51     // our ExternalDataset implementation, so we might as well use them here
52     // too.
53 
54     // Parse/validate the input arguments.
55     tensorflow::int64 version;
56     OP_REQUIRES_OK(
57         ctx, tensorflow::data::ParseScalarArgument(ctx, "version", &version));
58     std::vector<tensorflow::tstring> task_names;
59     OP_REQUIRES_OK(ctx, tensorflow::data::ParseVectorArgument(ctx, "task_names",
60                                                               &task_names));
61     std::vector<float> task_weights;
62     OP_REQUIRES_OK(ctx, tensorflow::data::ParseVectorArgument(
63                             ctx, "task_weights", &task_weights));
64     OP_REQUIRES(ctx, task_names.size() == task_weights.size(),
65                 tensorflow::errors::InvalidArgument(absl::StrCat(
66                     "task_names length must match task_weights length: ",
67                     task_names.size(), " vs. ", task_weights.size())));
68 
69     // Create the output proto, based on the inputs.
70     TaskEligibilityInfo eligibility_info;
71     eligibility_info.set_version(version);
72     // Create a `TaskWeight` message for each pair of `task_names` and
73     // `task_weights` elements.
74     auto task_weight_it = task_weights.cbegin();
75     for (const tensorflow::tstring& task_name : task_names) {
76       float task_weight = *task_weight_it++;
77       TaskWeight* task_weight_proto = eligibility_info.add_task_weights();
78       task_weight_proto->set_task_name(std::string(task_name));
79       task_weight_proto->set_weight(task_weight);
80     }
81 
82     // Place the serialized output proto into the output tensor.
83     tensorflow::Tensor* output_tensor;
84     OP_REQUIRES_OK(ctx,
85                    ctx->allocate_output("output", tensorflow::TensorShape({}),
86                                         &output_tensor));
87     output_tensor->scalar<tensorflow::tstring>()() =
88         eligibility_info.SerializeAsString();
89   }
90 };
91 
92 REGISTER_OP("CreateTaskEligibilityInfo")
93     .Input("version: int64")
94     .Input("task_names: string")
95     .Input("task_weights: float32")
96     .Output("output: string")
97     .SetShapeFn(tensorflow::shape_inference::ScalarShape);
98 
99 REGISTER_KERNEL_BUILDER(
100     Name("CreateTaskEligibilityInfo").Device(tensorflow::DEVICE_CPU),
101     CreateTaskEligibilityInfoOp);
102 
103 }  // namespace fcp
104