1 // Copyright 2022 The TensorFlow Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "tensorflow/compiler/xla/service/gpu/jitrt_custom_calls.h"
16
17 #include <cstdint>
18 #include <functional>
19 #include <iterator>
20 #include <memory>
21 #include <numeric>
22 #include <utility>
23
24 #include "llvm/ExecutionEngine/Orc/Mangling.h"
25 #include "mlir/Support/LogicalResult.h" // from @llvm-project
26 #include "tensorflow/compiler/xla/runtime/arguments.h"
27 #include "tensorflow/compiler/xla/runtime/custom_call.h"
28 #include "tensorflow/compiler/xla/runtime/executable.h"
29 #include "tensorflow/compiler/xla/runtime/jit_executable.h"
30 #include "tensorflow/compiler/xla/runtime/type_id.h"
31 #include "tensorflow/compiler/xla/runtime/types.h"
32 #include "tensorflow/compiler/xla/service/custom_call_status_internal.h"
33 #include "tensorflow/compiler/xla/service/custom_call_target_registry.h"
34 #include "tensorflow/compiler/xla/service/gpu/fft_thunk.h"
35 #include "tensorflow/compiler/xla/service/gpu/gpu_asm_opts_util.h"
36 #include "tensorflow/compiler/xla/service/gpu/gpu_conv_runner.h"
37 #include "tensorflow/compiler/xla/service/gpu/infeed_manager.h"
38 #include "tensorflow/compiler/xla/service/gpu/matmul_utils.h"
39 #include "tensorflow/compiler/xla/service/gpu/nccl_all_gather_thunk.h"
40 #include "tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.h"
41 #include "tensorflow/compiler/xla/service/gpu/nccl_all_to_all_thunk.h"
42 #include "tensorflow/compiler/xla/service/gpu/nccl_collective_permute_thunk.h"
43 #include "tensorflow/compiler/xla/service/gpu/nccl_collective_thunk.h"
44 #include "tensorflow/compiler/xla/service/gpu/outfeed_manager.h"
45 #include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h"
46 #include "tensorflow/compiler/xla/service/service_executable_run_options.h"
47 #include "tensorflow/compiler/xla/shape_util.h"
48 #include "tensorflow/compiler/xla/tfrt_utils.h"
49 #include "tensorflow/core/platform/human_readable_json.h"
50 #include "tensorflow/stream_executor/gpu/gpu_stream.h"
51 #include "tensorflow/stream_executor/gpu/gpu_types.h"
52 #include "tfrt/dtype/dtype.h" // from @tf_runtime
53
54 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
55 #include "tensorflow/compiler/xla/service/gpu/cholesky_thunk.h"
56 #include "tensorflow/compiler/xla/service/gpu/triangular_solve_thunk.h"
57 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
58
59 TFRT_DEFINE_EXPLICIT_DENSE_TYPE_ID(xla::runtime::CustomCall,
60 xla::gpu::JitRtKernelsCache);
61 TFRT_DEFINE_EXPLICIT_DENSE_TYPE_ID(xla::runtime::CustomCall,
62 xla::gpu::JitRtGemmConfigCache);
63 TFRT_DEFINE_EXPLICIT_DENSE_TYPE_ID(xla::runtime::CustomCall,
64 xla::gpu::JitRtCollectiveSupport);
65 TFRT_DEFINE_EXPLICIT_DENSE_TYPE_ID(xla::runtime::CustomCall,
66 xla::gpu::JitRtAsyncCollectiveSupport);
67 TFRT_DEFINE_EXPLICIT_DENSE_TYPE_ID(xla::runtime::CustomCall,
68 const xla::ServiceExecutableRunOptions);
69 TFRT_DEFINE_EXPLICIT_DENSE_TYPE_ID(xla::runtime::CustomCall,
70 const xla::DebugOptions);
71
72 namespace xla {
73 namespace gpu {
74
75 using Eigen::half;
76
77 using llvm::ArrayRef;
78 using llvm::Error;
79 using llvm::Optional;
80
81 using mlir::failure;
82 using mlir::FailureOr;
83 using mlir::LogicalResult;
84 using mlir::StringRef;
85 using mlir::succeeded;
86 using mlir::success;
87
88 using tfrt::MakeStringError;
89
90 using ::xla::runtime::AggregateAttrDef;
91 using ::xla::runtime::AggregateAttrEncoding;
92 using ::xla::runtime::CustomCall;
93 using ::xla::runtime::CustomCallAttrEncodingSet;
94 using ::xla::runtime::DirectCustomCallLibrary;
95 using ::xla::runtime::EnumAttrEncoding;
96 using ::xla::runtime::Executable;
97 using ::xla::runtime::Tagged;
98 using ::xla::runtime::TypeIDNameRegistry;
99
100 namespace se = ::stream_executor;
101 namespace lmhlo_gpu = ::mlir::lmhlo_gpu;
102 namespace mhlo = ::mlir::mhlo;
103
104 // Disable all CustomCall checks in optimized build.
RuntimeChecks()105 static constexpr CustomCall::RuntimeChecks RuntimeChecks() {
106 #if defined(NDEBUG)
107 return CustomCall::RuntimeChecks::kNone;
108 #else
109 return CustomCall::RuntimeChecks::kDefault;
110 #endif
111 }
112
113 // -------------------------------------------------------------------------- //
114
115 // Populate mapping from XLA (SE) enums/structs type id to symbol names.
PopulateXlaTypeIdNames(TypeIDNameRegistry & registry)116 void PopulateXlaTypeIdNames(TypeIDNameRegistry& registry) {
117 registry.Register<Tagged<se::dnn::ActivationMode>>(
118 "__type_id_se_dnn_activation");
119 registry.Register<Tagged<se::cuda::BlasLt::Epilogue>>(
120 "__type_id_se_cublas_lt_epilogue");
121 registry.Register<Tagged<se::fft::Type>>("__type_id_se_fft_type");
122
123 registry.Register<Tagged<DotDimensionNumbers>>(
124 "__type_id_dot_dimension_numbers");
125 registry.Register<Tagged<ConvDimensionNumbers>>(
126 "__type_id_conv_dimension_numbers");
127 registry.Register<Tagged<ConvBackendConfig>>("__type_id_conv_backend_config");
128 }
129
130 // Add custom call arguments and attributes encoding for custom HLO enums and
131 // structs, so that we can pass them to custom calls.
PopulateLmhloToXlaAttrEncoding(CustomCallAttrEncodingSet & encoding)132 void PopulateLmhloToXlaAttrEncoding(CustomCallAttrEncodingSet& encoding) {
133 encoding
134 .Add<EnumAttrEncoding<lmhlo_gpu::ActivationAttr, lmhlo_gpu::Activation,
135 se::dnn::ActivationMode>>(
136 [](lmhlo_gpu::Activation value) -> se::dnn::ActivationMode {
137 return ConvertConvActivationMode(value).value();
138 });
139
140 encoding.Add<EnumAttrEncoding<lmhlo_gpu::CublasLtMatmulEpilogueAttr,
141 lmhlo_gpu::CublasLtMatmulEpilogue,
142 se::cuda::BlasLt::Epilogue>>(
143 [](lmhlo_gpu::CublasLtMatmulEpilogue value)
144 -> se::cuda::BlasLt::Epilogue {
145 return cublas_lt::AsBlasLtEpilogue(value).value();
146 });
147
148 encoding
149 .Add<EnumAttrEncoding<mhlo::FftTypeAttr, mhlo::FftType, se::fft::Type>>(
150 [](mhlo::FftType value) -> se::fft::Type {
151 switch (value) {
152 case mhlo::FftType::FFT:
153 return se::fft::Type::kC2CForward;
154 case mhlo::FftType::IFFT:
155 return se::fft::Type::kC2CInverse;
156 case mhlo::FftType::RFFT:
157 return se::fft::Type::kR2C;
158 case mhlo::FftType::IRFFT:
159 return se::fft::Type::kC2R;
160 default:
161 return se::fft::Type::kInvalid;
162 }
163 });
164
165 using DotDimsAttr = mhlo::DotDimensionNumbersAttr;
166 encoding.Add<
167 xla::runtime::AggregateAttrEncoding<DotDimsAttr, DotDimensionNumbers>>(
168 encoding,
169 xla::runtime::AggregateAttrDef<DotDimsAttr>()
170 .Add("lhs_batch", &DotDimsAttr::getLhsBatchingDimensions)
171 .Add("lhs_contract", &DotDimsAttr::getLhsContractingDimensions)
172 .Add("rhs_batch", &DotDimsAttr::getRhsBatchingDimensions)
173 .Add("rhs_contract", &DotDimsAttr::getRhsContractingDimensions));
174
175 using ConvDimsAttr = mhlo::ConvDimensionNumbersAttr;
176 encoding.Add<
177 xla::runtime::AggregateAttrEncoding<ConvDimsAttr, ConvDimensionNumbers>>(
178 encoding,
179 xla::runtime::AggregateAttrDef<ConvDimsAttr>()
180 .Add("input_batch_dim", &ConvDimsAttr::getInputBatchDimension)
181 .Add("input_feature_dim", &ConvDimsAttr::getInputFeatureDimension)
182 .Add("input_spatial_dims", &ConvDimsAttr::getInputSpatialDimensions)
183 .Add("kernel_in_feature_dim",
184 &ConvDimsAttr::getKernelInputFeatureDimension)
185 .Add("kernel_out_feature_dim",
186 &ConvDimsAttr::getKernelOutputFeatureDimension)
187 .Add("kernel_spatial_dims", &ConvDimsAttr::getKernelSpatialDimensions)
188 .Add("output_batch_dim", &ConvDimsAttr::getOutputBatchDimension)
189 .Add("output_feature_dim", &ConvDimsAttr::getOutputFeatureDimension)
190 .Add("output_spatial_dims",
191 &ConvDimsAttr::getOutputSpatialDimensions));
192
193 using ConvConfigAttr = lmhlo_gpu::ConvolutionBackendConfigAttr;
194 encoding.Add<
195 xla::runtime::AggregateAttrEncoding<ConvConfigAttr, ConvBackendConfig>>(
196 encoding,
197 xla::runtime::AggregateAttrDef<ConvConfigAttr>()
198 .Add("algorithm", &ConvConfigAttr::getAlgorithm)
199 .Add("tensor_ops_enabled", &ConvConfigAttr::getTensorOpsEnabled)
200 .Add("is_cudnn_frontend", &ConvConfigAttr::getIsCudnnFrontend)
201 .Add("knob_ids", &ConvConfigAttr::getKnobIds)
202 .Add("knob_values", &ConvConfigAttr::getKnobValues)
203 .Add("operand_0_layout", &ConvConfigAttr::getOperand_0Layout)
204 .Add("operand_1_layout", &ConvConfigAttr::getOperand_1Layout)
205 .Add("result_layout", &ConvConfigAttr::getResultLayout)
206 .Add("workspace_size", &ConvConfigAttr::getWorkspaceSize));
207 }
208
209 // -------------------------------------------------------------------------- //
210
Get(se::StreamExecutor * executor,const char * data,StringRef name)211 se::KernelBase* JitRtKernelsCache::Get(se::StreamExecutor* executor,
212 const char* data, StringRef name) {
213 Key key(executor, data, name);
214
215 absl::MutexLock lock(&mutex_);
216 auto it = kernels_cache_.find(key);
217 if (it != kernels_cache_.end()) return it->second.get();
218
219 return nullptr;
220 }
221
Set(se::StreamExecutor * executor,const char * data,StringRef name,std::unique_ptr<se::KernelBase> kernel)222 se::KernelBase* JitRtKernelsCache::Set(se::StreamExecutor* executor,
223 const char* data, StringRef name,
224 std::unique_ptr<se::KernelBase> kernel) {
225 Key key(executor, data, name);
226
227 absl::MutexLock lock(&mutex_);
228 auto it = kernels_cache_.find(key);
229 if (it != kernels_cache_.end()) return it->second.get();
230
231 auto emplaced = kernels_cache_.try_emplace(key, std::move(kernel));
232 return emplaced.first->second.get();
233 }
234
235 template <typename MemrefArg>
GetDeviceAddress(MemrefArg & memref)236 static se::DeviceMemoryBase GetDeviceAddress(MemrefArg& memref) {
237 uint64_t size = tfrt::GetHostSize(memref.dtype);
238 for (auto dim : memref.sizes) size *= dim;
239 return se::DeviceMemoryBase(memref.data, size);
240 }
241
GetDeviceAddress(runtime::FlatMemrefView & memref)242 static se::DeviceMemoryBase GetDeviceAddress(runtime::FlatMemrefView& memref) {
243 return se::DeviceMemoryBase(memref.data, memref.size_in_bytes);
244 }
245
246 // -------------------------------------------------------------------------- //
247
Get(int64_t uid)248 const GemmConfig* JitRtGemmConfigCache::Get(int64_t uid) {
249 absl::MutexLock lock(&mutex_);
250 auto it = configs_.find(uid);
251 if (it != configs_.end()) return &it->second;
252 return nullptr;
253 }
254
Set(int64_t uid,GemmConfig config)255 const GemmConfig* JitRtGemmConfigCache::Set(int64_t uid, GemmConfig config) {
256 absl::MutexLock lock(&mutex_);
257 auto it = configs_.find(uid);
258 if (it != configs_.end()) return &it->second;
259
260 auto emplaced = configs_.try_emplace(uid, std::move(config));
261 return &emplaced.first->second;
262 }
263
264 // -------------------------------------------------------------------------- //
265
JitRtAsyncCollectiveSupport(se::Stream * async_comm_stream)266 JitRtAsyncCollectiveSupport::JitRtAsyncCollectiveSupport(
267 se::Stream* async_comm_stream)
268 : async_comm_stream_(async_comm_stream) {}
269
MaybeBlockAfterFirstRun(int32_t uid,int32_t device_ordinal,se::Stream * stream)270 Status JitRtCollectiveSupport::MaybeBlockAfterFirstRun(int32_t uid,
271 int32_t device_ordinal,
272 se::Stream* stream) {
273 bool block = [&] {
274 absl::MutexLock lock(&mutex_);
275 return executed_.try_emplace(Key(uid, device_ordinal), true).second;
276 }();
277 return block ? stream->BlockHostUntilDone() : Status::OK();
278 }
279
PopEvent(int32_t uid,int32_t device_ordinal)280 FailureOr<se::Event> JitRtAsyncCollectiveSupport::PopEvent(
281 int32_t uid, int32_t device_ordinal) {
282 const int64_t key = EventKey(uid, device_ordinal);
283
284 absl::MutexLock lock(&mutex_);
285 auto it = done_events_.find(key);
286 if (it == done_events_.end()) return failure();
287
288 se::Event done_event = std::move(it->second);
289 done_events_.erase(it);
290 return done_event;
291 }
292
PushEvent(int32_t uid,int32_t device_ordinal,se::Event done_event)293 LogicalResult JitRtAsyncCollectiveSupport::PushEvent(int32_t uid,
294 int32_t device_ordinal,
295 se::Event done_event) {
296 const int64_t key = EventKey(uid, device_ordinal);
297
298 absl::MutexLock lock(&mutex_);
299 auto result = done_events_.try_emplace(key, std::move(done_event));
300 if (!result.second) return failure(); // done event has not been consumed
301
302 return success();
303 }
304
305 // -------------------------------------------------------------------------- //
306
ToShape(const runtime::StridedMemrefView & memref)307 static Shape ToShape(const runtime::StridedMemrefView& memref) {
308 PrimitiveType type = TfrtToPrimitiveType(memref.dtype);
309
310 // Recover `minor_to_major` dimensions permutation from strides.
311 auto indexed_strides_range =
312 llvm::map_range(llvm::enumerate(memref.strides), [](auto pair) {
313 return std::pair<int64_t, size_t>{pair.value(), pair.index()};
314 });
315
316 auto indexed_strides = llvm::to_vector(indexed_strides_range);
317 llvm::stable_sort(indexed_strides);
318
319 llvm::SmallVector<int64_t> minor_to_major;
320 minor_to_major.reserve(indexed_strides.size());
321 for (auto& pair : indexed_strides) minor_to_major.push_back(pair.second);
322
323 return ShapeUtil::MakeShapeWithLayout(type, memref.sizes, minor_to_major);
324 }
325
GetGemmConfig(const runtime::StridedMemrefView & lhs,const runtime::StridedMemrefView & rhs,const runtime::StridedMemrefView & out,int64_t algorithm,double alpha_real,double alpha_imag,double beta,ArrayRef<int64_t> lhs_batch,ArrayRef<int64_t> lhs_contract,ArrayRef<int64_t> rhs_batch,ArrayRef<int64_t> rhs_contract)326 static StatusOr<GemmConfig> GetGemmConfig(const runtime::StridedMemrefView& lhs,
327 const runtime::StridedMemrefView& rhs,
328 const runtime::StridedMemrefView& out,
329 int64_t algorithm, double alpha_real,
330 double alpha_imag, double beta,
331 ArrayRef<int64_t> lhs_batch,
332 ArrayRef<int64_t> lhs_contract,
333 ArrayRef<int64_t> rhs_batch,
334 ArrayRef<int64_t> rhs_contract) {
335 return GemmConfig::For(ToShape(lhs), lhs_batch, lhs_contract, ToShape(rhs),
336 rhs_batch, rhs_contract, ToShape(out), alpha_real,
337 alpha_imag, beta, algorithm,
338 se::blas::kDefaultComputePrecision);
339 }
340
341 // -------------------------------------------------------------------------- //
342
343 #if XLA_ENABLE_XCCL
GetNcclComm(const NcclExecuteParams & params,int64_t group_mode,int64_t op_id,ArrayRef<int64_t> replica_group_offsets,ArrayRef<int64_t> replica_group_values)344 FailureOr<NcclComm::Lock> GetNcclComm(const NcclExecuteParams& params,
345 int64_t group_mode, int64_t op_id,
346 ArrayRef<int64_t> replica_group_offsets,
347 ArrayRef<int64_t> replica_group_values) {
348 // TODO(b/233930690): Pass the attribute below as a nested array.
349 // Pass an array of arrays using two vectors; one specifying all the values
350 // and another specifying the (ending) offsets of each array in the other
351 // vector. Example: [ [10, 20, 30, 40], [50, 60], [70, 80, 90] ] turns into
352 // offsets=[4, 6, 9] values=[10, 20, 30, 40, 50, 60, 70, 80, 90].
353 std::vector<ReplicaGroup> replica_groups;
354 int i = 0;
355 for (int64_t replica_group_end : replica_group_offsets) {
356 ReplicaGroup replica_group;
357 while (i < replica_group_end)
358 replica_group.add_replica_ids(replica_group_values[i++]);
359 replica_groups.push_back(replica_group);
360 }
361
362 auto comm =
363 LockNcclComm(params, replica_groups,
364 static_cast<CollectiveOpGroupMode>(group_mode), op_id);
365 if (comm.ok()) return std::move(comm.value());
366 return failure();
367 }
368 #endif // XLA_ENABLE_XCCL
369
GetDeviceBufferPairs(CustomCall::RemainingArgs & args)370 FailureOr<std::vector<DeviceBufferPair>> GetDeviceBufferPairs(
371 CustomCall::RemainingArgs& args) {
372 // Add MemRef arguments as buffer arguments.
373 const int buffer_pairs = args.size() / 2;
374 std::vector<DeviceBufferPair> device_buffers;
375 device_buffers.reserve(buffer_pairs);
376 for (int i = 0; i < buffer_pairs; ++i) {
377 auto source = args.get<runtime::StridedMemrefView>(i);
378 auto destination = args.get<runtime::StridedMemrefView>(i + buffer_pairs);
379 if (failed(source) || failed(destination)) {
380 // Unsupported argument type.
381 return failure();
382 }
383
384 int element_count = 1;
385 for (int size : source->sizes) element_count *= size;
386 device_buffers.emplace_back(DeviceBufferPair{
387 TfrtToPrimitiveType(source->dtype), element_count,
388 GetDeviceAddress(*source), GetDeviceAddress(*destination)});
389 }
390 return device_buffers;
391 }
392
393 // -------------------------------------------------------------------------- //
394
AsError(Status s)395 Error AsError(Status s) { return MakeStringError(s.error_message()); }
396
397 template <typename T>
AsError(StatusOr<T> & s)398 Error AsError(StatusOr<T>& s) {
399 assert(!s.ok());
400 return AsError(s.status());
401 }
402
403 // -------------------------------------------------------------------------- //
404
405 namespace {
406 struct LaunchFunc {
407 LLVM_ATTRIBUTE_ALWAYS_INLINE
408 Error operator()(const ServiceExecutableRunOptions* run_options,
409 JitRtKernelsCache* kernels_cache, int32_t grid_size_x,
410 int32_t grid_size_y, int32_t grid_size_z,
411 int32_t block_size_x, int32_t block_size_y,
412 int32_t block_size_z, CustomCall::RemainingArgs args,
413 StringRef ptx, StringRef name) const;
414
Handlerxla::gpu::__anon8ea4ed0d0611::LaunchFunc415 static LaunchFunc Handler() { return LaunchFunc(); }
416 };
417 } // namespace
418
operator ()(const ServiceExecutableRunOptions * run_options,JitRtKernelsCache * kernels_cache,int32_t grid_size_x,int32_t grid_size_y,int32_t grid_size_z,int32_t block_size_x,int32_t block_size_y,int32_t block_size_z,CustomCall::RemainingArgs args,StringRef ptx,StringRef name) const419 Error LaunchFunc::operator()(const ServiceExecutableRunOptions* run_options,
420 JitRtKernelsCache* kernels_cache,
421 int32_t grid_size_x, int32_t grid_size_y,
422 int32_t grid_size_z, int32_t block_size_x,
423 int32_t block_size_y, int32_t block_size_z,
424 CustomCall::RemainingArgs args, StringRef ptx,
425 StringRef name) const {
426 se::Stream* stream = run_options->stream();
427 se::StreamExecutor* executor = stream->parent();
428
429 LaunchDimensions launch_dimensions(
430 {grid_size_x, grid_size_y, grid_size_z},
431 {block_size_x, block_size_y, block_size_z});
432
433 se::KernelBase* kernel = kernels_cache->Get(executor, ptx.data(), name);
434
435 // If kernel does not exists create it from the ptx.
436 if (kernel == nullptr) {
437 auto created = CreateKernel(absl::string_view(name.data(), name.size()),
438 args.size(), ptx.data(), {}, executor);
439 if (!created.ok()) return AsError(created);
440
441 kernel =
442 kernels_cache->Set(executor, ptx.data(), name, std::move(*created));
443 }
444
445 VLOG(3) << "Launching " << kernel->name();
446 absl::InlinedVector<se::DeviceMemoryBase, 4> buffer_args;
447 buffer_args.reserve(args.size());
448
449 // Add MemRef arguments as buffer arguments.
450 for (unsigned i = 0; i < args.size(); ++i) {
451 // Simple row major memref passed as shapeless buffer.
452 auto memref = args.get<runtime::FlatMemrefView>(i);
453 if (succeeded(memref)) {
454 buffer_args.emplace_back(GetDeviceAddress(*memref));
455 continue;
456 }
457
458 // Memref layout must be encoded in the compiled device kernel, so we don't
459 // have to pass strides or minor to major dimensions order to the kernel.
460 auto strided = args.get<runtime::StridedMemrefView>(i);
461 if (succeeded(strided)) {
462 buffer_args.emplace_back(GetDeviceAddress(*strided));
463 continue;
464 }
465
466 return MakeStringError("Unsupported argumeent type");
467 }
468
469 // Execute device kernel on a main stream.
470 auto executed =
471 ExecuteKernelOnStream(*kernel, buffer_args, launch_dimensions, stream);
472 if (!executed.ok()) return AsError(executed);
473
474 return Error::success();
475 }
476
LaunchFunc(runtime::KernelContext * ctx,void ** args,void ** attrs)477 static bool LaunchFunc(runtime::KernelContext* ctx, void** args, void** attrs) {
478 static auto* handler = CustomCall::Bind("xla.gpu.func.launch")
479 .UserData<const ServiceExecutableRunOptions*>()
480 .UserData<JitRtKernelsCache*>()
481 .Arg<int32_t>() // grid_size_x
482 .Arg<int32_t>() // grid_size_y
483 .Arg<int32_t>() // grid_size_z
484 .Arg<int32_t>() // block_size_x
485 .Arg<int32_t>() // block_size_y
486 .Arg<int32_t>() // block_size_x
487 .RemainingArgs() // args
488 .Attr<StringRef>("ptx")
489 .Attr<StringRef>("kernel")
490 .To<RuntimeChecks()>(LaunchFunc::Handler())
491 .release();
492
493 return succeeded(Executable::Call(ctx, *handler, args, attrs));
494 }
495
496 // -------------------------------------------------------------------------- //
497
498 namespace {
499 struct Gemm {
500 LLVM_ATTRIBUTE_ALWAYS_INLINE
501 Error operator()(const ServiceExecutableRunOptions* run_options,
502 const DebugOptions* debug_options,
503 JitRtGemmConfigCache* configs,
504 runtime::StridedMemrefView lhs,
505 runtime::StridedMemrefView rhs,
506 runtime::StridedMemrefView out, int64_t algorithm,
507 double alpha_real, double alpha_imag, double beta,
508 DotDimensionNumbers dot_dims, int64_t uid) const;
509
Handlerxla::gpu::__anon8ea4ed0d0711::Gemm510 static Gemm Handler() { return Gemm(); }
511 };
512 } // namespace
513
operator ()(const ServiceExecutableRunOptions * run_options,const DebugOptions * debug_options,JitRtGemmConfigCache * configs,runtime::StridedMemrefView lhs,runtime::StridedMemrefView rhs,runtime::StridedMemrefView out,int64_t algorithm,double alpha_real,double alpha_imag,double beta,DotDimensionNumbers dot_dims,int64_t uid) const514 Error Gemm::operator()(const ServiceExecutableRunOptions* run_options,
515 const DebugOptions* debug_options,
516 JitRtGemmConfigCache* configs,
517 runtime::StridedMemrefView lhs,
518 runtime::StridedMemrefView rhs,
519 runtime::StridedMemrefView out, int64_t algorithm,
520 double alpha_real, double alpha_imag, double beta,
521 DotDimensionNumbers dot_dims, int64_t uid) const {
522 se::DeviceMemoryBase lhs_data = GetDeviceAddress(lhs);
523 se::DeviceMemoryBase rhs_data = GetDeviceAddress(rhs);
524 se::DeviceMemoryBase output_data = GetDeviceAddress(out);
525
526 VLOG(3) << "Running GEMM";
527 se::Stream* stream = run_options->stream();
528
529 // Find the gemm config for this instance of operation based on uid.
530 const GemmConfig* config = configs->Get(uid);
531 if (config == nullptr) {
532 auto cfg = GetGemmConfig(lhs, rhs, out, algorithm, alpha_real, alpha_imag,
533 beta, dot_dims.lhs_batch, dot_dims.lhs_contract,
534 dot_dims.rhs_batch, dot_dims.rhs_contract);
535 if (!cfg.ok()) return AsError(cfg);
536 config = configs->Set(uid, std::move(*cfg));
537 }
538
539 Status executed = [&]() -> Status {
540 return RunGemm(*config, lhs_data, rhs_data, output_data, stream);
541 }();
542
543 if (!executed.ok()) return AsError(executed);
544
545 return Error::success();
546 }
547
Gemm(runtime::KernelContext * ctx,void ** args,void ** attrs)548 static bool Gemm(runtime::KernelContext* ctx, void** args, void** attrs) {
549 static auto* handler = CustomCall::Bind("xla.gpu.gemm")
550 .UserData<const ServiceExecutableRunOptions*>()
551 .UserData<const DebugOptions*>()
552 .UserData<JitRtGemmConfigCache*>()
553 .Arg<runtime::StridedMemrefView>() // lhs
554 .Arg<runtime::StridedMemrefView>() // rhs
555 .Arg<runtime::StridedMemrefView>() // out
556 .Attr<int64_t>("algorithm")
557 .Attr<double>("alpha_real")
558 .Attr<double>("alpha_imag")
559 .Attr<double>("beta")
560 .Attr<DotDimensionNumbers>("dot_dims")
561 .Attr<int64_t>("uid")
562 .To<RuntimeChecks()>(Gemm::Handler())
563 .release();
564
565 return succeeded(Executable::Call(ctx, *handler, args, attrs));
566 }
567
568 // -------------------------------------------------------------------------- //
569
570 // TODO(ezhulenev): Cache matmul plans similar to GemmConfig for Gemm.
571
572 namespace {
573 struct CublasLtMatmul {
574 LLVM_ATTRIBUTE_ALWAYS_INLINE
575 Error operator()(const ServiceExecutableRunOptions* run_options,
576 const DebugOptions* debug_options,
577 runtime::StridedMemrefView a, runtime::StridedMemrefView b,
578 runtime::StridedMemrefView c, runtime::StridedMemrefView d,
579 Optional<runtime::StridedMemrefView> bias, int64_t algorithm,
580 double alpha_real, double alpha_imag, double beta,
581 DotDimensionNumbers dot_dims,
582 se::cuda::BlasLt::Epilogue epilogue,
583 ArrayRef<int32_t> precision, int64_t uid) const;
584
Handlerxla::gpu::__anon8ea4ed0d0911::CublasLtMatmul585 static CublasLtMatmul Handler() { return CublasLtMatmul(); }
586 };
587 } // namespace
588
operator ()(const ServiceExecutableRunOptions * run_options,const DebugOptions * debug_options,runtime::StridedMemrefView a,runtime::StridedMemrefView b,runtime::StridedMemrefView c,runtime::StridedMemrefView d,Optional<runtime::StridedMemrefView> bias,int64_t algorithm,double alpha_real,double alpha_imag,double beta,DotDimensionNumbers dot_dims,se::cuda::BlasLt::Epilogue epilogue,ArrayRef<int32_t> precision,int64_t uid) const589 Error CublasLtMatmul::operator()(
590 const ServiceExecutableRunOptions* run_options,
591 const DebugOptions* debug_options, runtime::StridedMemrefView a,
592 runtime::StridedMemrefView b, runtime::StridedMemrefView c,
593 runtime::StridedMemrefView d, Optional<runtime::StridedMemrefView> bias,
594 int64_t algorithm, double alpha_real, double alpha_imag, double beta,
595 DotDimensionNumbers dot_dims, se::cuda::BlasLt::Epilogue epilogue,
596 ArrayRef<int32_t> precision, int64_t uid) const {
597 VLOG(3) << "Running CublasLtMatmul";
598 se::Stream* stream = run_options->stream();
599
600 // Construct a plan from a gemm config and an epilogue.
601 auto cfg = GetGemmConfig(a, b, c, algorithm, alpha_real, alpha_imag, beta,
602 dot_dims.lhs_batch, dot_dims.lhs_contract,
603 dot_dims.rhs_batch, dot_dims.rhs_contract);
604 if (!cfg.ok()) return AsError(cfg);
605
606 auto plan = cublas_lt::MatmulPlan::From(*cfg, epilogue);
607 if (!plan.ok()) return AsError(plan);
608
609 auto algos = plan->GetAlgorithms(stream);
610 if (!algos.ok()) return AsError(algos);
611
612 se::DeviceMemoryBase a_data = GetDeviceAddress(a);
613 se::DeviceMemoryBase b_data = GetDeviceAddress(b);
614 se::DeviceMemoryBase c_data = GetDeviceAddress(c);
615 se::DeviceMemoryBase d_data = GetDeviceAddress(d);
616 se::DeviceMemoryBase bias_data;
617 if (bias.has_value()) bias_data = GetDeviceAddress(*bias);
618
619 se::OwningScratchAllocator<> scratch_allocator(
620 stream->parent()->device_ordinal(), stream->parent()->GetAllocator());
621
622 auto st =
623 plan->ExecuteOnStream(stream, a_data, b_data, c_data, d_data, bias_data,
624 (*algos)[algorithm], scratch_allocator);
625 if (!st.ok()) return AsError(st);
626
627 return Error::success();
628 }
629
630 // Adds custom call bindings for matmul operations.
631 template <typename... Ts>
BindMatmulAttributes(runtime::CustomCallBinding<Ts...> binding)632 static auto BindMatmulAttributes(runtime::CustomCallBinding<Ts...> binding) {
633 return std::move(binding)
634 .template Attr<int64_t>("algorithm")
635 .template Attr<double>("alpha_real")
636 .template Attr<double>("alpha_imag")
637 .template Attr<double>("beta")
638 .template Attr<DotDimensionNumbers>("dot_dims")
639 .template Attr<se::cuda::BlasLt::Epilogue>("epilogue")
640 .template Attr<ArrayRef<int32_t>>("precision")
641 .template Attr<int64_t>("uid");
642 }
643
CublasLtMatmul(runtime::KernelContext * ctx,void ** args,void ** attrs)644 static bool CublasLtMatmul(runtime::KernelContext* ctx, void** args,
645 void** attrs) {
646 static auto* handler =
647 BindMatmulAttributes(CustomCall::Bind("xla.gpu.cublas.lt.matmul")
648 .UserData<const ServiceExecutableRunOptions*>()
649 .UserData<const DebugOptions*>()
650 .Arg<runtime::StridedMemrefView>() // a
651 .Arg<runtime::StridedMemrefView>() // b
652 .Arg<runtime::StridedMemrefView>() // c
653 .Arg<runtime::StridedMemrefView>() // d
654 .Value(CustomCall::None) // bias
655 )
656 .To<RuntimeChecks()>(CublasLtMatmul::Handler())
657 .release();
658
659 return succeeded(Executable::Call(ctx, *handler, args, attrs));
660 }
661
CublasLtMatmulBias(runtime::KernelContext * ctx,void ** args,void ** attrs)662 static bool CublasLtMatmulBias(runtime::KernelContext* ctx, void** args,
663 void** attrs) {
664 static auto* handler =
665 BindMatmulAttributes(CustomCall::Bind("xla.gpu.cublas.lt.matmul.bias")
666 .UserData<const ServiceExecutableRunOptions*>()
667 .UserData<const DebugOptions*>()
668 .Arg<runtime::StridedMemrefView>() // a
669 .Arg<runtime::StridedMemrefView>() // b
670 .Arg<runtime::StridedMemrefView>() // c
671 .Arg<runtime::StridedMemrefView>() // d
672 .Arg<runtime::StridedMemrefView>() // bias
673 )
674 .To<RuntimeChecks()>(CublasLtMatmul::Handler())
675 .release();
676
677 return succeeded(Executable::Call(ctx, *handler, args, attrs));
678 }
679
680 // -------------------------------------------------------------------------- //
681
682 // TODO(ezhulenev): We need to find a better way to pass structured attributes
683 // to JitRt custom calls.
684
685 // TODO(ezhulenev): Add caching layer for convolution configs and runners.
686
687 namespace {
688
689 struct Window {
690 ArrayRef<int64_t> window_strides;
691 ArrayRef<int64_t> padding;
692 ArrayRef<int64_t> lhs_dilation;
693 ArrayRef<int64_t> rhs_dilation;
694 ArrayRef<int64_t> window_reversal;
695 };
696
697 struct ConvAttrs {
698 int64_t feature_group_count;
699 double result_scale;
700 };
701
702 struct FusedConvAttrs {
703 se::dnn::ActivationMode activation_mode;
704 };
705
706 struct SideInputAttrs {
707 double side_input_scale;
708 };
709
710 } // namespace
711
GetConvDescriptor(CudnnConvKind kind,runtime::StridedMemrefView operand0,runtime::StridedMemrefView operand1,runtime::StridedMemrefView output,runtime::FlatMemrefView scratch,ConvDimensionNumbers dims,Window w,ConvBackendConfig b,ConvAttrs attrs,Optional<FusedConvAttrs> fused=llvm::None,Optional<SideInputAttrs> side_input=llvm::None)712 static GpuConvDescriptor GetConvDescriptor(
713 CudnnConvKind kind,
714 // Arguments
715 runtime::StridedMemrefView operand0, runtime::StridedMemrefView operand1,
716 runtime::StridedMemrefView output, runtime::FlatMemrefView scratch,
717 // Attributes
718 ConvDimensionNumbers dims, Window w, ConvBackendConfig b, ConvAttrs attrs,
719 // Conv-specific arguments and attributes
720 Optional<FusedConvAttrs> fused = llvm::None,
721 Optional<SideInputAttrs> side_input = llvm::None) {
722 // Build a convolution descriptor from the attributes.
723 GpuConvDescriptor descriptor;
724 descriptor.kind = kind;
725
726 // Apply backend config layout to the shape.
727 auto apply_layout = [](runtime::StridedMemrefView& memref,
728 ArrayRef<int64_t> minor_to_major) {
729 Shape shape = ToShape(memref);
730 return ShapeUtil::MakeShapeWithLayout(shape.element_type(),
731 shape.dimensions(), minor_to_major);
732 };
733
734 descriptor.operand0_shape = apply_layout(operand0, b.operand_0_layout);
735 descriptor.operand1_shape = apply_layout(operand1, b.operand_1_layout);
736 descriptor.result_shape = apply_layout(output, b.result_layout);
737
738 // Set up convolution dimensions numbers.
739 ConvolutionDimensionNumbers dns;
740 dns.set_input_batch_dimension(dims.input_batch_dim);
741 dns.set_input_feature_dimension(dims.input_feature_dim);
742 dns.set_kernel_input_feature_dimension(dims.kernel_in_feature_dim);
743 dns.set_kernel_output_feature_dimension(dims.kernel_out_feature_dim);
744 dns.set_output_batch_dimension(dims.output_batch_dim);
745 dns.set_output_feature_dimension(dims.output_feature_dim);
746 for (int64_t d : dims.input_spatial_dims) dns.add_input_spatial_dimensions(d);
747 for (int64_t d : dims.kernel_spatial_dims)
748 dns.add_kernel_spatial_dimensions(d);
749 for (int64_t d : dims.output_spatial_dims)
750 dns.add_output_spatial_dimensions(d);
751 descriptor.dnums = std::move(dns);
752
753 // Put together convolution window config.
754 for (auto index : llvm::seq<int>(0, w.window_strides.size())) {
755 WindowDimension* dim = descriptor.window.add_dimensions();
756 // Window size for a convolution is the same as the kernel size.
757 // Kernel size of the convolution is operand1_shape. We need to look at
758 // the convolution dimension numbers kernel spatial dimensions to get
759 // the window size.
760 int kernel_dim = descriptor.dnums.kernel_spatial_dimensions(index);
761 dim->set_size(descriptor.operand0_shape.dimensions(kernel_dim));
762 dim->set_stride(w.window_strides[index]);
763 dim->set_padding_low(w.padding[index]);
764 dim->set_padding_high(w.padding[index]);
765 dim->set_base_dilation(w.lhs_dilation[index]);
766 dim->set_window_dilation(w.rhs_dilation[index]);
767 dim->set_window_reversal(w.window_reversal[index]);
768 }
769
770 descriptor.scratch_size = scratch.size_in_bytes;
771 descriptor.feature_group_count = attrs.feature_group_count;
772 descriptor.backend_config.set_conv_result_scale(attrs.result_scale);
773
774 // Set up convolution algorigthm.
775 auto* algo = descriptor.backend_config.mutable_algorithm();
776 algo->set_algo_id(b.algorithm);
777 algo->set_math_type(b.tensor_ops_enabled
778 ? se::dnn::AlgorithmProto::TENSOR_OP_MATH
779 : se::dnn::AlgorithmProto::DEFAULT_MATH);
780 algo->set_is_cudnn_frontend(b.is_cudnn_frontend);
781
782 if (b.workspace_size >= 0)
783 algo->mutable_workspace_size()->set_value(b.workspace_size);
784
785 for (unsigned i = 0; i < b.knob_ids.size(); ++i) {
786 algo->mutable_tuning_knobs()->insert({b.knob_ids[i], b.knob_values[i]});
787 }
788
789 // Set attributes specific for fused convolutions.
790 if (fused.has_value())
791 descriptor.backend_config.set_activation_mode(fused->activation_mode);
792
793 // Set attributes specific for convolutions with side input.
794 if (side_input.has_value())
795 descriptor.backend_config.set_side_input_scale(
796 side_input->side_input_scale);
797
798 return descriptor;
799 }
800
801 namespace {
802 struct Conv {
803 LLVM_ATTRIBUTE_ALWAYS_INLINE
operator ()xla::gpu::__anon8ea4ed0d0c11::Conv804 Error operator()(
805 const ServiceExecutableRunOptions* run_options,
806 const DebugOptions* debug_options, runtime::StridedMemrefView operand0,
807 runtime::StridedMemrefView operand1,
808 Optional<runtime::FlatMemrefView> bias,
809 Optional<runtime::StridedMemrefView> side_input,
810 runtime::StridedMemrefView output, runtime::FlatMemrefView scratch,
811 ConvDimensionNumbers conv_dims,
812 // Window config
813 ArrayRef<int64_t> window_strides, ArrayRef<int64_t> padding,
814 ArrayRef<int64_t> lhs_dilation, ArrayRef<int64_t> rhs_dilation,
815 ArrayRef<int64_t> window_reversal,
816 // Backend config attributes
817 ConvBackendConfig backend_config,
818 // Remaining attributes
819 int64_t feature_group_count, double result_scale,
820 // Optional attributes for fused convolutions.
821 Optional<se::dnn::ActivationMode> activation_mode = llvm::None,
822 Optional<double> side_input_scale = llvm::None) const {
823 // Build config for optional attributes.
824 Optional<FusedConvAttrs> fused_attrs = llvm::None;
825 if (activation_mode.has_value()) fused_attrs = {*activation_mode};
826
827 Optional<SideInputAttrs> side_input_attrs = llvm::None;
828 if (side_input_scale.has_value()) side_input_attrs = {*side_input_scale};
829
830 // Prepare a descriptor for the XLA convolution.
831 GpuConvDescriptor descriptor = GetConvDescriptor(
832 kind, operand0, operand1, output, scratch, conv_dims,
833 {window_strides, padding, lhs_dilation, rhs_dilation, window_reversal},
834 backend_config, {feature_group_count, result_scale}, fused_attrs,
835 side_input_attrs);
836
837 // Convert descriptor to the Conv config.
838 StatusOr<GpuConvConfig> config = GetGpuConvConfig(descriptor, "");
839 if (!config.ok()) return AsError(config);
840
841 // Prepare buffer arguments.
842 std::vector<se::DeviceMemoryBase> buffers = {GetDeviceAddress(operand0),
843 GetDeviceAddress(operand1)};
844 if (bias.has_value()) buffers.push_back(GetDeviceAddress(*bias));
845 if (side_input.has_value())
846 buffers.push_back(GetDeviceAddress(*side_input));
847
848 se::DeviceMemoryBase result_buffer = GetDeviceAddress(output);
849 se::DeviceMemoryBase scratch_buffer = GetDeviceAddress(scratch);
850
851 RunConvOptions opts;
852
853 // Create a runner for the given config.
854 MaybeFusedConvRunner runner(*config);
855 opts.runner_cache = &runner;
856
857 // Run the convolution.
858 auto st = RunGpuConv(*config, buffers, result_buffer, scratch_buffer,
859 run_options->stream(), opts);
860 if (!st.ok() || !run_options->stream()->ok()) {
861 return AsError(st);
862 }
863
864 return Error::success();
865 }
866
Handlerxla::gpu::__anon8ea4ed0d0c11::Conv867 static Conv Handler(CudnnConvKind kind) { return Conv{kind}; }
868
869 CudnnConvKind kind;
870 };
871
872 } // namespace
873
874 // Adds custom call bindings for convolution operations.
875 template <typename... Ts>
BindConvAttributes(runtime::CustomCallBinding<Ts...> binding)876 static auto BindConvAttributes(runtime::CustomCallBinding<Ts...> binding) {
877 return std::move(binding)
878 // Convolution dimensions numbers
879 .template Attr<ConvDimensionNumbers>("conv_dims")
880 // Window config
881 .template Attr<ArrayRef<int64_t>>("window_strides")
882 .template Attr<ArrayRef<int64_t>>("padding")
883 .template Attr<ArrayRef<int64_t>>("lhs_dilation")
884 .template Attr<ArrayRef<int64_t>>("rhs_dilation")
885 .template Attr<ArrayRef<int64_t>>("window_reversal")
886 // Backend config attributes
887 .template Attr<ConvBackendConfig>("backend_config")
888 // Remaining attributes.
889 .template Attr<int64_t>("feature_group_count")
890 .template Attr<double>("result_scale");
891 }
892
893 template <CudnnConvKind kind>
ConvFn(runtime::KernelContext * ctx,void ** args,void ** attrs)894 static bool ConvFn(runtime::KernelContext* ctx, void** args, void** attrs) {
895 static auto* handler =
896 BindConvAttributes(CustomCall::Bind("xla.gpu.conv")
897 .UserData<const ServiceExecutableRunOptions*>()
898 .UserData<const DebugOptions*>()
899 .Arg<runtime::StridedMemrefView>() // operand0
900 .Arg<runtime::StridedMemrefView>() // operand1
901 .Value(CustomCall::None) // bias
902 .Value(CustomCall::None) // side_input
903 .Arg<runtime::StridedMemrefView>() // output
904 .Arg<runtime::FlatMemrefView>() // scratch
905 )
906 .To(Conv::Handler(kind))
907 .release();
908
909 return succeeded(Executable::Call(ctx, *handler, args, attrs));
910 }
911
912 template <CudnnConvKind kind>
ConvFusedFn(runtime::KernelContext * ctx,void ** args,void ** attrs)913 static bool ConvFusedFn(runtime::KernelContext* ctx, void** args,
914 void** attrs) {
915 static auto* handler =
916 BindConvAttributes(CustomCall::Bind("xla.gpu.conv.fused")
917 .UserData<const ServiceExecutableRunOptions*>()
918 .UserData<const DebugOptions*>()
919 .Arg<runtime::StridedMemrefView>() // operand0
920 .Arg<runtime::StridedMemrefView>() // operand1
921 .Arg<runtime::FlatMemrefView>() // bias
922 .Value(CustomCall::None) // side_input
923 .Arg<runtime::StridedMemrefView>() // output
924 .Arg<runtime::FlatMemrefView>() // scratch
925 )
926 .Attr<se::dnn::ActivationMode>("activation_mode")
927 .To(Conv::Handler(kind))
928 .release();
929
930 return succeeded(Executable::Call(ctx, *handler, args, attrs));
931 }
932
933 template <CudnnConvKind kind>
ConvFuseSideInputdFn(runtime::KernelContext * ctx,void ** args,void ** attrs)934 static bool ConvFuseSideInputdFn(runtime::KernelContext* ctx, void** args,
935 void** attrs) {
936 static auto* handler =
937 BindConvAttributes(CustomCall::Bind("xla.gpu.conv.fused.side_input")
938 .UserData<const ServiceExecutableRunOptions*>()
939 .UserData<const DebugOptions*>()
940 .Arg<runtime::StridedMemrefView>() // operand0
941 .Arg<runtime::StridedMemrefView>() // operand1
942 .Arg<runtime::FlatMemrefView>() // bias
943 .Arg<runtime::StridedMemrefView>() // side_input
944 .Arg<runtime::StridedMemrefView>() // output
945 .Arg<runtime::FlatMemrefView>() // scratch
946 )
947 .Attr<se::dnn::ActivationMode>("activation_mode")
948 .Attr<double>("side_input_scale")
949 .To(Conv::Handler(kind))
950 .release();
951
952 return succeeded(Executable::Call(ctx, *handler, args, attrs));
953 }
954
955 // -------------------------------------------------------------------------- //
956
957 namespace {
958 struct Infeed {
959 Error operator()(const ServiceExecutableRunOptions* run_options,
960 CustomCall::RemainingArgs args, StringRef config) const;
Handlerxla::gpu::__anon8ea4ed0d0d11::Infeed961 static Infeed Handler() { return Infeed(); }
962 };
963 } // namespace
964
operator ()(const ServiceExecutableRunOptions * run_options,CustomCall::RemainingArgs args,StringRef config) const965 Error Infeed::operator()(const ServiceExecutableRunOptions* run_options,
966 CustomCall::RemainingArgs args,
967 StringRef config) const {
968 VLOG(3) << "Infeeding to GPU";
969
970 se::Stream* stream = run_options->stream();
971 ShapeTree<se::ScopedDeviceMemory<uint8_t>> source_buffers =
972 GetOrCreateInfeedManager(stream->parent())->BlockingGetNextDestination();
973
974 // Check that we have correct number of arguments.
975 if (args.size() != source_buffers.leaf_count())
976 return MakeStringError("Incorrect number of arguments");
977
978 size_t index = 0;
979 for (auto& source : source_buffers.leaves()) {
980 // Get the destination buffer.
981 auto dest = args.get<runtime::StridedMemrefView>(index);
982 if (failed(dest))
983 return MakeStringError("Failed to get the destination buffer");
984
985 // Get the source buffer shape.
986 const Shape& source_shape =
987 ShapeUtil::GetSubshape(source_buffers.shape(), source.first);
988
989 // Check that destination shape matches the source shape.
990 Shape dest_shape = ToShape(*dest);
991 if (!ShapeUtil::ReshapeIsBitcast(dest_shape, source_shape)) {
992 return MakeStringError(
993 "The destination shape does not match the source shape");
994 }
995
996 se::DeviceMemoryBase dest_address = GetDeviceAddress(*dest);
997 se::ScopedDeviceMemory<uint8_t>& buffer = source.second;
998 stream->ThenMemcpy(&dest_address, *buffer.ptr(), buffer.ptr()->size());
999
1000 ++index;
1001 }
1002
1003 // TODO(ezhulenev): Make this function async?
1004 Status block_status = stream->BlockHostUntilDone();
1005 if (!block_status.ok()) return AsError(block_status);
1006
1007 VLOG(3) << "Infeeding to GPU complete";
1008
1009 return Error::success();
1010 }
1011
Infeed(runtime::KernelContext * ctx,void ** args,void ** attrs)1012 static bool Infeed(runtime::KernelContext* ctx, void** args, void** attrs) {
1013 static auto* handler = CustomCall::Bind("xla.gpu.infeed")
1014 .UserData<const ServiceExecutableRunOptions*>()
1015 .Arg<CustomCall::RemainingArgs>() // args
1016 .Attr<StringRef>("config")
1017 .To<RuntimeChecks()>(Infeed::Handler())
1018 .release();
1019
1020 return succeeded(Executable::Call(ctx, *handler, args, attrs));
1021 }
1022
1023 // -------------------------------------------------------------------------- //
1024
1025 namespace {
1026 struct Outfeed {
1027 Error operator()(const ServiceExecutableRunOptions* run_options,
1028 CustomCall::RemainingArgs args, StringRef config) const;
Handlerxla::gpu::__anon8ea4ed0d0e11::Outfeed1029 static Outfeed Handler() { return Outfeed(); }
1030 };
1031 } // namespace
1032
operator ()(const ServiceExecutableRunOptions * run_options,CustomCall::RemainingArgs args,StringRef config) const1033 Error Outfeed::operator()(const ServiceExecutableRunOptions* run_options,
1034 CustomCall::RemainingArgs args,
1035 StringRef config) const {
1036 VLOG(3) << "Outfeeding from GPU";
1037
1038 se::Stream* stream = run_options->stream();
1039 OutfeedManager* outfeed_manager = GetOrCreateOutfeedManager(stream->parent());
1040 ShapeTree<std::unique_ptr<OutfeedBuffer>>* dest_buffers =
1041 outfeed_manager->BlockingGetNextDestination();
1042
1043 // Nothing to be done for an outfeed with no inputs.
1044 // Note: Must do this after `BlockingGetNextDestination` above to dequeue an
1045 // entry from the outfeed manager.
1046 if (args.empty()) return Error::success();
1047
1048 // Check that we have correct number of arguments.
1049 if (args.size() != dest_buffers->leaf_count())
1050 return MakeStringError("Incorrect number of arguments");
1051
1052 size_t index = 0;
1053 for (auto& dest : dest_buffers->leaves()) {
1054 // Get the source buffer.
1055 auto source = args.get<runtime::StridedMemrefView>(index);
1056 if (failed(source))
1057 return MakeStringError("Failed to get the source buffer");
1058
1059 // Get the source buffer shape.
1060 const Shape& dest_shape =
1061 ShapeUtil::GetSubshape(dest_buffers->shape(), dest.first);
1062
1063 // Check that destination shape matches the source shape.
1064 Shape source_shape = ToShape(*source);
1065 if (!ShapeUtil::ReshapeIsBitcast(dest_shape, source_shape)) {
1066 return MakeStringError(
1067 "The destination shape does not match the source shape");
1068 }
1069
1070 se::DeviceMemoryBase source_address = GetDeviceAddress(*source);
1071 std::unique_ptr<OutfeedBuffer>& buffer = dest.second;
1072
1073 // Schedule the memory transfer.
1074 auto* dest_address = buffer->destination()->untyped_data();
1075 stream->ThenMemcpy(dest_address, source_address, buffer->length())
1076 .ThenDoHostCallback([&buffer]() { buffer->Done(); });
1077
1078 ++index;
1079 }
1080
1081 Status block_status = stream->BlockHostUntilDone();
1082 if (!block_status.ok()) return AsError(block_status);
1083
1084 VLOG(3) << "Outfeeding from GPU complete";
1085
1086 return Error::success();
1087 }
1088
Outfeed(runtime::KernelContext * ctx,void ** args,void ** attrs)1089 static bool Outfeed(runtime::KernelContext* ctx, void** args, void** attrs) {
1090 static auto* handler = CustomCall::Bind("xla.gpu.outfeed")
1091 .UserData<const ServiceExecutableRunOptions*>()
1092 .Arg<CustomCall::RemainingArgs>() // args
1093 .Attr<StringRef>("config")
1094 .To<RuntimeChecks()>(Outfeed::Handler())
1095 .release();
1096
1097 return succeeded(Executable::Call(ctx, *handler, args, attrs));
1098 }
1099
1100 // -------------------------------------------------------------------------- //
1101
1102 namespace {
1103
1104 enum class MemcpyDirection { kDeviceToDevice, kDeviceToHost, kHostToDevice };
1105
1106 template <MemcpyDirection direction>
1107 struct Memcpy {
1108 Error operator()(const ServiceExecutableRunOptions* run_options,
1109 runtime::FlatMemrefView dst,
1110 runtime::FlatMemrefView src) const;
Handlerxla::gpu::__anon8ea4ed0d1011::Memcpy1111 static Memcpy Handler() { return Memcpy(); }
1112 };
1113 } // namespace
1114
1115 template <MemcpyDirection direction>
operator ()(const ServiceExecutableRunOptions * run_options,runtime::FlatMemrefView dst,runtime::FlatMemrefView src) const1116 Error Memcpy<direction>::operator()(
1117 const ServiceExecutableRunOptions* run_options, runtime::FlatMemrefView dst,
1118 runtime::FlatMemrefView src) const {
1119 se::Stream* stream = run_options->stream();
1120
1121 if (dst.size_in_bytes != src.size_in_bytes) {
1122 return MakeStringError(
1123 "Source memref size does not match destination memref size");
1124 }
1125
1126 switch (direction) {
1127 case MemcpyDirection::kDeviceToDevice: {
1128 se::DeviceMemoryBase dst_data = GetDeviceAddress(dst);
1129 se::DeviceMemoryBase src_data = GetDeviceAddress(src);
1130 stream->ThenMemcpy(&dst_data, src_data, src.size_in_bytes);
1131 } break;
1132 case MemcpyDirection::kDeviceToHost: {
1133 se::DeviceMemoryBase src_data = GetDeviceAddress(src);
1134 stream->ThenMemcpy(dst.data, src_data, src.size_in_bytes);
1135 } break;
1136 case MemcpyDirection::kHostToDevice: {
1137 se::DeviceMemoryBase dst_data = GetDeviceAddress(dst);
1138 stream->ThenMemcpy(&dst_data, src.data, src.size_in_bytes);
1139 } break;
1140 }
1141
1142 // TODO(ezhulenev): H2D and D2H memcpy instead of blocking the execution
1143 // thread should return an async token that will become available when
1144 // transfer is completed.
1145 if (direction != MemcpyDirection::kDeviceToDevice) {
1146 auto st = stream->BlockHostUntilDone();
1147 if (!st.ok()) return AsError(st);
1148 }
1149
1150 return Error::success();
1151 }
1152
1153 template <MemcpyDirection direction>
MemcpyFn(runtime::KernelContext * ctx,void ** args,void ** attrs)1154 static bool MemcpyFn(runtime::KernelContext* ctx, void** args, void** attrs) {
1155 static auto* handler = CustomCall::Bind("xla.gpu.memcpy")
1156 .UserData<const ServiceExecutableRunOptions*>()
1157 .Arg<runtime::FlatMemrefView>() // dst
1158 .Arg<runtime::FlatMemrefView>() // src
1159 .To<RuntimeChecks()>(Memcpy<direction>::Handler())
1160 .release();
1161
1162 return succeeded(Executable::Call(ctx, *handler, args, attrs));
1163 }
1164
1165 // -------------------------------------------------------------------------- //
1166
1167 namespace {
1168
1169 struct Memset {
1170 Error operator()(const ServiceExecutableRunOptions* run_options,
1171 runtime::FlatMemrefView dst,
1172 CustomCall::VariantArg constant) const;
Handlerxla::gpu::__anon8ea4ed0d1111::Memset1173 static Memset Handler() { return Memset(); }
1174 };
1175
1176 } // namespace
1177
operator ()(const ServiceExecutableRunOptions * run_options,runtime::FlatMemrefView dst,CustomCall::VariantArg constant) const1178 Error Memset::operator()(const ServiceExecutableRunOptions* run_options,
1179 runtime::FlatMemrefView dst,
1180 CustomCall::VariantArg constant) const {
1181 se::Stream* stream = run_options->stream();
1182 se::DeviceMemoryBase dst_data = GetDeviceAddress(dst);
1183
1184 // If the constant is zero we can use memzero directly.
1185 bool set_zero = false;
1186
1187 // Check all supported data types to see if we have a zero value.
1188 if (auto i1 = constant.get<bool>(); succeeded(i1) && *i1 == false)
1189 set_zero = true;
1190 else if (auto i32 = constant.get<int32_t>(); succeeded(i32) && *i32 == 0)
1191 set_zero = true;
1192 else if (auto f16 = constant.get<half>(); succeeded(f16) && *f16 == half(0.0))
1193 set_zero = true;
1194 else if (auto f32 = constant.get<float>(); succeeded(f32) && *f32 == 0.0)
1195 set_zero = true;
1196
1197 if (set_zero) {
1198 stream->ThenMemZero(&dst_data, dst.size_in_bytes);
1199 return Error::success();
1200 }
1201
1202 // If the constant is not zero, use the given pattern to `memset`.
1203 // TODO(ezhulenev): Support 16 and 8 bit patterns.
1204 uint32_t pattern;
1205 if (auto i32 = constant.get<int32_t>(); succeeded(i32))
1206 pattern = *i32;
1207 else if (auto f32 = constant.get<float>(); succeeded(f32))
1208 pattern = reinterpret_cast<uint32_t&>(*f32);
1209 else
1210 return MakeStringError("Unsupported memset bit pattern type");
1211
1212 if (dst.size_in_bytes % 4 != 0)
1213 return MakeStringError("Memref size is not divisible by 4");
1214
1215 stream->ThenMemset32(&dst_data, pattern, dst.size_in_bytes);
1216
1217 return Error::success();
1218 }
1219
MemsetFn(runtime::KernelContext * ctx,void ** args,void ** attrs)1220 static bool MemsetFn(runtime::KernelContext* ctx, void** args, void** attrs) {
1221 static auto* handler = CustomCall::Bind("xla.gpu.memset")
1222 .UserData<const ServiceExecutableRunOptions*>()
1223 .Arg<runtime::FlatMemrefView>() // dst
1224 .Arg<CustomCall::VariantArg>() // constant
1225 .To<RuntimeChecks()>(Memset::Handler())
1226 .release();
1227
1228 return succeeded(Executable::Call(ctx, *handler, args, attrs));
1229 }
1230
1231 // -------------------------------------------------------------------------- //
1232
1233 namespace {
1234 struct Fft {
1235 LLVM_ATTRIBUTE_ALWAYS_INLINE
1236 Error operator()(const ServiceExecutableRunOptions* run_options,
1237 runtime::StridedMemrefView input,
1238 runtime::StridedMemrefView output,
1239 ArrayRef<int64_t> fft_length, se::fft::Type fft_type) const;
Handlerxla::gpu::__anon8ea4ed0d1211::Fft1240 static Fft Handler() { return Fft(); }
1241 };
1242 } // namespace
1243
operator ()(const ServiceExecutableRunOptions * run_options,runtime::StridedMemrefView input,runtime::StridedMemrefView output,ArrayRef<int64_t> fft_length,se::fft::Type fft_type) const1244 Error Fft::operator()(const ServiceExecutableRunOptions* run_options,
1245 runtime::StridedMemrefView input,
1246 runtime::StridedMemrefView output,
1247 ArrayRef<int64_t> fft_length,
1248 se::fft::Type fft_type) const {
1249 // TODO(ezhulenev): Cache FFT plans in the GpuExecutable.
1250 FftPlanCache fft_plan_cache;
1251
1252 se::Stream* stream = run_options->stream();
1253 se::StreamExecutor* executor = stream->parent();
1254
1255 if (input.dtype == tfrt::DType::F64 ||
1256 input.dtype == tfrt::DType::Complex128) {
1257 // Adjust FFT type to reflect double precision.
1258 switch (fft_type) {
1259 case se::fft::Type::kC2CForward:
1260 fft_type = se::fft::Type::kZ2ZForward;
1261 break;
1262 case se::fft::Type::kC2CInverse:
1263 fft_type = se::fft::Type::kZ2ZInverse;
1264 break;
1265 case se::fft::Type::kR2C:
1266 fft_type = se::fft::Type::kD2Z;
1267 break;
1268 case se::fft::Type::kC2R:
1269 fft_type = se::fft::Type::kZ2D;
1270 break;
1271 default:
1272 return MakeStringError("Unsupported FFT type");
1273 }
1274 }
1275
1276 auto st =
1277 RunFft(GetDeviceAddress(input), ToShape(input), GetDeviceAddress(output),
1278 ToShape(output), fft_type, fft_length, executor->device_ordinal(),
1279 &fft_plan_cache, stream, run_options->allocator());
1280 if (!st.ok()) return AsError(st);
1281
1282 return Error::success();
1283 }
1284
Fft(runtime::KernelContext * ctx,void ** args,void ** attrs)1285 static bool Fft(runtime::KernelContext* ctx, void** args, void** attrs) {
1286 static auto* handler = CustomCall::Bind("xla.gpu.fft")
1287 .UserData<const ServiceExecutableRunOptions*>()
1288 .Arg<runtime::StridedMemrefView>() // input
1289 .Arg<runtime::StridedMemrefView>() // output
1290 .Attr<ArrayRef<int64_t>>("fft_length")
1291 .Attr<se::fft::Type>("fft_type")
1292 .To<RuntimeChecks()>(Fft::Handler())
1293 .release();
1294 return succeeded(Executable::Call(ctx, *handler, args, attrs));
1295 }
1296
1297 // -------------------------------------------------------------------------- //
1298
1299 namespace {
1300 struct Cholesky {
1301 LLVM_ATTRIBUTE_ALWAYS_INLINE
1302 Error operator()(const ServiceExecutableRunOptions* run_options,
1303 const DebugOptions* debug_options,
1304 runtime::MemrefView operand, runtime::MemrefView a,
1305 runtime::MemrefView workspace, runtime::MemrefView info,
1306 int64_t batch_size, bool is_lower, int64_t n) const;
Handlerxla::gpu::__anon8ea4ed0d1311::Cholesky1307 static Cholesky Handler() { return Cholesky(); }
1308 };
1309 } // namespace
1310
operator ()(const ServiceExecutableRunOptions * run_options,const DebugOptions * debug_options,runtime::MemrefView operand,runtime::MemrefView a,runtime::MemrefView workspace,runtime::MemrefView info,int64_t batch_size,bool is_lower,int64_t n) const1311 Error Cholesky::operator()(const ServiceExecutableRunOptions* run_options,
1312 const DebugOptions* debug_options,
1313 runtime::MemrefView operand, runtime::MemrefView a,
1314 runtime::MemrefView workspace,
1315 runtime::MemrefView info, int64_t batch_size,
1316 bool is_lower, int64_t n) const {
1317 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
1318 se::DeviceMemoryBase operand_buffer = GetDeviceAddress(operand);
1319 se::DeviceMemoryBase a_buffer = GetDeviceAddress(a);
1320 se::DeviceMemoryBase workspace_buffer = GetDeviceAddress(workspace);
1321 se::DeviceMemoryBase info_buffer = GetDeviceAddress(info);
1322
1323 VLOG(3) << "Running Cholesky";
1324 se::Stream* stream = run_options->stream();
1325
1326 // Copy operand to the a buffer if they are different.
1327 if (a.data != operand.data)
1328 stream->ThenMemcpy(&a_buffer, operand_buffer, operand_buffer.size());
1329
1330 using UpperLower = se::blas::UpperLower;
1331 UpperLower uplo = is_lower ? UpperLower::kLower : UpperLower::kUpper;
1332
1333 CholeskyParams params{n, batch_size, uplo,
1334 a_buffer, workspace_buffer, info_buffer};
1335 auto executed =
1336 RunCholesky(xla::gpu::PtxOptsFromDebugOptions(*debug_options),
1337 TfrtToPrimitiveType(operand.dtype), ¶ms, stream);
1338 if (!executed.ok()) return AsError(executed);
1339
1340 return Error::success();
1341 #else // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
1342 return failure();
1343 #endif
1344 }
1345
Cholesky(runtime::KernelContext * ctx,void ** args,void ** attrs)1346 static bool Cholesky(runtime::KernelContext* ctx, void** args, void** attrs) {
1347 static auto* handler = CustomCall::Bind("xla.gpu.cholesky")
1348 .UserData<const ServiceExecutableRunOptions*>()
1349 .UserData<const DebugOptions*>()
1350 .Arg<runtime::MemrefView>() // operand
1351 .Arg<runtime::MemrefView>() // a
1352 .Arg<runtime::MemrefView>() // workspace
1353 .Arg<runtime::MemrefView>() // info
1354 .Attr<int64_t>("batch_size")
1355 .Attr<bool>("is_lower")
1356 .Attr<int64_t>("n")
1357 .To<RuntimeChecks()>(Cholesky::Handler())
1358 .release();
1359
1360 return succeeded(Executable::Call(ctx, *handler, args, attrs));
1361 }
1362
1363 // -------------------------------------------------------------------------- //
1364
1365 namespace {
1366
1367 // TODO(ezhulenev): Today XLA represents TriangularSolve as a "classic" XLA
1368 // custom call operation, and we provide a thin adaptor from Xla custom call
1369 // to JitRt custom call. Once we are fully migrated to JitRt exectuion, XLA
1370 // compiler should directly emit properly typed TriangularSolve JitRt custom
1371 // call (no need to pass config via the serialized string).
1372 struct TriangularSolve {
1373 // Adaptor from XlaCustomCall API to properly typed TriangularSolve handler.
1374 static Error run(const ServiceExecutableRunOptions* run_options,
1375 const DebugOptions* debug_options,
1376 CustomCall::RemainingArgs args, StringRef backend_config);
1377
1378 Error operator()(const ServiceExecutableRunOptions* run_options,
1379 const DebugOptions* debug_options,
1380 runtime::StridedMemrefView a, runtime::StridedMemrefView b,
1381 runtime::StridedMemrefView result,
1382 runtime::FlatMemrefView temp, bool left_side, bool lower,
1383 bool unit_diagonal,
1384 TriangularSolveOptions::Transpose transpose_a) const;
Handlerxla::gpu::__anon8ea4ed0d1411::TriangularSolve1385 static TriangularSolve Handler() { return TriangularSolve(); }
1386 };
1387
1388 } // namespace
1389
run(const ServiceExecutableRunOptions * run_options,const DebugOptions * debug_options,CustomCall::RemainingArgs args,StringRef backend_config)1390 Error TriangularSolve::run(const ServiceExecutableRunOptions* run_options,
1391 const DebugOptions* debug_options,
1392 CustomCall::RemainingArgs args,
1393 StringRef backend_config) {
1394 TriangularSolve handler = TriangularSolve::Handler();
1395
1396 if (args.size() != 4)
1397 return MakeStringError("Expected 4 arguments, got %n", args.size());
1398
1399 // Check if all arguments have the correct type.
1400 auto a = args.get<runtime::StridedMemrefView>(0);
1401 auto b = args.get<runtime::StridedMemrefView>(1);
1402 auto result = args.get<runtime::StridedMemrefView>(2);
1403 auto temp = args.get<runtime::FlatMemrefView>(3);
1404 if (failed(a) || failed(b) || failed(result) || failed(temp))
1405 return MakeStringError("Incorrect argument types");
1406
1407 // Parse backend config string.
1408 TriangularSolveOptions opts;
1409 auto st = tensorflow::HumanReadableJsonToProto(backend_config.str(), &opts);
1410 if (!st.ok()) return AsError(st);
1411
1412 return handler(run_options, debug_options, *a, *b, *result, *temp,
1413 opts.left_side(), opts.lower(), opts.unit_diagonal(),
1414 opts.transpose_a());
1415 }
1416
operator ()(const ServiceExecutableRunOptions * run_options,const DebugOptions * debug_options,runtime::StridedMemrefView a,runtime::StridedMemrefView b,runtime::StridedMemrefView result,runtime::FlatMemrefView temp,bool left_side,bool lower,bool unit_diagonal,TriangularSolveOptions::Transpose transpose_a) const1417 Error TriangularSolve::operator()(
1418 const ServiceExecutableRunOptions* run_options,
1419 const DebugOptions* debug_options, runtime::StridedMemrefView a,
1420 runtime::StridedMemrefView b, runtime::StridedMemrefView result,
1421 runtime::FlatMemrefView temp, bool left_side, bool lower,
1422 bool unit_diagonal, TriangularSolveOptions::Transpose transpose_a) const {
1423 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
1424 se::Stream* stream = run_options->stream();
1425
1426 se::DeviceMemoryBase a_data = GetDeviceAddress(a);
1427 se::DeviceMemoryBase b_data = GetDeviceAddress(b);
1428 se::DeviceMemoryBase result_data = GetDeviceAddress(result);
1429 se::DeviceMemoryBase temp_data = GetDeviceAddress(temp);
1430
1431 // Triangular solve is in-place on 'b', so copy 'b' to the output if they
1432 // aren't the same buffer.
1433 if (b.data != result.data)
1434 stream->ThenMemcpy(&result_data, b_data, b_data.size());
1435
1436 Shape b_shape = ToShape(b);
1437 int64_t m = b_shape.dimensions(b_shape.rank() - 2);
1438 int64_t n = b_shape.dimensions(b_shape.rank() - 1);
1439 int64_t batch_size = std::accumulate(
1440 b_shape.dimensions().begin(), b_shape.dimensions().end() - 2, int64_t{1},
1441 [](int64_t a, int64_t b) { return a * b; });
1442
1443 PrimitiveType elem_type = TfrtToPrimitiveType(b.dtype);
1444 int64_t elem_size = ShapeUtil::ByteSizeOfPrimitiveType(elem_type);
1445 int64_t a_batch_stride = left_side ? m * m * elem_size : n * n * elem_size;
1446 int64_t b_batch_stride = m * n * elem_size;
1447
1448 using Side = se::blas::Side;
1449 using Diagonal = se::blas::Diagonal;
1450 using Transpose = se::blas::Transpose;
1451 using UpperLower = se::blas::UpperLower;
1452
1453 // Convert custom call attributes to se::blas enums.
1454 UpperLower uplo = lower ? UpperLower::kLower : UpperLower::kUpper;
1455 Side side = left_side ? Side::kLeft : Side::kRight;
1456 Diagonal diagonal = unit_diagonal ? Diagonal::kUnit : Diagonal::kNonUnit;
1457
1458 auto transpose = [&]() -> mlir::FailureOr<Transpose> {
1459 switch (transpose_a) {
1460 case TriangularSolveOptions::NO_TRANSPOSE:
1461 return se::blas::Transpose::kNoTranspose;
1462 case TriangularSolveOptions::TRANSPOSE:
1463 return se::blas::Transpose::kTranspose;
1464 case TriangularSolveOptions::ADJOINT:
1465 return se::blas::Transpose::kConjugateTranspose;
1466 default:
1467 return failure();
1468 }
1469 }();
1470
1471 if (failed(transpose))
1472 return MakeStringError("Failed to convert transpose type");
1473
1474 auto st = RunTriangulatSolve(
1475 a_data, result_data, temp_data, PtxOptsFromDebugOptions(*debug_options),
1476 uplo, side, diagonal, *transpose, elem_type, batch_size, m, n,
1477 a_batch_stride, b_batch_stride, stream);
1478 if (!st.ok()) return AsError(st);
1479
1480 return Error::success();
1481 #else // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
1482 return failure();
1483 #endif
1484 }
1485
1486 // -------------------------------------------------------------------------- //
1487 // Implements JitRt custom call that forward to the Xla Custom Call handler.
1488 //
1489 // Longer term all Xla custom calls probably should be directly implemented as
1490 // JitRt custom calls. However for smooth migration from Thunks to JitRt we have
1491 // to seamlessly support all current XLA users.
1492 namespace {
1493 struct XlaCustomCall {
1494 using Stream = se::gpu::GpuStreamHandle;
1495
1496 Error operator()(const ServiceExecutableRunOptions* run_options,
1497 const DebugOptions* debug_options,
1498 CustomCall::RemainingArgs args, StringRef call_target_name,
1499 int32_t api_version, StringRef backend_config) const;
Handlerxla::gpu::__anon8ea4ed0d1711::XlaCustomCall1500 static XlaCustomCall Handler() { return XlaCustomCall(); }
1501 };
1502 } // namespace
1503
operator ()(const ServiceExecutableRunOptions * run_options,const DebugOptions * debug_options,CustomCall::RemainingArgs args,StringRef call_target_name,int32_t api_version,StringRef backend_config) const1504 Error XlaCustomCall::operator()(const ServiceExecutableRunOptions* run_options,
1505 const DebugOptions* debug_options,
1506 CustomCall::RemainingArgs args,
1507 StringRef call_target_name, int32_t api_version,
1508 StringRef backend_config) const {
1509 // Pattern match custom call to a few special cases, otherwise find the custom
1510 // call handler regustered with the runtime.
1511 if (call_target_name == kTriangularSolveCallTarget)
1512 return TriangularSolve::run(run_options, debug_options, args,
1513 backend_config);
1514
1515 // Find the Xla custom call handler.
1516 auto& platform_name = run_options->stream()->parent()->platform()->Name();
1517 void* call_target = CustomCallTargetRegistry::Global()->Lookup(
1518 call_target_name.str(), platform_name);
1519 if (!call_target) {
1520 return MakeStringError("Cannot find the Xla custom call handler ",
1521 call_target_name.str());
1522 }
1523
1524 // Prepare pointers to buffers to pass to the Xla custom call handler.
1525 llvm::SmallVector<void*> buffers;
1526 for (unsigned i = 0; i < args.size(); ++i) {
1527 // We use zero-sized memrefs to represent holes in custom calls with target
1528 // arguments mapping (see `CustomCallTargetArgMapping`).
1529 if (auto memref = args.get<runtime::FlatMemrefView>(i); succeeded(memref)) {
1530 buffers.push_back(memref->size_in_bytes == 0 ? nullptr : memref->data);
1531 continue;
1532 }
1533 if (auto strided = args.get<runtime::StridedMemrefView>(i);
1534 succeeded(strided)) {
1535 int64_t size_in_bytes = GetHostSize(strided->dtype);
1536 for (int64_t size : strided->sizes) size_in_bytes *= size;
1537 buffers.push_back(size_in_bytes == 0 ? nullptr : strided->data);
1538 continue;
1539 }
1540 return MakeStringError("Failed to get arguments as (strided) memref view");
1541 }
1542
1543 // Original custom call API version that doesn't support returning status.
1544 if (api_version == CustomCallApiVersion::API_VERSION_ORIGINAL) {
1545 using XlaCustomCallType = void (*)(Stream, void**, const char*, size_t);
1546 auto xla_call_target = reinterpret_cast<XlaCustomCallType>(call_target);
1547
1548 xla_call_target(se::gpu::AsGpuStreamValue(run_options->stream()),
1549 buffers.data(), backend_config.data(),
1550 backend_config.size());
1551
1552 return Error::success();
1553 }
1554
1555 // Xla Custom call API returning status.
1556 if (api_version == CustomCallApiVersion::API_VERSION_STATUS_RETURNING) {
1557 using XlaCustomCallType =
1558 void (*)(Stream, void**, const char*, size_t, XlaCustomCallStatus*);
1559 auto xla_call_target = reinterpret_cast<XlaCustomCallType>(call_target);
1560
1561 XlaCustomCallStatus custom_call_status;
1562 xla_call_target(se::gpu::AsGpuStreamValue(run_options->stream()),
1563 buffers.data(), backend_config.data(),
1564 backend_config.size(), &custom_call_status);
1565
1566 if (auto message = CustomCallStatusGetMessage(&custom_call_status)) {
1567 return MakeStringError(message.value());
1568 } else {
1569 return Error::success();
1570 }
1571 }
1572
1573 return MakeStringError("Incorrect custom call API version");
1574 }
1575
CustomCall(runtime::KernelContext * ctx,void ** args,void ** attrs)1576 static bool CustomCall(runtime::KernelContext* ctx, void** args, void** attrs) {
1577 static auto* handler = CustomCall::Bind("xla.gpu.memcpy")
1578 .UserData<const ServiceExecutableRunOptions*>()
1579 .UserData<const DebugOptions*>()
1580 .Arg<CustomCall::RemainingArgs>() // args
1581 .Attr<StringRef>("call_target_name")
1582 .Attr<int32_t>("api_version")
1583 .Attr<StringRef>("backend_config")
1584 .To<RuntimeChecks()>(XlaCustomCall::Handler())
1585 .release();
1586
1587 return succeeded(Executable::Call(ctx, *handler, args, attrs));
1588 }
1589
1590 // ------------------------------------------------------------------------- //
1591
1592 namespace {
1593 struct AllReduce {
1594 LLVM_ATTRIBUTE_ALWAYS_INLINE
1595 Error operator()(const ServiceExecutableRunOptions* run_options,
1596 JitRtCollectiveSupport* collectives,
1597 CustomCall::RemainingArgs args, int32_t uid,
1598 int64_t group_mode, int64_t op_id, int64_t reduction_kind,
1599 ArrayRef<int64_t> replica_group_offsets,
1600 ArrayRef<int64_t> replica_group_values) const;
Handlerxla::gpu::__anon8ea4ed0d1811::AllReduce1601 static AllReduce Handler() { return AllReduce(); }
1602 };
1603 } // namespace
1604
operator ()(const ServiceExecutableRunOptions * run_options,JitRtCollectiveSupport * collectives,CustomCall::RemainingArgs args,int32_t uid,int64_t group_mode,int64_t op_id,int64_t reduction_kind,ArrayRef<int64_t> replica_group_offsets,ArrayRef<int64_t> replica_group_values) const1605 Error AllReduce::operator()(const ServiceExecutableRunOptions* run_options,
1606 JitRtCollectiveSupport* collectives,
1607 CustomCall::RemainingArgs args, int32_t uid,
1608 int64_t group_mode, int64_t op_id,
1609 int64_t reduction_kind,
1610 ArrayRef<int64_t> replica_group_offsets,
1611 ArrayRef<int64_t> replica_group_values) const {
1612 #if XLA_ENABLE_XCCL
1613 VLOG(3) << "Running AllReduce";
1614 se::Stream* stream = run_options->stream();
1615 NcclExecuteParams params(*run_options, stream);
1616
1617 auto comm = GetNcclComm(params, group_mode, op_id, replica_group_offsets,
1618 replica_group_values);
1619 if (failed(comm)) return MakeStringError("Failed to get NcclComm");
1620
1621 auto device_buffers = GetDeviceBufferPairs(args);
1622 if (failed(device_buffers))
1623 return MakeStringError("Failed to get device buffers");
1624
1625 auto executed = RunAllReduce(static_cast<ReductionKind>(reduction_kind),
1626 *device_buffers, *stream, **comm);
1627 if (!executed.ok()) return AsError(executed);
1628
1629 int32_t device_ordinal = stream->parent()->device_ordinal();
1630 auto st = collectives->MaybeBlockAfterFirstRun(uid, device_ordinal, stream);
1631 if (!st.ok()) return AsError(st);
1632
1633 return Error::success();
1634 #else // XLA_ENABLE_XCCL
1635 // NCCL disabled.
1636 return MakeStringError("NCCL disabled");
1637 #endif // XLA_ENABLE_XCCL
1638 }
1639
AllReduce(runtime::KernelContext * ctx,void ** args,void ** attrs)1640 static bool AllReduce(runtime::KernelContext* ctx, void** args, void** attrs) {
1641 static auto* handler =
1642 CustomCall::Bind("xla.gpu.all_reduce")
1643 .UserData<const ServiceExecutableRunOptions*>()
1644 .UserData<JitRtCollectiveSupport*>()
1645 .RemainingArgs() // args
1646 .Attr<int32_t>("uid")
1647 .Attr<int64_t>("group_mode") // CollectiveOpGroupMode
1648 .Attr<int64_t>("op_id")
1649 .Attr<int64_t>("reduction_kind") // ReductionKind
1650 .Attr<ArrayRef<int64_t>>("replica_group_offsets")
1651 .Attr<ArrayRef<int64_t>>("replica_group_values")
1652 .To<RuntimeChecks()>(AllReduce::Handler())
1653 .release();
1654
1655 return succeeded(Executable::Call(ctx, *handler, args, attrs));
1656 }
1657
1658 // ------------------------------------------------------------------------- //
1659
1660 namespace {
1661 struct AllReduceStart {
1662 LLVM_ATTRIBUTE_ALWAYS_INLINE
1663 Error operator()(const ServiceExecutableRunOptions* run_options,
1664 JitRtAsyncCollectiveSupport* async_collectives,
1665 CustomCall::RemainingArgs args, int64_t group_mode,
1666 int64_t op_id, int64_t reduction_kind,
1667 ArrayRef<int64_t> replica_group_offsets,
1668 ArrayRef<int64_t> replica_group_values, int32_t uid) const;
Handlerxla::gpu::__anon8ea4ed0d1911::AllReduceStart1669 static AllReduceStart Handler() { return AllReduceStart(); }
1670 };
1671 } // namespace
1672
operator ()(const ServiceExecutableRunOptions * run_options,JitRtAsyncCollectiveSupport * async_collectives,CustomCall::RemainingArgs args,int64_t group_mode,int64_t op_id,int64_t reduction_kind,ArrayRef<int64_t> replica_group_offsets,ArrayRef<int64_t> replica_group_values,int32_t uid) const1673 Error AllReduceStart::operator()(const ServiceExecutableRunOptions* run_options,
1674 JitRtAsyncCollectiveSupport* async_collectives,
1675 CustomCall::RemainingArgs args,
1676 int64_t group_mode, int64_t op_id,
1677 int64_t reduction_kind,
1678 ArrayRef<int64_t> replica_group_offsets,
1679 ArrayRef<int64_t> replica_group_values,
1680 int32_t uid) const {
1681 #if XLA_ENABLE_XCCL
1682 VLOG(3) << "Running AllReduceStart";
1683 se::Stream* stream = run_options->stream();
1684 NcclExecuteParams params(*run_options, stream);
1685
1686 auto comm = GetNcclComm(params, group_mode, op_id, replica_group_offsets,
1687 replica_group_values);
1688 if (failed(comm)) return MakeStringError("Failed to get NcclComm");
1689
1690 auto device_buffers = GetDeviceBufferPairs(args);
1691 if (failed(device_buffers))
1692 return MakeStringError("Failed to get device buffers");
1693
1694 // Wait until compute inputs are ready.
1695 async_collectives->async_comm_stream()->ThenWaitFor(params.stream);
1696
1697 auto executed =
1698 RunAllReduce(static_cast<ReductionKind>(reduction_kind), *device_buffers,
1699 *async_collectives->async_comm_stream(), **comm);
1700 if (!executed.ok()) return AsError(executed);
1701
1702 // Create an event on the async stream for the completion of the all-reduce.
1703 se::Event done_event(async_collectives->async_comm_stream()->parent());
1704 if (!done_event.Init()) return MakeStringError("Failed to create event");
1705 async_collectives->async_comm_stream()->ThenRecordEvent(&done_event);
1706
1707 if (failed(async_collectives->PushEvent(
1708 uid, stream->parent()->device_ordinal(), std::move(done_event))))
1709 return MakeStringError("Failed to push event to async collectives");
1710
1711 return Error::success();
1712 #else // XLA_ENABLE_XCCL
1713 return MakeStringError("NCCL disabled");
1714 #endif // XLA_ENABLE_XCCL
1715 }
1716
AllReduceStart(runtime::KernelContext * ctx,void ** args,void ** attrs)1717 static bool AllReduceStart(runtime::KernelContext* ctx, void** args,
1718 void** attrs) {
1719 static auto* handler =
1720 CustomCall::Bind("xla.gpu.all_reduce_start")
1721 .UserData<const ServiceExecutableRunOptions*>()
1722 .UserData<JitRtAsyncCollectiveSupport*>()
1723 .RemainingArgs() // args
1724 .Attr<int64_t>("group_mode") // CollectiveOpGroupMode
1725 .Attr<int64_t>("op_id")
1726 .Attr<int64_t>("reduction_kind") // ReductionKind
1727 .Attr<ArrayRef<int64_t>>("replica_group_offsets")
1728 .Attr<ArrayRef<int64_t>>("replica_group_values")
1729 .Attr<int32_t>("uid")
1730 .To<RuntimeChecks()>(AllReduceStart::Handler())
1731 .release();
1732
1733 return succeeded(Executable::Call(ctx, *handler, args, attrs));
1734 }
1735
1736 // ------------------------------------------------------------------------- //
1737
1738 namespace {
1739 struct AllReduceDone {
1740 LLVM_ATTRIBUTE_ALWAYS_INLINE
1741 Error operator()(const ServiceExecutableRunOptions* run_options,
1742 JitRtCollectiveSupport* collectives,
1743 JitRtAsyncCollectiveSupport* async_collectives,
1744 CustomCall::RemainingArgs args, int32_t uid) const;
Handlerxla::gpu::__anon8ea4ed0d1a11::AllReduceDone1745 static AllReduceDone Handler() { return AllReduceDone(); }
1746 };
1747 } // namespace
1748
operator ()(const ServiceExecutableRunOptions * run_options,JitRtCollectiveSupport * collectives,JitRtAsyncCollectiveSupport * async_collectives,CustomCall::RemainingArgs args,int32_t uid) const1749 Error AllReduceDone::operator()(const ServiceExecutableRunOptions* run_options,
1750 JitRtCollectiveSupport* collectives,
1751 JitRtAsyncCollectiveSupport* async_collectives,
1752 CustomCall::RemainingArgs args,
1753 int32_t uid) const {
1754 #if XLA_ENABLE_XCCL
1755 VLOG(3) << "Running AllReduceDone";
1756 se::Stream* stream = run_options->stream();
1757
1758 int32_t device_ordinal = stream->parent()->device_ordinal();
1759 auto event = async_collectives->PopEvent(uid, device_ordinal);
1760 if (failed(event)) return MakeStringError("Failed to pop event");
1761
1762 stream->ThenWaitFor(&*event);
1763
1764 if (!collectives->MaybeBlockAfterFirstRun(uid, device_ordinal, stream).ok())
1765 return MakeStringError("Failed to block host");
1766
1767 return Error::success();
1768 #else // XLA_ENABLE_XCCL
1769 return MakeStringError("NCCL disabled");
1770 #endif // XLA_ENABLE_XCCL
1771 }
1772
AllReduceDone(runtime::KernelContext * ctx,void ** args,void ** attrs)1773 static bool AllReduceDone(runtime::KernelContext* ctx, void** args,
1774 void** attrs) {
1775 static auto* handler = CustomCall::Bind("xla.gpu.all_reduce_done")
1776 .UserData<const ServiceExecutableRunOptions*>()
1777 .UserData<JitRtCollectiveSupport*>()
1778 .UserData<JitRtAsyncCollectiveSupport*>()
1779 .RemainingArgs() // args
1780 .Attr<int32_t>("uid")
1781 .To<RuntimeChecks()>(AllReduceDone::Handler())
1782 .release();
1783
1784 return succeeded(Executable::Call(ctx, *handler, args, attrs));
1785 }
1786
1787 // -------------------------------------------------------------------------- //
1788
1789 namespace {
1790 struct ReduceScatter {
1791 LLVM_ATTRIBUTE_ALWAYS_INLINE
1792 Error operator()(const ServiceExecutableRunOptions* run_options,
1793 JitRtCollectiveSupport* collectives,
1794 CustomCall::RemainingArgs args, int32_t uid,
1795 int64_t group_mode, int64_t op_id, int64_t reduction_kind,
1796 ArrayRef<int64_t> replica_group_offsets,
1797 ArrayRef<int64_t> replica_group_values) const;
Handlerxla::gpu::__anon8ea4ed0d1b11::ReduceScatter1798 static ReduceScatter Handler() { return ReduceScatter(); }
1799 };
1800 } // namespace
1801
operator ()(const ServiceExecutableRunOptions * run_options,JitRtCollectiveSupport * collectives,CustomCall::RemainingArgs args,int32_t uid,int64_t group_mode,int64_t op_id,int64_t reduction_kind,ArrayRef<int64_t> replica_group_offsets,ArrayRef<int64_t> replica_group_values) const1802 Error ReduceScatter::operator()(const ServiceExecutableRunOptions* run_options,
1803 JitRtCollectiveSupport* collectives,
1804 CustomCall::RemainingArgs args, int32_t uid,
1805 int64_t group_mode, int64_t op_id,
1806 int64_t reduction_kind,
1807 ArrayRef<int64_t> replica_group_offsets,
1808 ArrayRef<int64_t> replica_group_values) const {
1809 #if XLA_ENABLE_XCCL
1810 VLOG(3) << "Running ReduceScatter";
1811 se::Stream* stream = run_options->stream();
1812 NcclExecuteParams params(*run_options, stream);
1813
1814 auto comm = GetNcclComm(params, group_mode, op_id, replica_group_offsets,
1815 replica_group_values);
1816 if (failed(comm)) return MakeStringError("Failed to get NcclComm");
1817
1818 auto device_buffers = GetDeviceBufferPairs(args);
1819 if (failed(device_buffers))
1820 return MakeStringError("Failed to get device buffers");
1821
1822 auto executed = RunReduceScatter(static_cast<ReductionKind>(reduction_kind),
1823 *device_buffers, *stream, **comm);
1824 if (!executed.ok()) return AsError(executed);
1825
1826 int32_t device_ordinal = stream->parent()->device_ordinal();
1827 if (!collectives->MaybeBlockAfterFirstRun(uid, device_ordinal, stream).ok())
1828 return MakeStringError("Failed to block host");
1829
1830 return Error::success();
1831 #else // XLA_ENABLE_XCCL
1832 return MakeStringError("NCCL disabled");
1833 #endif // XLA_ENABLE_XCCL
1834 }
1835
ReduceScatter(runtime::KernelContext * ctx,void ** args,void ** attrs)1836 static bool ReduceScatter(runtime::KernelContext* ctx, void** args,
1837 void** attrs) {
1838 static auto* handler =
1839 CustomCall::Bind("xla.gpu.reduce_scatter")
1840 .UserData<const ServiceExecutableRunOptions*>()
1841 .UserData<JitRtCollectiveSupport*>()
1842 .RemainingArgs() // args
1843 .Attr<int32_t>("uid")
1844 .Attr<int64_t>("group_mode") // CollectiveOpGroupMode
1845 .Attr<int64_t>("op_id")
1846 .Attr<int64_t>("reduction_kind") // ReductionKind
1847 .Attr<ArrayRef<int64_t>>("replica_group_offsets")
1848 .Attr<ArrayRef<int64_t>>("replica_group_values")
1849 .To<RuntimeChecks()>(ReduceScatter::Handler())
1850 .release();
1851
1852 return succeeded(Executable::Call(ctx, *handler, args, attrs));
1853 }
1854
1855 // -------------------------------------------------------------------------- //
1856
1857 namespace {
1858 struct AllGather {
1859 LLVM_ATTRIBUTE_ALWAYS_INLINE
1860 Error operator()(const ServiceExecutableRunOptions* run_options,
1861 JitRtCollectiveSupport* collectives,
1862 CustomCall::RemainingArgs args, int32_t uid,
1863 int64_t group_mode, int64_t op_id,
1864 ArrayRef<int64_t> replica_group_offsets,
1865 ArrayRef<int64_t> replica_group_values) const;
Handlerxla::gpu::__anon8ea4ed0d1c11::AllGather1866 static AllGather Handler() { return AllGather(); }
1867 };
1868 } // namespace
1869
operator ()(const ServiceExecutableRunOptions * run_options,JitRtCollectiveSupport * collectives,CustomCall::RemainingArgs args,int32_t uid,int64_t group_mode,int64_t op_id,ArrayRef<int64_t> replica_group_offsets,ArrayRef<int64_t> replica_group_values) const1870 Error AllGather::operator()(const ServiceExecutableRunOptions* run_options,
1871 JitRtCollectiveSupport* collectives,
1872 CustomCall::RemainingArgs args, int32_t uid,
1873 int64_t group_mode, int64_t op_id,
1874 ArrayRef<int64_t> replica_group_offsets,
1875 ArrayRef<int64_t> replica_group_values) const {
1876 #if XLA_ENABLE_XCCL
1877 VLOG(3) << "Running AllGather";
1878 se::Stream* stream = run_options->stream();
1879 NcclExecuteParams params(*run_options, stream);
1880
1881 auto comm = GetNcclComm(params, group_mode, op_id, replica_group_offsets,
1882 replica_group_values);
1883 if (failed(comm)) return MakeStringError("Failed to get NCCL comm");
1884
1885 auto device_buffers = GetDeviceBufferPairs(args);
1886 if (failed(device_buffers))
1887 return MakeStringError("Failed to get device buffers");
1888
1889 auto st = RunAllGather(*device_buffers, *stream, **comm);
1890 if (!st.ok()) return AsError(st);
1891
1892 int32_t device_ordinal = stream->parent()->device_ordinal();
1893 st = collectives->MaybeBlockAfterFirstRun(uid, device_ordinal, stream);
1894 if (!st.ok()) return AsError(st);
1895
1896 return Error::success();
1897 #else // XLA_ENABLE_XCCL
1898 return MakeStringError("NCCL diasbled");
1899 #endif // XLA_ENABLE_XCCL
1900 }
1901
AllGather(runtime::KernelContext * ctx,void ** args,void ** attrs)1902 static bool AllGather(runtime::KernelContext* ctx, void** args, void** attrs) {
1903 static auto* handler =
1904 CustomCall::Bind("xla.gpu.all_gather")
1905 .UserData<const ServiceExecutableRunOptions*>()
1906 .UserData<JitRtCollectiveSupport*>()
1907 .RemainingArgs() // args
1908 .Attr<int32_t>("uid")
1909 .Attr<int64_t>("group_mode") // CollectiveOpGroupMode
1910 .Attr<int64_t>("op_id")
1911 .Attr<ArrayRef<int64_t>>("replica_group_offsets")
1912 .Attr<ArrayRef<int64_t>>("replica_group_values")
1913 .To<RuntimeChecks()>(AllGather::Handler())
1914 .release();
1915
1916 return succeeded(Executable::Call(ctx, *handler, args, attrs));
1917 }
1918
1919 // -------------------------------------------------------------------------- //
1920
1921 namespace {
1922 struct AllToAll {
1923 LLVM_ATTRIBUTE_ALWAYS_INLINE
1924 Error operator()(const ServiceExecutableRunOptions* run_options,
1925 JitRtCollectiveSupport* collectives,
1926 CustomCall::RemainingArgs args, int32_t uid,
1927 int64_t group_mode, bool has_split_dimension, int64_t op_id,
1928 ArrayRef<int64_t> replica_group_offsets,
1929 ArrayRef<int64_t> replica_group_values) const;
Handlerxla::gpu::__anon8ea4ed0d1d11::AllToAll1930 static AllToAll Handler() { return AllToAll(); }
1931 };
1932 } // namespace
1933
operator ()(const ServiceExecutableRunOptions * run_options,JitRtCollectiveSupport * collectives,CustomCall::RemainingArgs args,int32_t uid,int64_t group_mode,bool has_split_dimension,int64_t op_id,ArrayRef<int64_t> replica_group_offsets,ArrayRef<int64_t> replica_group_values) const1934 Error AllToAll::operator()(const ServiceExecutableRunOptions* run_options,
1935 JitRtCollectiveSupport* collectives,
1936 CustomCall::RemainingArgs args, int32_t uid,
1937 int64_t group_mode, bool has_split_dimension,
1938 int64_t op_id,
1939 ArrayRef<int64_t> replica_group_offsets,
1940 ArrayRef<int64_t> replica_group_values) const {
1941 #if XLA_ENABLE_XCCL
1942 VLOG(3) << "Running AllToAll";
1943 se::Stream* stream = run_options->stream();
1944 NcclExecuteParams params(*run_options, stream);
1945
1946 auto comm = GetNcclComm(params, group_mode, op_id, replica_group_offsets,
1947 replica_group_values);
1948 if (failed(comm)) return MakeStringError("Failed to get NCCL comm");
1949
1950 auto device_buffers = GetDeviceBufferPairs(args);
1951 if (failed(device_buffers))
1952 return MakeStringError("Failed to get device buffers");
1953
1954 auto st = RunAllToAll(has_split_dimension, *device_buffers, *stream, **comm);
1955 if (!st.ok()) return AsError(st);
1956
1957 int32_t device_ordinal = stream->parent()->device_ordinal();
1958 st = collectives->MaybeBlockAfterFirstRun(uid, device_ordinal, stream);
1959 if (!st.ok()) return AsError(st);
1960
1961 return Error::success();
1962 #else // XLA_ENABLE_XCCL
1963 return MakeStringError("NCCL disabled");
1964 #endif // XLA_ENABLE_XCCL
1965 }
1966
AllToAll(runtime::KernelContext * ctx,void ** args,void ** attrs)1967 static bool AllToAll(runtime::KernelContext* ctx, void** args, void** attrs) {
1968 static auto* handler =
1969 CustomCall::Bind("xla.gpu.all_to_all")
1970 .UserData<const ServiceExecutableRunOptions*>()
1971 .UserData<JitRtCollectiveSupport*>()
1972 .RemainingArgs() // args
1973 .Attr<int32_t>("uid")
1974 .Attr<int64_t>("group_mode") // CollectiveOpGroupMode
1975 .Attr<bool>("has_split_dimension")
1976 .Attr<int64_t>("op_id")
1977 .Attr<ArrayRef<int64_t>>("replica_group_offsets")
1978 .Attr<ArrayRef<int64_t>>("replica_group_values")
1979 .To<RuntimeChecks()>(AllToAll::Handler())
1980 .release();
1981
1982 return succeeded(Executable::Call(ctx, *handler, args, attrs));
1983 }
1984
1985 // -------------------------------------------------------------------------- //
1986
1987 namespace {
1988 struct CollectivePermute {
1989 LLVM_ATTRIBUTE_ALWAYS_INLINE
1990 Error operator()(const ServiceExecutableRunOptions* run_options,
1991 JitRtCollectiveSupport* collectives,
1992 CustomCall::RemainingArgs args, int32_t uid,
1993 int64_t group_mode, int64_t op_id,
1994 ArrayRef<int64_t> replica_group_offsets,
1995 ArrayRef<int64_t> replica_group_values,
1996 ArrayRef<int64_t> source_peers,
1997 ArrayRef<int64_t> target_peers) const;
Handlerxla::gpu::__anon8ea4ed0d1e11::CollectivePermute1998 static CollectivePermute Handler() { return CollectivePermute(); }
1999 };
2000 } // namespace
2001
operator ()(const ServiceExecutableRunOptions * run_options,JitRtCollectiveSupport * collectives,CustomCall::RemainingArgs args,int32_t uid,int64_t group_mode,int64_t op_id,ArrayRef<int64_t> replica_group_offsets,ArrayRef<int64_t> replica_group_values,ArrayRef<int64_t> source_peers,ArrayRef<int64_t> target_peers) const2002 Error CollectivePermute::operator()(
2003 const ServiceExecutableRunOptions* run_options,
2004 JitRtCollectiveSupport* collectives, CustomCall::RemainingArgs args,
2005 int32_t uid, int64_t group_mode, int64_t op_id,
2006 ArrayRef<int64_t> replica_group_offsets,
2007 ArrayRef<int64_t> replica_group_values, ArrayRef<int64_t> source_peers,
2008 ArrayRef<int64_t> target_peers) const {
2009 #if XLA_ENABLE_XCCL
2010 VLOG(3) << "Running CollectivePermute";
2011 se::Stream* stream = run_options->stream();
2012 NcclExecuteParams params(*run_options, stream);
2013
2014 auto comm = GetNcclComm(params, group_mode, op_id, replica_group_offsets,
2015 replica_group_values);
2016 if (failed(comm)) return MakeStringError("Failed to get NcclComm");
2017
2018 auto device_buffers = GetDeviceBufferPairs(args);
2019 if (failed(device_buffers))
2020 return MakeStringError("Failed to get device buffers");
2021 if (device_buffers->size() != 1) {
2022 return MakeStringError("Expected device buffer size: 1, got ",
2023 device_buffers->size());
2024 }
2025
2026 StatusOr<GlobalDeviceId> global_device_id = params.GetGlobalDeviceId();
2027 if (!global_device_id.ok()) return AsError(global_device_id);
2028
2029 StatusOr<DeviceAssignment::LogicalID> current_logical_id =
2030 params.device_assn->LogicalIdForDevice(global_device_id.value());
2031 if (!current_logical_id.ok()) return AsError(current_logical_id);
2032
2033 const int64_t current_id = static_cast<CollectiveOpGroupMode>(group_mode) ==
2034 CollectiveOpGroupMode::kCrossReplica
2035 ? current_logical_id.value().replica_id
2036 : current_logical_id.value().computation_id;
2037 std::string device_string = NcclCollectiveThunk::GetDeviceString(params);
2038
2039 NcclCollectivePermuteConfig::IdToSourceTargetMap id_to_source_target;
2040 for (int i = 0; i < source_peers.size(); ++i) {
2041 id_to_source_target.insert({target_peers[i], {}}).first->second.source =
2042 source_peers[i];
2043 id_to_source_target.insert({source_peers[i], {}}).first->second.target =
2044 target_peers[i];
2045 }
2046 const NcclCollectivePermuteConfig::SourceTargetMapEntry source_target =
2047 NcclCollectivePermuteConfig::GetSourceTarget(id_to_source_target,
2048 current_id);
2049
2050 auto executed =
2051 RunCollectivePermute(source_target, (*device_buffers)[0], *stream, **comm,
2052 device_string, current_id);
2053 if (!executed.ok()) return AsError(executed);
2054
2055 int32_t device_ordinal = stream->parent()->device_ordinal();
2056 auto st = collectives->MaybeBlockAfterFirstRun(uid, device_ordinal, stream);
2057 if (!st.ok()) return AsError(st);
2058
2059 return Error::success();
2060 #else // XLA_ENABLE_XCCL
2061 return MakeStringError("NCCL disabled");
2062 #endif // XLA_ENABLE_XCCL
2063 }
2064
CollectivePermute(runtime::KernelContext * ctx,void ** args,void ** attrs)2065 static bool CollectivePermute(runtime::KernelContext* ctx, void** args,
2066 void** attrs) {
2067 static auto* handler =
2068 CustomCall::Bind("xla.gpu.collective_permute")
2069 .UserData<const ServiceExecutableRunOptions*>()
2070 .UserData<JitRtCollectiveSupport*>()
2071 .RemainingArgs() // args
2072 .Attr<int32_t>("uid")
2073 .Attr<int64_t>("group_mode") // CollectiveOpGroupMode
2074 .Attr<int64_t>("op_id")
2075 .Attr<ArrayRef<int64_t>>("replica_group_offsets")
2076 .Attr<ArrayRef<int64_t>>("replica_group_values")
2077 .Attr<ArrayRef<int64_t>>("source_peers")
2078 .Attr<ArrayRef<int64_t>>("target_peers")
2079 .To<RuntimeChecks()>(CollectivePermute::Handler())
2080 .release();
2081
2082 return succeeded(Executable::Call(ctx, *handler, args, attrs));
2083 }
2084
2085 // -------------------------------------------------------------------------- //
2086
2087 namespace {
2088 struct ReplicaId {
2089 LLVM_ATTRIBUTE_ALWAYS_INLINE
2090 Error operator()(const ServiceExecutableRunOptions* run_options,
2091 runtime::FlatMemrefView result) const;
Handlerxla::gpu::__anon8ea4ed0d1f11::ReplicaId2092 static ReplicaId Handler() { return ReplicaId(); }
2093 };
2094 } // namespace
2095
operator ()(const ServiceExecutableRunOptions * run_options,runtime::FlatMemrefView result) const2096 Error ReplicaId::operator()(const ServiceExecutableRunOptions* run_options,
2097 runtime::FlatMemrefView result) const {
2098 VLOG(3) << "Running ReplicaId";
2099 se::Stream* stream = run_options->stream();
2100 NcclExecuteParams params(*run_options, stream);
2101
2102 StatusOr<GlobalDeviceId> global_device_id = params.GetGlobalDeviceId();
2103 if (!global_device_id.ok()) return AsError(global_device_id);
2104
2105 StatusOr<DeviceAssignment::LogicalID> logical_id =
2106 params.device_assn->LogicalIdForDevice(global_device_id.value());
2107 if (!logical_id.ok()) return AsError(logical_id);
2108
2109 se::DeviceMemoryBase result_data = GetDeviceAddress(result);
2110 params.stream->ThenMemset32(&result_data, logical_id.value().replica_id,
2111 /*size=*/4);
2112
2113 return Error::success();
2114 }
2115
ReplicaId(runtime::KernelContext * ctx,void ** args,void ** attrs)2116 static bool ReplicaId(runtime::KernelContext* ctx, void** args, void** attrs) {
2117 static auto* handler = CustomCall::Bind("xla.gpu.replica_id")
2118 .UserData<const ServiceExecutableRunOptions*>()
2119 .Arg<runtime::FlatMemrefView>() // result
2120 .To<RuntimeChecks()>(ReplicaId::Handler())
2121 .release();
2122
2123 return succeeded(Executable::Call(ctx, *handler, args, attrs));
2124 }
2125
2126 // -------------------------------------------------------------------------- //
2127
2128 namespace {
2129 struct PartitionId {
2130 LLVM_ATTRIBUTE_ALWAYS_INLINE
2131 Error operator()(const ServiceExecutableRunOptions* run_options,
2132 runtime::FlatMemrefView result) const;
Handlerxla::gpu::__anon8ea4ed0d2011::PartitionId2133 static PartitionId Handler() { return PartitionId(); }
2134 };
2135 } // namespace
2136
operator ()(const ServiceExecutableRunOptions * run_options,runtime::FlatMemrefView result) const2137 Error PartitionId::operator()(const ServiceExecutableRunOptions* run_options,
2138 runtime::FlatMemrefView result) const {
2139 VLOG(3) << "Running PartitionId";
2140 se::Stream* stream = run_options->stream();
2141 NcclExecuteParams params(*run_options, stream);
2142
2143 StatusOr<GlobalDeviceId> global_device_id = params.GetGlobalDeviceId();
2144 if (!global_device_id.ok()) return AsError(global_device_id);
2145
2146 StatusOr<DeviceAssignment::LogicalID> logical_id =
2147 params.device_assn->LogicalIdForDevice(global_device_id.value());
2148 if (!logical_id.ok()) return AsError(logical_id);
2149
2150 se::DeviceMemoryBase result_data = GetDeviceAddress(result);
2151 params.stream->ThenMemset32(&result_data, logical_id.value().computation_id,
2152 /*size=*/4);
2153
2154 return Error::success();
2155 }
2156
PartitionId(runtime::KernelContext * ctx,void ** args,void ** attrs)2157 static bool PartitionId(runtime::KernelContext* ctx, void** args,
2158 void** attrs) {
2159 static auto* handler = CustomCall::Bind("xla.gpu.partition_id")
2160 .UserData<const ServiceExecutableRunOptions*>()
2161 .Arg<runtime::FlatMemrefView>() // result
2162 .To<RuntimeChecks()>(PartitionId::Handler())
2163 .release();
2164
2165 return succeeded(Executable::Call(ctx, *handler, args, attrs));
2166 }
2167
2168 // -------------------------------------------------------------------------- //
2169
JitRtGpuCustomCalls()2170 DirectCustomCallLibrary JitRtGpuCustomCalls() {
2171 DirectCustomCallLibrary lib;
2172
2173 lib.Insert("xla.gpu.fft", &xla::gpu::Fft);
2174 lib.Insert("xla.gpu.cholesky", &xla::gpu::Cholesky);
2175 lib.Insert("xla.gpu.collective_permute", &xla::gpu::CollectivePermute);
2176 lib.Insert("xla.gpu.func.launch", &xla::gpu::LaunchFunc);
2177 lib.Insert("xla.gpu.gemm", &xla::gpu::Gemm);
2178 lib.Insert("xla.gpu.cublas.lt.matmul", &xla::gpu::CublasLtMatmul);
2179 lib.Insert("xla.gpu.cublas.lt.matmul.bias", &xla::gpu::CublasLtMatmulBias);
2180
2181 auto conv = [](StringRef name) { return ("xla.gpu.conv." + name).str(); };
2182 lib.Insert(conv("forward"), &ConvFn<CudnnConvKind::kForward>);
2183 lib.Insert(conv("backward.input"), &ConvFn<CudnnConvKind::kBackwardInput>);
2184 lib.Insert(conv("backward.filter"), &ConvFn<CudnnConvKind::kBackwardFilter>);
2185 lib.Insert(conv("forward.fused"),
2186 &ConvFusedFn<CudnnConvKind::kForwardActivation>);
2187 lib.Insert(conv("forward.fused.side_input"),
2188 &ConvFuseSideInputdFn<CudnnConvKind::kForwardActivation>);
2189
2190 lib.Insert("xla.gpu.memcpy.d2d", &MemcpyFn<MemcpyDirection::kDeviceToDevice>);
2191 lib.Insert("xla.gpu.memcpy.h2d", &MemcpyFn<MemcpyDirection::kHostToDevice>);
2192 lib.Insert("xla.gpu.memcpy.d2h", &MemcpyFn<MemcpyDirection::kDeviceToHost>);
2193 lib.Insert("xla.gpu.memset", &MemsetFn);
2194 lib.Insert("xla.gpu.infeed", &xla::gpu::Infeed);
2195 lib.Insert("xla.gpu.outfeed", &xla::gpu::Outfeed);
2196 lib.Insert("xla.gpu.custom_call", &xla::gpu::CustomCall);
2197
2198 // Collective operations.
2199 lib.Insert("xla.gpu.all_gather", &xla::gpu::AllGather);
2200 lib.Insert("xla.gpu.all_reduce", &xla::gpu::AllReduce);
2201 lib.Insert("xla.gpu.all_reduce_done", &xla::gpu::AllReduceDone);
2202 lib.Insert("xla.gpu.all_reduce_start", &xla::gpu::AllReduceStart);
2203 lib.Insert("xla.gpu.all_to_all", &xla::gpu::AllToAll);
2204 lib.Insert("xla.gpu.reduce_scatter", &xla::gpu::ReduceScatter);
2205 lib.Insert("xla.gpu.partition_id", &xla::gpu::PartitionId);
2206 lib.Insert("xla.gpu.replica_id", &xla::gpu::ReplicaId);
2207
2208 return lib;
2209 }
2210
2211 } // namespace gpu
2212 } // namespace xla
2213