xref: /aosp_15_r20/external/federated-compute/fcp/tensorflow/serve_slices_registry.h (revision 14675a029014e728ec732f129a32e299b2da0601)
1 /*
2  * Copyright 2021 Google LLC
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef FCP_TENSORFLOW_SERVE_SLICES_REGISTRY_H_
18 #define FCP_TENSORFLOW_SERVE_SLICES_REGISTRY_H_
19 
20 #include <functional>
21 #include <string>
22 #include <utility>
23 
24 #include "fcp/tensorflow/host_object.h"
25 
26 // Forward declare Tensor to avoid an explicit dependency on the TensorFlow
27 // framework. Dependencies of custom ops (which this target is) are not able to
28 // depend on the full TensorFlow framework.
29 namespace tensorflow {
30 
31 class Tensor;
32 
33 }  // namespace tensorflow
34 
35 namespace fcp {
36 
37 // A callback to invoke when the `ServeSlices` custom op is called.
38 //
39 // Callbacks are responsible for ensuring that the provided `server_val` is
40 // sliced up using the provided selection function (`select_fn`) and that the
41 // resulting slices are made available to clients.
42 //
43 // May be invoked from other threads by the TensorFlow runtime.
44 //
45 // Inputs:
46 //   callback_token: The random token associated with this callback by the
47 //     `HostObjectRegistration` returned by
48 //     `register_serve_slices_callback(...)`.
49 //   server_val: A series of arbitrary-typed tensors from which slices may be
50 //     generated using a selection function (referred to as `select_fn`).
51 //     These tensors must be passed into the `select_fn` by writing them to the
52 //     placeholder tensors named by `select_fn_server_val_input_names`, which
53 //     must contain exactly one tensor name for each tensor in `server_val`.
54 //   max_key: An integer indicating the maximum slice index which may be
55 //     requested. Slice indices start at zero and may go up to `max_key`
56 //     (inclusive).
57 //   select_fn_initialize_op: An op to run before each call to `select_fn` in
58 //     order to reinitialize any state `select_fn` may contain.
59 //   select_fn_server_val_input_tensor_names: A list of names of the tensors
60 //     that make up the `server_val` portion of the inputs to `select_fn`. Must
61 //     be the same length as the number of tensors in `server_val`.
62 //   select_fn_key_input_tensor_name: The name of the tensor that is the `key`
63 //     input to `select_fn`.
64 //   select_fn_filename_input_tensor_name: The name of the placeholder tensor
65 //     that is the `filename` input to `select_fn`. The `filename` is used to
66 //     specify where the resulting slice should be written.
67 //   select_fn_target_tensor_name: The name of the `target` tensor to run which
68 //     will result in `select_fn`'s output being written to `filename`.
69 //
70 // Outputs:
71 //   served_at_id: A string ID under which the resulting slices will be served.
72 //     This can then be provided to the `FetchSlicesOp` running on clients.
73 using ServeSlicesCallback = std::function<std::string(
74     /*callback_token=*/RandomToken,
75     /*server_val=*/std::vector<tensorflow::Tensor>,
76     /*max_key=*/int32_t,
77     /*select_fn_initialize_op=*/std::string,
78     /*select_fn_server_val_input_tensor_names=*/std::vector<std::string>,
79     /*select_fn_key_input_tensor_name=*/absl::string_view,
80     /*select_fn_filename_input_tensor_name=*/absl::string_view,
81     /*select_fn_target_tensor_name=*/absl::string_view)>;
82 
83 // Registers a callback to be invoked by the `ServeSlices` op.
84 //
85 // Inputs:
86 //   callback: The callback to register.
87 //
88 // Outputs:
89 //   A `HostObjectRegistration` value which owns the association of the callback
90 //   with the global callback registry. When this object is destroyed, the
91 //   callback will be unregistered. To refer to this callback in other methods,
92 //   use the `token()` method on this object.
register_serve_slices_callback(ServeSlicesCallback callback)93 inline HostObjectRegistration register_serve_slices_callback(
94     ServeSlicesCallback callback) {
95   return HostObjectRegistry<ServeSlicesCallback>::Register(
96       std::make_shared<ServeSlicesCallback>(std::move(callback)));
97 }
98 
99 // Returns the callback registered with the given `token` if one exists.
100 inline std::optional<std::shared_ptr<ServeSlicesCallback>>
get_serve_slices_callback(RandomToken token)101 get_serve_slices_callback(RandomToken token) {
102   return HostObjectRegistry<ServeSlicesCallback>::TryLookup(token);
103 }
104 
105 }  // namespace fcp
106 
107 #endif  // FCP_TENSORFLOW_SERVE_SLICES_REGISTRY_H_
108