1 /*
2 * Copyright 2021 Google LLC
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include <string>
18 #include <utility>
19
20 #include "absl/strings/str_cat.h"
21 #include "absl/strings/str_format.h"
22 #include "fcp/tensorflow/serve_slices_registry.h"
23 #include "tensorflow/core/framework/common_shape_fns.h"
24 #include "tensorflow/core/framework/op.h"
25 #include "tensorflow/core/framework/op_kernel.h"
26 #include "tensorflow/core/platform/errors.h"
27 #include "tensorflow/core/platform/status.h"
28 #include "tensorflow/core/platform/stringpiece.h"
29
30 namespace fcp {
31
32 namespace {
33
34 REGISTER_OP("ServeSlices")
35 .Attr("NumTensorsInServerVal: int")
36 .Attr("ServerValType: list(type)")
37 .Input("callback_token: string")
38 .Input("server_val: ServerValType")
39 .Input("max_key: int32")
40 .Input("select_fn_initialize_op: string")
41 .Input(
42 "select_fn_server_val_input_tensor_names: NumTensorsInServerVal * "
43 "string")
44 .Input("select_fn_key_input_tensor_name: string")
45 .Input("select_fn_filename_input_tensor_name: string")
46 .Input("select_fn_target_tensor_name: string")
47 .Output("served_at_id: string")
48 .SetIsStateful()
49 .SetShapeFn(tensorflow::shape_inference::ScalarShape);
50
51 template <class T>
get_scalar_input(tensorflow::OpKernelContext * context,tensorflow::StringPiece name,T * scalar_out)52 tensorflow::Status get_scalar_input(tensorflow::OpKernelContext* context,
53 tensorflow::StringPiece name,
54 T* scalar_out) {
55 const tensorflow::Tensor* tensor;
56 TF_RETURN_IF_ERROR(context->input(name, &tensor));
57 *scalar_out = tensor->scalar<T>()();
58 return tensorflow::OkStatus();
59 }
60
get_arbitrary_input_list_as_tensor_vector(tensorflow::OpKernelContext * context,tensorflow::StringPiece name,std::vector<tensorflow::Tensor> * out)61 tensorflow::Status get_arbitrary_input_list_as_tensor_vector(
62 tensorflow::OpKernelContext* context, tensorflow::StringPiece name,
63 std::vector<tensorflow::Tensor>* out) {
64 tensorflow::OpInputList input_list;
65 TF_RETURN_IF_ERROR(context->input_list(name, &input_list));
66 out->reserve(input_list.size());
67 for (const tensorflow::Tensor& tensor : input_list) {
68 out->push_back(tensor);
69 }
70 return tensorflow::OkStatus();
71 }
72
get_string_list_input(tensorflow::OpKernelContext * context,tensorflow::StringPiece name,std::vector<std::string> * out)73 tensorflow::Status get_string_list_input(tensorflow::OpKernelContext* context,
74 tensorflow::StringPiece name,
75 std::vector<std::string>* out) {
76 tensorflow::OpInputList input_list;
77 TF_RETURN_IF_ERROR(context->input_list(name, &input_list));
78 out->reserve(input_list.size());
79 for (const tensorflow::Tensor& tensor : input_list) {
80 out->emplace_back(tensor.scalar<tensorflow::tstring>()());
81 }
82 return tensorflow::OkStatus();
83 }
84
85 // ServeSlices op-kernel.
86 //
87 // The ServeSlicesOp registers values present on a federated computation server
88 // to be sliced and served to clients for a `federated_select`
89 //
90 // Inputs:
91 // callback_token: The ID of the C++ callback to invoke in order to register
92 // the
93 // given value. Callbacks must first be registered using
94 // `register_serve_slices_callback`.
95 // server_val: A series of arbitrary-typed tensors from which slices may be
96 // generated using a selection function (referred to as `select_fn`).
97 // These tensors must be passed into the `select_fn` by writing them to the
98 // placeholder tensors named by `select_fn_server_val_input_names`, which
99 // must contain exactly one tensor name for each tensor in `server_val`.
100 // max_key: An integer indicating the maximum slice index which may be
101 // requested. Slice indices start at zero and may go up to `max_key`
102 // (inclusive).
103 // select_fn_initialize_op: An op to run before each call to `select_fn` in
104 // order to reinitialize any state `select_fn` may contain.
105 // select_fn_server_val_input_tensor_names: A list of names of the tensors
106 // that make up the `server_val` portion of the inputs to `select_fn`. Must
107 // be the same length as the number of tensors in `server_val`.
108 // select_fn_key_input_tensor_name: The name of the tensor that is the `key`
109 // input to `select_fn`.
110 // select_fn_filename_input_tensor_name: The name of the placeholder tensor
111 // that is the `filename` input to `select_fn`. The `filename` is used to
112 // specify where the resulting slice should be written.
113 // select_fn_target_tensor_name: The name of the `target` tensor to run which
114 // will result in `select_fn`'s output being written to `filename`.
115 //
116 // Outputs:
117 // served_at_id: A string ID under which the resulting slices will be served.
118 // This can then be provided to the `FetchSlicesOp` running on clients.
119 class ServeSlicesOp : public tensorflow::OpKernel {
120 public:
ServeSlicesOp(tensorflow::OpKernelConstruction * context)121 explicit ServeSlicesOp(tensorflow::OpKernelConstruction* context)
122 : OpKernel(context) {}
123
Compute(tensorflow::OpKernelContext * context)124 void Compute(tensorflow::OpKernelContext* context) override {
125 tensorflow::tstring callback_token_tensor;
126 OP_REQUIRES_OK(context, get_scalar_input(context, "callback_token",
127 &callback_token_tensor));
128 absl::Span<char const> callback_token_bytes = callback_token_tensor;
129 OP_REQUIRES(context, callback_token_bytes.size() == kRandomTokenSizeInBytes,
130 tensorflow::errors::InvalidArgument(absl::StrFormat(
131 "Tokens have a fixed size. Expected: %d; Actual %d",
132 kRandomTokenSizeInBytes, callback_token_bytes.size())));
133 RandomToken callback_token = RandomToken::FromBytes(callback_token_bytes);
134
135 std::vector<tensorflow::Tensor> server_val;
136 OP_REQUIRES_OK(context, get_arbitrary_input_list_as_tensor_vector(
137 context, "server_val", &server_val));
138
139 int32_t max_key;
140 OP_REQUIRES_OK(context, get_scalar_input(context, "max_key", &max_key));
141
142 tensorflow::tstring select_fn_initialize_op;
143 OP_REQUIRES_OK(context, get_scalar_input(context, "select_fn_initialize_op",
144 &select_fn_initialize_op));
145
146 std::vector<std::string> select_fn_server_val_input_tensor_names;
147 OP_REQUIRES_OK(context,
148 get_string_list_input(
149 context, "select_fn_server_val_input_tensor_names",
150 &select_fn_server_val_input_tensor_names));
151
152 tensorflow::tstring select_fn_key_input_tensor_name;
153 OP_REQUIRES_OK(context,
154 get_scalar_input(context, "select_fn_key_input_tensor_name",
155 &select_fn_key_input_tensor_name));
156
157 tensorflow::tstring select_fn_filename_input_tensor_name;
158 OP_REQUIRES_OK(context, get_scalar_input(
159 context, "select_fn_filename_input_tensor_name",
160 &select_fn_filename_input_tensor_name));
161
162 tensorflow::tstring select_fn_target_tensor_name;
163 OP_REQUIRES_OK(context,
164 get_scalar_input(context, "select_fn_target_tensor_name",
165 &select_fn_target_tensor_name));
166
167 std::optional<std::shared_ptr<ServeSlicesCallback>> callback =
168 get_serve_slices_callback(callback_token);
169 OP_REQUIRES(context, callback.has_value(),
170 tensorflow::errors::InvalidArgument(
171 absl::StrCat("No `ServeSlices` callback found for token ",
172 callback_token.ToPrintableString())));
173 std::string served_at_id =
174 (**callback)(callback_token, std::move(server_val), max_key,
175 std::move(select_fn_initialize_op),
176 std::move(select_fn_server_val_input_tensor_names),
177 std::move(select_fn_key_input_tensor_name),
178 std::move(select_fn_filename_input_tensor_name),
179 std::move(select_fn_target_tensor_name));
180
181 tensorflow::Tensor* output_tensor;
182 OP_REQUIRES_OK(context, context->allocate_output(0, {}, &output_tensor));
183 output_tensor->scalar<tensorflow::tstring>()() = std::move(served_at_id);
184 }
185 };
186
187 REGISTER_KERNEL_BUILDER(Name("ServeSlices").Device(tensorflow::DEVICE_CPU),
188 ServeSlicesOp);
189
190 } // namespace
191
192 } // namespace fcp
193