1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/gpu/nccl_collective_permute_thunk.h"
17
18 #include <map>
19 #include <optional>
20 #include <string>
21 #include <utility>
22 #include <vector>
23
24 #include "absl/algorithm/container.h"
25 #include "absl/base/call_once.h"
26 #include "absl/container/flat_hash_set.h"
27 #include "tensorflow/compiler/mlir/xla/attribute_exporter.h"
28 #include "tensorflow/compiler/xla/service/collective_ops_utils.h"
29 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
30 #include "tensorflow/compiler/xla/xla_data.pb.h"
31
32 #if XLA_ENABLE_XCCL
33 #include "tensorflow/stream_executor/gpu/gpu_stream.h"
34 #endif
35
36 namespace xla {
37 namespace gpu {
38
39 /*static*/ NcclCollectivePermuteConfig
GetNcclCollectivePermuteConfig(mlir::lmhlo::CollectivePermuteOp op,int64_t replica_count,int64_t partition_count)40 NcclCollectivePermuteThunk::GetNcclCollectivePermuteConfig(
41 mlir::lmhlo::CollectivePermuteOp op, int64_t replica_count,
42 int64_t partition_count) {
43 NcclCollectivePermuteConfig collective_permute_config;
44 auto& config = collective_permute_config.config;
45
46 config.operand_count = 1;
47 const Shape shape = GetShape(op.getOperand());
48 config.operand_element_type.push_back(shape.element_type());
49 config.SetCollectiveOpKindAndID(op);
50 config.group_mode = GetGroupMode(op);
51
52 // With a collective permute, all execution instances together form one
53 // replica group.
54 const int64_t num_participants =
55 config.group_mode == CollectiveOpGroupMode::kCrossReplica
56 ? replica_count
57 : partition_count;
58 config.replica_groups.emplace_back();
59 ReplicaGroup& replica_group = config.replica_groups.front();
60 for (int i = 0; i < num_participants; ++i) {
61 replica_group.add_replica_ids(i);
62 }
63
64 const std::vector<std::pair<int64_t, int64_t>> source_target_pairs =
65 ConvertNx2Attribute(op.getSourceTargetPairs()).ValueOrDie();
66
67 for (const std::pair<int64_t, int64_t>& source_target : source_target_pairs) {
68 int64_t source = source_target.first;
69 int64_t target = source_target.second;
70
71 collective_permute_config.id_to_source_target.insert({target, {}})
72 .first->second.source = source;
73 collective_permute_config.id_to_source_target.insert({source, {}})
74 .first->second.target = target;
75 }
76
77 return collective_permute_config;
78 }
79
80 // The collective permute is degenerate if all source-target pairs are identity,
81 // and all the IDs appear in the list.
IsDegenerate(mlir::lmhlo::CollectivePermuteOp op,int64_t replica_count,int64_t partition_count)82 /*static*/ bool NcclCollectivePermuteThunk::IsDegenerate(
83 mlir::lmhlo::CollectivePermuteOp op, int64_t replica_count,
84 int64_t partition_count) {
85 const std::vector<std::pair<int64_t, int64_t>> source_target_pairs =
86 ConvertNx2Attribute(op.getSourceTargetPairs()).ValueOrDie();
87 // Each ID can appear only once as a source and as a target. So if all pairs
88 // are identity, all IDs must appear in the list is the size == number of
89 // replicas/partitions.
90 const int64_t expected_size =
91 op.getChannelId() ? partition_count : replica_count;
92 return source_target_pairs.size() == expected_size &&
93 absl::c_all_of(source_target_pairs,
94 [](const std::pair<int64_t, int64_t>& source_target) {
95 return source_target.first == source_target.second;
96 });
97 }
98
CanImplement(mlir::lmhlo::CollectivePermuteOp op)99 /*static*/ bool NcclCollectivePermuteThunk::CanImplement(
100 mlir::lmhlo::CollectivePermuteOp op) {
101 const Shape shape = GetShape(op.getOperand());
102 return IsTypeSupportedByNccl(shape.element_type());
103 }
104
NcclCollectivePermuteThunk(ThunkInfo thunk_info,mlir::lmhlo::CollectivePermuteOp op,int64_t replica_count,int64_t partition_count,const Buffer & buffer)105 NcclCollectivePermuteThunk::NcclCollectivePermuteThunk(
106 ThunkInfo thunk_info, mlir::lmhlo::CollectivePermuteOp op,
107 int64_t replica_count, int64_t partition_count, const Buffer& buffer)
108 : NcclCollectiveThunk(Thunk::kCollectivePermute, thunk_info),
109 config_(
110 GetNcclCollectivePermuteConfig(op, replica_count, partition_count)),
111 buffer_(buffer) {}
112
RunNcclCollective(const ExecuteParams & params,ncclComm_t comm)113 Status NcclCollectivePermuteThunk::RunNcclCollective(
114 const ExecuteParams& params, ncclComm_t comm) {
115 TF_ASSIGN_OR_RETURN(
116 std::vector<DeviceBufferPair> device_buffers,
117 ConvertToDeviceBuffers(params, {buffer_},
118 config_.config.operand_element_type));
119 if (device_buffers.size() != 1)
120 return FailedPrecondition("Expected a single input-output buffer pair.");
121
122 TF_ASSIGN_OR_RETURN(const GlobalDeviceId global_device_id,
123 params.nccl_params.GetGlobalDeviceId());
124 TF_ASSIGN_OR_RETURN(
125 const DeviceAssignment::LogicalID current_logical_id,
126 params.nccl_params.device_assn->LogicalIdForDevice(global_device_id));
127 const int64_t current_id =
128 config_.config.group_mode == CollectiveOpGroupMode::kCrossReplica
129 ? current_logical_id.replica_id
130 : current_logical_id.computation_id;
131 std::string device_string = GetDeviceString(params.nccl_params);
132
133 const NcclCollectivePermuteConfig::SourceTargetMapEntry source_target =
134 NcclCollectivePermuteConfig::GetSourceTarget(config_.id_to_source_target,
135 current_id);
136
137 return RunCollectivePermute(source_target, device_buffers[0], *params.stream,
138 comm, device_string, current_id);
139 }
140
RunCollectivePermute(NcclCollectivePermuteConfig::SourceTargetMapEntry source_target,DeviceBufferPair & buffer,se::Stream & stream,ncclComm_t comm,absl::string_view device_string,int64_t current_id)141 Status RunCollectivePermute(
142 NcclCollectivePermuteConfig::SourceTargetMapEntry source_target,
143 DeviceBufferPair& buffer, se::Stream& stream, ncclComm_t comm,
144 absl::string_view device_string, int64_t current_id) {
145 #if XLA_ENABLE_XCCL
146 // Determine the source and target IDs for this instance. The source ID is the
147 // ID which will copy its data to this instance. The destination ID is the ID
148 // to which this instance will copy its data. Either are optional.
149 //
150 // No source and no dest:
151 // - this instance does not actually participate, no one send it any data and
152 // it does not have to send any data as well. Since there is no dest,
153 // just memzero() the dest buffer as required by the collective permute
154 // semantics.
155 //
156 // No source, dest present:
157 // - This instance has to send data to 'dest' Issue an send of the input.
158 // Since there is no source, memzero the dest buffer.
159 //
160 // Source present, no destination:
161 // - This instance received data from the source, does not have to send data
162 // to anyone, Issue a receive.
163 //
164 // Source and dest both present:
165 // - Issue a send of the input to dest, receive for the output from the
166 // src.
167 //
168 //
169
170 int device_ordinal = stream.parent()->device_ordinal();
171 VLOG(3) << "Performing collective permute from device ordinal: "
172 << device_ordinal;
173
174 const std::optional<int64_t> source_id = source_target.source;
175 const std::optional<int64_t> target_id = source_target.target;
176
177 // NCCL 2.8.x has an issue with point-to-point communication primitives if
178 // different ranks process different amounts of data. This can happen in the
179 // case of a collective permute as certain nodes may not do any send or
180 // receives, or do only send or only receive. Sending and receiving to self
181 // as well (identity pair) causes this imbalance. NCCL 2.8.x requires the
182 // use of NCCL_LAUNCH_MODE=PARALLEL to avoid these issues. See
183 // https://docs.nvidia.com/deeplearning/nccl/release-notes/rel_2-8-4.html#rel_2-8-4
184 if (!IsNcclLaunchModeParallel()) {
185 static absl::once_flag log_once;
186 absl::call_once(log_once, [] {
187 LOG(WARNING) << "NCCL based collective permute may not work correctly if "
188 "NCCL_LAUNCH_MODE is not set to PARALLEL";
189 });
190 }
191
192 se::DeviceMemoryBase src_addr = buffer.source_buffer;
193 se::DeviceMemoryBase dest_addr = buffer.destination_buffer;
194
195 VLOG(3) << absl::StreamFormat("%s : id = %d, source_id = %d, target_id = %d",
196 device_string, current_id,
197 source_id.value_or(-1), target_id.value_or(-1));
198
199 XLA_CUDA_RETURN_IF_ERROR(ncclGroupStart());
200
201 TF_ASSIGN_OR_RETURN(auto dtype_and_multiplier,
202 ToNcclDataTypeAndCountMultiplier(buffer.element_type));
203 ncclDataType_t dtype = dtype_and_multiplier.first;
204 int element_count = buffer.element_count * dtype_and_multiplier.second;
205
206 se::gpu::GpuStreamHandle gpu_stream = se::gpu::AsGpuStreamValue(&stream);
207
208 // send source buffer to target peer if needed.
209 if (target_id) {
210 VLOG(3) << absl::StreamFormat(
211 "%s : Calling ncclSend(sendbuff=%p, count=%d, peer=%d "
212 "comm=%p, stream=%p)",
213 device_string, src_addr.opaque(), element_count, *target_id,
214 static_cast<const void*>(comm), gpu_stream);
215 XLA_CUDA_RETURN_IF_ERROR(ncclSend(src_addr.opaque(), element_count, dtype,
216 *target_id, comm, gpu_stream));
217 }
218
219 // Receive data from the source peer to the destination buffer.
220 if (source_id) {
221 VLOG(3) << absl::StreamFormat(
222 "%s : Calling ncclRecv(recvbuff=%p, count=%d, peer=%d comm=%p, "
223 "stream=%p)",
224 device_string, dest_addr.opaque(), element_count, *source_id,
225 static_cast<const void*>(comm), gpu_stream);
226 XLA_CUDA_RETURN_IF_ERROR(ncclRecv(dest_addr.opaque(), element_count, dtype,
227 *source_id, comm, gpu_stream));
228 }
229 XLA_CUDA_RETURN_IF_ERROR(ncclGroupEnd());
230
231 if (!source_id) {
232 // If there is no source peer, i.e. no one send us any data, zero out dest
233 // buffer.
234 VLOG(3) << absl::StreamFormat("%s : collective-Permute: Issuing MemZero",
235 device_string);
236 stream.ThenMemZero(&dest_addr, dest_addr.size());
237 }
238 return OkStatus();
239 #else // XLA_ENABLE_XCCL
240 return Unimplemented(
241 "NCCL support is not available: this binary was not built with a CUDA "
242 "compiler, which is necessary to build the NCCL source library.");
243 #endif // XLA_ENABLE_XCCL
244 }
245
246 } // namespace gpu
247 } // namespace xla
248