xref: /aosp_15_r20/external/tensorflow/tensorflow/python/framework/python_api_dispatcher_wrapper.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 // Python bindings for tensorflow/python/framework/python_api_dispatcher.h.
16 
17 #include "pybind11/pybind11.h"
18 #include "pybind11/pytypes.h"
19 #include "pybind11/stl.h"
20 #include "tensorflow/python/framework/python_api_dispatcher.h"
21 #include "tensorflow/python/lib/core/safe_pyobject_ptr.h"
22 
23 namespace py = pybind11;
24 
25 using tensorflow::py_dispatch::PyInstanceChecker;
26 using tensorflow::py_dispatch::PyListChecker;
27 using tensorflow::py_dispatch::PySignatureChecker;
28 using tensorflow::py_dispatch::PythonAPIDispatcher;
29 using tensorflow::py_dispatch::PyTypeChecker;
30 using tensorflow::py_dispatch::PyUnionChecker;
31 
32 namespace {
33 
Dispatch(PythonAPIDispatcher * self,py::handle args,py::handle kwargs)34 py::object Dispatch(PythonAPIDispatcher* self, py::handle args,
35                     py::handle kwargs) {
36   auto result = self->Dispatch(args.ptr(), kwargs.ptr());
37   if (result == nullptr) {
38     throw py::error_already_set();
39   } else {
40     return py::reinterpret_steal<py::object>(result.release());
41   }
42 }
43 
MakePythonAPIDispatcher(const std::string & api_name,const std::vector<std::string> & arg_names,py::handle defaults)44 PythonAPIDispatcher MakePythonAPIDispatcher(
45     const std::string& api_name, const std::vector<std::string>& arg_names,
46     py::handle defaults) {
47   std::vector<const char*> name_strs;
48   name_strs.reserve(arg_names.size());
49   for (const auto& name : arg_names) {
50     name_strs.push_back(name.c_str());
51   }
52   absl::Span<const char*> arg_names_span(name_strs);
53   if (defaults.ptr() == Py_None) {
54     return PythonAPIDispatcher(api_name, arg_names_span, {});
55   } else {
56     tensorflow::Safe_PyObjectPtr fast_defaults(
57         PySequence_Fast(defaults.ptr(), "defaults is not a sequence"));
58     if (!fast_defaults) {
59       throw py::error_already_set();
60     }
61     return PythonAPIDispatcher(
62         api_name, arg_names_span,
63         absl::MakeSpan(PySequence_Fast_ITEMS(fast_defaults.get()),
64                        PySequence_Fast_GET_SIZE(fast_defaults.get())));
65   }
66 }
67 
68 }  // namespace
69 
PYBIND11_MODULE(_pywrap_python_api_dispatcher,m)70 PYBIND11_MODULE(_pywrap_python_api_dispatcher, m) {
71   py::enum_<PyTypeChecker::MatchType>(m, "MatchType")
72       .value("NO_MATCH", PyTypeChecker::MatchType::NO_MATCH)
73       .value("MATCH", PyTypeChecker::MatchType::MATCH)
74       .value("MATCH_DISPATCHABLE", PyTypeChecker::MatchType::MATCH_DISPATCHABLE)
75       .export_values();
76 
77   py::class_<PyTypeChecker, std::shared_ptr<PyTypeChecker>>(m, "PyTypeChecker")
78       .def("Check", [](PyTypeChecker* self,
79                        py::handle value) { return self->Check(value.ptr()); })
80       .def("cost", &PyTypeChecker::cost)
81       .def("cache_size",
82            [](PyTypeChecker* self) {
83              return static_cast<PyInstanceChecker*>(self)->cache_size();
84            })
85       .def("__repr__", [](PyTypeChecker* self) {
86         return absl::StrCat("<PyTypeChecker ", self->DebugString(), ">");
87       });
88 
89   py::class_<PySignatureChecker>(m, "PySignatureChecker")
90       .def(py::init<
91            std::vector<std::pair<int, std::shared_ptr<PyTypeChecker>>>>())
92       .def("CheckCanonicalizedArgs",
93            [](PySignatureChecker* self, py::tuple args) {
94              tensorflow::Safe_PyObjectPtr seq(PySequence_Fast(args.ptr(), ""));
95              PyObject** items = PySequence_Fast_ITEMS(seq.get());
96              int n = PySequence_Fast_GET_SIZE(seq.get());
97              return self->CheckCanonicalizedArgs(absl::MakeSpan(items, n));
98            })
99       .def("__repr__", [](PySignatureChecker* self) {
100         return absl::StrCat("<PySignatureChecker ", self->DebugString(), ">");
101       });
102 
103   py::class_<PythonAPIDispatcher>(m, "PythonAPIDispatcher")
104       .def(py::init(&MakePythonAPIDispatcher))
105       .def("Register",
106            [](PythonAPIDispatcher* self, PySignatureChecker signature_checker,
107               py::handle func) {
108              return self->Register(signature_checker, func.ptr());
109            })
110       .def("Dispatch", &Dispatch)
111       .def("Unregister",
112            [](PythonAPIDispatcher* self, py::handle func) {
113              return self->Unregister(func.ptr());
114            })
115       .def("__repr__", &PythonAPIDispatcher::DebugString);
116 
117   m.def("MakeInstanceChecker", [](py::args py_classes) {
118     std::vector<PyObject*> py_classes_vector;
119     py_classes_vector.reserve(py_classes.size());
120     for (auto& cls : py_classes) {
121       if (!PyType_Check(cls.ptr())) {
122         throw py::type_error("`*py_classes` must be a tuple of types.");
123       }
124       py_classes_vector.push_back(cls.ptr());
125     }
126     return std::shared_ptr<PyTypeChecker>(
127         std::make_shared<PyInstanceChecker>(py_classes_vector));
128   });
129   m.def("MakeListChecker", [](std::shared_ptr<PyTypeChecker> elt_type) {
130     return std::shared_ptr<PyTypeChecker>(
131         std::make_shared<PyListChecker>(elt_type));
132   });
133   m.def("MakeUnionChecker",
134         [](const std::vector<std::shared_ptr<PyTypeChecker>>& options) {
135           return std::shared_ptr<PyTypeChecker>(
136               std::make_shared<PyUnionChecker>(options));
137         });
138   m.def("register_dispatchable_type", [](py::handle py_class) {
139     if (!tensorflow::py_dispatch::RegisterDispatchableType(py_class.ptr())) {
140       throw py::error_already_set();
141     } else {
142       return py_class;
143     }
144   });
145 }
146