xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/gpu/jitrt_custom_calls.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 // Copyright 2022 The TensorFlow Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //      http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "tensorflow/compiler/xla/service/gpu/jitrt_custom_calls.h"
16 
17 #include <cstdint>
18 #include <functional>
19 #include <iterator>
20 #include <memory>
21 #include <numeric>
22 #include <utility>
23 
24 #include "llvm/ExecutionEngine/Orc/Mangling.h"
25 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
26 #include "tensorflow/compiler/xla/runtime/arguments.h"
27 #include "tensorflow/compiler/xla/runtime/custom_call.h"
28 #include "tensorflow/compiler/xla/runtime/executable.h"
29 #include "tensorflow/compiler/xla/runtime/jit_executable.h"
30 #include "tensorflow/compiler/xla/runtime/type_id.h"
31 #include "tensorflow/compiler/xla/runtime/types.h"
32 #include "tensorflow/compiler/xla/service/custom_call_status_internal.h"
33 #include "tensorflow/compiler/xla/service/custom_call_target_registry.h"
34 #include "tensorflow/compiler/xla/service/gpu/fft_thunk.h"
35 #include "tensorflow/compiler/xla/service/gpu/gpu_asm_opts_util.h"
36 #include "tensorflow/compiler/xla/service/gpu/gpu_conv_runner.h"
37 #include "tensorflow/compiler/xla/service/gpu/infeed_manager.h"
38 #include "tensorflow/compiler/xla/service/gpu/matmul_utils.h"
39 #include "tensorflow/compiler/xla/service/gpu/nccl_all_gather_thunk.h"
40 #include "tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.h"
41 #include "tensorflow/compiler/xla/service/gpu/nccl_all_to_all_thunk.h"
42 #include "tensorflow/compiler/xla/service/gpu/nccl_collective_permute_thunk.h"
43 #include "tensorflow/compiler/xla/service/gpu/nccl_collective_thunk.h"
44 #include "tensorflow/compiler/xla/service/gpu/outfeed_manager.h"
45 #include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h"
46 #include "tensorflow/compiler/xla/service/service_executable_run_options.h"
47 #include "tensorflow/compiler/xla/shape_util.h"
48 #include "tensorflow/compiler/xla/tfrt_utils.h"
49 #include "tensorflow/core/platform/human_readable_json.h"
50 #include "tensorflow/stream_executor/gpu/gpu_stream.h"
51 #include "tensorflow/stream_executor/gpu/gpu_types.h"
52 #include "tfrt/dtype/dtype.h"  // from @tf_runtime
53 
54 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
55 #include "tensorflow/compiler/xla/service/gpu/cholesky_thunk.h"
56 #include "tensorflow/compiler/xla/service/gpu/triangular_solve_thunk.h"
57 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
58 
59 TFRT_DEFINE_EXPLICIT_DENSE_TYPE_ID(xla::runtime::CustomCall,
60                                    xla::gpu::JitRtKernelsCache);
61 TFRT_DEFINE_EXPLICIT_DENSE_TYPE_ID(xla::runtime::CustomCall,
62                                    xla::gpu::JitRtGemmConfigCache);
63 TFRT_DEFINE_EXPLICIT_DENSE_TYPE_ID(xla::runtime::CustomCall,
64                                    xla::gpu::JitRtCollectiveSupport);
65 TFRT_DEFINE_EXPLICIT_DENSE_TYPE_ID(xla::runtime::CustomCall,
66                                    xla::gpu::JitRtAsyncCollectiveSupport);
67 TFRT_DEFINE_EXPLICIT_DENSE_TYPE_ID(xla::runtime::CustomCall,
68                                    const xla::ServiceExecutableRunOptions);
69 TFRT_DEFINE_EXPLICIT_DENSE_TYPE_ID(xla::runtime::CustomCall,
70                                    const xla::DebugOptions);
71 
72 namespace xla {
73 namespace gpu {
74 
75 using Eigen::half;
76 
77 using llvm::ArrayRef;
78 using llvm::Error;
79 using llvm::Optional;
80 
81 using mlir::failure;
82 using mlir::FailureOr;
83 using mlir::LogicalResult;
84 using mlir::StringRef;
85 using mlir::succeeded;
86 using mlir::success;
87 
88 using tfrt::MakeStringError;
89 
90 using ::xla::runtime::AggregateAttrDef;
91 using ::xla::runtime::AggregateAttrEncoding;
92 using ::xla::runtime::CustomCall;
93 using ::xla::runtime::CustomCallAttrEncodingSet;
94 using ::xla::runtime::DirectCustomCallLibrary;
95 using ::xla::runtime::EnumAttrEncoding;
96 using ::xla::runtime::Executable;
97 using ::xla::runtime::Tagged;
98 using ::xla::runtime::TypeIDNameRegistry;
99 
100 namespace se = ::stream_executor;
101 namespace lmhlo_gpu = ::mlir::lmhlo_gpu;
102 namespace mhlo = ::mlir::mhlo;
103 
104 // Disable all CustomCall checks in optimized build.
RuntimeChecks()105 static constexpr CustomCall::RuntimeChecks RuntimeChecks() {
106 #if defined(NDEBUG)
107   return CustomCall::RuntimeChecks::kNone;
108 #else
109   return CustomCall::RuntimeChecks::kDefault;
110 #endif
111 }
112 
113 // -------------------------------------------------------------------------- //
114 
115 // Populate mapping from XLA (SE) enums/structs type id to symbol names.
PopulateXlaTypeIdNames(TypeIDNameRegistry & registry)116 void PopulateXlaTypeIdNames(TypeIDNameRegistry& registry) {
117   registry.Register<Tagged<se::dnn::ActivationMode>>(
118       "__type_id_se_dnn_activation");
119   registry.Register<Tagged<se::cuda::BlasLt::Epilogue>>(
120       "__type_id_se_cublas_lt_epilogue");
121   registry.Register<Tagged<se::fft::Type>>("__type_id_se_fft_type");
122 
123   registry.Register<Tagged<DotDimensionNumbers>>(
124       "__type_id_dot_dimension_numbers");
125   registry.Register<Tagged<ConvDimensionNumbers>>(
126       "__type_id_conv_dimension_numbers");
127   registry.Register<Tagged<ConvBackendConfig>>("__type_id_conv_backend_config");
128 }
129 
130 // Add custom call arguments and attributes encoding for custom HLO enums and
131 // structs, so that we can pass them to custom calls.
PopulateLmhloToXlaAttrEncoding(CustomCallAttrEncodingSet & encoding)132 void PopulateLmhloToXlaAttrEncoding(CustomCallAttrEncodingSet& encoding) {
133   encoding
134       .Add<EnumAttrEncoding<lmhlo_gpu::ActivationAttr, lmhlo_gpu::Activation,
135                             se::dnn::ActivationMode>>(
136           [](lmhlo_gpu::Activation value) -> se::dnn::ActivationMode {
137             return ConvertConvActivationMode(value).value();
138           });
139 
140   encoding.Add<EnumAttrEncoding<lmhlo_gpu::CublasLtMatmulEpilogueAttr,
141                                 lmhlo_gpu::CublasLtMatmulEpilogue,
142                                 se::cuda::BlasLt::Epilogue>>(
143       [](lmhlo_gpu::CublasLtMatmulEpilogue value)
144           -> se::cuda::BlasLt::Epilogue {
145         return cublas_lt::AsBlasLtEpilogue(value).value();
146       });
147 
148   encoding
149       .Add<EnumAttrEncoding<mhlo::FftTypeAttr, mhlo::FftType, se::fft::Type>>(
150           [](mhlo::FftType value) -> se::fft::Type {
151             switch (value) {
152               case mhlo::FftType::FFT:
153                 return se::fft::Type::kC2CForward;
154               case mhlo::FftType::IFFT:
155                 return se::fft::Type::kC2CInverse;
156               case mhlo::FftType::RFFT:
157                 return se::fft::Type::kR2C;
158               case mhlo::FftType::IRFFT:
159                 return se::fft::Type::kC2R;
160               default:
161                 return se::fft::Type::kInvalid;
162             }
163           });
164 
165   using DotDimsAttr = mhlo::DotDimensionNumbersAttr;
166   encoding.Add<
167       xla::runtime::AggregateAttrEncoding<DotDimsAttr, DotDimensionNumbers>>(
168       encoding,
169       xla::runtime::AggregateAttrDef<DotDimsAttr>()
170           .Add("lhs_batch", &DotDimsAttr::getLhsBatchingDimensions)
171           .Add("lhs_contract", &DotDimsAttr::getLhsContractingDimensions)
172           .Add("rhs_batch", &DotDimsAttr::getRhsBatchingDimensions)
173           .Add("rhs_contract", &DotDimsAttr::getRhsContractingDimensions));
174 
175   using ConvDimsAttr = mhlo::ConvDimensionNumbersAttr;
176   encoding.Add<
177       xla::runtime::AggregateAttrEncoding<ConvDimsAttr, ConvDimensionNumbers>>(
178       encoding,
179       xla::runtime::AggregateAttrDef<ConvDimsAttr>()
180           .Add("input_batch_dim", &ConvDimsAttr::getInputBatchDimension)
181           .Add("input_feature_dim", &ConvDimsAttr::getInputFeatureDimension)
182           .Add("input_spatial_dims", &ConvDimsAttr::getInputSpatialDimensions)
183           .Add("kernel_in_feature_dim",
184                &ConvDimsAttr::getKernelInputFeatureDimension)
185           .Add("kernel_out_feature_dim",
186                &ConvDimsAttr::getKernelOutputFeatureDimension)
187           .Add("kernel_spatial_dims", &ConvDimsAttr::getKernelSpatialDimensions)
188           .Add("output_batch_dim", &ConvDimsAttr::getOutputBatchDimension)
189           .Add("output_feature_dim", &ConvDimsAttr::getOutputFeatureDimension)
190           .Add("output_spatial_dims",
191                &ConvDimsAttr::getOutputSpatialDimensions));
192 
193   using ConvConfigAttr = lmhlo_gpu::ConvolutionBackendConfigAttr;
194   encoding.Add<
195       xla::runtime::AggregateAttrEncoding<ConvConfigAttr, ConvBackendConfig>>(
196       encoding,
197       xla::runtime::AggregateAttrDef<ConvConfigAttr>()
198           .Add("algorithm", &ConvConfigAttr::getAlgorithm)
199           .Add("tensor_ops_enabled", &ConvConfigAttr::getTensorOpsEnabled)
200           .Add("is_cudnn_frontend", &ConvConfigAttr::getIsCudnnFrontend)
201           .Add("knob_ids", &ConvConfigAttr::getKnobIds)
202           .Add("knob_values", &ConvConfigAttr::getKnobValues)
203           .Add("operand_0_layout", &ConvConfigAttr::getOperand_0Layout)
204           .Add("operand_1_layout", &ConvConfigAttr::getOperand_1Layout)
205           .Add("result_layout", &ConvConfigAttr::getResultLayout)
206           .Add("workspace_size", &ConvConfigAttr::getWorkspaceSize));
207 }
208 
209 // -------------------------------------------------------------------------- //
210 
Get(se::StreamExecutor * executor,const char * data,StringRef name)211 se::KernelBase* JitRtKernelsCache::Get(se::StreamExecutor* executor,
212                                        const char* data, StringRef name) {
213   Key key(executor, data, name);
214 
215   absl::MutexLock lock(&mutex_);
216   auto it = kernels_cache_.find(key);
217   if (it != kernels_cache_.end()) return it->second.get();
218 
219   return nullptr;
220 }
221 
Set(se::StreamExecutor * executor,const char * data,StringRef name,std::unique_ptr<se::KernelBase> kernel)222 se::KernelBase* JitRtKernelsCache::Set(se::StreamExecutor* executor,
223                                        const char* data, StringRef name,
224                                        std::unique_ptr<se::KernelBase> kernel) {
225   Key key(executor, data, name);
226 
227   absl::MutexLock lock(&mutex_);
228   auto it = kernels_cache_.find(key);
229   if (it != kernels_cache_.end()) return it->second.get();
230 
231   auto emplaced = kernels_cache_.try_emplace(key, std::move(kernel));
232   return emplaced.first->second.get();
233 }
234 
235 template <typename MemrefArg>
GetDeviceAddress(MemrefArg & memref)236 static se::DeviceMemoryBase GetDeviceAddress(MemrefArg& memref) {
237   uint64_t size = tfrt::GetHostSize(memref.dtype);
238   for (auto dim : memref.sizes) size *= dim;
239   return se::DeviceMemoryBase(memref.data, size);
240 }
241 
GetDeviceAddress(runtime::FlatMemrefView & memref)242 static se::DeviceMemoryBase GetDeviceAddress(runtime::FlatMemrefView& memref) {
243   return se::DeviceMemoryBase(memref.data, memref.size_in_bytes);
244 }
245 
246 // -------------------------------------------------------------------------- //
247 
Get(int64_t uid)248 const GemmConfig* JitRtGemmConfigCache::Get(int64_t uid) {
249   absl::MutexLock lock(&mutex_);
250   auto it = configs_.find(uid);
251   if (it != configs_.end()) return &it->second;
252   return nullptr;
253 }
254 
Set(int64_t uid,GemmConfig config)255 const GemmConfig* JitRtGemmConfigCache::Set(int64_t uid, GemmConfig config) {
256   absl::MutexLock lock(&mutex_);
257   auto it = configs_.find(uid);
258   if (it != configs_.end()) return &it->second;
259 
260   auto emplaced = configs_.try_emplace(uid, std::move(config));
261   return &emplaced.first->second;
262 }
263 
264 // -------------------------------------------------------------------------- //
265 
JitRtAsyncCollectiveSupport(se::Stream * async_comm_stream)266 JitRtAsyncCollectiveSupport::JitRtAsyncCollectiveSupport(
267     se::Stream* async_comm_stream)
268     : async_comm_stream_(async_comm_stream) {}
269 
MaybeBlockAfterFirstRun(int32_t uid,int32_t device_ordinal,se::Stream * stream)270 Status JitRtCollectiveSupport::MaybeBlockAfterFirstRun(int32_t uid,
271                                                        int32_t device_ordinal,
272                                                        se::Stream* stream) {
273   bool block = [&] {
274     absl::MutexLock lock(&mutex_);
275     return executed_.try_emplace(Key(uid, device_ordinal), true).second;
276   }();
277   return block ? stream->BlockHostUntilDone() : Status::OK();
278 }
279 
PopEvent(int32_t uid,int32_t device_ordinal)280 FailureOr<se::Event> JitRtAsyncCollectiveSupport::PopEvent(
281     int32_t uid, int32_t device_ordinal) {
282   const int64_t key = EventKey(uid, device_ordinal);
283 
284   absl::MutexLock lock(&mutex_);
285   auto it = done_events_.find(key);
286   if (it == done_events_.end()) return failure();
287 
288   se::Event done_event = std::move(it->second);
289   done_events_.erase(it);
290   return done_event;
291 }
292 
PushEvent(int32_t uid,int32_t device_ordinal,se::Event done_event)293 LogicalResult JitRtAsyncCollectiveSupport::PushEvent(int32_t uid,
294                                                      int32_t device_ordinal,
295                                                      se::Event done_event) {
296   const int64_t key = EventKey(uid, device_ordinal);
297 
298   absl::MutexLock lock(&mutex_);
299   auto result = done_events_.try_emplace(key, std::move(done_event));
300   if (!result.second) return failure();  // done event has not been consumed
301 
302   return success();
303 }
304 
305 // -------------------------------------------------------------------------- //
306 
ToShape(const runtime::StridedMemrefView & memref)307 static Shape ToShape(const runtime::StridedMemrefView& memref) {
308   PrimitiveType type = TfrtToPrimitiveType(memref.dtype);
309 
310   // Recover `minor_to_major` dimensions permutation from strides.
311   auto indexed_strides_range =
312       llvm::map_range(llvm::enumerate(memref.strides), [](auto pair) {
313         return std::pair<int64_t, size_t>{pair.value(), pair.index()};
314       });
315 
316   auto indexed_strides = llvm::to_vector(indexed_strides_range);
317   llvm::stable_sort(indexed_strides);
318 
319   llvm::SmallVector<int64_t> minor_to_major;
320   minor_to_major.reserve(indexed_strides.size());
321   for (auto& pair : indexed_strides) minor_to_major.push_back(pair.second);
322 
323   return ShapeUtil::MakeShapeWithLayout(type, memref.sizes, minor_to_major);
324 }
325 
GetGemmConfig(const runtime::StridedMemrefView & lhs,const runtime::StridedMemrefView & rhs,const runtime::StridedMemrefView & out,int64_t algorithm,double alpha_real,double alpha_imag,double beta,ArrayRef<int64_t> lhs_batch,ArrayRef<int64_t> lhs_contract,ArrayRef<int64_t> rhs_batch,ArrayRef<int64_t> rhs_contract)326 static StatusOr<GemmConfig> GetGemmConfig(const runtime::StridedMemrefView& lhs,
327                                           const runtime::StridedMemrefView& rhs,
328                                           const runtime::StridedMemrefView& out,
329                                           int64_t algorithm, double alpha_real,
330                                           double alpha_imag, double beta,
331                                           ArrayRef<int64_t> lhs_batch,
332                                           ArrayRef<int64_t> lhs_contract,
333                                           ArrayRef<int64_t> rhs_batch,
334                                           ArrayRef<int64_t> rhs_contract) {
335   return GemmConfig::For(ToShape(lhs), lhs_batch, lhs_contract, ToShape(rhs),
336                          rhs_batch, rhs_contract, ToShape(out), alpha_real,
337                          alpha_imag, beta, algorithm,
338                          se::blas::kDefaultComputePrecision);
339 }
340 
341 // -------------------------------------------------------------------------- //
342 
343 #if XLA_ENABLE_XCCL
GetNcclComm(const NcclExecuteParams & params,int64_t group_mode,int64_t op_id,ArrayRef<int64_t> replica_group_offsets,ArrayRef<int64_t> replica_group_values)344 FailureOr<NcclComm::Lock> GetNcclComm(const NcclExecuteParams& params,
345                                       int64_t group_mode, int64_t op_id,
346                                       ArrayRef<int64_t> replica_group_offsets,
347                                       ArrayRef<int64_t> replica_group_values) {
348   // TODO(b/233930690): Pass the attribute below as a nested array.
349   // Pass an array of arrays using two vectors; one specifying all the values
350   // and another specifying the (ending) offsets of each array in the other
351   // vector. Example: [ [10, 20, 30, 40], [50, 60], [70, 80, 90] ] turns into
352   // offsets=[4, 6, 9] values=[10, 20, 30, 40, 50, 60, 70, 80, 90].
353   std::vector<ReplicaGroup> replica_groups;
354   int i = 0;
355   for (int64_t replica_group_end : replica_group_offsets) {
356     ReplicaGroup replica_group;
357     while (i < replica_group_end)
358       replica_group.add_replica_ids(replica_group_values[i++]);
359     replica_groups.push_back(replica_group);
360   }
361 
362   auto comm =
363       LockNcclComm(params, replica_groups,
364                    static_cast<CollectiveOpGroupMode>(group_mode), op_id);
365   if (comm.ok()) return std::move(comm.value());
366   return failure();
367 }
368 #endif  // XLA_ENABLE_XCCL
369 
GetDeviceBufferPairs(CustomCall::RemainingArgs & args)370 FailureOr<std::vector<DeviceBufferPair>> GetDeviceBufferPairs(
371     CustomCall::RemainingArgs& args) {
372   // Add MemRef arguments as buffer arguments.
373   const int buffer_pairs = args.size() / 2;
374   std::vector<DeviceBufferPair> device_buffers;
375   device_buffers.reserve(buffer_pairs);
376   for (int i = 0; i < buffer_pairs; ++i) {
377     auto source = args.get<runtime::StridedMemrefView>(i);
378     auto destination = args.get<runtime::StridedMemrefView>(i + buffer_pairs);
379     if (failed(source) || failed(destination)) {
380       // Unsupported argument type.
381       return failure();
382     }
383 
384     int element_count = 1;
385     for (int size : source->sizes) element_count *= size;
386     device_buffers.emplace_back(DeviceBufferPair{
387         TfrtToPrimitiveType(source->dtype), element_count,
388         GetDeviceAddress(*source), GetDeviceAddress(*destination)});
389   }
390   return device_buffers;
391 }
392 
393 // -------------------------------------------------------------------------- //
394 
AsError(Status s)395 Error AsError(Status s) { return MakeStringError(s.error_message()); }
396 
397 template <typename T>
AsError(StatusOr<T> & s)398 Error AsError(StatusOr<T>& s) {
399   assert(!s.ok());
400   return AsError(s.status());
401 }
402 
403 // -------------------------------------------------------------------------- //
404 
405 namespace {
406 struct LaunchFunc {
407   LLVM_ATTRIBUTE_ALWAYS_INLINE
408   Error operator()(const ServiceExecutableRunOptions* run_options,
409                    JitRtKernelsCache* kernels_cache, int32_t grid_size_x,
410                    int32_t grid_size_y, int32_t grid_size_z,
411                    int32_t block_size_x, int32_t block_size_y,
412                    int32_t block_size_z, CustomCall::RemainingArgs args,
413                    StringRef ptx, StringRef name) const;
414 
Handlerxla::gpu::__anon8ea4ed0d0611::LaunchFunc415   static LaunchFunc Handler() { return LaunchFunc(); }
416 };
417 }  // namespace
418 
operator ()(const ServiceExecutableRunOptions * run_options,JitRtKernelsCache * kernels_cache,int32_t grid_size_x,int32_t grid_size_y,int32_t grid_size_z,int32_t block_size_x,int32_t block_size_y,int32_t block_size_z,CustomCall::RemainingArgs args,StringRef ptx,StringRef name) const419 Error LaunchFunc::operator()(const ServiceExecutableRunOptions* run_options,
420                              JitRtKernelsCache* kernels_cache,
421                              int32_t grid_size_x, int32_t grid_size_y,
422                              int32_t grid_size_z, int32_t block_size_x,
423                              int32_t block_size_y, int32_t block_size_z,
424                              CustomCall::RemainingArgs args, StringRef ptx,
425                              StringRef name) const {
426   se::Stream* stream = run_options->stream();
427   se::StreamExecutor* executor = stream->parent();
428 
429   LaunchDimensions launch_dimensions(
430       {grid_size_x, grid_size_y, grid_size_z},
431       {block_size_x, block_size_y, block_size_z});
432 
433   se::KernelBase* kernel = kernels_cache->Get(executor, ptx.data(), name);
434 
435   // If kernel does not exists create it from the ptx.
436   if (kernel == nullptr) {
437     auto created = CreateKernel(absl::string_view(name.data(), name.size()),
438                                 args.size(), ptx.data(), {}, executor);
439     if (!created.ok()) return AsError(created);
440 
441     kernel =
442         kernels_cache->Set(executor, ptx.data(), name, std::move(*created));
443   }
444 
445   VLOG(3) << "Launching " << kernel->name();
446   absl::InlinedVector<se::DeviceMemoryBase, 4> buffer_args;
447   buffer_args.reserve(args.size());
448 
449   // Add MemRef arguments as buffer arguments.
450   for (unsigned i = 0; i < args.size(); ++i) {
451     // Simple row major memref passed as shapeless buffer.
452     auto memref = args.get<runtime::FlatMemrefView>(i);
453     if (succeeded(memref)) {
454       buffer_args.emplace_back(GetDeviceAddress(*memref));
455       continue;
456     }
457 
458     // Memref layout must be encoded in the compiled device kernel, so we don't
459     // have to pass strides or minor to major dimensions order to the kernel.
460     auto strided = args.get<runtime::StridedMemrefView>(i);
461     if (succeeded(strided)) {
462       buffer_args.emplace_back(GetDeviceAddress(*strided));
463       continue;
464     }
465 
466     return MakeStringError("Unsupported argumeent type");
467   }
468 
469   // Execute device kernel on a main stream.
470   auto executed =
471       ExecuteKernelOnStream(*kernel, buffer_args, launch_dimensions, stream);
472   if (!executed.ok()) return AsError(executed);
473 
474   return Error::success();
475 }
476 
LaunchFunc(runtime::KernelContext * ctx,void ** args,void ** attrs)477 static bool LaunchFunc(runtime::KernelContext* ctx, void** args, void** attrs) {
478   static auto* handler = CustomCall::Bind("xla.gpu.func.launch")
479                              .UserData<const ServiceExecutableRunOptions*>()
480                              .UserData<JitRtKernelsCache*>()
481                              .Arg<int32_t>()   // grid_size_x
482                              .Arg<int32_t>()   // grid_size_y
483                              .Arg<int32_t>()   // grid_size_z
484                              .Arg<int32_t>()   // block_size_x
485                              .Arg<int32_t>()   // block_size_y
486                              .Arg<int32_t>()   // block_size_x
487                              .RemainingArgs()  // args
488                              .Attr<StringRef>("ptx")
489                              .Attr<StringRef>("kernel")
490                              .To<RuntimeChecks()>(LaunchFunc::Handler())
491                              .release();
492 
493   return succeeded(Executable::Call(ctx, *handler, args, attrs));
494 }
495 
496 // -------------------------------------------------------------------------- //
497 
498 namespace {
499 struct Gemm {
500   LLVM_ATTRIBUTE_ALWAYS_INLINE
501   Error operator()(const ServiceExecutableRunOptions* run_options,
502                    const DebugOptions* debug_options,
503                    JitRtGemmConfigCache* configs,
504                    runtime::StridedMemrefView lhs,
505                    runtime::StridedMemrefView rhs,
506                    runtime::StridedMemrefView out, int64_t algorithm,
507                    double alpha_real, double alpha_imag, double beta,
508                    DotDimensionNumbers dot_dims, int64_t uid) const;
509 
Handlerxla::gpu::__anon8ea4ed0d0711::Gemm510   static Gemm Handler() { return Gemm(); }
511 };
512 }  // namespace
513 
operator ()(const ServiceExecutableRunOptions * run_options,const DebugOptions * debug_options,JitRtGemmConfigCache * configs,runtime::StridedMemrefView lhs,runtime::StridedMemrefView rhs,runtime::StridedMemrefView out,int64_t algorithm,double alpha_real,double alpha_imag,double beta,DotDimensionNumbers dot_dims,int64_t uid) const514 Error Gemm::operator()(const ServiceExecutableRunOptions* run_options,
515                        const DebugOptions* debug_options,
516                        JitRtGemmConfigCache* configs,
517                        runtime::StridedMemrefView lhs,
518                        runtime::StridedMemrefView rhs,
519                        runtime::StridedMemrefView out, int64_t algorithm,
520                        double alpha_real, double alpha_imag, double beta,
521                        DotDimensionNumbers dot_dims, int64_t uid) const {
522   se::DeviceMemoryBase lhs_data = GetDeviceAddress(lhs);
523   se::DeviceMemoryBase rhs_data = GetDeviceAddress(rhs);
524   se::DeviceMemoryBase output_data = GetDeviceAddress(out);
525 
526   VLOG(3) << "Running GEMM";
527   se::Stream* stream = run_options->stream();
528 
529   // Find the gemm config for this instance of operation based on uid.
530   const GemmConfig* config = configs->Get(uid);
531   if (config == nullptr) {
532     auto cfg = GetGemmConfig(lhs, rhs, out, algorithm, alpha_real, alpha_imag,
533                              beta, dot_dims.lhs_batch, dot_dims.lhs_contract,
534                              dot_dims.rhs_batch, dot_dims.rhs_contract);
535     if (!cfg.ok()) return AsError(cfg);
536     config = configs->Set(uid, std::move(*cfg));
537   }
538 
539   Status executed = [&]() -> Status {
540     return RunGemm(*config, lhs_data, rhs_data, output_data, stream);
541   }();
542 
543   if (!executed.ok()) return AsError(executed);
544 
545   return Error::success();
546 }
547 
Gemm(runtime::KernelContext * ctx,void ** args,void ** attrs)548 static bool Gemm(runtime::KernelContext* ctx, void** args, void** attrs) {
549   static auto* handler = CustomCall::Bind("xla.gpu.gemm")
550                              .UserData<const ServiceExecutableRunOptions*>()
551                              .UserData<const DebugOptions*>()
552                              .UserData<JitRtGemmConfigCache*>()
553                              .Arg<runtime::StridedMemrefView>()  // lhs
554                              .Arg<runtime::StridedMemrefView>()  // rhs
555                              .Arg<runtime::StridedMemrefView>()  // out
556                              .Attr<int64_t>("algorithm")
557                              .Attr<double>("alpha_real")
558                              .Attr<double>("alpha_imag")
559                              .Attr<double>("beta")
560                              .Attr<DotDimensionNumbers>("dot_dims")
561                              .Attr<int64_t>("uid")
562                              .To<RuntimeChecks()>(Gemm::Handler())
563                              .release();
564 
565   return succeeded(Executable::Call(ctx, *handler, args, attrs));
566 }
567 
568 // -------------------------------------------------------------------------- //
569 
570 // TODO(ezhulenev): Cache matmul plans similar to GemmConfig for Gemm.
571 
572 namespace {
573 struct CublasLtMatmul {
574   LLVM_ATTRIBUTE_ALWAYS_INLINE
575   Error operator()(const ServiceExecutableRunOptions* run_options,
576                    const DebugOptions* debug_options,
577                    runtime::StridedMemrefView a, runtime::StridedMemrefView b,
578                    runtime::StridedMemrefView c, runtime::StridedMemrefView d,
579                    Optional<runtime::StridedMemrefView> bias, int64_t algorithm,
580                    double alpha_real, double alpha_imag, double beta,
581                    DotDimensionNumbers dot_dims,
582                    se::cuda::BlasLt::Epilogue epilogue,
583                    ArrayRef<int32_t> precision, int64_t uid) const;
584 
Handlerxla::gpu::__anon8ea4ed0d0911::CublasLtMatmul585   static CublasLtMatmul Handler() { return CublasLtMatmul(); }
586 };
587 }  // namespace
588 
operator ()(const ServiceExecutableRunOptions * run_options,const DebugOptions * debug_options,runtime::StridedMemrefView a,runtime::StridedMemrefView b,runtime::StridedMemrefView c,runtime::StridedMemrefView d,Optional<runtime::StridedMemrefView> bias,int64_t algorithm,double alpha_real,double alpha_imag,double beta,DotDimensionNumbers dot_dims,se::cuda::BlasLt::Epilogue epilogue,ArrayRef<int32_t> precision,int64_t uid) const589 Error CublasLtMatmul::operator()(
590     const ServiceExecutableRunOptions* run_options,
591     const DebugOptions* debug_options, runtime::StridedMemrefView a,
592     runtime::StridedMemrefView b, runtime::StridedMemrefView c,
593     runtime::StridedMemrefView d, Optional<runtime::StridedMemrefView> bias,
594     int64_t algorithm, double alpha_real, double alpha_imag, double beta,
595     DotDimensionNumbers dot_dims, se::cuda::BlasLt::Epilogue epilogue,
596     ArrayRef<int32_t> precision, int64_t uid) const {
597   VLOG(3) << "Running CublasLtMatmul";
598   se::Stream* stream = run_options->stream();
599 
600   // Construct a plan from a gemm config and an epilogue.
601   auto cfg = GetGemmConfig(a, b, c, algorithm, alpha_real, alpha_imag, beta,
602                            dot_dims.lhs_batch, dot_dims.lhs_contract,
603                            dot_dims.rhs_batch, dot_dims.rhs_contract);
604   if (!cfg.ok()) return AsError(cfg);
605 
606   auto plan = cublas_lt::MatmulPlan::From(*cfg, epilogue);
607   if (!plan.ok()) return AsError(plan);
608 
609   auto algos = plan->GetAlgorithms(stream);
610   if (!algos.ok()) return AsError(algos);
611 
612   se::DeviceMemoryBase a_data = GetDeviceAddress(a);
613   se::DeviceMemoryBase b_data = GetDeviceAddress(b);
614   se::DeviceMemoryBase c_data = GetDeviceAddress(c);
615   se::DeviceMemoryBase d_data = GetDeviceAddress(d);
616   se::DeviceMemoryBase bias_data;
617   if (bias.has_value()) bias_data = GetDeviceAddress(*bias);
618 
619   se::OwningScratchAllocator<> scratch_allocator(
620       stream->parent()->device_ordinal(), stream->parent()->GetAllocator());
621 
622   auto st =
623       plan->ExecuteOnStream(stream, a_data, b_data, c_data, d_data, bias_data,
624                             (*algos)[algorithm], scratch_allocator);
625   if (!st.ok()) return AsError(st);
626 
627   return Error::success();
628 }
629 
630 // Adds custom call bindings for matmul operations.
631 template <typename... Ts>
BindMatmulAttributes(runtime::CustomCallBinding<Ts...> binding)632 static auto BindMatmulAttributes(runtime::CustomCallBinding<Ts...> binding) {
633   return std::move(binding)
634       .template Attr<int64_t>("algorithm")
635       .template Attr<double>("alpha_real")
636       .template Attr<double>("alpha_imag")
637       .template Attr<double>("beta")
638       .template Attr<DotDimensionNumbers>("dot_dims")
639       .template Attr<se::cuda::BlasLt::Epilogue>("epilogue")
640       .template Attr<ArrayRef<int32_t>>("precision")
641       .template Attr<int64_t>("uid");
642 }
643 
CublasLtMatmul(runtime::KernelContext * ctx,void ** args,void ** attrs)644 static bool CublasLtMatmul(runtime::KernelContext* ctx, void** args,
645                            void** attrs) {
646   static auto* handler =
647       BindMatmulAttributes(CustomCall::Bind("xla.gpu.cublas.lt.matmul")
648                                .UserData<const ServiceExecutableRunOptions*>()
649                                .UserData<const DebugOptions*>()
650                                .Arg<runtime::StridedMemrefView>()  // a
651                                .Arg<runtime::StridedMemrefView>()  // b
652                                .Arg<runtime::StridedMemrefView>()  // c
653                                .Arg<runtime::StridedMemrefView>()  // d
654                                .Value(CustomCall::None)            // bias
655                            )
656           .To<RuntimeChecks()>(CublasLtMatmul::Handler())
657           .release();
658 
659   return succeeded(Executable::Call(ctx, *handler, args, attrs));
660 }
661 
CublasLtMatmulBias(runtime::KernelContext * ctx,void ** args,void ** attrs)662 static bool CublasLtMatmulBias(runtime::KernelContext* ctx, void** args,
663                                void** attrs) {
664   static auto* handler =
665       BindMatmulAttributes(CustomCall::Bind("xla.gpu.cublas.lt.matmul.bias")
666                                .UserData<const ServiceExecutableRunOptions*>()
667                                .UserData<const DebugOptions*>()
668                                .Arg<runtime::StridedMemrefView>()  // a
669                                .Arg<runtime::StridedMemrefView>()  // b
670                                .Arg<runtime::StridedMemrefView>()  // c
671                                .Arg<runtime::StridedMemrefView>()  // d
672                                .Arg<runtime::StridedMemrefView>()  // bias
673                            )
674           .To<RuntimeChecks()>(CublasLtMatmul::Handler())
675           .release();
676 
677   return succeeded(Executable::Call(ctx, *handler, args, attrs));
678 }
679 
680 // -------------------------------------------------------------------------- //
681 
682 // TODO(ezhulenev): We need to find a better way to pass structured attributes
683 // to JitRt custom calls.
684 
685 // TODO(ezhulenev): Add caching layer for convolution configs and runners.
686 
687 namespace {
688 
689 struct Window {
690   ArrayRef<int64_t> window_strides;
691   ArrayRef<int64_t> padding;
692   ArrayRef<int64_t> lhs_dilation;
693   ArrayRef<int64_t> rhs_dilation;
694   ArrayRef<int64_t> window_reversal;
695 };
696 
697 struct ConvAttrs {
698   int64_t feature_group_count;
699   double result_scale;
700 };
701 
702 struct FusedConvAttrs {
703   se::dnn::ActivationMode activation_mode;
704 };
705 
706 struct SideInputAttrs {
707   double side_input_scale;
708 };
709 
710 }  // namespace
711 
GetConvDescriptor(CudnnConvKind kind,runtime::StridedMemrefView operand0,runtime::StridedMemrefView operand1,runtime::StridedMemrefView output,runtime::FlatMemrefView scratch,ConvDimensionNumbers dims,Window w,ConvBackendConfig b,ConvAttrs attrs,Optional<FusedConvAttrs> fused=llvm::None,Optional<SideInputAttrs> side_input=llvm::None)712 static GpuConvDescriptor GetConvDescriptor(
713     CudnnConvKind kind,
714     // Arguments
715     runtime::StridedMemrefView operand0, runtime::StridedMemrefView operand1,
716     runtime::StridedMemrefView output, runtime::FlatMemrefView scratch,
717     // Attributes
718     ConvDimensionNumbers dims, Window w, ConvBackendConfig b, ConvAttrs attrs,
719     // Conv-specific arguments and attributes
720     Optional<FusedConvAttrs> fused = llvm::None,
721     Optional<SideInputAttrs> side_input = llvm::None) {
722   // Build a convolution descriptor from the attributes.
723   GpuConvDescriptor descriptor;
724   descriptor.kind = kind;
725 
726   // Apply backend config layout to the shape.
727   auto apply_layout = [](runtime::StridedMemrefView& memref,
728                          ArrayRef<int64_t> minor_to_major) {
729     Shape shape = ToShape(memref);
730     return ShapeUtil::MakeShapeWithLayout(shape.element_type(),
731                                           shape.dimensions(), minor_to_major);
732   };
733 
734   descriptor.operand0_shape = apply_layout(operand0, b.operand_0_layout);
735   descriptor.operand1_shape = apply_layout(operand1, b.operand_1_layout);
736   descriptor.result_shape = apply_layout(output, b.result_layout);
737 
738   // Set up convolution dimensions numbers.
739   ConvolutionDimensionNumbers dns;
740   dns.set_input_batch_dimension(dims.input_batch_dim);
741   dns.set_input_feature_dimension(dims.input_feature_dim);
742   dns.set_kernel_input_feature_dimension(dims.kernel_in_feature_dim);
743   dns.set_kernel_output_feature_dimension(dims.kernel_out_feature_dim);
744   dns.set_output_batch_dimension(dims.output_batch_dim);
745   dns.set_output_feature_dimension(dims.output_feature_dim);
746   for (int64_t d : dims.input_spatial_dims) dns.add_input_spatial_dimensions(d);
747   for (int64_t d : dims.kernel_spatial_dims)
748     dns.add_kernel_spatial_dimensions(d);
749   for (int64_t d : dims.output_spatial_dims)
750     dns.add_output_spatial_dimensions(d);
751   descriptor.dnums = std::move(dns);
752 
753   // Put together convolution window config.
754   for (auto index : llvm::seq<int>(0, w.window_strides.size())) {
755     WindowDimension* dim = descriptor.window.add_dimensions();
756     // Window size for a convolution is the same as the kernel size.
757     // Kernel size of the convolution is operand1_shape. We need to look at
758     // the convolution dimension numbers kernel spatial dimensions to get
759     // the window size.
760     int kernel_dim = descriptor.dnums.kernel_spatial_dimensions(index);
761     dim->set_size(descriptor.operand0_shape.dimensions(kernel_dim));
762     dim->set_stride(w.window_strides[index]);
763     dim->set_padding_low(w.padding[index]);
764     dim->set_padding_high(w.padding[index]);
765     dim->set_base_dilation(w.lhs_dilation[index]);
766     dim->set_window_dilation(w.rhs_dilation[index]);
767     dim->set_window_reversal(w.window_reversal[index]);
768   }
769 
770   descriptor.scratch_size = scratch.size_in_bytes;
771   descriptor.feature_group_count = attrs.feature_group_count;
772   descriptor.backend_config.set_conv_result_scale(attrs.result_scale);
773 
774   // Set up convolution algorigthm.
775   auto* algo = descriptor.backend_config.mutable_algorithm();
776   algo->set_algo_id(b.algorithm);
777   algo->set_math_type(b.tensor_ops_enabled
778                           ? se::dnn::AlgorithmProto::TENSOR_OP_MATH
779                           : se::dnn::AlgorithmProto::DEFAULT_MATH);
780   algo->set_is_cudnn_frontend(b.is_cudnn_frontend);
781 
782   if (b.workspace_size >= 0)
783     algo->mutable_workspace_size()->set_value(b.workspace_size);
784 
785   for (unsigned i = 0; i < b.knob_ids.size(); ++i) {
786     algo->mutable_tuning_knobs()->insert({b.knob_ids[i], b.knob_values[i]});
787   }
788 
789   // Set attributes specific for fused convolutions.
790   if (fused.has_value())
791     descriptor.backend_config.set_activation_mode(fused->activation_mode);
792 
793   // Set attributes specific for convolutions with side input.
794   if (side_input.has_value())
795     descriptor.backend_config.set_side_input_scale(
796         side_input->side_input_scale);
797 
798   return descriptor;
799 }
800 
801 namespace {
802 struct Conv {
803   LLVM_ATTRIBUTE_ALWAYS_INLINE
operator ()xla::gpu::__anon8ea4ed0d0c11::Conv804   Error operator()(
805       const ServiceExecutableRunOptions* run_options,
806       const DebugOptions* debug_options, runtime::StridedMemrefView operand0,
807       runtime::StridedMemrefView operand1,
808       Optional<runtime::FlatMemrefView> bias,
809       Optional<runtime::StridedMemrefView> side_input,
810       runtime::StridedMemrefView output, runtime::FlatMemrefView scratch,
811       ConvDimensionNumbers conv_dims,
812       // Window config
813       ArrayRef<int64_t> window_strides, ArrayRef<int64_t> padding,
814       ArrayRef<int64_t> lhs_dilation, ArrayRef<int64_t> rhs_dilation,
815       ArrayRef<int64_t> window_reversal,
816       // Backend config attributes
817       ConvBackendConfig backend_config,
818       // Remaining attributes
819       int64_t feature_group_count, double result_scale,
820       // Optional attributes for fused convolutions.
821       Optional<se::dnn::ActivationMode> activation_mode = llvm::None,
822       Optional<double> side_input_scale = llvm::None) const {
823     // Build config for optional attributes.
824     Optional<FusedConvAttrs> fused_attrs = llvm::None;
825     if (activation_mode.has_value()) fused_attrs = {*activation_mode};
826 
827     Optional<SideInputAttrs> side_input_attrs = llvm::None;
828     if (side_input_scale.has_value()) side_input_attrs = {*side_input_scale};
829 
830     // Prepare a descriptor for the XLA convolution.
831     GpuConvDescriptor descriptor = GetConvDescriptor(
832         kind, operand0, operand1, output, scratch, conv_dims,
833         {window_strides, padding, lhs_dilation, rhs_dilation, window_reversal},
834         backend_config, {feature_group_count, result_scale}, fused_attrs,
835         side_input_attrs);
836 
837     // Convert descriptor to the Conv config.
838     StatusOr<GpuConvConfig> config = GetGpuConvConfig(descriptor, "");
839     if (!config.ok()) return AsError(config);
840 
841     // Prepare buffer arguments.
842     std::vector<se::DeviceMemoryBase> buffers = {GetDeviceAddress(operand0),
843                                                  GetDeviceAddress(operand1)};
844     if (bias.has_value()) buffers.push_back(GetDeviceAddress(*bias));
845     if (side_input.has_value())
846       buffers.push_back(GetDeviceAddress(*side_input));
847 
848     se::DeviceMemoryBase result_buffer = GetDeviceAddress(output);
849     se::DeviceMemoryBase scratch_buffer = GetDeviceAddress(scratch);
850 
851     RunConvOptions opts;
852 
853     // Create a runner for the given config.
854     MaybeFusedConvRunner runner(*config);
855     opts.runner_cache = &runner;
856 
857     // Run the convolution.
858     auto st = RunGpuConv(*config, buffers, result_buffer, scratch_buffer,
859                          run_options->stream(), opts);
860     if (!st.ok() || !run_options->stream()->ok()) {
861       return AsError(st);
862     }
863 
864     return Error::success();
865   }
866 
Handlerxla::gpu::__anon8ea4ed0d0c11::Conv867   static Conv Handler(CudnnConvKind kind) { return Conv{kind}; }
868 
869   CudnnConvKind kind;
870 };
871 
872 }  // namespace
873 
874 // Adds custom call bindings for convolution operations.
875 template <typename... Ts>
BindConvAttributes(runtime::CustomCallBinding<Ts...> binding)876 static auto BindConvAttributes(runtime::CustomCallBinding<Ts...> binding) {
877   return std::move(binding)
878       // Convolution dimensions numbers
879       .template Attr<ConvDimensionNumbers>("conv_dims")
880       // Window config
881       .template Attr<ArrayRef<int64_t>>("window_strides")
882       .template Attr<ArrayRef<int64_t>>("padding")
883       .template Attr<ArrayRef<int64_t>>("lhs_dilation")
884       .template Attr<ArrayRef<int64_t>>("rhs_dilation")
885       .template Attr<ArrayRef<int64_t>>("window_reversal")
886       // Backend config attributes
887       .template Attr<ConvBackendConfig>("backend_config")
888       // Remaining attributes.
889       .template Attr<int64_t>("feature_group_count")
890       .template Attr<double>("result_scale");
891 }
892 
893 template <CudnnConvKind kind>
ConvFn(runtime::KernelContext * ctx,void ** args,void ** attrs)894 static bool ConvFn(runtime::KernelContext* ctx, void** args, void** attrs) {
895   static auto* handler =
896       BindConvAttributes(CustomCall::Bind("xla.gpu.conv")
897                              .UserData<const ServiceExecutableRunOptions*>()
898                              .UserData<const DebugOptions*>()
899                              .Arg<runtime::StridedMemrefView>()  // operand0
900                              .Arg<runtime::StridedMemrefView>()  // operand1
901                              .Value(CustomCall::None)            // bias
902                              .Value(CustomCall::None)            // side_input
903                              .Arg<runtime::StridedMemrefView>()  // output
904                              .Arg<runtime::FlatMemrefView>()     // scratch
905                          )
906           .To(Conv::Handler(kind))
907           .release();
908 
909   return succeeded(Executable::Call(ctx, *handler, args, attrs));
910 }
911 
912 template <CudnnConvKind kind>
ConvFusedFn(runtime::KernelContext * ctx,void ** args,void ** attrs)913 static bool ConvFusedFn(runtime::KernelContext* ctx, void** args,
914                         void** attrs) {
915   static auto* handler =
916       BindConvAttributes(CustomCall::Bind("xla.gpu.conv.fused")
917                              .UserData<const ServiceExecutableRunOptions*>()
918                              .UserData<const DebugOptions*>()
919                              .Arg<runtime::StridedMemrefView>()  // operand0
920                              .Arg<runtime::StridedMemrefView>()  // operand1
921                              .Arg<runtime::FlatMemrefView>()     // bias
922                              .Value(CustomCall::None)            // side_input
923                              .Arg<runtime::StridedMemrefView>()  // output
924                              .Arg<runtime::FlatMemrefView>()     // scratch
925                          )
926           .Attr<se::dnn::ActivationMode>("activation_mode")
927           .To(Conv::Handler(kind))
928           .release();
929 
930   return succeeded(Executable::Call(ctx, *handler, args, attrs));
931 }
932 
933 template <CudnnConvKind kind>
ConvFuseSideInputdFn(runtime::KernelContext * ctx,void ** args,void ** attrs)934 static bool ConvFuseSideInputdFn(runtime::KernelContext* ctx, void** args,
935                                  void** attrs) {
936   static auto* handler =
937       BindConvAttributes(CustomCall::Bind("xla.gpu.conv.fused.side_input")
938                              .UserData<const ServiceExecutableRunOptions*>()
939                              .UserData<const DebugOptions*>()
940                              .Arg<runtime::StridedMemrefView>()  // operand0
941                              .Arg<runtime::StridedMemrefView>()  // operand1
942                              .Arg<runtime::FlatMemrefView>()     // bias
943                              .Arg<runtime::StridedMemrefView>()  // side_input
944                              .Arg<runtime::StridedMemrefView>()  // output
945                              .Arg<runtime::FlatMemrefView>()     // scratch
946                          )
947           .Attr<se::dnn::ActivationMode>("activation_mode")
948           .Attr<double>("side_input_scale")
949           .To(Conv::Handler(kind))
950           .release();
951 
952   return succeeded(Executable::Call(ctx, *handler, args, attrs));
953 }
954 
955 // -------------------------------------------------------------------------- //
956 
957 namespace {
958 struct Infeed {
959   Error operator()(const ServiceExecutableRunOptions* run_options,
960                    CustomCall::RemainingArgs args, StringRef config) const;
Handlerxla::gpu::__anon8ea4ed0d0d11::Infeed961   static Infeed Handler() { return Infeed(); }
962 };
963 }  // namespace
964 
operator ()(const ServiceExecutableRunOptions * run_options,CustomCall::RemainingArgs args,StringRef config) const965 Error Infeed::operator()(const ServiceExecutableRunOptions* run_options,
966                          CustomCall::RemainingArgs args,
967                          StringRef config) const {
968   VLOG(3) << "Infeeding to GPU";
969 
970   se::Stream* stream = run_options->stream();
971   ShapeTree<se::ScopedDeviceMemory<uint8_t>> source_buffers =
972       GetOrCreateInfeedManager(stream->parent())->BlockingGetNextDestination();
973 
974   // Check that we have correct number of arguments.
975   if (args.size() != source_buffers.leaf_count())
976     return MakeStringError("Incorrect number of arguments");
977 
978   size_t index = 0;
979   for (auto& source : source_buffers.leaves()) {
980     // Get the destination buffer.
981     auto dest = args.get<runtime::StridedMemrefView>(index);
982     if (failed(dest))
983       return MakeStringError("Failed to get the destination buffer");
984 
985     // Get the source buffer shape.
986     const Shape& source_shape =
987         ShapeUtil::GetSubshape(source_buffers.shape(), source.first);
988 
989     // Check that destination shape matches the source shape.
990     Shape dest_shape = ToShape(*dest);
991     if (!ShapeUtil::ReshapeIsBitcast(dest_shape, source_shape)) {
992       return MakeStringError(
993           "The destination shape does not match the source shape");
994     }
995 
996     se::DeviceMemoryBase dest_address = GetDeviceAddress(*dest);
997     se::ScopedDeviceMemory<uint8_t>& buffer = source.second;
998     stream->ThenMemcpy(&dest_address, *buffer.ptr(), buffer.ptr()->size());
999 
1000     ++index;
1001   }
1002 
1003   // TODO(ezhulenev): Make this function async?
1004   Status block_status = stream->BlockHostUntilDone();
1005   if (!block_status.ok()) return AsError(block_status);
1006 
1007   VLOG(3) << "Infeeding to GPU complete";
1008 
1009   return Error::success();
1010 }
1011 
Infeed(runtime::KernelContext * ctx,void ** args,void ** attrs)1012 static bool Infeed(runtime::KernelContext* ctx, void** args, void** attrs) {
1013   static auto* handler = CustomCall::Bind("xla.gpu.infeed")
1014                              .UserData<const ServiceExecutableRunOptions*>()
1015                              .Arg<CustomCall::RemainingArgs>()  // args
1016                              .Attr<StringRef>("config")
1017                              .To<RuntimeChecks()>(Infeed::Handler())
1018                              .release();
1019 
1020   return succeeded(Executable::Call(ctx, *handler, args, attrs));
1021 }
1022 
1023 // -------------------------------------------------------------------------- //
1024 
1025 namespace {
1026 struct Outfeed {
1027   Error operator()(const ServiceExecutableRunOptions* run_options,
1028                    CustomCall::RemainingArgs args, StringRef config) const;
Handlerxla::gpu::__anon8ea4ed0d0e11::Outfeed1029   static Outfeed Handler() { return Outfeed(); }
1030 };
1031 }  // namespace
1032 
operator ()(const ServiceExecutableRunOptions * run_options,CustomCall::RemainingArgs args,StringRef config) const1033 Error Outfeed::operator()(const ServiceExecutableRunOptions* run_options,
1034                           CustomCall::RemainingArgs args,
1035                           StringRef config) const {
1036   VLOG(3) << "Outfeeding from GPU";
1037 
1038   se::Stream* stream = run_options->stream();
1039   OutfeedManager* outfeed_manager = GetOrCreateOutfeedManager(stream->parent());
1040   ShapeTree<std::unique_ptr<OutfeedBuffer>>* dest_buffers =
1041       outfeed_manager->BlockingGetNextDestination();
1042 
1043   // Nothing to be done for an outfeed with no inputs.
1044   // Note: Must do this after `BlockingGetNextDestination` above to dequeue an
1045   // entry from the outfeed manager.
1046   if (args.empty()) return Error::success();
1047 
1048   // Check that we have correct number of arguments.
1049   if (args.size() != dest_buffers->leaf_count())
1050     return MakeStringError("Incorrect number of arguments");
1051 
1052   size_t index = 0;
1053   for (auto& dest : dest_buffers->leaves()) {
1054     // Get the source buffer.
1055     auto source = args.get<runtime::StridedMemrefView>(index);
1056     if (failed(source))
1057       return MakeStringError("Failed to get the source buffer");
1058 
1059     // Get the source buffer shape.
1060     const Shape& dest_shape =
1061         ShapeUtil::GetSubshape(dest_buffers->shape(), dest.first);
1062 
1063     // Check that destination shape matches the source shape.
1064     Shape source_shape = ToShape(*source);
1065     if (!ShapeUtil::ReshapeIsBitcast(dest_shape, source_shape)) {
1066       return MakeStringError(
1067           "The destination shape does not match the source shape");
1068     }
1069 
1070     se::DeviceMemoryBase source_address = GetDeviceAddress(*source);
1071     std::unique_ptr<OutfeedBuffer>& buffer = dest.second;
1072 
1073     // Schedule the memory transfer.
1074     auto* dest_address = buffer->destination()->untyped_data();
1075     stream->ThenMemcpy(dest_address, source_address, buffer->length())
1076         .ThenDoHostCallback([&buffer]() { buffer->Done(); });
1077 
1078     ++index;
1079   }
1080 
1081   Status block_status = stream->BlockHostUntilDone();
1082   if (!block_status.ok()) return AsError(block_status);
1083 
1084   VLOG(3) << "Outfeeding from GPU complete";
1085 
1086   return Error::success();
1087 }
1088 
Outfeed(runtime::KernelContext * ctx,void ** args,void ** attrs)1089 static bool Outfeed(runtime::KernelContext* ctx, void** args, void** attrs) {
1090   static auto* handler = CustomCall::Bind("xla.gpu.outfeed")
1091                              .UserData<const ServiceExecutableRunOptions*>()
1092                              .Arg<CustomCall::RemainingArgs>()  // args
1093                              .Attr<StringRef>("config")
1094                              .To<RuntimeChecks()>(Outfeed::Handler())
1095                              .release();
1096 
1097   return succeeded(Executable::Call(ctx, *handler, args, attrs));
1098 }
1099 
1100 // -------------------------------------------------------------------------- //
1101 
1102 namespace {
1103 
1104 enum class MemcpyDirection { kDeviceToDevice, kDeviceToHost, kHostToDevice };
1105 
1106 template <MemcpyDirection direction>
1107 struct Memcpy {
1108   Error operator()(const ServiceExecutableRunOptions* run_options,
1109                    runtime::FlatMemrefView dst,
1110                    runtime::FlatMemrefView src) const;
Handlerxla::gpu::__anon8ea4ed0d1011::Memcpy1111   static Memcpy Handler() { return Memcpy(); }
1112 };
1113 }  // namespace
1114 
1115 template <MemcpyDirection direction>
operator ()(const ServiceExecutableRunOptions * run_options,runtime::FlatMemrefView dst,runtime::FlatMemrefView src) const1116 Error Memcpy<direction>::operator()(
1117     const ServiceExecutableRunOptions* run_options, runtime::FlatMemrefView dst,
1118     runtime::FlatMemrefView src) const {
1119   se::Stream* stream = run_options->stream();
1120 
1121   if (dst.size_in_bytes != src.size_in_bytes) {
1122     return MakeStringError(
1123         "Source memref size does not match destination memref size");
1124   }
1125 
1126   switch (direction) {
1127     case MemcpyDirection::kDeviceToDevice: {
1128       se::DeviceMemoryBase dst_data = GetDeviceAddress(dst);
1129       se::DeviceMemoryBase src_data = GetDeviceAddress(src);
1130       stream->ThenMemcpy(&dst_data, src_data, src.size_in_bytes);
1131     } break;
1132     case MemcpyDirection::kDeviceToHost: {
1133       se::DeviceMemoryBase src_data = GetDeviceAddress(src);
1134       stream->ThenMemcpy(dst.data, src_data, src.size_in_bytes);
1135     } break;
1136     case MemcpyDirection::kHostToDevice: {
1137       se::DeviceMemoryBase dst_data = GetDeviceAddress(dst);
1138       stream->ThenMemcpy(&dst_data, src.data, src.size_in_bytes);
1139     } break;
1140   }
1141 
1142   // TODO(ezhulenev): H2D and D2H memcpy instead of blocking the execution
1143   // thread should return an async token that will become available when
1144   // transfer is completed.
1145   if (direction != MemcpyDirection::kDeviceToDevice) {
1146     auto st = stream->BlockHostUntilDone();
1147     if (!st.ok()) return AsError(st);
1148   }
1149 
1150   return Error::success();
1151 }
1152 
1153 template <MemcpyDirection direction>
MemcpyFn(runtime::KernelContext * ctx,void ** args,void ** attrs)1154 static bool MemcpyFn(runtime::KernelContext* ctx, void** args, void** attrs) {
1155   static auto* handler = CustomCall::Bind("xla.gpu.memcpy")
1156                              .UserData<const ServiceExecutableRunOptions*>()
1157                              .Arg<runtime::FlatMemrefView>()  // dst
1158                              .Arg<runtime::FlatMemrefView>()  // src
1159                              .To<RuntimeChecks()>(Memcpy<direction>::Handler())
1160                              .release();
1161 
1162   return succeeded(Executable::Call(ctx, *handler, args, attrs));
1163 }
1164 
1165 // -------------------------------------------------------------------------- //
1166 
1167 namespace {
1168 
1169 struct Memset {
1170   Error operator()(const ServiceExecutableRunOptions* run_options,
1171                    runtime::FlatMemrefView dst,
1172                    CustomCall::VariantArg constant) const;
Handlerxla::gpu::__anon8ea4ed0d1111::Memset1173   static Memset Handler() { return Memset(); }
1174 };
1175 
1176 }  // namespace
1177 
operator ()(const ServiceExecutableRunOptions * run_options,runtime::FlatMemrefView dst,CustomCall::VariantArg constant) const1178 Error Memset::operator()(const ServiceExecutableRunOptions* run_options,
1179                          runtime::FlatMemrefView dst,
1180                          CustomCall::VariantArg constant) const {
1181   se::Stream* stream = run_options->stream();
1182   se::DeviceMemoryBase dst_data = GetDeviceAddress(dst);
1183 
1184   // If the constant is zero we can use memzero directly.
1185   bool set_zero = false;
1186 
1187   // Check all supported data types to see if we have a zero value.
1188   if (auto i1 = constant.get<bool>(); succeeded(i1) && *i1 == false)
1189     set_zero = true;
1190   else if (auto i32 = constant.get<int32_t>(); succeeded(i32) && *i32 == 0)
1191     set_zero = true;
1192   else if (auto f16 = constant.get<half>(); succeeded(f16) && *f16 == half(0.0))
1193     set_zero = true;
1194   else if (auto f32 = constant.get<float>(); succeeded(f32) && *f32 == 0.0)
1195     set_zero = true;
1196 
1197   if (set_zero) {
1198     stream->ThenMemZero(&dst_data, dst.size_in_bytes);
1199     return Error::success();
1200   }
1201 
1202   // If the constant is not zero, use the given pattern to `memset`.
1203   // TODO(ezhulenev): Support 16 and 8 bit patterns.
1204   uint32_t pattern;
1205   if (auto i32 = constant.get<int32_t>(); succeeded(i32))
1206     pattern = *i32;
1207   else if (auto f32 = constant.get<float>(); succeeded(f32))
1208     pattern = reinterpret_cast<uint32_t&>(*f32);
1209   else
1210     return MakeStringError("Unsupported memset bit pattern type");
1211 
1212   if (dst.size_in_bytes % 4 != 0)
1213     return MakeStringError("Memref size is not divisible by 4");
1214 
1215   stream->ThenMemset32(&dst_data, pattern, dst.size_in_bytes);
1216 
1217   return Error::success();
1218 }
1219 
MemsetFn(runtime::KernelContext * ctx,void ** args,void ** attrs)1220 static bool MemsetFn(runtime::KernelContext* ctx, void** args, void** attrs) {
1221   static auto* handler = CustomCall::Bind("xla.gpu.memset")
1222                              .UserData<const ServiceExecutableRunOptions*>()
1223                              .Arg<runtime::FlatMemrefView>()  // dst
1224                              .Arg<CustomCall::VariantArg>()   // constant
1225                              .To<RuntimeChecks()>(Memset::Handler())
1226                              .release();
1227 
1228   return succeeded(Executable::Call(ctx, *handler, args, attrs));
1229 }
1230 
1231 // -------------------------------------------------------------------------- //
1232 
1233 namespace {
1234 struct Fft {
1235   LLVM_ATTRIBUTE_ALWAYS_INLINE
1236   Error operator()(const ServiceExecutableRunOptions* run_options,
1237                    runtime::StridedMemrefView input,
1238                    runtime::StridedMemrefView output,
1239                    ArrayRef<int64_t> fft_length, se::fft::Type fft_type) const;
Handlerxla::gpu::__anon8ea4ed0d1211::Fft1240   static Fft Handler() { return Fft(); }
1241 };
1242 }  // namespace
1243 
operator ()(const ServiceExecutableRunOptions * run_options,runtime::StridedMemrefView input,runtime::StridedMemrefView output,ArrayRef<int64_t> fft_length,se::fft::Type fft_type) const1244 Error Fft::operator()(const ServiceExecutableRunOptions* run_options,
1245                       runtime::StridedMemrefView input,
1246                       runtime::StridedMemrefView output,
1247                       ArrayRef<int64_t> fft_length,
1248                       se::fft::Type fft_type) const {
1249   // TODO(ezhulenev): Cache FFT plans in the GpuExecutable.
1250   FftPlanCache fft_plan_cache;
1251 
1252   se::Stream* stream = run_options->stream();
1253   se::StreamExecutor* executor = stream->parent();
1254 
1255   if (input.dtype == tfrt::DType::F64 ||
1256       input.dtype == tfrt::DType::Complex128) {
1257     // Adjust FFT type to reflect double precision.
1258     switch (fft_type) {
1259       case se::fft::Type::kC2CForward:
1260         fft_type = se::fft::Type::kZ2ZForward;
1261         break;
1262       case se::fft::Type::kC2CInverse:
1263         fft_type = se::fft::Type::kZ2ZInverse;
1264         break;
1265       case se::fft::Type::kR2C:
1266         fft_type = se::fft::Type::kD2Z;
1267         break;
1268       case se::fft::Type::kC2R:
1269         fft_type = se::fft::Type::kZ2D;
1270         break;
1271       default:
1272         return MakeStringError("Unsupported FFT type");
1273     }
1274   }
1275 
1276   auto st =
1277       RunFft(GetDeviceAddress(input), ToShape(input), GetDeviceAddress(output),
1278              ToShape(output), fft_type, fft_length, executor->device_ordinal(),
1279              &fft_plan_cache, stream, run_options->allocator());
1280   if (!st.ok()) return AsError(st);
1281 
1282   return Error::success();
1283 }
1284 
Fft(runtime::KernelContext * ctx,void ** args,void ** attrs)1285 static bool Fft(runtime::KernelContext* ctx, void** args, void** attrs) {
1286   static auto* handler = CustomCall::Bind("xla.gpu.fft")
1287                              .UserData<const ServiceExecutableRunOptions*>()
1288                              .Arg<runtime::StridedMemrefView>()  // input
1289                              .Arg<runtime::StridedMemrefView>()  // output
1290                              .Attr<ArrayRef<int64_t>>("fft_length")
1291                              .Attr<se::fft::Type>("fft_type")
1292                              .To<RuntimeChecks()>(Fft::Handler())
1293                              .release();
1294   return succeeded(Executable::Call(ctx, *handler, args, attrs));
1295 }
1296 
1297 // -------------------------------------------------------------------------- //
1298 
1299 namespace {
1300 struct Cholesky {
1301   LLVM_ATTRIBUTE_ALWAYS_INLINE
1302   Error operator()(const ServiceExecutableRunOptions* run_options,
1303                    const DebugOptions* debug_options,
1304                    runtime::MemrefView operand, runtime::MemrefView a,
1305                    runtime::MemrefView workspace, runtime::MemrefView info,
1306                    int64_t batch_size, bool is_lower, int64_t n) const;
Handlerxla::gpu::__anon8ea4ed0d1311::Cholesky1307   static Cholesky Handler() { return Cholesky(); }
1308 };
1309 }  // namespace
1310 
operator ()(const ServiceExecutableRunOptions * run_options,const DebugOptions * debug_options,runtime::MemrefView operand,runtime::MemrefView a,runtime::MemrefView workspace,runtime::MemrefView info,int64_t batch_size,bool is_lower,int64_t n) const1311 Error Cholesky::operator()(const ServiceExecutableRunOptions* run_options,
1312                            const DebugOptions* debug_options,
1313                            runtime::MemrefView operand, runtime::MemrefView a,
1314                            runtime::MemrefView workspace,
1315                            runtime::MemrefView info, int64_t batch_size,
1316                            bool is_lower, int64_t n) const {
1317 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
1318   se::DeviceMemoryBase operand_buffer = GetDeviceAddress(operand);
1319   se::DeviceMemoryBase a_buffer = GetDeviceAddress(a);
1320   se::DeviceMemoryBase workspace_buffer = GetDeviceAddress(workspace);
1321   se::DeviceMemoryBase info_buffer = GetDeviceAddress(info);
1322 
1323   VLOG(3) << "Running Cholesky";
1324   se::Stream* stream = run_options->stream();
1325 
1326   // Copy operand to the a buffer if they are different.
1327   if (a.data != operand.data)
1328     stream->ThenMemcpy(&a_buffer, operand_buffer, operand_buffer.size());
1329 
1330   using UpperLower = se::blas::UpperLower;
1331   UpperLower uplo = is_lower ? UpperLower::kLower : UpperLower::kUpper;
1332 
1333   CholeskyParams params{n,        batch_size,       uplo,
1334                         a_buffer, workspace_buffer, info_buffer};
1335   auto executed =
1336       RunCholesky(xla::gpu::PtxOptsFromDebugOptions(*debug_options),
1337                   TfrtToPrimitiveType(operand.dtype), &params, stream);
1338   if (!executed.ok()) return AsError(executed);
1339 
1340   return Error::success();
1341 #else  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
1342   return failure();
1343 #endif
1344 }
1345 
Cholesky(runtime::KernelContext * ctx,void ** args,void ** attrs)1346 static bool Cholesky(runtime::KernelContext* ctx, void** args, void** attrs) {
1347   static auto* handler = CustomCall::Bind("xla.gpu.cholesky")
1348                              .UserData<const ServiceExecutableRunOptions*>()
1349                              .UserData<const DebugOptions*>()
1350                              .Arg<runtime::MemrefView>()  // operand
1351                              .Arg<runtime::MemrefView>()  // a
1352                              .Arg<runtime::MemrefView>()  // workspace
1353                              .Arg<runtime::MemrefView>()  // info
1354                              .Attr<int64_t>("batch_size")
1355                              .Attr<bool>("is_lower")
1356                              .Attr<int64_t>("n")
1357                              .To<RuntimeChecks()>(Cholesky::Handler())
1358                              .release();
1359 
1360   return succeeded(Executable::Call(ctx, *handler, args, attrs));
1361 }
1362 
1363 // -------------------------------------------------------------------------- //
1364 
1365 namespace {
1366 
1367 // TODO(ezhulenev): Today XLA represents TriangularSolve as a "classic" XLA
1368 // custom call operation, and we provide a thin adaptor from Xla custom call
1369 // to JitRt custom call. Once we are fully migrated to JitRt exectuion, XLA
1370 // compiler should directly emit properly typed TriangularSolve JitRt custom
1371 // call (no need to pass config via the serialized string).
1372 struct TriangularSolve {
1373   // Adaptor from XlaCustomCall API to properly typed TriangularSolve handler.
1374   static Error run(const ServiceExecutableRunOptions* run_options,
1375                    const DebugOptions* debug_options,
1376                    CustomCall::RemainingArgs args, StringRef backend_config);
1377 
1378   Error operator()(const ServiceExecutableRunOptions* run_options,
1379                    const DebugOptions* debug_options,
1380                    runtime::StridedMemrefView a, runtime::StridedMemrefView b,
1381                    runtime::StridedMemrefView result,
1382                    runtime::FlatMemrefView temp, bool left_side, bool lower,
1383                    bool unit_diagonal,
1384                    TriangularSolveOptions::Transpose transpose_a) const;
Handlerxla::gpu::__anon8ea4ed0d1411::TriangularSolve1385   static TriangularSolve Handler() { return TriangularSolve(); }
1386 };
1387 
1388 }  // namespace
1389 
run(const ServiceExecutableRunOptions * run_options,const DebugOptions * debug_options,CustomCall::RemainingArgs args,StringRef backend_config)1390 Error TriangularSolve::run(const ServiceExecutableRunOptions* run_options,
1391                            const DebugOptions* debug_options,
1392                            CustomCall::RemainingArgs args,
1393                            StringRef backend_config) {
1394   TriangularSolve handler = TriangularSolve::Handler();
1395 
1396   if (args.size() != 4)
1397     return MakeStringError("Expected 4 arguments, got %n", args.size());
1398 
1399   // Check if all arguments have the correct type.
1400   auto a = args.get<runtime::StridedMemrefView>(0);
1401   auto b = args.get<runtime::StridedMemrefView>(1);
1402   auto result = args.get<runtime::StridedMemrefView>(2);
1403   auto temp = args.get<runtime::FlatMemrefView>(3);
1404   if (failed(a) || failed(b) || failed(result) || failed(temp))
1405     return MakeStringError("Incorrect argument types");
1406 
1407   // Parse backend config string.
1408   TriangularSolveOptions opts;
1409   auto st = tensorflow::HumanReadableJsonToProto(backend_config.str(), &opts);
1410   if (!st.ok()) return AsError(st);
1411 
1412   return handler(run_options, debug_options, *a, *b, *result, *temp,
1413                  opts.left_side(), opts.lower(), opts.unit_diagonal(),
1414                  opts.transpose_a());
1415 }
1416 
operator ()(const ServiceExecutableRunOptions * run_options,const DebugOptions * debug_options,runtime::StridedMemrefView a,runtime::StridedMemrefView b,runtime::StridedMemrefView result,runtime::FlatMemrefView temp,bool left_side,bool lower,bool unit_diagonal,TriangularSolveOptions::Transpose transpose_a) const1417 Error TriangularSolve::operator()(
1418     const ServiceExecutableRunOptions* run_options,
1419     const DebugOptions* debug_options, runtime::StridedMemrefView a,
1420     runtime::StridedMemrefView b, runtime::StridedMemrefView result,
1421     runtime::FlatMemrefView temp, bool left_side, bool lower,
1422     bool unit_diagonal, TriangularSolveOptions::Transpose transpose_a) const {
1423 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
1424   se::Stream* stream = run_options->stream();
1425 
1426   se::DeviceMemoryBase a_data = GetDeviceAddress(a);
1427   se::DeviceMemoryBase b_data = GetDeviceAddress(b);
1428   se::DeviceMemoryBase result_data = GetDeviceAddress(result);
1429   se::DeviceMemoryBase temp_data = GetDeviceAddress(temp);
1430 
1431   // Triangular solve is in-place on 'b', so copy 'b' to the output if they
1432   // aren't the same buffer.
1433   if (b.data != result.data)
1434     stream->ThenMemcpy(&result_data, b_data, b_data.size());
1435 
1436   Shape b_shape = ToShape(b);
1437   int64_t m = b_shape.dimensions(b_shape.rank() - 2);
1438   int64_t n = b_shape.dimensions(b_shape.rank() - 1);
1439   int64_t batch_size = std::accumulate(
1440       b_shape.dimensions().begin(), b_shape.dimensions().end() - 2, int64_t{1},
1441       [](int64_t a, int64_t b) { return a * b; });
1442 
1443   PrimitiveType elem_type = TfrtToPrimitiveType(b.dtype);
1444   int64_t elem_size = ShapeUtil::ByteSizeOfPrimitiveType(elem_type);
1445   int64_t a_batch_stride = left_side ? m * m * elem_size : n * n * elem_size;
1446   int64_t b_batch_stride = m * n * elem_size;
1447 
1448   using Side = se::blas::Side;
1449   using Diagonal = se::blas::Diagonal;
1450   using Transpose = se::blas::Transpose;
1451   using UpperLower = se::blas::UpperLower;
1452 
1453   // Convert custom call attributes to se::blas enums.
1454   UpperLower uplo = lower ? UpperLower::kLower : UpperLower::kUpper;
1455   Side side = left_side ? Side::kLeft : Side::kRight;
1456   Diagonal diagonal = unit_diagonal ? Diagonal::kUnit : Diagonal::kNonUnit;
1457 
1458   auto transpose = [&]() -> mlir::FailureOr<Transpose> {
1459     switch (transpose_a) {
1460       case TriangularSolveOptions::NO_TRANSPOSE:
1461         return se::blas::Transpose::kNoTranspose;
1462       case TriangularSolveOptions::TRANSPOSE:
1463         return se::blas::Transpose::kTranspose;
1464       case TriangularSolveOptions::ADJOINT:
1465         return se::blas::Transpose::kConjugateTranspose;
1466       default:
1467         return failure();
1468     }
1469   }();
1470 
1471   if (failed(transpose))
1472     return MakeStringError("Failed to convert transpose type");
1473 
1474   auto st = RunTriangulatSolve(
1475       a_data, result_data, temp_data, PtxOptsFromDebugOptions(*debug_options),
1476       uplo, side, diagonal, *transpose, elem_type, batch_size, m, n,
1477       a_batch_stride, b_batch_stride, stream);
1478   if (!st.ok()) return AsError(st);
1479 
1480   return Error::success();
1481 #else  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
1482   return failure();
1483 #endif
1484 }
1485 
1486 // -------------------------------------------------------------------------- //
1487 // Implements JitRt custom call that forward to the Xla Custom Call handler.
1488 //
1489 // Longer term all Xla custom calls probably should be directly implemented as
1490 // JitRt custom calls. However for smooth migration from Thunks to JitRt we have
1491 // to seamlessly support all current XLA users.
1492 namespace {
1493 struct XlaCustomCall {
1494   using Stream = se::gpu::GpuStreamHandle;
1495 
1496   Error operator()(const ServiceExecutableRunOptions* run_options,
1497                    const DebugOptions* debug_options,
1498                    CustomCall::RemainingArgs args, StringRef call_target_name,
1499                    int32_t api_version, StringRef backend_config) const;
Handlerxla::gpu::__anon8ea4ed0d1711::XlaCustomCall1500   static XlaCustomCall Handler() { return XlaCustomCall(); }
1501 };
1502 }  // namespace
1503 
operator ()(const ServiceExecutableRunOptions * run_options,const DebugOptions * debug_options,CustomCall::RemainingArgs args,StringRef call_target_name,int32_t api_version,StringRef backend_config) const1504 Error XlaCustomCall::operator()(const ServiceExecutableRunOptions* run_options,
1505                                 const DebugOptions* debug_options,
1506                                 CustomCall::RemainingArgs args,
1507                                 StringRef call_target_name, int32_t api_version,
1508                                 StringRef backend_config) const {
1509   // Pattern match custom call to a few special cases, otherwise find the custom
1510   // call handler regustered with the runtime.
1511   if (call_target_name == kTriangularSolveCallTarget)
1512     return TriangularSolve::run(run_options, debug_options, args,
1513                                 backend_config);
1514 
1515   // Find the Xla custom call handler.
1516   auto& platform_name = run_options->stream()->parent()->platform()->Name();
1517   void* call_target = CustomCallTargetRegistry::Global()->Lookup(
1518       call_target_name.str(), platform_name);
1519   if (!call_target) {
1520     return MakeStringError("Cannot find the Xla custom call handler ",
1521                            call_target_name.str());
1522   }
1523 
1524   // Prepare pointers to buffers to pass to the Xla custom call handler.
1525   llvm::SmallVector<void*> buffers;
1526   for (unsigned i = 0; i < args.size(); ++i) {
1527     // We use zero-sized memrefs to represent holes in custom calls with target
1528     // arguments mapping (see `CustomCallTargetArgMapping`).
1529     if (auto memref = args.get<runtime::FlatMemrefView>(i); succeeded(memref)) {
1530       buffers.push_back(memref->size_in_bytes == 0 ? nullptr : memref->data);
1531       continue;
1532     }
1533     if (auto strided = args.get<runtime::StridedMemrefView>(i);
1534         succeeded(strided)) {
1535       int64_t size_in_bytes = GetHostSize(strided->dtype);
1536       for (int64_t size : strided->sizes) size_in_bytes *= size;
1537       buffers.push_back(size_in_bytes == 0 ? nullptr : strided->data);
1538       continue;
1539     }
1540     return MakeStringError("Failed to get arguments as (strided) memref view");
1541   }
1542 
1543   // Original custom call API version that doesn't support returning status.
1544   if (api_version == CustomCallApiVersion::API_VERSION_ORIGINAL) {
1545     using XlaCustomCallType = void (*)(Stream, void**, const char*, size_t);
1546     auto xla_call_target = reinterpret_cast<XlaCustomCallType>(call_target);
1547 
1548     xla_call_target(se::gpu::AsGpuStreamValue(run_options->stream()),
1549                     buffers.data(), backend_config.data(),
1550                     backend_config.size());
1551 
1552     return Error::success();
1553   }
1554 
1555   // Xla Custom call API returning status.
1556   if (api_version == CustomCallApiVersion::API_VERSION_STATUS_RETURNING) {
1557     using XlaCustomCallType =
1558         void (*)(Stream, void**, const char*, size_t, XlaCustomCallStatus*);
1559     auto xla_call_target = reinterpret_cast<XlaCustomCallType>(call_target);
1560 
1561     XlaCustomCallStatus custom_call_status;
1562     xla_call_target(se::gpu::AsGpuStreamValue(run_options->stream()),
1563                     buffers.data(), backend_config.data(),
1564                     backend_config.size(), &custom_call_status);
1565 
1566     if (auto message = CustomCallStatusGetMessage(&custom_call_status)) {
1567       return MakeStringError(message.value());
1568     } else {
1569       return Error::success();
1570     }
1571   }
1572 
1573   return MakeStringError("Incorrect custom call API version");
1574 }
1575 
CustomCall(runtime::KernelContext * ctx,void ** args,void ** attrs)1576 static bool CustomCall(runtime::KernelContext* ctx, void** args, void** attrs) {
1577   static auto* handler = CustomCall::Bind("xla.gpu.memcpy")
1578                              .UserData<const ServiceExecutableRunOptions*>()
1579                              .UserData<const DebugOptions*>()
1580                              .Arg<CustomCall::RemainingArgs>()  // args
1581                              .Attr<StringRef>("call_target_name")
1582                              .Attr<int32_t>("api_version")
1583                              .Attr<StringRef>("backend_config")
1584                              .To<RuntimeChecks()>(XlaCustomCall::Handler())
1585                              .release();
1586 
1587   return succeeded(Executable::Call(ctx, *handler, args, attrs));
1588 }
1589 
1590 // ------------------------------------------------------------------------- //
1591 
1592 namespace {
1593 struct AllReduce {
1594   LLVM_ATTRIBUTE_ALWAYS_INLINE
1595   Error operator()(const ServiceExecutableRunOptions* run_options,
1596                    JitRtCollectiveSupport* collectives,
1597                    CustomCall::RemainingArgs args, int32_t uid,
1598                    int64_t group_mode, int64_t op_id, int64_t reduction_kind,
1599                    ArrayRef<int64_t> replica_group_offsets,
1600                    ArrayRef<int64_t> replica_group_values) const;
Handlerxla::gpu::__anon8ea4ed0d1811::AllReduce1601   static AllReduce Handler() { return AllReduce(); }
1602 };
1603 }  // namespace
1604 
operator ()(const ServiceExecutableRunOptions * run_options,JitRtCollectiveSupport * collectives,CustomCall::RemainingArgs args,int32_t uid,int64_t group_mode,int64_t op_id,int64_t reduction_kind,ArrayRef<int64_t> replica_group_offsets,ArrayRef<int64_t> replica_group_values) const1605 Error AllReduce::operator()(const ServiceExecutableRunOptions* run_options,
1606                             JitRtCollectiveSupport* collectives,
1607                             CustomCall::RemainingArgs args, int32_t uid,
1608                             int64_t group_mode, int64_t op_id,
1609                             int64_t reduction_kind,
1610                             ArrayRef<int64_t> replica_group_offsets,
1611                             ArrayRef<int64_t> replica_group_values) const {
1612 #if XLA_ENABLE_XCCL
1613   VLOG(3) << "Running AllReduce";
1614   se::Stream* stream = run_options->stream();
1615   NcclExecuteParams params(*run_options, stream);
1616 
1617   auto comm = GetNcclComm(params, group_mode, op_id, replica_group_offsets,
1618                           replica_group_values);
1619   if (failed(comm)) return MakeStringError("Failed to get NcclComm");
1620 
1621   auto device_buffers = GetDeviceBufferPairs(args);
1622   if (failed(device_buffers))
1623     return MakeStringError("Failed to get device buffers");
1624 
1625   auto executed = RunAllReduce(static_cast<ReductionKind>(reduction_kind),
1626                                *device_buffers, *stream, **comm);
1627   if (!executed.ok()) return AsError(executed);
1628 
1629   int32_t device_ordinal = stream->parent()->device_ordinal();
1630   auto st = collectives->MaybeBlockAfterFirstRun(uid, device_ordinal, stream);
1631   if (!st.ok()) return AsError(st);
1632 
1633   return Error::success();
1634 #else   // XLA_ENABLE_XCCL
1635   // NCCL disabled.
1636   return MakeStringError("NCCL disabled");
1637 #endif  // XLA_ENABLE_XCCL
1638 }
1639 
AllReduce(runtime::KernelContext * ctx,void ** args,void ** attrs)1640 static bool AllReduce(runtime::KernelContext* ctx, void** args, void** attrs) {
1641   static auto* handler =
1642       CustomCall::Bind("xla.gpu.all_reduce")
1643           .UserData<const ServiceExecutableRunOptions*>()
1644           .UserData<JitRtCollectiveSupport*>()
1645           .RemainingArgs()  // args
1646           .Attr<int32_t>("uid")
1647           .Attr<int64_t>("group_mode")  // CollectiveOpGroupMode
1648           .Attr<int64_t>("op_id")
1649           .Attr<int64_t>("reduction_kind")  // ReductionKind
1650           .Attr<ArrayRef<int64_t>>("replica_group_offsets")
1651           .Attr<ArrayRef<int64_t>>("replica_group_values")
1652           .To<RuntimeChecks()>(AllReduce::Handler())
1653           .release();
1654 
1655   return succeeded(Executable::Call(ctx, *handler, args, attrs));
1656 }
1657 
1658 // ------------------------------------------------------------------------- //
1659 
1660 namespace {
1661 struct AllReduceStart {
1662   LLVM_ATTRIBUTE_ALWAYS_INLINE
1663   Error operator()(const ServiceExecutableRunOptions* run_options,
1664                    JitRtAsyncCollectiveSupport* async_collectives,
1665                    CustomCall::RemainingArgs args, int64_t group_mode,
1666                    int64_t op_id, int64_t reduction_kind,
1667                    ArrayRef<int64_t> replica_group_offsets,
1668                    ArrayRef<int64_t> replica_group_values, int32_t uid) const;
Handlerxla::gpu::__anon8ea4ed0d1911::AllReduceStart1669   static AllReduceStart Handler() { return AllReduceStart(); }
1670 };
1671 }  // namespace
1672 
operator ()(const ServiceExecutableRunOptions * run_options,JitRtAsyncCollectiveSupport * async_collectives,CustomCall::RemainingArgs args,int64_t group_mode,int64_t op_id,int64_t reduction_kind,ArrayRef<int64_t> replica_group_offsets,ArrayRef<int64_t> replica_group_values,int32_t uid) const1673 Error AllReduceStart::operator()(const ServiceExecutableRunOptions* run_options,
1674                                  JitRtAsyncCollectiveSupport* async_collectives,
1675                                  CustomCall::RemainingArgs args,
1676                                  int64_t group_mode, int64_t op_id,
1677                                  int64_t reduction_kind,
1678                                  ArrayRef<int64_t> replica_group_offsets,
1679                                  ArrayRef<int64_t> replica_group_values,
1680                                  int32_t uid) const {
1681 #if XLA_ENABLE_XCCL
1682   VLOG(3) << "Running AllReduceStart";
1683   se::Stream* stream = run_options->stream();
1684   NcclExecuteParams params(*run_options, stream);
1685 
1686   auto comm = GetNcclComm(params, group_mode, op_id, replica_group_offsets,
1687                           replica_group_values);
1688   if (failed(comm)) return MakeStringError("Failed to get NcclComm");
1689 
1690   auto device_buffers = GetDeviceBufferPairs(args);
1691   if (failed(device_buffers))
1692     return MakeStringError("Failed to get device buffers");
1693 
1694   // Wait until compute inputs are ready.
1695   async_collectives->async_comm_stream()->ThenWaitFor(params.stream);
1696 
1697   auto executed =
1698       RunAllReduce(static_cast<ReductionKind>(reduction_kind), *device_buffers,
1699                    *async_collectives->async_comm_stream(), **comm);
1700   if (!executed.ok()) return AsError(executed);
1701 
1702   // Create an event on the async stream for the completion of the all-reduce.
1703   se::Event done_event(async_collectives->async_comm_stream()->parent());
1704   if (!done_event.Init()) return MakeStringError("Failed to create event");
1705   async_collectives->async_comm_stream()->ThenRecordEvent(&done_event);
1706 
1707   if (failed(async_collectives->PushEvent(
1708           uid, stream->parent()->device_ordinal(), std::move(done_event))))
1709     return MakeStringError("Failed to push event to async collectives");
1710 
1711   return Error::success();
1712 #else   // XLA_ENABLE_XCCL
1713   return MakeStringError("NCCL disabled");
1714 #endif  // XLA_ENABLE_XCCL
1715 }
1716 
AllReduceStart(runtime::KernelContext * ctx,void ** args,void ** attrs)1717 static bool AllReduceStart(runtime::KernelContext* ctx, void** args,
1718                            void** attrs) {
1719   static auto* handler =
1720       CustomCall::Bind("xla.gpu.all_reduce_start")
1721           .UserData<const ServiceExecutableRunOptions*>()
1722           .UserData<JitRtAsyncCollectiveSupport*>()
1723           .RemainingArgs()              // args
1724           .Attr<int64_t>("group_mode")  // CollectiveOpGroupMode
1725           .Attr<int64_t>("op_id")
1726           .Attr<int64_t>("reduction_kind")  // ReductionKind
1727           .Attr<ArrayRef<int64_t>>("replica_group_offsets")
1728           .Attr<ArrayRef<int64_t>>("replica_group_values")
1729           .Attr<int32_t>("uid")
1730           .To<RuntimeChecks()>(AllReduceStart::Handler())
1731           .release();
1732 
1733   return succeeded(Executable::Call(ctx, *handler, args, attrs));
1734 }
1735 
1736 // ------------------------------------------------------------------------- //
1737 
1738 namespace {
1739 struct AllReduceDone {
1740   LLVM_ATTRIBUTE_ALWAYS_INLINE
1741   Error operator()(const ServiceExecutableRunOptions* run_options,
1742                    JitRtCollectiveSupport* collectives,
1743                    JitRtAsyncCollectiveSupport* async_collectives,
1744                    CustomCall::RemainingArgs args, int32_t uid) const;
Handlerxla::gpu::__anon8ea4ed0d1a11::AllReduceDone1745   static AllReduceDone Handler() { return AllReduceDone(); }
1746 };
1747 }  // namespace
1748 
operator ()(const ServiceExecutableRunOptions * run_options,JitRtCollectiveSupport * collectives,JitRtAsyncCollectiveSupport * async_collectives,CustomCall::RemainingArgs args,int32_t uid) const1749 Error AllReduceDone::operator()(const ServiceExecutableRunOptions* run_options,
1750                                 JitRtCollectiveSupport* collectives,
1751                                 JitRtAsyncCollectiveSupport* async_collectives,
1752                                 CustomCall::RemainingArgs args,
1753                                 int32_t uid) const {
1754 #if XLA_ENABLE_XCCL
1755   VLOG(3) << "Running AllReduceDone";
1756   se::Stream* stream = run_options->stream();
1757 
1758   int32_t device_ordinal = stream->parent()->device_ordinal();
1759   auto event = async_collectives->PopEvent(uid, device_ordinal);
1760   if (failed(event)) return MakeStringError("Failed to pop event");
1761 
1762   stream->ThenWaitFor(&*event);
1763 
1764   if (!collectives->MaybeBlockAfterFirstRun(uid, device_ordinal, stream).ok())
1765     return MakeStringError("Failed to block host");
1766 
1767   return Error::success();
1768 #else   // XLA_ENABLE_XCCL
1769   return MakeStringError("NCCL disabled");
1770 #endif  // XLA_ENABLE_XCCL
1771 }
1772 
AllReduceDone(runtime::KernelContext * ctx,void ** args,void ** attrs)1773 static bool AllReduceDone(runtime::KernelContext* ctx, void** args,
1774                           void** attrs) {
1775   static auto* handler = CustomCall::Bind("xla.gpu.all_reduce_done")
1776                              .UserData<const ServiceExecutableRunOptions*>()
1777                              .UserData<JitRtCollectiveSupport*>()
1778                              .UserData<JitRtAsyncCollectiveSupport*>()
1779                              .RemainingArgs()  // args
1780                              .Attr<int32_t>("uid")
1781                              .To<RuntimeChecks()>(AllReduceDone::Handler())
1782                              .release();
1783 
1784   return succeeded(Executable::Call(ctx, *handler, args, attrs));
1785 }
1786 
1787 // -------------------------------------------------------------------------- //
1788 
1789 namespace {
1790 struct ReduceScatter {
1791   LLVM_ATTRIBUTE_ALWAYS_INLINE
1792   Error operator()(const ServiceExecutableRunOptions* run_options,
1793                    JitRtCollectiveSupport* collectives,
1794                    CustomCall::RemainingArgs args, int32_t uid,
1795                    int64_t group_mode, int64_t op_id, int64_t reduction_kind,
1796                    ArrayRef<int64_t> replica_group_offsets,
1797                    ArrayRef<int64_t> replica_group_values) const;
Handlerxla::gpu::__anon8ea4ed0d1b11::ReduceScatter1798   static ReduceScatter Handler() { return ReduceScatter(); }
1799 };
1800 }  // namespace
1801 
operator ()(const ServiceExecutableRunOptions * run_options,JitRtCollectiveSupport * collectives,CustomCall::RemainingArgs args,int32_t uid,int64_t group_mode,int64_t op_id,int64_t reduction_kind,ArrayRef<int64_t> replica_group_offsets,ArrayRef<int64_t> replica_group_values) const1802 Error ReduceScatter::operator()(const ServiceExecutableRunOptions* run_options,
1803                                 JitRtCollectiveSupport* collectives,
1804                                 CustomCall::RemainingArgs args, int32_t uid,
1805                                 int64_t group_mode, int64_t op_id,
1806                                 int64_t reduction_kind,
1807                                 ArrayRef<int64_t> replica_group_offsets,
1808                                 ArrayRef<int64_t> replica_group_values) const {
1809 #if XLA_ENABLE_XCCL
1810   VLOG(3) << "Running ReduceScatter";
1811   se::Stream* stream = run_options->stream();
1812   NcclExecuteParams params(*run_options, stream);
1813 
1814   auto comm = GetNcclComm(params, group_mode, op_id, replica_group_offsets,
1815                           replica_group_values);
1816   if (failed(comm)) return MakeStringError("Failed to get NcclComm");
1817 
1818   auto device_buffers = GetDeviceBufferPairs(args);
1819   if (failed(device_buffers))
1820     return MakeStringError("Failed to get device buffers");
1821 
1822   auto executed = RunReduceScatter(static_cast<ReductionKind>(reduction_kind),
1823                                    *device_buffers, *stream, **comm);
1824   if (!executed.ok()) return AsError(executed);
1825 
1826   int32_t device_ordinal = stream->parent()->device_ordinal();
1827   if (!collectives->MaybeBlockAfterFirstRun(uid, device_ordinal, stream).ok())
1828     return MakeStringError("Failed to block host");
1829 
1830   return Error::success();
1831 #else   // XLA_ENABLE_XCCL
1832   return MakeStringError("NCCL disabled");
1833 #endif  // XLA_ENABLE_XCCL
1834 }
1835 
ReduceScatter(runtime::KernelContext * ctx,void ** args,void ** attrs)1836 static bool ReduceScatter(runtime::KernelContext* ctx, void** args,
1837                           void** attrs) {
1838   static auto* handler =
1839       CustomCall::Bind("xla.gpu.reduce_scatter")
1840           .UserData<const ServiceExecutableRunOptions*>()
1841           .UserData<JitRtCollectiveSupport*>()
1842           .RemainingArgs()  // args
1843           .Attr<int32_t>("uid")
1844           .Attr<int64_t>("group_mode")  // CollectiveOpGroupMode
1845           .Attr<int64_t>("op_id")
1846           .Attr<int64_t>("reduction_kind")  // ReductionKind
1847           .Attr<ArrayRef<int64_t>>("replica_group_offsets")
1848           .Attr<ArrayRef<int64_t>>("replica_group_values")
1849           .To<RuntimeChecks()>(ReduceScatter::Handler())
1850           .release();
1851 
1852   return succeeded(Executable::Call(ctx, *handler, args, attrs));
1853 }
1854 
1855 // -------------------------------------------------------------------------- //
1856 
1857 namespace {
1858 struct AllGather {
1859   LLVM_ATTRIBUTE_ALWAYS_INLINE
1860   Error operator()(const ServiceExecutableRunOptions* run_options,
1861                    JitRtCollectiveSupport* collectives,
1862                    CustomCall::RemainingArgs args, int32_t uid,
1863                    int64_t group_mode, int64_t op_id,
1864                    ArrayRef<int64_t> replica_group_offsets,
1865                    ArrayRef<int64_t> replica_group_values) const;
Handlerxla::gpu::__anon8ea4ed0d1c11::AllGather1866   static AllGather Handler() { return AllGather(); }
1867 };
1868 }  // namespace
1869 
operator ()(const ServiceExecutableRunOptions * run_options,JitRtCollectiveSupport * collectives,CustomCall::RemainingArgs args,int32_t uid,int64_t group_mode,int64_t op_id,ArrayRef<int64_t> replica_group_offsets,ArrayRef<int64_t> replica_group_values) const1870 Error AllGather::operator()(const ServiceExecutableRunOptions* run_options,
1871                             JitRtCollectiveSupport* collectives,
1872                             CustomCall::RemainingArgs args, int32_t uid,
1873                             int64_t group_mode, int64_t op_id,
1874                             ArrayRef<int64_t> replica_group_offsets,
1875                             ArrayRef<int64_t> replica_group_values) const {
1876 #if XLA_ENABLE_XCCL
1877   VLOG(3) << "Running AllGather";
1878   se::Stream* stream = run_options->stream();
1879   NcclExecuteParams params(*run_options, stream);
1880 
1881   auto comm = GetNcclComm(params, group_mode, op_id, replica_group_offsets,
1882                           replica_group_values);
1883   if (failed(comm)) return MakeStringError("Failed to get NCCL comm");
1884 
1885   auto device_buffers = GetDeviceBufferPairs(args);
1886   if (failed(device_buffers))
1887     return MakeStringError("Failed to get device buffers");
1888 
1889   auto st = RunAllGather(*device_buffers, *stream, **comm);
1890   if (!st.ok()) return AsError(st);
1891 
1892   int32_t device_ordinal = stream->parent()->device_ordinal();
1893   st = collectives->MaybeBlockAfterFirstRun(uid, device_ordinal, stream);
1894   if (!st.ok()) return AsError(st);
1895 
1896   return Error::success();
1897 #else   // XLA_ENABLE_XCCL
1898   return MakeStringError("NCCL diasbled");
1899 #endif  // XLA_ENABLE_XCCL
1900 }
1901 
AllGather(runtime::KernelContext * ctx,void ** args,void ** attrs)1902 static bool AllGather(runtime::KernelContext* ctx, void** args, void** attrs) {
1903   static auto* handler =
1904       CustomCall::Bind("xla.gpu.all_gather")
1905           .UserData<const ServiceExecutableRunOptions*>()
1906           .UserData<JitRtCollectiveSupport*>()
1907           .RemainingArgs()  // args
1908           .Attr<int32_t>("uid")
1909           .Attr<int64_t>("group_mode")  // CollectiveOpGroupMode
1910           .Attr<int64_t>("op_id")
1911           .Attr<ArrayRef<int64_t>>("replica_group_offsets")
1912           .Attr<ArrayRef<int64_t>>("replica_group_values")
1913           .To<RuntimeChecks()>(AllGather::Handler())
1914           .release();
1915 
1916   return succeeded(Executable::Call(ctx, *handler, args, attrs));
1917 }
1918 
1919 // -------------------------------------------------------------------------- //
1920 
1921 namespace {
1922 struct AllToAll {
1923   LLVM_ATTRIBUTE_ALWAYS_INLINE
1924   Error operator()(const ServiceExecutableRunOptions* run_options,
1925                    JitRtCollectiveSupport* collectives,
1926                    CustomCall::RemainingArgs args, int32_t uid,
1927                    int64_t group_mode, bool has_split_dimension, int64_t op_id,
1928                    ArrayRef<int64_t> replica_group_offsets,
1929                    ArrayRef<int64_t> replica_group_values) const;
Handlerxla::gpu::__anon8ea4ed0d1d11::AllToAll1930   static AllToAll Handler() { return AllToAll(); }
1931 };
1932 }  // namespace
1933 
operator ()(const ServiceExecutableRunOptions * run_options,JitRtCollectiveSupport * collectives,CustomCall::RemainingArgs args,int32_t uid,int64_t group_mode,bool has_split_dimension,int64_t op_id,ArrayRef<int64_t> replica_group_offsets,ArrayRef<int64_t> replica_group_values) const1934 Error AllToAll::operator()(const ServiceExecutableRunOptions* run_options,
1935                            JitRtCollectiveSupport* collectives,
1936                            CustomCall::RemainingArgs args, int32_t uid,
1937                            int64_t group_mode, bool has_split_dimension,
1938                            int64_t op_id,
1939                            ArrayRef<int64_t> replica_group_offsets,
1940                            ArrayRef<int64_t> replica_group_values) const {
1941 #if XLA_ENABLE_XCCL
1942   VLOG(3) << "Running AllToAll";
1943   se::Stream* stream = run_options->stream();
1944   NcclExecuteParams params(*run_options, stream);
1945 
1946   auto comm = GetNcclComm(params, group_mode, op_id, replica_group_offsets,
1947                           replica_group_values);
1948   if (failed(comm)) return MakeStringError("Failed to get NCCL comm");
1949 
1950   auto device_buffers = GetDeviceBufferPairs(args);
1951   if (failed(device_buffers))
1952     return MakeStringError("Failed to get device buffers");
1953 
1954   auto st = RunAllToAll(has_split_dimension, *device_buffers, *stream, **comm);
1955   if (!st.ok()) return AsError(st);
1956 
1957   int32_t device_ordinal = stream->parent()->device_ordinal();
1958   st = collectives->MaybeBlockAfterFirstRun(uid, device_ordinal, stream);
1959   if (!st.ok()) return AsError(st);
1960 
1961   return Error::success();
1962 #else   // XLA_ENABLE_XCCL
1963   return MakeStringError("NCCL disabled");
1964 #endif  // XLA_ENABLE_XCCL
1965 }
1966 
AllToAll(runtime::KernelContext * ctx,void ** args,void ** attrs)1967 static bool AllToAll(runtime::KernelContext* ctx, void** args, void** attrs) {
1968   static auto* handler =
1969       CustomCall::Bind("xla.gpu.all_to_all")
1970           .UserData<const ServiceExecutableRunOptions*>()
1971           .UserData<JitRtCollectiveSupport*>()
1972           .RemainingArgs()  // args
1973           .Attr<int32_t>("uid")
1974           .Attr<int64_t>("group_mode")  // CollectiveOpGroupMode
1975           .Attr<bool>("has_split_dimension")
1976           .Attr<int64_t>("op_id")
1977           .Attr<ArrayRef<int64_t>>("replica_group_offsets")
1978           .Attr<ArrayRef<int64_t>>("replica_group_values")
1979           .To<RuntimeChecks()>(AllToAll::Handler())
1980           .release();
1981 
1982   return succeeded(Executable::Call(ctx, *handler, args, attrs));
1983 }
1984 
1985 // -------------------------------------------------------------------------- //
1986 
1987 namespace {
1988 struct CollectivePermute {
1989   LLVM_ATTRIBUTE_ALWAYS_INLINE
1990   Error operator()(const ServiceExecutableRunOptions* run_options,
1991                    JitRtCollectiveSupport* collectives,
1992                    CustomCall::RemainingArgs args, int32_t uid,
1993                    int64_t group_mode, int64_t op_id,
1994                    ArrayRef<int64_t> replica_group_offsets,
1995                    ArrayRef<int64_t> replica_group_values,
1996                    ArrayRef<int64_t> source_peers,
1997                    ArrayRef<int64_t> target_peers) const;
Handlerxla::gpu::__anon8ea4ed0d1e11::CollectivePermute1998   static CollectivePermute Handler() { return CollectivePermute(); }
1999 };
2000 }  // namespace
2001 
operator ()(const ServiceExecutableRunOptions * run_options,JitRtCollectiveSupport * collectives,CustomCall::RemainingArgs args,int32_t uid,int64_t group_mode,int64_t op_id,ArrayRef<int64_t> replica_group_offsets,ArrayRef<int64_t> replica_group_values,ArrayRef<int64_t> source_peers,ArrayRef<int64_t> target_peers) const2002 Error CollectivePermute::operator()(
2003     const ServiceExecutableRunOptions* run_options,
2004     JitRtCollectiveSupport* collectives, CustomCall::RemainingArgs args,
2005     int32_t uid, int64_t group_mode, int64_t op_id,
2006     ArrayRef<int64_t> replica_group_offsets,
2007     ArrayRef<int64_t> replica_group_values, ArrayRef<int64_t> source_peers,
2008     ArrayRef<int64_t> target_peers) const {
2009 #if XLA_ENABLE_XCCL
2010   VLOG(3) << "Running CollectivePermute";
2011   se::Stream* stream = run_options->stream();
2012   NcclExecuteParams params(*run_options, stream);
2013 
2014   auto comm = GetNcclComm(params, group_mode, op_id, replica_group_offsets,
2015                           replica_group_values);
2016   if (failed(comm)) return MakeStringError("Failed to get NcclComm");
2017 
2018   auto device_buffers = GetDeviceBufferPairs(args);
2019   if (failed(device_buffers))
2020     return MakeStringError("Failed to get device buffers");
2021   if (device_buffers->size() != 1) {
2022     return MakeStringError("Expected device buffer size: 1, got ",
2023                            device_buffers->size());
2024   }
2025 
2026   StatusOr<GlobalDeviceId> global_device_id = params.GetGlobalDeviceId();
2027   if (!global_device_id.ok()) return AsError(global_device_id);
2028 
2029   StatusOr<DeviceAssignment::LogicalID> current_logical_id =
2030       params.device_assn->LogicalIdForDevice(global_device_id.value());
2031   if (!current_logical_id.ok()) return AsError(current_logical_id);
2032 
2033   const int64_t current_id = static_cast<CollectiveOpGroupMode>(group_mode) ==
2034                                      CollectiveOpGroupMode::kCrossReplica
2035                                  ? current_logical_id.value().replica_id
2036                                  : current_logical_id.value().computation_id;
2037   std::string device_string = NcclCollectiveThunk::GetDeviceString(params);
2038 
2039   NcclCollectivePermuteConfig::IdToSourceTargetMap id_to_source_target;
2040   for (int i = 0; i < source_peers.size(); ++i) {
2041     id_to_source_target.insert({target_peers[i], {}}).first->second.source =
2042         source_peers[i];
2043     id_to_source_target.insert({source_peers[i], {}}).first->second.target =
2044         target_peers[i];
2045   }
2046   const NcclCollectivePermuteConfig::SourceTargetMapEntry source_target =
2047       NcclCollectivePermuteConfig::GetSourceTarget(id_to_source_target,
2048                                                    current_id);
2049 
2050   auto executed =
2051       RunCollectivePermute(source_target, (*device_buffers)[0], *stream, **comm,
2052                            device_string, current_id);
2053   if (!executed.ok()) return AsError(executed);
2054 
2055   int32_t device_ordinal = stream->parent()->device_ordinal();
2056   auto st = collectives->MaybeBlockAfterFirstRun(uid, device_ordinal, stream);
2057   if (!st.ok()) return AsError(st);
2058 
2059   return Error::success();
2060 #else   // XLA_ENABLE_XCCL
2061   return MakeStringError("NCCL disabled");
2062 #endif  // XLA_ENABLE_XCCL
2063 }
2064 
CollectivePermute(runtime::KernelContext * ctx,void ** args,void ** attrs)2065 static bool CollectivePermute(runtime::KernelContext* ctx, void** args,
2066                               void** attrs) {
2067   static auto* handler =
2068       CustomCall::Bind("xla.gpu.collective_permute")
2069           .UserData<const ServiceExecutableRunOptions*>()
2070           .UserData<JitRtCollectiveSupport*>()
2071           .RemainingArgs()  // args
2072           .Attr<int32_t>("uid")
2073           .Attr<int64_t>("group_mode")  // CollectiveOpGroupMode
2074           .Attr<int64_t>("op_id")
2075           .Attr<ArrayRef<int64_t>>("replica_group_offsets")
2076           .Attr<ArrayRef<int64_t>>("replica_group_values")
2077           .Attr<ArrayRef<int64_t>>("source_peers")
2078           .Attr<ArrayRef<int64_t>>("target_peers")
2079           .To<RuntimeChecks()>(CollectivePermute::Handler())
2080           .release();
2081 
2082   return succeeded(Executable::Call(ctx, *handler, args, attrs));
2083 }
2084 
2085 // -------------------------------------------------------------------------- //
2086 
2087 namespace {
2088 struct ReplicaId {
2089   LLVM_ATTRIBUTE_ALWAYS_INLINE
2090   Error operator()(const ServiceExecutableRunOptions* run_options,
2091                    runtime::FlatMemrefView result) const;
Handlerxla::gpu::__anon8ea4ed0d1f11::ReplicaId2092   static ReplicaId Handler() { return ReplicaId(); }
2093 };
2094 }  // namespace
2095 
operator ()(const ServiceExecutableRunOptions * run_options,runtime::FlatMemrefView result) const2096 Error ReplicaId::operator()(const ServiceExecutableRunOptions* run_options,
2097                             runtime::FlatMemrefView result) const {
2098   VLOG(3) << "Running ReplicaId";
2099   se::Stream* stream = run_options->stream();
2100   NcclExecuteParams params(*run_options, stream);
2101 
2102   StatusOr<GlobalDeviceId> global_device_id = params.GetGlobalDeviceId();
2103   if (!global_device_id.ok()) return AsError(global_device_id);
2104 
2105   StatusOr<DeviceAssignment::LogicalID> logical_id =
2106       params.device_assn->LogicalIdForDevice(global_device_id.value());
2107   if (!logical_id.ok()) return AsError(logical_id);
2108 
2109   se::DeviceMemoryBase result_data = GetDeviceAddress(result);
2110   params.stream->ThenMemset32(&result_data, logical_id.value().replica_id,
2111                               /*size=*/4);
2112 
2113   return Error::success();
2114 }
2115 
ReplicaId(runtime::KernelContext * ctx,void ** args,void ** attrs)2116 static bool ReplicaId(runtime::KernelContext* ctx, void** args, void** attrs) {
2117   static auto* handler = CustomCall::Bind("xla.gpu.replica_id")
2118                              .UserData<const ServiceExecutableRunOptions*>()
2119                              .Arg<runtime::FlatMemrefView>()  // result
2120                              .To<RuntimeChecks()>(ReplicaId::Handler())
2121                              .release();
2122 
2123   return succeeded(Executable::Call(ctx, *handler, args, attrs));
2124 }
2125 
2126 // -------------------------------------------------------------------------- //
2127 
2128 namespace {
2129 struct PartitionId {
2130   LLVM_ATTRIBUTE_ALWAYS_INLINE
2131   Error operator()(const ServiceExecutableRunOptions* run_options,
2132                    runtime::FlatMemrefView result) const;
Handlerxla::gpu::__anon8ea4ed0d2011::PartitionId2133   static PartitionId Handler() { return PartitionId(); }
2134 };
2135 }  // namespace
2136 
operator ()(const ServiceExecutableRunOptions * run_options,runtime::FlatMemrefView result) const2137 Error PartitionId::operator()(const ServiceExecutableRunOptions* run_options,
2138                               runtime::FlatMemrefView result) const {
2139   VLOG(3) << "Running PartitionId";
2140   se::Stream* stream = run_options->stream();
2141   NcclExecuteParams params(*run_options, stream);
2142 
2143   StatusOr<GlobalDeviceId> global_device_id = params.GetGlobalDeviceId();
2144   if (!global_device_id.ok()) return AsError(global_device_id);
2145 
2146   StatusOr<DeviceAssignment::LogicalID> logical_id =
2147       params.device_assn->LogicalIdForDevice(global_device_id.value());
2148   if (!logical_id.ok()) return AsError(logical_id);
2149 
2150   se::DeviceMemoryBase result_data = GetDeviceAddress(result);
2151   params.stream->ThenMemset32(&result_data, logical_id.value().computation_id,
2152                               /*size=*/4);
2153 
2154   return Error::success();
2155 }
2156 
PartitionId(runtime::KernelContext * ctx,void ** args,void ** attrs)2157 static bool PartitionId(runtime::KernelContext* ctx, void** args,
2158                         void** attrs) {
2159   static auto* handler = CustomCall::Bind("xla.gpu.partition_id")
2160                              .UserData<const ServiceExecutableRunOptions*>()
2161                              .Arg<runtime::FlatMemrefView>()  // result
2162                              .To<RuntimeChecks()>(PartitionId::Handler())
2163                              .release();
2164 
2165   return succeeded(Executable::Call(ctx, *handler, args, attrs));
2166 }
2167 
2168 // -------------------------------------------------------------------------- //
2169 
JitRtGpuCustomCalls()2170 DirectCustomCallLibrary JitRtGpuCustomCalls() {
2171   DirectCustomCallLibrary lib;
2172 
2173   lib.Insert("xla.gpu.fft", &xla::gpu::Fft);
2174   lib.Insert("xla.gpu.cholesky", &xla::gpu::Cholesky);
2175   lib.Insert("xla.gpu.collective_permute", &xla::gpu::CollectivePermute);
2176   lib.Insert("xla.gpu.func.launch", &xla::gpu::LaunchFunc);
2177   lib.Insert("xla.gpu.gemm", &xla::gpu::Gemm);
2178   lib.Insert("xla.gpu.cublas.lt.matmul", &xla::gpu::CublasLtMatmul);
2179   lib.Insert("xla.gpu.cublas.lt.matmul.bias", &xla::gpu::CublasLtMatmulBias);
2180 
2181   auto conv = [](StringRef name) { return ("xla.gpu.conv." + name).str(); };
2182   lib.Insert(conv("forward"), &ConvFn<CudnnConvKind::kForward>);
2183   lib.Insert(conv("backward.input"), &ConvFn<CudnnConvKind::kBackwardInput>);
2184   lib.Insert(conv("backward.filter"), &ConvFn<CudnnConvKind::kBackwardFilter>);
2185   lib.Insert(conv("forward.fused"),
2186              &ConvFusedFn<CudnnConvKind::kForwardActivation>);
2187   lib.Insert(conv("forward.fused.side_input"),
2188              &ConvFuseSideInputdFn<CudnnConvKind::kForwardActivation>);
2189 
2190   lib.Insert("xla.gpu.memcpy.d2d", &MemcpyFn<MemcpyDirection::kDeviceToDevice>);
2191   lib.Insert("xla.gpu.memcpy.h2d", &MemcpyFn<MemcpyDirection::kHostToDevice>);
2192   lib.Insert("xla.gpu.memcpy.d2h", &MemcpyFn<MemcpyDirection::kDeviceToHost>);
2193   lib.Insert("xla.gpu.memset", &MemsetFn);
2194   lib.Insert("xla.gpu.infeed", &xla::gpu::Infeed);
2195   lib.Insert("xla.gpu.outfeed", &xla::gpu::Outfeed);
2196   lib.Insert("xla.gpu.custom_call", &xla::gpu::CustomCall);
2197 
2198   // Collective operations.
2199   lib.Insert("xla.gpu.all_gather", &xla::gpu::AllGather);
2200   lib.Insert("xla.gpu.all_reduce", &xla::gpu::AllReduce);
2201   lib.Insert("xla.gpu.all_reduce_done", &xla::gpu::AllReduceDone);
2202   lib.Insert("xla.gpu.all_reduce_start", &xla::gpu::AllReduceStart);
2203   lib.Insert("xla.gpu.all_to_all", &xla::gpu::AllToAll);
2204   lib.Insert("xla.gpu.reduce_scatter", &xla::gpu::ReduceScatter);
2205   lib.Insert("xla.gpu.partition_id", &xla::gpu::PartitionId);
2206   lib.Insert("xla.gpu.replica_id", &xla::gpu::ReplicaId);
2207 
2208   return lib;
2209 }
2210 
2211 }  // namespace gpu
2212 }  // namespace xla
2213