xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/gpu/nccl_collective_permute_thunk.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/gpu/nccl_collective_permute_thunk.h"
17 
18 #include <map>
19 #include <optional>
20 #include <string>
21 #include <utility>
22 #include <vector>
23 
24 #include "absl/algorithm/container.h"
25 #include "absl/base/call_once.h"
26 #include "absl/container/flat_hash_set.h"
27 #include "tensorflow/compiler/mlir/xla/attribute_exporter.h"
28 #include "tensorflow/compiler/xla/service/collective_ops_utils.h"
29 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
30 #include "tensorflow/compiler/xla/xla_data.pb.h"
31 
32 #if XLA_ENABLE_XCCL
33 #include "tensorflow/stream_executor/gpu/gpu_stream.h"
34 #endif
35 
36 namespace xla {
37 namespace gpu {
38 
39 /*static*/ NcclCollectivePermuteConfig
GetNcclCollectivePermuteConfig(mlir::lmhlo::CollectivePermuteOp op,int64_t replica_count,int64_t partition_count)40 NcclCollectivePermuteThunk::GetNcclCollectivePermuteConfig(
41     mlir::lmhlo::CollectivePermuteOp op, int64_t replica_count,
42     int64_t partition_count) {
43   NcclCollectivePermuteConfig collective_permute_config;
44   auto& config = collective_permute_config.config;
45 
46   config.operand_count = 1;
47   const Shape shape = GetShape(op.getOperand());
48   config.operand_element_type.push_back(shape.element_type());
49   config.SetCollectiveOpKindAndID(op);
50   config.group_mode = GetGroupMode(op);
51 
52   // With a collective permute, all execution instances together form one
53   // replica group.
54   const int64_t num_participants =
55       config.group_mode == CollectiveOpGroupMode::kCrossReplica
56           ? replica_count
57           : partition_count;
58   config.replica_groups.emplace_back();
59   ReplicaGroup& replica_group = config.replica_groups.front();
60   for (int i = 0; i < num_participants; ++i) {
61     replica_group.add_replica_ids(i);
62   }
63 
64   const std::vector<std::pair<int64_t, int64_t>> source_target_pairs =
65       ConvertNx2Attribute(op.getSourceTargetPairs()).ValueOrDie();
66 
67   for (const std::pair<int64_t, int64_t>& source_target : source_target_pairs) {
68     int64_t source = source_target.first;
69     int64_t target = source_target.second;
70 
71     collective_permute_config.id_to_source_target.insert({target, {}})
72         .first->second.source = source;
73     collective_permute_config.id_to_source_target.insert({source, {}})
74         .first->second.target = target;
75   }
76 
77   return collective_permute_config;
78 }
79 
80 // The collective permute is degenerate if all source-target pairs are identity,
81 // and all the IDs appear in the list.
IsDegenerate(mlir::lmhlo::CollectivePermuteOp op,int64_t replica_count,int64_t partition_count)82 /*static*/ bool NcclCollectivePermuteThunk::IsDegenerate(
83     mlir::lmhlo::CollectivePermuteOp op, int64_t replica_count,
84     int64_t partition_count) {
85   const std::vector<std::pair<int64_t, int64_t>> source_target_pairs =
86       ConvertNx2Attribute(op.getSourceTargetPairs()).ValueOrDie();
87   // Each ID can appear only once as a source and as a target. So if all pairs
88   // are identity, all IDs must appear in the list is the size == number of
89   // replicas/partitions.
90   const int64_t expected_size =
91       op.getChannelId() ? partition_count : replica_count;
92   return source_target_pairs.size() == expected_size &&
93          absl::c_all_of(source_target_pairs,
94                         [](const std::pair<int64_t, int64_t>& source_target) {
95                           return source_target.first == source_target.second;
96                         });
97 }
98 
CanImplement(mlir::lmhlo::CollectivePermuteOp op)99 /*static*/ bool NcclCollectivePermuteThunk::CanImplement(
100     mlir::lmhlo::CollectivePermuteOp op) {
101   const Shape shape = GetShape(op.getOperand());
102   return IsTypeSupportedByNccl(shape.element_type());
103 }
104 
NcclCollectivePermuteThunk(ThunkInfo thunk_info,mlir::lmhlo::CollectivePermuteOp op,int64_t replica_count,int64_t partition_count,const Buffer & buffer)105 NcclCollectivePermuteThunk::NcclCollectivePermuteThunk(
106     ThunkInfo thunk_info, mlir::lmhlo::CollectivePermuteOp op,
107     int64_t replica_count, int64_t partition_count, const Buffer& buffer)
108     : NcclCollectiveThunk(Thunk::kCollectivePermute, thunk_info),
109       config_(
110           GetNcclCollectivePermuteConfig(op, replica_count, partition_count)),
111       buffer_(buffer) {}
112 
RunNcclCollective(const ExecuteParams & params,ncclComm_t comm)113 Status NcclCollectivePermuteThunk::RunNcclCollective(
114     const ExecuteParams& params, ncclComm_t comm) {
115   TF_ASSIGN_OR_RETURN(
116       std::vector<DeviceBufferPair> device_buffers,
117       ConvertToDeviceBuffers(params, {buffer_},
118                              config_.config.operand_element_type));
119   if (device_buffers.size() != 1)
120     return FailedPrecondition("Expected a single input-output buffer pair.");
121 
122   TF_ASSIGN_OR_RETURN(const GlobalDeviceId global_device_id,
123                       params.nccl_params.GetGlobalDeviceId());
124   TF_ASSIGN_OR_RETURN(
125       const DeviceAssignment::LogicalID current_logical_id,
126       params.nccl_params.device_assn->LogicalIdForDevice(global_device_id));
127   const int64_t current_id =
128       config_.config.group_mode == CollectiveOpGroupMode::kCrossReplica
129           ? current_logical_id.replica_id
130           : current_logical_id.computation_id;
131   std::string device_string = GetDeviceString(params.nccl_params);
132 
133   const NcclCollectivePermuteConfig::SourceTargetMapEntry source_target =
134       NcclCollectivePermuteConfig::GetSourceTarget(config_.id_to_source_target,
135                                                    current_id);
136 
137   return RunCollectivePermute(source_target, device_buffers[0], *params.stream,
138                               comm, device_string, current_id);
139 }
140 
RunCollectivePermute(NcclCollectivePermuteConfig::SourceTargetMapEntry source_target,DeviceBufferPair & buffer,se::Stream & stream,ncclComm_t comm,absl::string_view device_string,int64_t current_id)141 Status RunCollectivePermute(
142     NcclCollectivePermuteConfig::SourceTargetMapEntry source_target,
143     DeviceBufferPair& buffer, se::Stream& stream, ncclComm_t comm,
144     absl::string_view device_string, int64_t current_id) {
145 #if XLA_ENABLE_XCCL
146   // Determine the source and target IDs for this instance. The source ID is the
147   // ID which will copy its data to this instance. The destination ID is the ID
148   // to which this instance will copy its data. Either are optional.
149   //
150   // No source and no dest:
151   //  - this instance does not actually participate, no one send it any data and
152   //    it does not have to send any data as well. Since there is no dest,
153   //    just memzero() the dest buffer as required by the collective permute
154   //    semantics.
155   //
156   // No source, dest present:
157   //  - This instance has to send data to 'dest' Issue an send of the input.
158   //    Since there is no source, memzero the dest buffer.
159   //
160   // Source present, no destination:
161   //  - This instance received data from the source, does not have to send data
162   //    to anyone, Issue a receive.
163   //
164   // Source and dest both present:
165   //   - Issue a send of the input to dest, receive for the output from the
166   //     src.
167   //
168   //
169 
170   int device_ordinal = stream.parent()->device_ordinal();
171   VLOG(3) << "Performing collective permute from device ordinal: "
172           << device_ordinal;
173 
174   const std::optional<int64_t> source_id = source_target.source;
175   const std::optional<int64_t> target_id = source_target.target;
176 
177   // NCCL 2.8.x has an issue with point-to-point communication primitives if
178   // different ranks process different amounts of data. This can happen in the
179   // case of a collective permute as certain nodes may not do any send or
180   // receives, or do only send or only receive. Sending and receiving to self
181   // as well (identity pair) causes this imbalance. NCCL 2.8.x requires the
182   // use of NCCL_LAUNCH_MODE=PARALLEL to avoid these issues. See
183   // https://docs.nvidia.com/deeplearning/nccl/release-notes/rel_2-8-4.html#rel_2-8-4
184   if (!IsNcclLaunchModeParallel()) {
185     static absl::once_flag log_once;
186     absl::call_once(log_once, [] {
187       LOG(WARNING) << "NCCL based collective permute may not work correctly if "
188                       "NCCL_LAUNCH_MODE is not set to PARALLEL";
189     });
190   }
191 
192   se::DeviceMemoryBase src_addr = buffer.source_buffer;
193   se::DeviceMemoryBase dest_addr = buffer.destination_buffer;
194 
195   VLOG(3) << absl::StreamFormat("%s : id = %d, source_id = %d, target_id = %d",
196                                 device_string, current_id,
197                                 source_id.value_or(-1), target_id.value_or(-1));
198 
199   XLA_CUDA_RETURN_IF_ERROR(ncclGroupStart());
200 
201   TF_ASSIGN_OR_RETURN(auto dtype_and_multiplier,
202                       ToNcclDataTypeAndCountMultiplier(buffer.element_type));
203   ncclDataType_t dtype = dtype_and_multiplier.first;
204   int element_count = buffer.element_count * dtype_and_multiplier.second;
205 
206   se::gpu::GpuStreamHandle gpu_stream = se::gpu::AsGpuStreamValue(&stream);
207 
208   // send source buffer to target peer if needed.
209   if (target_id) {
210     VLOG(3) << absl::StreamFormat(
211         "%s : Calling ncclSend(sendbuff=%p, count=%d, peer=%d "
212         "comm=%p, stream=%p)",
213         device_string, src_addr.opaque(), element_count, *target_id,
214         static_cast<const void*>(comm), gpu_stream);
215     XLA_CUDA_RETURN_IF_ERROR(ncclSend(src_addr.opaque(), element_count, dtype,
216                                       *target_id, comm, gpu_stream));
217   }
218 
219   // Receive data from the source peer to the destination buffer.
220   if (source_id) {
221     VLOG(3) << absl::StreamFormat(
222         "%s : Calling ncclRecv(recvbuff=%p, count=%d, peer=%d comm=%p, "
223         "stream=%p)",
224         device_string, dest_addr.opaque(), element_count, *source_id,
225         static_cast<const void*>(comm), gpu_stream);
226     XLA_CUDA_RETURN_IF_ERROR(ncclRecv(dest_addr.opaque(), element_count, dtype,
227                                       *source_id, comm, gpu_stream));
228   }
229   XLA_CUDA_RETURN_IF_ERROR(ncclGroupEnd());
230 
231   if (!source_id) {
232     // If there is no source peer, i.e. no one send us any data, zero out dest
233     // buffer.
234     VLOG(3) << absl::StreamFormat("%s : collective-Permute: Issuing MemZero",
235                                   device_string);
236     stream.ThenMemZero(&dest_addr, dest_addr.size());
237   }
238   return OkStatus();
239 #else   // XLA_ENABLE_XCCL
240   return Unimplemented(
241       "NCCL support is not available: this binary was not built with a CUDA "
242       "compiler, which is necessary to build the NCCL source library.");
243 #endif  // XLA_ENABLE_XCCL
244 }
245 
246 }  // namespace gpu
247 }  // namespace xla
248