xref: /aosp_15_r20/external/tensorflow/tensorflow/python/util/function_parameter_canonicalizer.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/python/util/function_parameter_canonicalizer.h"
17 
18 #include "absl/container/flat_hash_set.h"
19 #include "tensorflow/core/platform/logging.h"
20 #include "tensorflow/core/platform/macros.h"
21 #include "tensorflow/python/lib/core/py_util.h"
22 #include "tensorflow/python/lib/core/safe_pyobject_ptr.h"
23 
24 namespace {
PyUnicodeAsUtf8Compat(PyObject * obj)25 inline const char* PyUnicodeAsUtf8Compat(PyObject* obj) {
26 #if PY_MAJOR_VERSION < 3
27   return PyString_AS_STRING(obj);
28 #else
29   return PyUnicode_AsUTF8(obj);
30 #endif
31 }
32 
PyUnicodeInternFromStringCompat(const char * str)33 inline PyObject* PyUnicodeInternFromStringCompat(const char* str) {
34 #if PY_MAJOR_VERSION < 3
35   return PyString_InternFromString(str);
36 #else
37   return PyUnicode_InternFromString(str);
38 #endif
39 }
40 
PyUnicodeInternInPlaceCompat(PyObject ** obj)41 inline void PyUnicodeInternInPlaceCompat(PyObject** obj) {
42 #if PY_MAJOR_VERSION < 3
43   PyString_InternInPlace(obj);
44 #else
45   PyUnicode_InternInPlace(obj);
46 #endif
47 }
48 
49 }  // namespace
50 
51 namespace tensorflow {
52 
FunctionParameterCanonicalizer(absl::Span<const char * > arg_names,absl::Span<PyObject * > defaults)53 FunctionParameterCanonicalizer::FunctionParameterCanonicalizer(
54     absl::Span<const char*> arg_names, absl::Span<PyObject*> defaults)
55     : positional_args_size_(arg_names.size() - defaults.size()) {
56   DCheckPyGilState();
57   DCHECK_GE(positional_args_size_, 0);
58 
59   interned_arg_names_.reserve(arg_names.size());
60   for (const char* obj : arg_names)
61     interned_arg_names_.emplace_back(PyUnicodeInternFromStringCompat(obj));
62 
63   DCHECK(AreInternedArgNamesUnique());
64 
65   for (PyObject* obj : defaults) Py_INCREF(obj);
66   defaults_ = std::vector<Safe_PyObjectPtr>(defaults.begin(), defaults.end());
67 }
68 
Canonicalize(PyObject * args,PyObject * kwargs,absl::Span<PyObject * > result)69 bool FunctionParameterCanonicalizer::Canonicalize(
70     PyObject* args, PyObject* kwargs, absl::Span<PyObject*> result) {
71   // TODO(kkb): Closely follow `Python/ceval.c`'s logic and error handling.
72 
73   DCheckPyGilState();
74   DCHECK(PyTuple_CheckExact(args));
75   DCHECK(kwargs == nullptr || PyDict_CheckExact(kwargs));
76   DCHECK_EQ(result.size(), interned_arg_names_.size());
77 
78   const int args_size = Py_SIZE(args);
79   int remaining_positional_args_count = positional_args_size_ - args_size;
80 
81   // Check if the number of input arguments are too many.
82   if (TF_PREDICT_FALSE(args_size > interned_arg_names_.size())) {
83     PyErr_SetString(
84         PyExc_TypeError,
85         absl::StrCat("Too many arguments were given. Expected ",
86                      interned_arg_names_.size(), " but got ", args_size, ".")
87             .c_str());
88     return false;
89   }
90 
91   // Fill positional arguments.
92   for (int i = 0; i < args_size; ++i) result[i] = PyTuple_GET_ITEM(args, i);
93 
94   // Fill default arguments.
95   for (int i = std::max(positional_args_size_, args_size);
96        i < interned_arg_names_.size(); ++i)
97     result[i] = defaults_[i - positional_args_size_].get();
98 
99   // Fill keyword arguments.
100   if (kwargs != nullptr) {
101     PyObject *key, *value;
102     Py_ssize_t pos = 0;
103     while (PyDict_Next(kwargs, &pos, &key, &value)) {
104       std::size_t index = InternedArgNameLinearSearch(key);
105 
106       // Check if key object(argument name) was found in the pre-built intern
107       // string table.
108       if (TF_PREDICT_FALSE(index == interned_arg_names_.size())) {
109         // `key` might not be an interend string, so get the interned string
110         // and try again.  Note: we need to call INCREF before we use
111         // InternInPlace, to prevent the key in the dictionary from being
112         // prematurely deleted in the case where InternInPlace switches `key`
113         // to point at a new object.  We call DECREF(key) once we're done
114         // (which might decref the original key *or* the interned version).
115         Py_INCREF(key);
116         PyUnicodeInternInPlaceCompat(&key);
117         index = InternedArgNameLinearSearch(key);
118         Py_DECREF(key);
119 
120         // Stil not found, then return an error.
121         if (TF_PREDICT_FALSE(index == interned_arg_names_.size())) {
122           PyErr_Format(PyExc_TypeError,
123                        "Got an unexpected keyword argument '%s'",
124                        PyUnicodeAsUtf8Compat(key));
125           return false;
126         }
127       }
128 
129       // Check if the keyword argument overlaps with positional arguments.
130       if (TF_PREDICT_FALSE(index < args_size)) {
131         PyErr_Format(PyExc_TypeError, "Got multiple values for argument '%s'",
132                      PyUnicodeAsUtf8Compat(key));
133         return false;
134       }
135 
136       if (TF_PREDICT_FALSE(index < positional_args_size_))
137         --remaining_positional_args_count;
138 
139       result[index] = value;
140     }
141   }
142 
143   // Check if all the arguments are filled.
144   // Example failure, not enough number of arguments passed: `matmul(x)`
145   if (TF_PREDICT_FALSE(remaining_positional_args_count > 0)) {
146     // TODO(kkb): Report what arguments are missing.
147     PyErr_SetString(PyExc_TypeError, "Missing required positional argument");
148     return false;
149   }
150 
151   return true;
152 }
153 
154 ABSL_MUST_USE_RESULT
155 ABSL_ATTRIBUTE_HOT
InternedArgNameLinearSearch(PyObject * name)156 inline std::size_t FunctionParameterCanonicalizer::InternedArgNameLinearSearch(
157     PyObject* name) {
158   std::size_t result = interned_arg_names_.size();
159 
160   for (std::size_t i = 0; i < interned_arg_names_.size(); ++i)
161     if (TF_PREDICT_FALSE(name == interned_arg_names_[i].get())) return i;
162 
163   return result;
164 }
165 
AreInternedArgNamesUnique()166 bool FunctionParameterCanonicalizer::AreInternedArgNamesUnique() {
167   absl::flat_hash_set<PyObject*> interned_arg_names_set;
168   for (const Safe_PyObjectPtr& obj : interned_arg_names_)
169     interned_arg_names_set.emplace(obj.get());
170 
171   return interned_arg_names_set.size() == interned_arg_names_.size();
172 }
173 }  // namespace tensorflow
174