xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_NCCL_ALL_REDUCE_THUNK_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_NCCL_ALL_REDUCE_THUNK_H_
18 
19 #include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/lhlo/IR/lhlo_ops.h"
20 #include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/lhlo_gpu/IR/lhlo_gpu_ops.h"
21 #include "tensorflow/compiler/xla/service/collective_ops_utils.h"
22 #include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h"
23 #include "tensorflow/compiler/xla/service/gpu/nccl_collective_thunk.h"
24 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
25 #include "tensorflow/compiler/xla/xla_data.pb.h"
26 
27 namespace xla {
28 namespace gpu {
29 
30 struct NcclAllReduceConfig {
31   NcclCollectiveConfig config;
32   ReductionKind reduction_kind;
33 };
34 
35 // Thunk that performs a NCCL-based All-Reduce or Reduce-Scatter among CUDA
36 // GPU-based replicas.
37 class NcclAllReduceThunkBase : public NcclCollectiveThunk {
38  public:
39   static std::optional<ReductionKind> MatchAllReduceComputation(
40       mlir::Region& computation);
41 
42   NcclAllReduceThunkBase(Kind kind, ThunkInfo thunk_info,
43                          NcclAllReduceConfig config,
44                          std::vector<Buffer> buffers);
45 
46  protected:
config()47   const NcclCollectiveConfig& config() const override { return config_.config; }
48 
49  protected:
50   const NcclAllReduceConfig config_;
51   const std::vector<Buffer> buffers_;
52 };
53 
54 class NcclAllReduceThunk : public NcclAllReduceThunkBase {
55  public:
56   NcclAllReduceThunk(ThunkInfo thunk_info, mlir::lmhlo::AllReduceOp op,
57                      std::vector<Buffer> buffers);
58 
GetName()59   static const char* GetName() { return "AllReduce"; }
60 
61   static bool CanImplement(mlir::lmhlo::AllReduceOp op);
62   static bool IsDegenerate(mlir::lmhlo::AllReduceOp op, int64_t replica_count,
63                            int64_t partition_count);
64   static CollectiveOpGroupMode GetGroupMode(mlir::lmhlo::AllReduceOp op);
65 
66  protected:
67   Status RunNcclCollective(const ExecuteParams& params,
68                            ncclComm_t comm) override;
69 };
70 
71 class NcclAllReduceStartThunk : public NcclAllReduceThunkBase {
72  public:
73   NcclAllReduceStartThunk(ThunkInfo thunk_info,
74                           mlir::lmhlo_gpu::AllReduceStartOp op,
75                           std::vector<Buffer> buffers);
76 
GetName()77   static const char* GetName() { return "AllReduceStart"; }
78 
79   static bool CanImplement(mlir::lmhlo_gpu::AllReduceStartOp op);
80   static bool IsDegenerate(mlir::lmhlo_gpu::AllReduceStartOp op,
81                            int64_t replica_count, int64_t partition_count);
82   static CollectiveOpGroupMode GetGroupMode(
83       mlir::lmhlo_gpu::AllReduceStartOp op);
84 
85   StatusOr<se::Event> TakeDoneEvent(int device_ordinal)
86       ABSL_LOCKS_EXCLUDED(mu_);
87 
88  protected:
89   Status RunNcclCollective(const ExecuteParams& params,
90                            ncclComm_t comm) override;
91 
92  private:
93   absl::Mutex mu_;
94   // Store done events (by device ordinal) for the done thunk to wait on.
95   absl::flat_hash_map<int, se::Event> done_events_ ABSL_GUARDED_BY(mu_);
96 };
97 
98 class NcclAllReduceDoneThunk : public Thunk {
99  public:
100   explicit NcclAllReduceDoneThunk(ThunkInfo thunk_info,
101                                   NcclAllReduceStartThunk& start_thunk);
102 
103   Status ExecuteOnStream(const ExecuteParams& params) override;
104 
105  private:
106   NcclAllReduceStartThunk& start_thunk_;
107 };
108 
109 class NcclReduceScatterThunk : public NcclAllReduceThunkBase {
110  public:
111   NcclReduceScatterThunk(ThunkInfo thunk_info, mlir::lmhlo::ReduceScatterOp op,
112                          std::vector<Buffer> buffers);
113 
GetName()114   static const char* GetName() { return "ReduceScatter"; }
115 
116   // Returns whether the given instruction can be lowered to a nccl
117   // reduce-scatter call.
118   static bool CanImplement(mlir::lmhlo::ReduceScatterOp op);
119   static bool IsDegenerate(mlir::lmhlo::ReduceScatterOp op,
120                            int64_t replica_count, int64_t partition_count);
121   static CollectiveOpGroupMode GetGroupMode(mlir::lmhlo::ReduceScatterOp op);
122 
123  protected:
124   Status RunNcclCollective(const ExecuteParams& params,
125                            ncclComm_t comm) override;
126 };
127 
128 Status RunAllReduce(ReductionKind reduction_kind,
129                     std::vector<DeviceBufferPair>& buffers, se::Stream& stream,
130                     ncclComm_t comm);
131 
132 Status RunReduceScatter(ReductionKind reduction_kind,
133                         std::vector<DeviceBufferPair>& buffers,
134                         se::Stream& stream, ncclComm_t comm);
135 
136 }  // namespace gpu
137 }  // namespace xla
138 
139 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_NCCL_ALL_REDUCE_THUNK_H_
140