xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/gpu/precompiled_kernels.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/gpu/precompiled_kernels.h"
17 
18 #include <string>
19 #include <utility>
20 
21 #include "absl/base/call_once.h"
22 #include "absl/base/thread_annotations.h"
23 #include "absl/container/flat_hash_map.h"
24 #include "tensorflow/compiler/xla/statusor.h"
25 #include "tensorflow/compiler/xla/util.h"
26 #include "tensorflow/stream_executor/gpu/asm_compiler.h"
27 
28 #if TENSORFLOW_USE_ROCM
29 #include "tensorflow/stream_executor/gpu/gpu_stream.h"
30 namespace stream_executor {
31 namespace gpu {
32 
33 extern void rocm_MakeBatchPointers(void* stream, char* base, int stride, int n,
34                                    void** ptrs_out);
35 
36 }
37 }  // namespace stream_executor
38 #endif
39 
40 namespace xla {
41 namespace gpu {
42 namespace {
43 
44 // GPU kernel to populate an array of pointers:
45 //
46 //   [base + stride * i for i in range(n)].
47 //
48 // Generated from the following CUDA code.
49 //
50 // extern "C" {
51 // __global__ void __xla_MakeBatchPointers(char* base, int stride,
52 //                                         int n, void** ptrs_out) {
53 //   int idx = threadIdx.x + blockIdx.x * blockDim.x;
54 //   if (idx >= n) return;
55 //   ptrs_out[idx] = base + idx * stride;
56 // }
57 // }
58 constexpr const char* kMakeBatchPointersPtx = R"(
59 .version 4.2
60 .target sm_35
61 .address_size 64
62 
63 .visible .entry __xla_MakeBatchPointers(
64         .param .u64 __xla_MakeBatchPointers_param_0,
65         .param .u32 __xla_MakeBatchPointers_param_1,
66         .param .u32 __xla_MakeBatchPointers_param_2,
67         .param .u64 __xla_MakeBatchPointers_param_3
68 )
69 {
70         .reg .pred      %p<2>;
71         .reg .b32       %r<8>;
72         .reg .b64       %rd<8>;
73 
74         ld.param.u32    %r2, [__xla_MakeBatchPointers_param_2];
75         mov.u32         %r3, %tid.x;
76         mov.u32         %r4, %ctaid.x;
77         mov.u32         %r5, %ntid.x;
78         mad.lo.s32      %r6, %r4, %r5, %r3;
79         setp.ge.s32     %p1, %r6, %r2;
80         @%p1 bra        LBB0_2;
81         ld.param.u64    %rd3, [__xla_MakeBatchPointers_param_0];
82         ld.param.u64    %rd4, [__xla_MakeBatchPointers_param_3];
83         cvta.to.global.u64      %rd5, %rd4;
84         ld.param.u32    %r1, [__xla_MakeBatchPointers_param_1];
85         mul.wide.s32    %rd6, %r6, 8;
86         add.s64         %rd1, %rd5, %rd6;
87         mul.lo.s32      %r7, %r6, %r1;
88         cvt.s64.s32     %rd7, %r7;
89         add.s64         %rd2, %rd3, %rd7;
90         st.global.u64   [%rd1], %rd2;
91 LBB0_2:
92         ret;
93 }
94 )";
95 
96 // Lazily compiles ptx kernel, once per StreamExecutor.
97 //
98 // Thread-safe.
99 template <typename... KernelArgs>
100 class LazyKernel {
101  public:
LazyKernel(absl::string_view kernel_name,const char * ptx,const se::GpuAsmOpts & asm_opts)102   LazyKernel(absl::string_view kernel_name, const char* ptx,
103              const se::GpuAsmOpts& asm_opts)
104       : kernel_name_(kernel_name), ptx_(ptx), asm_opts_(asm_opts) {}
105 
Get(se::StreamExecutor * stream_exec)106   StatusOr<se::TypedKernel<KernelArgs...>*> Get(
107       se::StreamExecutor* stream_exec) {
108     absl::MutexLock lock(&mu_);
109 
110     auto result = kernels_.emplace(stream_exec, nullptr);
111     if (result.second) {
112       absl::Span<const uint8_t> compiled_ptx;
113       StatusOr<absl::Span<const uint8_t>> compiled_ptx_or =
114           se::CompileGpuAsmOrGetCached(stream_exec->device_ordinal(), ptx_,
115                                        asm_opts_);
116       if (compiled_ptx_or.ok()) {
117         compiled_ptx = std::move(compiled_ptx_or).value();
118       } else {
119         static absl::once_flag logged_once;
120         absl::call_once(logged_once, [&]() {
121           LOG(WARNING)
122               << compiled_ptx_or.status().ToString()
123               << "\nRelying on driver to perform ptx compilation. "
124               << "\nSetting XLA_FLAGS=--xla_gpu_cuda_data_dir=/path/to/cuda "
125               << " or modifying $PATH can be used to set the location of ptxas."
126               << "\nThis message will only be logged once.";
127         });
128       }
129 
130       auto kernel = stream_exec->CreateTypedKernel<KernelArgs...>(
131           kernel_name_, ptx_, compiled_ptx);
132       if (kernel.ok()) {
133         result.first->second = *std::move(kernel);
134       } else {
135         kernels_.erase(result.first);
136         return kernel.status();
137       }
138     }
139     return result.first->second.get();
140   }
141 
142  private:
143   std::string kernel_name_;
144   const char* ptx_;
145   se::GpuAsmOpts asm_opts_;
146 
147   absl::Mutex mu_;
148 
149   // A mutex keyed on StreamExecutor* is ok because StreamExecutors are never
150   // destroyed.
151   absl::flat_hash_map<se::StreamExecutor*,
152                       std::unique_ptr<se::TypedKernel<KernelArgs...>>>
153       kernels_ ABSL_GUARDED_BY(mu_);
154 };
155 
156 }  // anonymous namespace
157 
MakeBatchPointers(se::Stream * stream,const se::GpuAsmOpts & asm_opts,se::DeviceMemoryBase base_ptr,int stride_bytes,int n,se::DeviceMemoryBase ptrs_out)158 Status MakeBatchPointers(se::Stream* stream, const se::GpuAsmOpts& asm_opts,
159                          se::DeviceMemoryBase base_ptr, int stride_bytes, int n,
160                          se::DeviceMemoryBase ptrs_out) {
161 #if TENSORFLOW_USE_ROCM
162   stream_executor::gpu::rocm_MakeBatchPointers(
163       se::gpu::AsGpuStreamValue(stream),
164       reinterpret_cast<char*>(base_ptr.opaque()), stride_bytes, n,
165       reinterpret_cast<void**>(ptrs_out.opaque()));
166 #else
167   static auto* lazy_kernel =
168       new LazyKernel<se::DeviceMemoryBase /*base_ptr*/, int /*stride_bytes*/,
169                      int /*n*/, se::DeviceMemoryBase /*ptrs_out*/>(
170           "__xla_MakeBatchPointers", kMakeBatchPointersPtx, asm_opts);
171 
172   TF_ASSIGN_OR_RETURN(auto kernel, lazy_kernel->Get(stream->parent()));
173 
174   constexpr int kThreads = 128;
175   TF_RETURN_IF_ERROR(
176       stream->ThenLaunch(se::ThreadDim(kThreads, 1, 1),
177                          se::BlockDim(CeilOfRatio(n, kThreads), 1, 1), *kernel,
178                          base_ptr, stride_bytes, n, ptrs_out));
179 #endif
180   return OkStatus();
181 }
182 
183 }  // namespace gpu
184 }  // namespace xla
185