xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/python/py_client_gpu.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "absl/base/casts.h"
17 #include "absl/strings/numbers.h"
18 #if TENSORFLOW_USE_ROCM
19 #include "rocm/include/hip/hip_runtime.h"
20 #else
21 #include "third_party/gpus/cuda/include/cuda.h"
22 #include "third_party/gpus/cuda/include/cuda_runtime_api.h"
23 #endif
24 #include "pybind11/pybind11.h"
25 #include "tensorflow/compiler/xla/python/callback.h"
26 #include "tensorflow/compiler/xla/python/exceptions.h"
27 
28 #if TENSORFLOW_USE_ROCM
29 #define gpuStreamHandle hipStream_t
30 #define gpuMemcpyAsync hipMemcpyAsync
31 #define gpuStreamSynchronize hipStreamSynchronize
32 #define gpuMemcpyDeviceToHost hipMemcpyDeviceToHost
33 #define gpuMemcpyHostToDevice hipMemcpyHostToDevice
34 #else
35 #define gpuStreamHandle CUstream
36 #define gpuMemcpyAsync cudaMemcpyAsync
37 #define gpuStreamSynchronize cudaStreamSynchronize
38 #define gpuMemcpyDeviceToHost cudaMemcpyDeviceToHost
39 #define gpuMemcpyHostToDevice cudaMemcpyHostToDevice
40 #endif
41 
42 namespace py = pybind11;
43 
44 namespace xla {
45 
XlaPythonGpuCallback(gpuStreamHandle stream,void ** buffers,const char * opaque,size_t opaque_len,XlaCustomCallStatus * status)46 void XlaPythonGpuCallback(gpuStreamHandle stream, void** buffers,
47                           const char* opaque, size_t opaque_len,
48                           XlaCustomCallStatus* status) {
49   // Ignore `descriptor` arg to callback
50   buffers += 1;
51   uint64_t descriptor;
52   if (!absl::SimpleAtoi(opaque, &descriptor)) {
53     throw xla::XlaRuntimeError("Invalid callback descriptor");
54     return;
55   }
56   CpuCallback* callback =
57       absl::bit_cast<CpuCallback*>(static_cast<uintptr_t>(descriptor));
58   size_t arity = callback->num_args();
59   std::vector<void*> host_input_buffers(arity);
60   // Copy input GPU buffers to host
61   for (size_t i = 0; i < arity; ++i) {
62     CpuCallback::Arg arg = callback->args()[i];
63     if (arg.type == TOKEN) {
64       host_input_buffers[i] = nullptr;
65       continue;
66     }
67     void* buf = new char[arg.size_in_bytes];
68     host_input_buffers[i] = buf;
69     // TODO(b/238441608): Use pinned memory here to speed up the transfer.
70     gpuMemcpyAsync(buf, buffers[i], arg.size_in_bytes, gpuMemcpyDeviceToHost,
71                    stream);
72   }
73   gpuStreamSynchronize(stream);
74   py::gil_scoped_acquire gil;
75   py::tuple host_input_arrays(arity);
76   for (size_t i = 0; i < arity; ++i) {
77     CpuCallback::Arg arg = callback->args()[i];
78     if (arg.type == TOKEN) {
79       host_input_arrays[i] = py::none();
80       continue;
81     }
82     py::capsule base(host_input_buffers[i],
83                      [](void* ptr) { delete[] static_cast<char*>(ptr); });
84     host_input_arrays[i] =
85         py::array(arg.dtype, arg.dims, arg.strides,
86                   const_cast<void*>(host_input_buffers[i]), /*base=*/base);
87     host_input_arrays[i].attr("flags").attr("writeable") = Py_False;
88   }
89   std::optional<py::tuple> maybe_result_tuple =
90       callback->Call(host_input_arrays, status);
91   if (!maybe_result_tuple) {
92     return;
93   }
94   py::tuple result_tuple = maybe_result_tuple.value();
95   std::vector<void*> temp_buffers;
96   for (size_t i = 0; i < callback->results().size(); ++i) {
97     CpuCallback::Result result = callback->results()[i];
98     if (result.type == TOKEN) {
99       continue;
100     }
101     py::object output = py::reinterpret_borrow<py::object>(
102         PyTuple_GetItem(result_tuple.ptr(), i));
103     py::array array = py::cast<py::array>(std::move(output));
104     absl::Span<int64_t const> dims(
105         reinterpret_cast<const int64_t*>(array.shape()), array.ndim());
106     absl::Span<int64_t const> strides(
107         reinterpret_cast<const int64_t*>(array.strides()), array.ndim());
108     if (strides == result.expected_strides) {
109       gpuMemcpyAsync(buffers[arity + i], array.data(), result.size_in_bytes,
110                      gpuMemcpyHostToDevice, stream);
111     } else {
112       void* temp = new char[result.size_in_bytes];
113       temp_buffers.push_back(temp);
114       xla::StatusOr<std::shared_ptr<xla::TransposePlan>> plan =
115           callback->transpose_cache().GetOrCreate(
116               xla::primitive_util::ByteWidth(result.type), dims,
117               result.reversed_layout,
118               /*input_layout=*/xla::TransposePlan::Striding{strides});
119       if (!plan.ok()) {
120         throw xla::XlaRuntimeError(plan.status().ToString());
121       }
122       plan.ValueOrDie()->Execute(array.data(), temp);
123       gpuMemcpyAsync(buffers[arity + i], temp, result.size_in_bytes,
124                      gpuMemcpyHostToDevice, stream);
125     }
126   }
127   py::gil_scoped_release release;
128   gpuStreamSynchronize(stream);
129   for (int i = 0; i < temp_buffers.size(); ++i) {
130     delete[] static_cast<char*>(temp_buffers[i]);
131   }
132 }
133 
134 }  // namespace xla
135