1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/gpu/precompiled_kernels.h"
17
18 #include <string>
19 #include <utility>
20
21 #include "absl/base/call_once.h"
22 #include "absl/base/thread_annotations.h"
23 #include "absl/container/flat_hash_map.h"
24 #include "tensorflow/compiler/xla/statusor.h"
25 #include "tensorflow/compiler/xla/util.h"
26 #include "tensorflow/stream_executor/gpu/asm_compiler.h"
27
28 #if TENSORFLOW_USE_ROCM
29 #include "tensorflow/stream_executor/gpu/gpu_stream.h"
30 namespace stream_executor {
31 namespace gpu {
32
33 extern void rocm_MakeBatchPointers(void* stream, char* base, int stride, int n,
34 void** ptrs_out);
35
36 }
37 } // namespace stream_executor
38 #endif
39
40 namespace xla {
41 namespace gpu {
42 namespace {
43
44 // GPU kernel to populate an array of pointers:
45 //
46 // [base + stride * i for i in range(n)].
47 //
48 // Generated from the following CUDA code.
49 //
50 // extern "C" {
51 // __global__ void __xla_MakeBatchPointers(char* base, int stride,
52 // int n, void** ptrs_out) {
53 // int idx = threadIdx.x + blockIdx.x * blockDim.x;
54 // if (idx >= n) return;
55 // ptrs_out[idx] = base + idx * stride;
56 // }
57 // }
58 constexpr const char* kMakeBatchPointersPtx = R"(
59 .version 4.2
60 .target sm_35
61 .address_size 64
62
63 .visible .entry __xla_MakeBatchPointers(
64 .param .u64 __xla_MakeBatchPointers_param_0,
65 .param .u32 __xla_MakeBatchPointers_param_1,
66 .param .u32 __xla_MakeBatchPointers_param_2,
67 .param .u64 __xla_MakeBatchPointers_param_3
68 )
69 {
70 .reg .pred %p<2>;
71 .reg .b32 %r<8>;
72 .reg .b64 %rd<8>;
73
74 ld.param.u32 %r2, [__xla_MakeBatchPointers_param_2];
75 mov.u32 %r3, %tid.x;
76 mov.u32 %r4, %ctaid.x;
77 mov.u32 %r5, %ntid.x;
78 mad.lo.s32 %r6, %r4, %r5, %r3;
79 setp.ge.s32 %p1, %r6, %r2;
80 @%p1 bra LBB0_2;
81 ld.param.u64 %rd3, [__xla_MakeBatchPointers_param_0];
82 ld.param.u64 %rd4, [__xla_MakeBatchPointers_param_3];
83 cvta.to.global.u64 %rd5, %rd4;
84 ld.param.u32 %r1, [__xla_MakeBatchPointers_param_1];
85 mul.wide.s32 %rd6, %r6, 8;
86 add.s64 %rd1, %rd5, %rd6;
87 mul.lo.s32 %r7, %r6, %r1;
88 cvt.s64.s32 %rd7, %r7;
89 add.s64 %rd2, %rd3, %rd7;
90 st.global.u64 [%rd1], %rd2;
91 LBB0_2:
92 ret;
93 }
94 )";
95
96 // Lazily compiles ptx kernel, once per StreamExecutor.
97 //
98 // Thread-safe.
99 template <typename... KernelArgs>
100 class LazyKernel {
101 public:
LazyKernel(absl::string_view kernel_name,const char * ptx,const se::GpuAsmOpts & asm_opts)102 LazyKernel(absl::string_view kernel_name, const char* ptx,
103 const se::GpuAsmOpts& asm_opts)
104 : kernel_name_(kernel_name), ptx_(ptx), asm_opts_(asm_opts) {}
105
Get(se::StreamExecutor * stream_exec)106 StatusOr<se::TypedKernel<KernelArgs...>*> Get(
107 se::StreamExecutor* stream_exec) {
108 absl::MutexLock lock(&mu_);
109
110 auto result = kernels_.emplace(stream_exec, nullptr);
111 if (result.second) {
112 absl::Span<const uint8_t> compiled_ptx;
113 StatusOr<absl::Span<const uint8_t>> compiled_ptx_or =
114 se::CompileGpuAsmOrGetCached(stream_exec->device_ordinal(), ptx_,
115 asm_opts_);
116 if (compiled_ptx_or.ok()) {
117 compiled_ptx = std::move(compiled_ptx_or).value();
118 } else {
119 static absl::once_flag logged_once;
120 absl::call_once(logged_once, [&]() {
121 LOG(WARNING)
122 << compiled_ptx_or.status().ToString()
123 << "\nRelying on driver to perform ptx compilation. "
124 << "\nSetting XLA_FLAGS=--xla_gpu_cuda_data_dir=/path/to/cuda "
125 << " or modifying $PATH can be used to set the location of ptxas."
126 << "\nThis message will only be logged once.";
127 });
128 }
129
130 auto kernel = stream_exec->CreateTypedKernel<KernelArgs...>(
131 kernel_name_, ptx_, compiled_ptx);
132 if (kernel.ok()) {
133 result.first->second = *std::move(kernel);
134 } else {
135 kernels_.erase(result.first);
136 return kernel.status();
137 }
138 }
139 return result.first->second.get();
140 }
141
142 private:
143 std::string kernel_name_;
144 const char* ptx_;
145 se::GpuAsmOpts asm_opts_;
146
147 absl::Mutex mu_;
148
149 // A mutex keyed on StreamExecutor* is ok because StreamExecutors are never
150 // destroyed.
151 absl::flat_hash_map<se::StreamExecutor*,
152 std::unique_ptr<se::TypedKernel<KernelArgs...>>>
153 kernels_ ABSL_GUARDED_BY(mu_);
154 };
155
156 } // anonymous namespace
157
MakeBatchPointers(se::Stream * stream,const se::GpuAsmOpts & asm_opts,se::DeviceMemoryBase base_ptr,int stride_bytes,int n,se::DeviceMemoryBase ptrs_out)158 Status MakeBatchPointers(se::Stream* stream, const se::GpuAsmOpts& asm_opts,
159 se::DeviceMemoryBase base_ptr, int stride_bytes, int n,
160 se::DeviceMemoryBase ptrs_out) {
161 #if TENSORFLOW_USE_ROCM
162 stream_executor::gpu::rocm_MakeBatchPointers(
163 se::gpu::AsGpuStreamValue(stream),
164 reinterpret_cast<char*>(base_ptr.opaque()), stride_bytes, n,
165 reinterpret_cast<void**>(ptrs_out.opaque()));
166 #else
167 static auto* lazy_kernel =
168 new LazyKernel<se::DeviceMemoryBase /*base_ptr*/, int /*stride_bytes*/,
169 int /*n*/, se::DeviceMemoryBase /*ptrs_out*/>(
170 "__xla_MakeBatchPointers", kMakeBatchPointersPtx, asm_opts);
171
172 TF_ASSIGN_OR_RETURN(auto kernel, lazy_kernel->Get(stream->parent()));
173
174 constexpr int kThreads = 128;
175 TF_RETURN_IF_ERROR(
176 stream->ThenLaunch(se::ThreadDim(kThreads, 1, 1),
177 se::BlockDim(CeilOfRatio(n, kThreads), 1, 1), *kernel,
178 base_ptr, stride_bytes, n, ptrs_out));
179 #endif
180 return OkStatus();
181 }
182
183 } // namespace gpu
184 } // namespace xla
185