1 /*
2 * Copyright 2022 Google LLC
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "fcp/tensorflow/serve_slices_registry.h"
18
19 #include <pybind11/functional.h>
20 #include <pybind11/pybind11.h>
21
22 #include <functional>
23 #include <optional>
24 #include <string>
25 #include <utility>
26 #include <vector>
27
28 #include "absl/strings/string_view.h"
29 #include "fcp/base/random_token.h"
30 #include "fcp/tensorflow/host_object.h"
31 #include "pybind11_abseil/absl_casters.h"
32 #include "tensorflow/core/framework/tensor.h"
33 #include "tensorflow/core/framework/tensor.pb.h"
34
35 namespace pybind11::detail {
36
37 // Type caster converting a Tensor from C++ to Python.
38 template <>
39 struct type_caster<tensorflow::Tensor> {
40 PYBIND11_TYPE_CASTER(tensorflow::Tensor, const_name("Tensor"));
41
castpybind11::detail::type_caster42 static handle cast(const tensorflow::Tensor& tensor, return_value_policy,
43 handle) {
44 // We'd ideally use tensorflow::TensorToNdarray, but that function isn't
45 // available to code running in custom ops. Instead, we pass the Tensor
46 // as a serialized proto and convert to an ndarray in Python.
47 tensorflow::TensorProto proto;
48 if (tensor.dtype() == tensorflow::DT_STRING) {
49 // Strings encoded using AsProtoTensorContent are incompatible with
50 // tf.make_ndarray.
51 tensor.AsProtoField(&proto);
52 } else {
53 tensor.AsProtoTensorContent(&proto);
54 }
55 std::string serialized = proto.SerializeAsString();
56 return PyBytes_FromStringAndSize(serialized.data(), serialized.size());
57 }
58 };
59
60 } // namespace pybind11::detail
61
62 namespace {
63
64 namespace py = ::pybind11;
65
66 // A variant of fcp::ServeSlicesCallback with Python-friendly types.
67 using ServeSlicesCallback = std::function<std::string(
68 /*callback_token=*/py::bytes,
69 /*server_val=*/std::vector<tensorflow::Tensor>,
70 /*max_key=*/int32_t,
71 /*select_fn_initialize_op=*/std::string,
72 /*select_fn_server_val_input_tensor_names=*/std::vector<std::string>,
73 /*select_fn_key_input_tensor_name=*/absl::string_view,
74 /*select_fn_filename_input_tensor_name=*/absl::string_view,
75 /*select_fn_target_tensor_name=*/absl::string_view)>;
76
77 // A fcp::HostObjectRegistration wrapper allowing use as a context manager.
78 class ServeSlicesCallbackRegistration {
79 public:
ServeSlicesCallbackRegistration(ServeSlicesCallback callback)80 explicit ServeSlicesCallbackRegistration(ServeSlicesCallback callback)
81 : callback_(std::move(callback)) {}
82
enter()83 py::bytes enter() {
84 registration_ = fcp::register_serve_slices_callback(
85 [this](fcp::RandomToken callback_token,
86 std::vector<tensorflow::Tensor> server_val, int32_t max_key,
87 std::string select_fn_initialize_op,
88 std::vector<std::string> select_fn_server_val_input_tensor_names,
89 absl::string_view select_fn_key_input_tensor_name,
90 absl::string_view select_fn_filename_input_tensor_name,
91 absl::string_view select_fn_target_tensor_name) {
92 // The GIL isn't normally held in the context of ServeSlicesCallbacks,
93 // which are typically invoked from the ServeSlices TensorFlow op.
94 py::gil_scoped_acquire acquire;
95 return callback_(callback_token.ToString(), std::move(server_val),
96 max_key, std::move(select_fn_initialize_op),
97 std::move(select_fn_server_val_input_tensor_names),
98 select_fn_key_input_tensor_name,
99 select_fn_filename_input_tensor_name,
100 select_fn_target_tensor_name);
101 });
102 return registration_->token().ToString();
103 }
104
exit(py::object,py::object,py::object)105 void exit(py::object, py::object, py::object) { registration_.reset(); }
106
107 private:
108 ServeSlicesCallback callback_;
109 std::optional<fcp::HostObjectRegistration> registration_;
110 };
111
PYBIND11_MODULE(_serve_slices_op,m)112 PYBIND11_MODULE(_serve_slices_op, m) {
113 py::class_<ServeSlicesCallbackRegistration>(m,
114 "ServeSlicesCallbackRegistration")
115 .def("__enter__", &ServeSlicesCallbackRegistration::enter)
116 .def("__exit__", &ServeSlicesCallbackRegistration::exit);
117
118 m.def(
119 "register_serve_slices_callback",
120 [](ServeSlicesCallback callback) {
121 return ServeSlicesCallbackRegistration(std::move(callback));
122 },
123 py::return_value_policy::move);
124 }
125
126 } // namespace
127