1 /*
2 * Copyright 2021 Google LLC
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #ifndef FCP_TENSORFLOW_SERVE_SLICES_REGISTRY_H_
18 #define FCP_TENSORFLOW_SERVE_SLICES_REGISTRY_H_
19
20 #include <functional>
21 #include <string>
22 #include <utility>
23
24 #include "fcp/tensorflow/host_object.h"
25
26 // Forward declare Tensor to avoid an explicit dependency on the TensorFlow
27 // framework. Dependencies of custom ops (which this target is) are not able to
28 // depend on the full TensorFlow framework.
29 namespace tensorflow {
30
31 class Tensor;
32
33 } // namespace tensorflow
34
35 namespace fcp {
36
37 // A callback to invoke when the `ServeSlices` custom op is called.
38 //
39 // Callbacks are responsible for ensuring that the provided `server_val` is
40 // sliced up using the provided selection function (`select_fn`) and that the
41 // resulting slices are made available to clients.
42 //
43 // May be invoked from other threads by the TensorFlow runtime.
44 //
45 // Inputs:
46 // callback_token: The random token associated with this callback by the
47 // `HostObjectRegistration` returned by
48 // `register_serve_slices_callback(...)`.
49 // server_val: A series of arbitrary-typed tensors from which slices may be
50 // generated using a selection function (referred to as `select_fn`).
51 // These tensors must be passed into the `select_fn` by writing them to the
52 // placeholder tensors named by `select_fn_server_val_input_names`, which
53 // must contain exactly one tensor name for each tensor in `server_val`.
54 // max_key: An integer indicating the maximum slice index which may be
55 // requested. Slice indices start at zero and may go up to `max_key`
56 // (inclusive).
57 // select_fn_initialize_op: An op to run before each call to `select_fn` in
58 // order to reinitialize any state `select_fn` may contain.
59 // select_fn_server_val_input_tensor_names: A list of names of the tensors
60 // that make up the `server_val` portion of the inputs to `select_fn`. Must
61 // be the same length as the number of tensors in `server_val`.
62 // select_fn_key_input_tensor_name: The name of the tensor that is the `key`
63 // input to `select_fn`.
64 // select_fn_filename_input_tensor_name: The name of the placeholder tensor
65 // that is the `filename` input to `select_fn`. The `filename` is used to
66 // specify where the resulting slice should be written.
67 // select_fn_target_tensor_name: The name of the `target` tensor to run which
68 // will result in `select_fn`'s output being written to `filename`.
69 //
70 // Outputs:
71 // served_at_id: A string ID under which the resulting slices will be served.
72 // This can then be provided to the `FetchSlicesOp` running on clients.
73 using ServeSlicesCallback = std::function<std::string(
74 /*callback_token=*/RandomToken,
75 /*server_val=*/std::vector<tensorflow::Tensor>,
76 /*max_key=*/int32_t,
77 /*select_fn_initialize_op=*/std::string,
78 /*select_fn_server_val_input_tensor_names=*/std::vector<std::string>,
79 /*select_fn_key_input_tensor_name=*/absl::string_view,
80 /*select_fn_filename_input_tensor_name=*/absl::string_view,
81 /*select_fn_target_tensor_name=*/absl::string_view)>;
82
83 // Registers a callback to be invoked by the `ServeSlices` op.
84 //
85 // Inputs:
86 // callback: The callback to register.
87 //
88 // Outputs:
89 // A `HostObjectRegistration` value which owns the association of the callback
90 // with the global callback registry. When this object is destroyed, the
91 // callback will be unregistered. To refer to this callback in other methods,
92 // use the `token()` method on this object.
register_serve_slices_callback(ServeSlicesCallback callback)93 inline HostObjectRegistration register_serve_slices_callback(
94 ServeSlicesCallback callback) {
95 return HostObjectRegistry<ServeSlicesCallback>::Register(
96 std::make_shared<ServeSlicesCallback>(std::move(callback)));
97 }
98
99 // Returns the callback registered with the given `token` if one exists.
100 inline std::optional<std::shared_ptr<ServeSlicesCallback>>
get_serve_slices_callback(RandomToken token)101 get_serve_slices_callback(RandomToken token) {
102 return HostObjectRegistry<ServeSlicesCallback>::TryLookup(token);
103 }
104
105 } // namespace fcp
106
107 #endif // FCP_TENSORFLOW_SERVE_SLICES_REGISTRY_H_
108