xref: /aosp_15_r20/external/tensorflow/tensorflow/python/lib/core/py_func.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/python/lib/core/py_func.h"
17 
18 #include <Python.h>
19 
20 // clang-format: off
21 // Must be included first.
22 #include "tensorflow/python/lib/core/numpy.h"
23 // clang-format: on
24 
25 #include <array>
26 
27 #include "numpy/arrayobject.h"
28 #include "tensorflow/c/eager/c_api.h"
29 #include "tensorflow/c/eager/tfe_context_internal.h"
30 #include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
31 #include "tensorflow/c/tf_status_helper.h"
32 #include "tensorflow/core/common_runtime/eager/context.h"
33 #include "tensorflow/core/common_runtime/eager/tensor_handle.h"
34 #include "tensorflow/core/framework/allocation_description.pb.h"
35 #include "tensorflow/core/framework/op_kernel.h"
36 #include "tensorflow/core/framework/tensor.h"
37 #include "tensorflow/core/lib/core/errors.h"
38 #include "tensorflow/core/lib/core/status.h"
39 #include "tensorflow/core/lib/core/threadpool.h"
40 #include "tensorflow/core/platform/macros.h"
41 #include "tensorflow/core/platform/mutex.h"
42 #include "tensorflow/core/platform/types.h"
43 #include "tensorflow/python/eager/pywrap_tfe.h"
44 #include "tensorflow/python/lib/core/ndarray_tensor.h"
45 #include "tensorflow/python/lib/core/ndarray_tensor_bridge.h"
46 #include "tensorflow/python/lib/core/py_util.h"
47 #include "tensorflow/python/lib/core/safe_ptr.h"
48 
49 namespace tensorflow {
50 namespace {
51 
52 static mutex mu(LINKER_INITIALIZED);
53 static PyObject* py_trampoline TF_GUARDED_BY(mu) = nullptr;
54 
55 // Returns the py_trampoline that is used to pass the control to the
56 // python runtime.
GetPyTrampoline()57 PyObject* GetPyTrampoline() {
58   mutex_lock l(mu);
59   return py_trampoline;
60 }
61 
62 // A call to the registered python function.
63 struct PyCall {
64   // Passed to python runtime to call the python function registered
65   // with this "token".
66   string token;
67 
68   // The device on which Tensors are stored; only used for EagerPyFunc.
69   Device* device = nullptr;
70 
71   // True if the call is associated with an EagerPyFunc.
72   bool eager = false;
73 
74   // True if the call is running under eager async mode.
75   bool eager_async = false;
76 
77   // Inputs and outputs of this function invocation.
78   std::vector<Tensor> ins;
79   std::vector<Tensor> out;
80 };
81 
IsCPUDevice(const Device * d)82 bool IsCPUDevice(const Device* d) {
83   return d == nullptr || d->tensorflow_accelerator_device_info() == nullptr;
84 }
85 
86 // Givens the 'call', prepares the token and inputs as a python tuple
87 // that is appropriate for calling the trampoline.
MakeArgTuple(const PyCall * call,TFE_Context * ctx,PyObject ** tuple)88 Status MakeArgTuple(const PyCall* call, TFE_Context* ctx, PyObject** tuple) {
89   int64_t n = call->ins.size();
90   PyObject* lst = PyList_New(n);
91   CHECK(lst);
92   // TFE_TensorHandle assumes that CPU is identified by nullptr.
93   //
94   // Set device name to be empty if the device is CPU.
95   const char* device_name = nullptr;
96 
97   if (call->device != nullptr && !IsCPUDevice(call->device))
98     device_name = call->device->name().c_str();
99 
100   for (int64_t i = 0; i < n; ++i) {
101     PyObject* arg = nullptr;
102     if (call->eager) {
103       Tensor t = call->ins[i];
104       arg = EagerTensorFromHandle(tensorflow::wrap(
105           tensorflow::unwrap(ctx)->CreateLocalHandleFromTFTensor(t,
106                                                                  device_name)));
107       if (arg == nullptr) {
108         Py_DECREF(lst);
109         return errors::Internal("Unable to procure EagerTensor from Tensor.");
110       }
111     } else {
112       Status s = TensorToNdarray(call->ins[i], &arg);
113       if (!s.ok()) {
114         Py_DECREF(lst);
115         return s;
116       }
117       arg = PyArray_Return(reinterpret_cast<PyArrayObject*>(arg));
118     }
119     PyList_SetItem(lst, i, arg);
120   }
121   *tuple = Py_BuildValue("(ssN)", call->token.c_str(), device_name, lst);
122   CHECK(*tuple);
123   return OkStatus();
124 }
125 
IsSingleNone(PyObject * obj)126 bool IsSingleNone(PyObject* obj) {
127   if (!PyArray_Check(obj)) {
128     return false;
129   }
130   PyArrayObject* array_obj = reinterpret_cast<PyArrayObject*>(obj);
131   if (PyArray_NDIM(array_obj) != 0 || PyArray_SIZE(array_obj) != 1) {
132     return false;
133   }
134   std::array<npy_intp, 0> indices;
135   char* item_ptr =
136       static_cast<char*>(PyArray_GetPtr(array_obj, indices.data()));
137   PyObject* item = PyArray_GETITEM(array_obj, item_ptr);
138   CHECK(item);
139   return item == Py_None;
140 }
141 
142 // Retrieves a Tensor from `eager_tensor` and stores it in `output_tensor`.
143 // Validates that `output_tensor` is backed by memory in `expected_device`
144 // (which is assumed to be a local device, one on which the kernel was
145 // executed.)
146 //
147 // It may be nice to copy the tensor to the right device instead of failing if
148 // it isn't already there. This is left as a future exercise.  The required
149 // device-copying logic is implemented in Python at the moment.
ExtractTensorFromEagerTensor(const PyObject * eager_tensor,TFE_Context * ctx,const Device * expected_device,const Tensor ** output_tensor)150 tensorflow::Status ExtractTensorFromEagerTensor(const PyObject* eager_tensor,
151                                                 TFE_Context* ctx,
152                                                 const Device* expected_device,
153                                                 const Tensor** output_tensor) {
154   tensorflow::TensorHandle* handle = down_cast<tensorflow::TensorHandle*>(
155       tensorflow::unwrap(ctx)->TFTensorHandleFromInterface(
156           tensorflow::unwrap(EagerTensor_Handle(eager_tensor))));
157 
158   Device* actual_device = handle->device();
159   TF_RETURN_IF_ERROR(handle->Tensor(output_tensor));
160   // actual_device may be nullptr, which implies local CPU.
161   if (expected_device == actual_device) return OkStatus();
162   const string& expected_device_name = expected_device->attributes().name();
163   if (actual_device == nullptr) {
164     if (!IsCPUDevice(expected_device)) {
165       return errors::Internal(
166           "Expected the py_func to return a Tensor backed by memory in ",
167           expected_device_name,
168           ", but is actually backed by local host memory. This is a bug.");
169     }
170     return OkStatus();
171   }
172   // NOTE(ebrevdo): Here we could try comparing "actual_device_name"
173   // (actual_device->attributes()->name()) to expected_device_name and ensure
174   // they're the same.  However, this comparison fails if we create a ClusterDef
175   // on localhost, mainly because the Device created by Eager code doesn't match
176   // the device created by a session.  In this case, expected_device_name may
177   // contain "worker" but the Eager device name contains "localhost".  Since we
178   // can't easily access the true underlying device of "worker" here, we are not
179   // able to perform a proper comparison.  Furthermore, we can't check
180   // IsCPUDevice(actual_device) because the kernel's device may indeed be a
181   // GPU device (the python interpreter doesn't use it, however).
182   return OkStatus();
183 }
184 
185 // Calls the registered py function through the trampoline.
DoCallPyFunc(PyCall * call,bool * out_log_on_error)186 Status DoCallPyFunc(PyCall* call, bool* out_log_on_error) {
187   *out_log_on_error = true;
188   PyObject* trampoline = GetPyTrampoline();
189   if (trampoline == nullptr) {
190     return errors::InvalidArgument(
191         "Missing py trampoline. Most likely, it is a link error.");
192   }
193 
194   // Prepare the argument.
195   PyObject* args = nullptr;
196   std::unique_ptr<EagerExecutor> new_executor = nullptr;
197   EagerExecutor* old_executor = nullptr;
198   if (call->eager) {
199     // See FuncRegistry._ctx.
200     TFE_Context* ctx = reinterpret_cast<TFE_Context*>(PyCapsule_GetPointer(
201         PyObject_GetAttrString(trampoline, "_ctx"), nullptr));
202     CHECK_NE(ctx, nullptr);
203     TF_RETURN_IF_ERROR(MakeArgTuple(call, ctx, &args));
204     new_executor.reset(new EagerExecutor(call->eager_async));
205     old_executor = &(tensorflow::unwrap(ctx)->Executor());
206     tensorflow::unwrap(ctx)->SetExecutorForThread(new_executor.get());
207   } else {
208     TF_RETURN_IF_ERROR(MakeArgTuple(call, nullptr, &args));
209   }
210   CHECK(args);
211 
212   // Invokes the trampoline.
213   PyObject* result = PyEval_CallObject(trampoline, args);
214   Py_DECREF(args);
215   Status s = OkStatus();
216   if (result == nullptr) {
217     if (PyErr_Occurred()) {
218       if (PyErr_ExceptionMatches(PyExc_ValueError) ||
219           PyErr_ExceptionMatches(PyExc_TypeError)) {
220         s = errors::InvalidArgument(PyExceptionFetch());
221       } else if (PyErr_ExceptionMatches(PyExc_StopIteration)) {
222         *out_log_on_error = false;
223         s = errors::OutOfRange(PyExceptionFetch());
224       } else if (PyErr_ExceptionMatches(PyExc_MemoryError)) {
225         s = errors::ResourceExhausted(PyExceptionFetch());
226       } else if (PyErr_ExceptionMatches(PyExc_NotImplementedError)) {
227         s = errors::Unimplemented(PyExceptionFetch());
228       } else {
229         // TODO(ebrevdo): Check if exception is an OpError and use the
230         // OpError.error_code property to map it back in the Status.
231         s = errors::Unknown(PyExceptionFetch());
232       }
233     } else {
234       s = errors::Internal("Failed to run py callback ", call->token,
235                            ": see error log.");
236     }
237   }
238 
239   TFE_Context* ctx = reinterpret_cast<TFE_Context*>(PyCapsule_GetPointer(
240       PyObject_GetAttrString(trampoline, "_ctx"), /*name=*/nullptr));
241   if (new_executor != nullptr) {
242     s.Update(new_executor->WaitForAllPendingNodes());
243     tensorflow::unwrap(ctx)->SetExecutorForThread(old_executor);
244   }
245 
246   TF_RETURN_IF_ERROR(s);
247 
248   // Process the return values and convert them to TF Tensors.
249   if (PyList_Check(result)) {
250     // `result` is a Python list; if this operation is an `EagerPyFunc`, then
251     // every item in the list must be an `EagerTensor`; otherwise, every element
252     // must be a NumPy array.
253     call->out.clear();
254     for (int i = 0; i < PyList_Size(result); ++i) {
255       Tensor t;
256       if (call->eager) {
257         const PyObject* item = PyList_GetItem(result, i);
258         if (EagerTensor_CheckExact(item)) {
259           const Tensor* tensor = nullptr;
260           s = ExtractTensorFromEagerTensor(item, ctx, call->device, &tensor);
261           if (s.ok()) t = *tensor;
262         } else {
263           s = errors::FailedPrecondition(
264               "Expected EagerTensor, found PyObject of type: ",
265               Py_TYPE(item)->tp_name);
266         }
267       } else {
268         s = NdarrayToTensor(PyList_GetItem(result, i), &t);
269       }
270 
271       if (!s.ok()) {
272         break;
273       }
274       call->out.push_back(t);
275     }
276   } else if (EagerTensor_CheckExact(result) || result == Py_None) {
277     // result is an `EagerTensor` or `None`.
278     DCHECK(call->eager);
279     if (result != Py_None) {
280       const Tensor* t = nullptr;
281       s = ExtractTensorFromEagerTensor(result, ctx, call->device, &t);
282       if (s.ok()) call->out.push_back(*t);
283     }
284   } else if (PyArray_Check(result)) {
285     // `result` is a NumPy array.
286     DCHECK(!call->eager);
287     if (!IsSingleNone(result)) {
288       Tensor t;
289       s = NdarrayToTensor(result, &t);
290       if (s.ok()) {
291         call->out.push_back(t);
292       }
293     }
294   } else {
295     s = errors::Internal("Unexpected PyObject was returned: ",
296                          Py_TYPE(result)->tp_name);
297   }
298   Py_DECREF(result);
299   return s;
300 }
301 
302 }  // end namespace
303 
InitializePyTrampoline(PyObject * trampoline)304 void InitializePyTrampoline(PyObject* trampoline) {
305   mutex_lock l(mu);
306   if (py_trampoline == nullptr) {
307     py_trampoline = trampoline;
308     Py_INCREF(py_trampoline);
309   } else {
310     LOG(WARNING) << "InitializeCallback should only be called once";
311   }
312 }
313 
314 class PyFuncOp : public OpKernel {
315  public:
PyFuncOp(OpKernelConstruction * ctx)316   explicit PyFuncOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
317     OP_REQUIRES_OK(ctx, ctx->GetAttr("token", &token_));
318     eager_ = type_string() == "EagerPyFunc";
319     if (eager_) {
320       OP_REQUIRES_OK(ctx, ctx->GetAttr("is_async", &eager_async_));
321     }
322   }
323 
IsExpensive()324   bool IsExpensive() override { return true; }
325 
Compute(OpKernelContext * ctx)326   void Compute(OpKernelContext* ctx) override {
327     PyCall call;
328     call.token = token_;
329     call.eager = eager_;
330     if (call.eager) {
331       // Eager's C API uses `Device`, whereas `OpKernelContext` stores a
332       // `DeviceBase`; attempt to downcast.
333       call.device = dynamic_cast<Device*>(ctx->device());
334       if (call.device == nullptr) {
335         ctx->CtxFailureWithWarning(errors::Internal(
336             "Unrecognized device class: ", ctx->device()->name()));
337         return;
338       }
339       call.eager_async = eager_async_;
340     }
341 
342     VLOG(1) << "PyFuncOp of token " << call.token << "is called.";
343 
344     for (int i = 0; i < ctx->num_inputs(); ++i) {
345       call.ins.push_back(ctx->input(i));
346     }
347 
348     // NOTE(mrry): There is a potential time-of-check-to-time-of-use race here.
349     // because it is possible that `Py_Finalize()` could be called in another
350     // thread between this check and the  call to `PyGILState_Ensure()`, which
351     // will abort the process if `Py_Finalize()` has been called. A more robust
352     // solution would be welcome, but it is not obvious how to make this work
353     // using the current Python C API.
354     OP_REQUIRES(ctx, Py_IsInitialized(),
355                 errors::FailedPrecondition(
356                     "Python interpreter state is not initialized. "
357                     "The process may be terminated."));
358 
359     PyGILState_STATE py_threadstate;
360     py_threadstate = PyGILState_Ensure();
361     bool log_on_error;
362     Status s = DoCallPyFunc(&call, &log_on_error);
363     // Sometimes py_funcs can be called without a session and leak memory. This
364     // ensures we clear the decref cache so this doesn't happen.
365     ClearDecrefCache();
366     PyGILState_Release(py_threadstate);
367 
368     // Ensures that GIL is released even when !s.ok().
369     if (!s.ok()) {
370       if (log_on_error) {
371         ctx->CtxFailureWithWarning(s);
372       } else {
373         ctx->CtxFailure(s);
374       }
375       return;
376     }
377 
378     OP_REQUIRES(ctx, static_cast<int32>(call.out.size()) == ctx->num_outputs(),
379                 errors::InvalidArgument(token_, " returns ", call.out.size(),
380                                         " values, but expects to see ",
381                                         ctx->num_outputs(), " values."));
382     for (size_t i = 0; i < call.out.size(); ++i) {
383       const auto& t = call.out[i];
384       OP_REQUIRES(
385           ctx, t.dtype() == output_type(i),
386           errors::InvalidArgument(i, "-th value returned by ", token_, " is ",
387                                   DataTypeString(t.dtype()), ", but expects ",
388                                   DataTypeString(output_type(i))));
389       ctx->set_output(i, t);
390     }
391   }
392 
393  private:
394   string token_;
395 
396   // True if and only if this op should execute the python function eagerly,
397   // i.e., if and only if the eager attribute is set.
398   bool eager_;
399 
400   bool eager_async_;
401 
402   TF_DISALLOW_COPY_AND_ASSIGN(PyFuncOp);
403 };
404 
405 REGISTER_KERNEL_BUILDER(Name("PyFunc").Device(DEVICE_CPU), PyFuncOp);
406 REGISTER_KERNEL_BUILDER(Name("PyFuncStateless").Device(DEVICE_CPU), PyFuncOp);
407 REGISTER_KERNEL_BUILDER(Name("EagerPyFunc").Device(DEVICE_CPU), PyFuncOp);
408 
409 DataType gpu_types[] = {
410     // No strings and int32s, no ref types and no resource/variant types.
411     DT_FLOAT,      DT_DOUBLE,   DT_UINT8,  DT_INT16,   DT_INT8,
412     DT_COMPLEX64,  DT_INT64,    DT_BOOL,   DT_QINT8,   DT_QUINT8,
413     DT_QINT32,     DT_BFLOAT16, DT_QINT16, DT_QUINT16, DT_UINT16,
414     DT_COMPLEX128, DT_HALF,     DT_UINT32, DT_UINT64,
415 };
416 
417 REGISTER_KERNEL_BUILDER(Name("EagerPyFunc")
418                             .Device(DEVICE_DEFAULT)
419                             .TypeConstraint("Tin", gpu_types)
420                             .TypeConstraint("Tout", gpu_types),
421                         PyFuncOp);
422 
423 }  // end namespace tensorflow
424