xref: /aosp_15_r20/external/federated-compute/fcp/tensorflow/serve_slices_op.cc (revision 14675a029014e728ec732f129a32e299b2da0601)
1 /*
2  * Copyright 2021 Google LLC
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <string>
18 #include <utility>
19 
20 #include "absl/strings/str_cat.h"
21 #include "absl/strings/str_format.h"
22 #include "fcp/tensorflow/serve_slices_registry.h"
23 #include "tensorflow/core/framework/common_shape_fns.h"
24 #include "tensorflow/core/framework/op.h"
25 #include "tensorflow/core/framework/op_kernel.h"
26 #include "tensorflow/core/platform/errors.h"
27 #include "tensorflow/core/platform/status.h"
28 #include "tensorflow/core/platform/stringpiece.h"
29 
30 namespace fcp {
31 
32 namespace {
33 
34 REGISTER_OP("ServeSlices")
35     .Attr("NumTensorsInServerVal: int")
36     .Attr("ServerValType: list(type)")
37     .Input("callback_token: string")
38     .Input("server_val: ServerValType")
39     .Input("max_key: int32")
40     .Input("select_fn_initialize_op: string")
41     .Input(
42         "select_fn_server_val_input_tensor_names: NumTensorsInServerVal * "
43         "string")
44     .Input("select_fn_key_input_tensor_name: string")
45     .Input("select_fn_filename_input_tensor_name: string")
46     .Input("select_fn_target_tensor_name: string")
47     .Output("served_at_id: string")
48     .SetIsStateful()
49     .SetShapeFn(tensorflow::shape_inference::ScalarShape);
50 
51 template <class T>
get_scalar_input(tensorflow::OpKernelContext * context,tensorflow::StringPiece name,T * scalar_out)52 tensorflow::Status get_scalar_input(tensorflow::OpKernelContext* context,
53                                     tensorflow::StringPiece name,
54                                     T* scalar_out) {
55   const tensorflow::Tensor* tensor;
56   TF_RETURN_IF_ERROR(context->input(name, &tensor));
57   *scalar_out = tensor->scalar<T>()();
58   return tensorflow::OkStatus();
59 }
60 
get_arbitrary_input_list_as_tensor_vector(tensorflow::OpKernelContext * context,tensorflow::StringPiece name,std::vector<tensorflow::Tensor> * out)61 tensorflow::Status get_arbitrary_input_list_as_tensor_vector(
62     tensorflow::OpKernelContext* context, tensorflow::StringPiece name,
63     std::vector<tensorflow::Tensor>* out) {
64   tensorflow::OpInputList input_list;
65   TF_RETURN_IF_ERROR(context->input_list(name, &input_list));
66   out->reserve(input_list.size());
67   for (const tensorflow::Tensor& tensor : input_list) {
68     out->push_back(tensor);
69   }
70   return tensorflow::OkStatus();
71 }
72 
get_string_list_input(tensorflow::OpKernelContext * context,tensorflow::StringPiece name,std::vector<std::string> * out)73 tensorflow::Status get_string_list_input(tensorflow::OpKernelContext* context,
74                                          tensorflow::StringPiece name,
75                                          std::vector<std::string>* out) {
76   tensorflow::OpInputList input_list;
77   TF_RETURN_IF_ERROR(context->input_list(name, &input_list));
78   out->reserve(input_list.size());
79   for (const tensorflow::Tensor& tensor : input_list) {
80     out->emplace_back(tensor.scalar<tensorflow::tstring>()());
81   }
82   return tensorflow::OkStatus();
83 }
84 
85 // ServeSlices op-kernel.
86 //
87 // The ServeSlicesOp registers values present on a federated computation server
88 // to be sliced and served to clients for a `federated_select`
89 //
90 // Inputs:
91 //   callback_token: The ID of the C++ callback to invoke in order to register
92 //   the
93 //     given value. Callbacks must first be registered using
94 //     `register_serve_slices_callback`.
95 //   server_val: A series of arbitrary-typed tensors from which slices may be
96 //     generated using a selection function (referred to as `select_fn`).
97 //     These tensors must be passed into the `select_fn` by writing them to the
98 //     placeholder tensors named by `select_fn_server_val_input_names`, which
99 //     must contain exactly one tensor name for each tensor in `server_val`.
100 //   max_key: An integer indicating the maximum slice index which may be
101 //     requested. Slice indices start at zero and may go up to `max_key`
102 //     (inclusive).
103 //   select_fn_initialize_op: An op to run before each call to `select_fn` in
104 //     order to reinitialize any state `select_fn` may contain.
105 //   select_fn_server_val_input_tensor_names: A list of names of the tensors
106 //     that make up the `server_val` portion of the inputs to `select_fn`. Must
107 //     be the same length as the number of tensors in `server_val`.
108 //   select_fn_key_input_tensor_name: The name of the tensor that is the `key`
109 //     input to `select_fn`.
110 //   select_fn_filename_input_tensor_name: The name of the placeholder tensor
111 //     that is the `filename` input to `select_fn`. The `filename` is used to
112 //     specify where the resulting slice should be written.
113 //   select_fn_target_tensor_name: The name of the `target` tensor to run which
114 //     will result in `select_fn`'s output being written to `filename`.
115 //
116 // Outputs:
117 //   served_at_id: A string ID under which the resulting slices will be served.
118 //     This can then be provided to the `FetchSlicesOp` running on clients.
119 class ServeSlicesOp : public tensorflow::OpKernel {
120  public:
ServeSlicesOp(tensorflow::OpKernelConstruction * context)121   explicit ServeSlicesOp(tensorflow::OpKernelConstruction* context)
122       : OpKernel(context) {}
123 
Compute(tensorflow::OpKernelContext * context)124   void Compute(tensorflow::OpKernelContext* context) override {
125     tensorflow::tstring callback_token_tensor;
126     OP_REQUIRES_OK(context, get_scalar_input(context, "callback_token",
127                                              &callback_token_tensor));
128     absl::Span<char const> callback_token_bytes = callback_token_tensor;
129     OP_REQUIRES(context, callback_token_bytes.size() == kRandomTokenSizeInBytes,
130                 tensorflow::errors::InvalidArgument(absl::StrFormat(
131                     "Tokens have a fixed size. Expected: %d; Actual %d",
132                     kRandomTokenSizeInBytes, callback_token_bytes.size())));
133     RandomToken callback_token = RandomToken::FromBytes(callback_token_bytes);
134 
135     std::vector<tensorflow::Tensor> server_val;
136     OP_REQUIRES_OK(context, get_arbitrary_input_list_as_tensor_vector(
137                                 context, "server_val", &server_val));
138 
139     int32_t max_key;
140     OP_REQUIRES_OK(context, get_scalar_input(context, "max_key", &max_key));
141 
142     tensorflow::tstring select_fn_initialize_op;
143     OP_REQUIRES_OK(context, get_scalar_input(context, "select_fn_initialize_op",
144                                              &select_fn_initialize_op));
145 
146     std::vector<std::string> select_fn_server_val_input_tensor_names;
147     OP_REQUIRES_OK(context,
148                    get_string_list_input(
149                        context, "select_fn_server_val_input_tensor_names",
150                        &select_fn_server_val_input_tensor_names));
151 
152     tensorflow::tstring select_fn_key_input_tensor_name;
153     OP_REQUIRES_OK(context,
154                    get_scalar_input(context, "select_fn_key_input_tensor_name",
155                                     &select_fn_key_input_tensor_name));
156 
157     tensorflow::tstring select_fn_filename_input_tensor_name;
158     OP_REQUIRES_OK(context, get_scalar_input(
159                                 context, "select_fn_filename_input_tensor_name",
160                                 &select_fn_filename_input_tensor_name));
161 
162     tensorflow::tstring select_fn_target_tensor_name;
163     OP_REQUIRES_OK(context,
164                    get_scalar_input(context, "select_fn_target_tensor_name",
165                                     &select_fn_target_tensor_name));
166 
167     std::optional<std::shared_ptr<ServeSlicesCallback>> callback =
168         get_serve_slices_callback(callback_token);
169     OP_REQUIRES(context, callback.has_value(),
170                 tensorflow::errors::InvalidArgument(
171                     absl::StrCat("No `ServeSlices` callback found for token ",
172                                  callback_token.ToPrintableString())));
173     std::string served_at_id =
174         (**callback)(callback_token, std::move(server_val), max_key,
175                      std::move(select_fn_initialize_op),
176                      std::move(select_fn_server_val_input_tensor_names),
177                      std::move(select_fn_key_input_tensor_name),
178                      std::move(select_fn_filename_input_tensor_name),
179                      std::move(select_fn_target_tensor_name));
180 
181     tensorflow::Tensor* output_tensor;
182     OP_REQUIRES_OK(context, context->allocate_output(0, {}, &output_tensor));
183     output_tensor->scalar<tensorflow::tstring>()() = std::move(served_at_id);
184   }
185 };
186 
187 REGISTER_KERNEL_BUILDER(Name("ServeSlices").Device(tensorflow::DEVICE_CPU),
188                         ServeSlicesOp);
189 
190 }  // namespace
191 
192 }  // namespace fcp
193