xref: /aosp_15_r20/external/federated-compute/fcp/tensorflow/example_selector_fuser_op.cc (revision 14675a029014e728ec732f129a32e299b2da0601)
1 /*
2  * Copyright 2021 Google LLC
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include <string>
17 
18 #include "fcp/protos/plan.pb.h"
19 #include "google/protobuf/any.pb.h"
20 #include "tensorflow/core/framework/common_shape_fns.h"
21 #include "tensorflow/core/framework/dataset.h"
22 #include "tensorflow/core/framework/op.h"
23 #include "tensorflow/core/framework/op_kernel.h"
24 #include "tensorflow/core/framework/tensor.h"
25 #include "tensorflow/core/lib/core/errors.h"
26 
27 namespace fcp {
28 
29 using ::google::internal::federated::plan::ExampleSelector;
30 using ::tensorflow::DEVICE_CPU;
31 using ::tensorflow::OpKernel;
32 using ::tensorflow::OpKernelConstruction;
33 using ::tensorflow::OpKernelContext;
34 using ::tensorflow::Tensor;
35 using ::tensorflow::data::ParseScalarArgument;
36 
37 /**
38  * ExampleSelectorFuserOp op-kernel.
39  *
40  * ExampleSelectorFuser fills the resumption token field for an existing
41  * ExampleSelector protobuf message. The resumption token field is an Any proto
42  * which can be any user defined protobuf message. The user needs to provide the
43  * type url and content for the resumption token.
44  *
45  * Inputs:
46  *   example_selector: A string scalar encodes an ExampleSelector protobuf
47  *   message.
48  *   resumption_token_type_url: String scalar. The type_url for the resumption
49  *   token.
50  *   resumption_token_content: String scalar.  The bytes for the resumption
51  *   token message.
52  *
53  * Output:
54  *   A string tensor contains the fused ExampleSelector message serialized to
55  * string.
56  */
57 class ExampleSelectorFuserOp : public OpKernel {
58  public:
ExampleSelectorFuserOp(OpKernelConstruction * context)59   explicit ExampleSelectorFuserOp(OpKernelConstruction* context)
60       : OpKernel(context) {}
61 
Compute(OpKernelContext * ctx)62   void Compute(OpKernelContext* ctx) override {
63     tensorflow::tstring example_selector_str;
64     OP_REQUIRES_OK(ctx, ParseScalarArgument<tensorflow::tstring>(
65                             ctx, "example_selector", &example_selector_str));
66     tensorflow::tstring resumption_token_type_url_str;
67     OP_REQUIRES_OK(ctx, ParseScalarArgument<tensorflow::tstring>(
68                             ctx, "resumption_token_type_url",
69                             &resumption_token_type_url_str));
70     tensorflow::tstring resumption_token_content_str;
71     OP_REQUIRES_OK(ctx, ParseScalarArgument<tensorflow::tstring>(
72                             ctx, "resumption_token_content",
73                             &resumption_token_content_str));
74     ExampleSelector example_selector;
75     if (!example_selector.ParseFromString(
76             std::string(example_selector_str.data()))) {
77       ctx->SetStatus(tensorflow::Status(
78           // Remove the cast after TF 2.12 is released and used in FCP.
79           static_cast<tensorflow::errors::Code>(
80               absl::StatusCode::kInvalidArgument),
81           tensorflow::StringPiece("Cannot parse ExampleSelector")));
82       return;
83     }
84     example_selector.mutable_resumption_token()->set_type_url(
85         std::string(resumption_token_type_url_str.data()));
86     example_selector.mutable_resumption_token()->set_value(
87         std::string(resumption_token_content_str.data()));
88 
89     Tensor* output_tensor = nullptr;
90     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, {}, &output_tensor));
91     output_tensor->flat<tensorflow::tstring>()(0) =
92         example_selector.SerializeAsString();
93   }
94 };
95 
96 REGISTER_OP("ExampleSelectorFuser")
97     .Input("example_selector: string")
98     .Input("resumption_token_type_url: string")
99     .Input("resumption_token_content: string")
100     .Output("fused_example_selector: string")
101     .SetShapeFn(tensorflow::shape_inference::ScalarShape);
102 REGISTER_KERNEL_BUILDER(Name("ExampleSelectorFuser").Device(DEVICE_CPU),
103                         ExampleSelectorFuserOp);
104 }  // namespace fcp
105