xref: /aosp_15_r20/external/federated-compute/fcp/tensorflow/python/serve_slices_registry.cc (revision 14675a029014e728ec732f129a32e299b2da0601)
1 /*
2  * Copyright 2022 Google LLC
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "fcp/tensorflow/serve_slices_registry.h"
18 
19 #include <pybind11/functional.h>
20 #include <pybind11/pybind11.h>
21 
22 #include <functional>
23 #include <optional>
24 #include <string>
25 #include <utility>
26 #include <vector>
27 
28 #include "absl/strings/string_view.h"
29 #include "fcp/base/random_token.h"
30 #include "fcp/tensorflow/host_object.h"
31 #include "pybind11_abseil/absl_casters.h"
32 #include "tensorflow/core/framework/tensor.h"
33 #include "tensorflow/core/framework/tensor.pb.h"
34 
35 namespace pybind11::detail {
36 
37 // Type caster converting a Tensor from C++ to Python.
38 template <>
39 struct type_caster<tensorflow::Tensor> {
40   PYBIND11_TYPE_CASTER(tensorflow::Tensor, const_name("Tensor"));
41 
castpybind11::detail::type_caster42   static handle cast(const tensorflow::Tensor& tensor, return_value_policy,
43                      handle) {
44     // We'd ideally use tensorflow::TensorToNdarray, but that function isn't
45     // available to code running in custom ops. Instead, we pass the Tensor
46     // as a serialized proto and convert to an ndarray in Python.
47     tensorflow::TensorProto proto;
48     if (tensor.dtype() == tensorflow::DT_STRING) {
49       // Strings encoded using AsProtoTensorContent are incompatible with
50       // tf.make_ndarray.
51       tensor.AsProtoField(&proto);
52     } else {
53       tensor.AsProtoTensorContent(&proto);
54     }
55     std::string serialized = proto.SerializeAsString();
56     return PyBytes_FromStringAndSize(serialized.data(), serialized.size());
57   }
58 };
59 
60 }  // namespace pybind11::detail
61 
62 namespace {
63 
64 namespace py = ::pybind11;
65 
66 // A variant of fcp::ServeSlicesCallback with Python-friendly types.
67 using ServeSlicesCallback = std::function<std::string(
68     /*callback_token=*/py::bytes,
69     /*server_val=*/std::vector<tensorflow::Tensor>,
70     /*max_key=*/int32_t,
71     /*select_fn_initialize_op=*/std::string,
72     /*select_fn_server_val_input_tensor_names=*/std::vector<std::string>,
73     /*select_fn_key_input_tensor_name=*/absl::string_view,
74     /*select_fn_filename_input_tensor_name=*/absl::string_view,
75     /*select_fn_target_tensor_name=*/absl::string_view)>;
76 
77 // A fcp::HostObjectRegistration wrapper allowing use as a context manager.
78 class ServeSlicesCallbackRegistration {
79  public:
ServeSlicesCallbackRegistration(ServeSlicesCallback callback)80   explicit ServeSlicesCallbackRegistration(ServeSlicesCallback callback)
81       : callback_(std::move(callback)) {}
82 
enter()83   py::bytes enter() {
84     registration_ = fcp::register_serve_slices_callback(
85         [this](fcp::RandomToken callback_token,
86                std::vector<tensorflow::Tensor> server_val, int32_t max_key,
87                std::string select_fn_initialize_op,
88                std::vector<std::string> select_fn_server_val_input_tensor_names,
89                absl::string_view select_fn_key_input_tensor_name,
90                absl::string_view select_fn_filename_input_tensor_name,
91                absl::string_view select_fn_target_tensor_name) {
92           // The GIL isn't normally held in the context of ServeSlicesCallbacks,
93           // which are typically invoked from the ServeSlices TensorFlow op.
94           py::gil_scoped_acquire acquire;
95           return callback_(callback_token.ToString(), std::move(server_val),
96                            max_key, std::move(select_fn_initialize_op),
97                            std::move(select_fn_server_val_input_tensor_names),
98                            select_fn_key_input_tensor_name,
99                            select_fn_filename_input_tensor_name,
100                            select_fn_target_tensor_name);
101         });
102     return registration_->token().ToString();
103   }
104 
exit(py::object,py::object,py::object)105   void exit(py::object, py::object, py::object) { registration_.reset(); }
106 
107  private:
108   ServeSlicesCallback callback_;
109   std::optional<fcp::HostObjectRegistration> registration_;
110 };
111 
PYBIND11_MODULE(_serve_slices_op,m)112 PYBIND11_MODULE(_serve_slices_op, m) {
113   py::class_<ServeSlicesCallbackRegistration>(m,
114                                               "ServeSlicesCallbackRegistration")
115       .def("__enter__", &ServeSlicesCallbackRegistration::enter)
116       .def("__exit__", &ServeSlicesCallbackRegistration::exit);
117 
118   m.def(
119       "register_serve_slices_callback",
120       [](ServeSlicesCallback callback) {
121         return ServeSlicesCallbackRegistration(std::move(callback));
122       },
123       py::return_value_policy::move);
124 }
125 
126 }  // namespace
127