1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/stream_executor/gpu/redzone_allocator.h"
17
18 #include <memory>
19
20 #include "absl/base/call_once.h"
21 #include "absl/container/fixed_array.h"
22 #include "absl/strings/str_format.h"
23 #include "absl/types/optional.h"
24 #include "tensorflow/compiler/xla/stream_executor/device_memory.h"
25 #include "tensorflow/compiler/xla/stream_executor/gpu/asm_compiler.h"
26 #include "tensorflow/compiler/xla/stream_executor/gpu/gpu_asm_opts.h"
27 #include "tensorflow/compiler/xla/stream_executor/kernel.h"
28 #include "tensorflow/compiler/xla/stream_executor/kernel_spec.h"
29 #include "tensorflow/compiler/xla/stream_executor/stream.h"
30 #include "tensorflow/compiler/xla/stream_executor/stream_executor_pimpl.h"
31 #include "tensorflow/core/framework/allocator.h"
32 #include "tensorflow/core/lib/core/errors.h"
33 #include "tensorflow/core/lib/core/status.h"
34
35 namespace stream_executor {
36
37 // Rounds the value up to a multiple of the divisor by first calling CeilOfRatio
38 // then multiplying by the divisor. For example: RoundUpToNearest(13, 8) => 16
39 template <typename T>
RoundUpToNearest(T value,T divisor)40 static T RoundUpToNearest(T value, T divisor) {
41 return tensorflow::MathUtil::CeilOfRatio(value, divisor) * divisor;
42 }
43
44 // The size of the redzone at the end of the user buffer is rounded up to a
45 // multiple of kRhsRedzoneAlign. This simplifies the implementation a bit.
46 constexpr int64_t kRhsRedzoneAlign = 4;
47
48 using RedzoneCheckStatus = RedzoneAllocator::RedzoneCheckStatus;
49
RedzoneAllocator(Stream * stream,DeviceMemoryAllocator * memory_allocator,GpuAsmOpts ptx_compilation_opts,int64_t memory_limit,int64_t redzone_size,uint8 redzone_pattern)50 RedzoneAllocator::RedzoneAllocator(Stream* stream,
51 DeviceMemoryAllocator* memory_allocator,
52 GpuAsmOpts ptx_compilation_opts,
53 int64_t memory_limit, int64_t redzone_size,
54 uint8 redzone_pattern)
55 : device_ordinal_(stream->parent()->device_ordinal()),
56 stream_(stream),
57 memory_limit_(memory_limit),
58 redzone_size_(RoundUpToNearest(
59 redzone_size,
60 static_cast<int64_t>(tensorflow::Allocator::kAllocatorAlignment))),
61 redzone_pattern_(redzone_pattern),
62 memory_allocator_(memory_allocator),
63 gpu_compilation_opts_(ptx_compilation_opts) {}
64
AllocateBytes(int64_t byte_size)65 port::StatusOr<DeviceMemory<uint8>> RedzoneAllocator::AllocateBytes(
66 int64_t byte_size) {
67 CHECK_GE(byte_size, 0) << "byte_size must be positive.";
68 if (byte_size > GetMemoryLimitInBytes()) {
69 return port::Status(
70 port::error::RESOURCE_EXHAUSTED,
71 absl::StrFormat(
72 "Allocating %d bytes exceeds the memory limit of %d bytes.",
73 byte_size, GetMemoryLimitInBytes()));
74 }
75
76 int64_t rhs_slop = RoundUpToNearest(byte_size, kRhsRedzoneAlign) - byte_size;
77 TF_ASSIGN_OR_RETURN(
78 OwningDeviceMemory allocated_buffer,
79 memory_allocator_->Allocate(device_ordinal_,
80 byte_size + 2 * redzone_size_ + rhs_slop,
81 /*retry_on_failure=*/false));
82 allocated_bytes_excluding_redzones_ += byte_size;
83
84 static_assert(sizeof(uint8) == 1, "Unexpected size");
85 DeviceMemory<uint8> allocated_buffer_memory(*allocated_buffer);
86
87 DeviceMemory<uint8> lhs_redzone = stream_->parent()->GetSubBuffer(
88 &allocated_buffer_memory, 0, redzone_size_);
89
90 DeviceMemory<uint8> data_chunk = stream_->parent()->GetSubBuffer(
91 &allocated_buffer_memory, redzone_size_, byte_size);
92
93 // Split up the RHS redzone into two pieces:
94 // - 0 to kRhsRedzoneAlign bytes adjacent to the user buffer, followed by
95 // - redzone_size_ bytes.
96 // We do this because Stream::ThenMemset32 requires the buffer address and
97 // size to be aligned to 4 bytes.
98 DeviceMemory<uint8> rhs_redzone_slop = stream_->parent()->GetSubBuffer(
99 &allocated_buffer_memory, redzone_size_ + byte_size, rhs_slop);
100
101 DeviceMemory<uint8> rhs_redzone_nonslop = stream_->parent()->GetSubBuffer(
102 &allocated_buffer_memory, redzone_size_ + byte_size + rhs_slop,
103 redzone_size_);
104
105 uint8 pattern_arr[] = {redzone_pattern_, redzone_pattern_, redzone_pattern_,
106 redzone_pattern_};
107 uint32 pattern32;
108 std::memcpy(&pattern32, pattern_arr, sizeof(pattern32));
109 stream_->ThenMemset32(&lhs_redzone, pattern32, redzone_size_);
110 if (rhs_slop != 0) {
111 stream_->ThenMemcpy(&rhs_redzone_slop, &pattern32, rhs_slop);
112 }
113 stream_->ThenMemset32(&rhs_redzone_nonslop, pattern32, redzone_size_);
114
115 allocated_buffers_.emplace_back(std::move(allocated_buffer), byte_size);
116 return data_chunk;
117 }
118
119 // PTX blob for the function which checks that every byte in
120 // input_buffer (length is buffer_length) is equal to redzone_pattern.
121 //
122 // On mismatch, increment the counter pointed to by out_mismatch_cnt_ptr.
123 //
124 // Generated from:
125 // __global__ void redzone_checker(unsigned char* input_buffer,
126 // unsigned char redzone_pattern,
127 // unsigned long long buffer_length,
128 // int* out_mismatched_ptr) {
129 // unsigned long long idx = threadIdx.x + blockIdx.x * blockDim.x;
130 // if (idx >= buffer_length) return;
131 // if (input_buffer[idx] != redzone_pattern) atomicAdd(out_mismatched_ptr, 1);
132 // }
133 //
134 // Code must compile for the oldest GPU XLA may be compiled for.
135 static const char* redzone_checker_ptx = R"(
136 .version 4.2
137 .target sm_30
138 .address_size 64
139
140 .visible .entry redzone_checker(
141 .param .u64 input_buffer,
142 .param .u8 redzone_pattern,
143 .param .u64 buffer_length,
144 .param .u64 out_mismatch_cnt_ptr
145 )
146 {
147 .reg .pred %p<3>;
148 .reg .b16 %rs<3>;
149 .reg .b32 %r<6>;
150 .reg .b64 %rd<8>;
151
152 ld.param.u64 %rd6, [buffer_length];
153 mov.u32 %r1, %tid.x;
154 mov.u32 %r2, %ctaid.x;
155 mov.u32 %r3, %ntid.x;
156 mad.lo.s32 %r4, %r3, %r2, %r1;
157 cvt.u64.u32 %rd3, %r4;
158 setp.ge.u64 %p1, %rd3, %rd6;
159 @%p1 bra LBB6_3;
160 ld.param.u8 %rs1, [redzone_pattern];
161 ld.param.u64 %rd4, [input_buffer];
162 cvta.to.global.u64 %rd2, %rd4;
163 add.s64 %rd7, %rd2, %rd3;
164 ld.global.u8 %rs2, [%rd7];
165 setp.eq.s16 %p2, %rs2, %rs1;
166 @%p2 bra LBB6_3;
167 ld.param.u64 %rd5, [out_mismatch_cnt_ptr];
168 cvta.to.global.u64 %rd1, %rd5;
169 atom.global.add.u32 %r5, [%rd1], 1;
170 LBB6_3:
171 ret;
172 }
173 )";
174
175 // The PTX in redzone_checker_ptx has to be launched with specified types
176 // in the specified order.
177 using ComparisonKernelT =
178 TypedKernel<DeviceMemory<uint8>, uint8, uint64_t, DeviceMemory<uint64_t>>;
179
180 // Check that redzones weren't overwritten on a host.
181 //
182 // Slower, but gives a more useful error message.
CheckRedzoneHost(DeviceMemoryBase redzone,DeviceMemoryBase user_allocation,absl::string_view name,Stream * stream,uint8 redzone_pattern)183 static port::StatusOr<RedzoneCheckStatus> CheckRedzoneHost(
184 DeviceMemoryBase redzone, DeviceMemoryBase user_allocation,
185 absl::string_view name, Stream* stream, uint8 redzone_pattern) {
186 uint64_t size = redzone.size();
187 auto redzone_data = std::make_unique<uint8[]>(size);
188 TF_RETURN_IF_ERROR(stream->ThenMemcpy(redzone_data.get(), redzone, size)
189 .BlockHostUntilDone());
190
191 std::array<uint8, sizeof(uint64_t)> pattern_arr;
192 pattern_arr.fill(redzone_pattern);
193 uint64_t pattern64;
194 std::memcpy(&pattern64, pattern_arr.data(), sizeof(uint64_t));
195
196 int64_t i;
197 for (i = 0; i + 7 < size; i += sizeof(uint64_t)) {
198 uint64_t rz_value = *reinterpret_cast<uint64_t*>(&redzone_data[i]);
199 if (rz_value != pattern64) {
200 return RedzoneCheckStatus(name, user_allocation.opaque(), i, pattern64,
201 rz_value);
202 }
203 }
204 for (; i < size; ++i) {
205 uint8 rz_value = redzone_data[i];
206 if (rz_value != redzone_pattern) {
207 return RedzoneCheckStatus(name, user_allocation.opaque(), i,
208 redzone_pattern, rz_value);
209 }
210 }
211 return RedzoneCheckStatus::OK();
212 }
213
214 // Run the redzone checker on the provided buffer redzone.
215 //
216 // Increment out_param if mismatch occurs.
RunRedzoneChecker(Stream * stream,const DeviceMemory<uint8> & redzone,uint8 redzone_pattern,const DeviceMemory<uint64_t> & out_param,const ComparisonKernelT & comparison_kernel)217 static port::Status RunRedzoneChecker(
218 Stream* stream, const DeviceMemory<uint8>& redzone, uint8 redzone_pattern,
219 const DeviceMemory<uint64_t>& out_param,
220 const ComparisonKernelT& comparison_kernel) {
221 StreamExecutor* executor = stream->parent();
222
223 int64_t num_elements = redzone.size();
224 int64_t threads_per_block = std::min(
225 executor->GetDeviceDescription().threads_per_block_limit(), num_elements);
226 int64_t block_count =
227 tensorflow::MathUtil::CeilOfRatio(num_elements, threads_per_block);
228
229 TF_RETURN_IF_ERROR(stream->ThenLaunch(
230 ThreadDim(threads_per_block), BlockDim(block_count), comparison_kernel,
231 redzone, redzone_pattern, redzone.size(), out_param));
232 return ::tensorflow::OkStatus();
233 }
234
235 // Since we reuse the same buffer for multiple checks, we re-initialize redzone
236 // with a NaN pattern after a failed check.
237 //
238 // This function is blocking, since redzone failing is a rare event.
ReinitializeRedzone(Stream * stream,DeviceMemoryBase redzone,uint8 redzone_pattern)239 static port::Status ReinitializeRedzone(Stream* stream,
240 DeviceMemoryBase redzone,
241 uint8 redzone_pattern) {
242 absl::FixedArray<uint8> redzone_array(redzone.size());
243 redzone_array.fill(redzone_pattern);
244 stream->ThenMemcpy(&redzone, redzone_array.data(), redzone.size());
245 TF_RETURN_IF_ERROR(stream->BlockHostUntilDone());
246 return ::tensorflow::OkStatus();
247 }
248
249 // Check redzones around the user allocation.
250 //
251 // Precondition: the memory pointed out by out_param is zeroed.
CheckRedzonesForBuffer(Stream * stream,DeviceMemoryBase memory,const DeviceMemory<uint64_t> & out_param,const ComparisonKernelT & comparison_kernel,int64_t user_allocation_size,uint64_t redzone_size,uint8 redzone_pattern)252 static port::StatusOr<RedzoneCheckStatus> CheckRedzonesForBuffer(
253 Stream* stream, DeviceMemoryBase memory,
254 const DeviceMemory<uint64_t>& out_param,
255 const ComparisonKernelT& comparison_kernel, int64_t user_allocation_size,
256 uint64_t redzone_size, uint8 redzone_pattern) {
257 StreamExecutor* executor = stream->parent();
258 int64_t rhs_slop =
259 RoundUpToNearest<int64_t>(user_allocation_size, kRhsRedzoneAlign) -
260 user_allocation_size;
261 CHECK_EQ(memory.size(), user_allocation_size + rhs_slop + 2 * redzone_size);
262
263 DeviceMemory<uint8> buffer_uint8(memory);
264 DeviceMemory<uint8> lhs_redzone =
265 executor->GetSubBuffer(&buffer_uint8, 0,
266 /*element_count=*/redzone_size);
267 DeviceMemory<uint8> user_allocation =
268 executor->GetSubBuffer(&buffer_uint8, redzone_size,
269 /*element_count=*/user_allocation_size);
270 DeviceMemory<uint8> rhs_redzone =
271 executor->GetSubBuffer(&buffer_uint8, redzone_size + user_allocation_size,
272 /*element_count=*/redzone_size + rhs_slop);
273
274 TF_RETURN_IF_ERROR(RunRedzoneChecker(stream, lhs_redzone, redzone_pattern,
275 out_param, comparison_kernel));
276 TF_RETURN_IF_ERROR(RunRedzoneChecker(stream, rhs_redzone, redzone_pattern,
277 out_param, comparison_kernel));
278 int64_t result;
279 CHECK_EQ(out_param.size(), sizeof(result));
280 stream->ThenMemcpy(&result, out_param, sizeof(result));
281 TF_RETURN_IF_ERROR(stream->BlockHostUntilDone());
282
283 if (result != 0) {
284 TF_ASSIGN_OR_RETURN(RedzoneCheckStatus lhs_check,
285 CheckRedzoneHost(lhs_redzone, user_allocation, "LHS",
286 stream, redzone_pattern));
287 TF_ASSIGN_OR_RETURN(RedzoneCheckStatus rhs_check,
288 CheckRedzoneHost(rhs_redzone, user_allocation, "RHS",
289 stream, redzone_pattern));
290
291 CHECK(!lhs_check.ok() || !rhs_check.ok())
292 << "Mismatched results with host and device comparison";
293
294 TF_RETURN_IF_ERROR(
295 ReinitializeRedzone(stream, lhs_redzone, redzone_pattern));
296 TF_RETURN_IF_ERROR(
297 ReinitializeRedzone(stream, rhs_redzone, redzone_pattern));
298 return !lhs_check.ok() ? lhs_check : rhs_check;
299 }
300
301 return RedzoneCheckStatus::OK();
302 }
303
CheckRedzones() const304 port::StatusOr<RedzoneCheckStatus> RedzoneAllocator::CheckRedzones() const {
305 StreamExecutor* executor = stream_->parent();
306
307 absl::Span<const uint8> compiled_ptx = {};
308 port::StatusOr<absl::Span<const uint8>> compiled_ptx_or =
309 CompileGpuAsmOrGetCached(executor->device_ordinal(), redzone_checker_ptx,
310 gpu_compilation_opts_);
311 if (compiled_ptx_or.ok()) {
312 compiled_ptx = compiled_ptx_or.ValueOrDie();
313 } else {
314 static absl::once_flag ptxas_not_found_logged;
315 absl::call_once(ptxas_not_found_logged, [&]() {
316 LOG(WARNING) << compiled_ptx_or.status().ToString()
317 << "\nRelying on driver to perform ptx compilation. "
318 << "\nModify $PATH to customize ptxas location."
319 << "\nThis message will be only logged once.";
320 });
321 }
322
323 ScopedDeviceMemory<uint64_t> out_param =
324 executor->AllocateOwnedScalar<uint64_t>();
325 stream_->ThenMemZero(out_param.ptr(), sizeof(uint64_t));
326
327 #if GOOGLE_CUDA
328 TF_ASSIGN_OR_RETURN(
329 std::shared_ptr<ComparisonKernelT> loaded_kernel,
330 (LoadKernelOrGetPtr<DeviceMemory<uint8>, uint8, uint64_t,
331 DeviceMemory<uint64_t>>(
332 executor, "redzone_checker", redzone_checker_ptx, compiled_ptx)));
333 #else
334 TF_ASSIGN_OR_RETURN(
335 std::unique_ptr<ComparisonKernelT> loaded_kernel,
336 (executor->CreateTypedKernel<DeviceMemory<uint8>, uint8, uint64_t,
337 DeviceMemory<uint64_t>>(
338 "redzone_checker", redzone_checker_ptx, compiled_ptx)));
339 #endif // GOOGLE_CUDA
340
341 for (const auto& buf_and_size : allocated_buffers_) {
342 TF_ASSIGN_OR_RETURN(
343 RedzoneCheckStatus redzone_status,
344 CheckRedzonesForBuffer(stream_, *buf_and_size.first, out_param.cref(),
345 *loaded_kernel, buf_and_size.second,
346 redzone_size_, redzone_pattern_));
347 if (!redzone_status.ok()) {
348 return redzone_status;
349 }
350 }
351
352 return RedzoneCheckStatus::OK();
353 }
354
RedzoneFailureMsg() const355 std::string RedzoneCheckStatus::RedzoneFailureMsg() const {
356 return absl::StrFormat(
357 "Redzone mismatch in %s redzone of buffer %p at offset %d; "
358 "expected %08x but was %08x.",
359 buffer_name, user_buffer_address, offset, expected_value, actual_value);
360 }
361
362 } // namespace stream_executor
363