1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""TensorFlow collective Ops.""" 16from tensorflow.python.ops import gen_collective_ops 17 18 19def all_reduce(t, 20 group_size, 21 group_key, 22 instance_key, 23 merge_op='Add', 24 final_op='Id', 25 subdiv_offsets=(0,), 26 communication_hint='auto', 27 timeout=0): 28 """Reduces tensors collectively, across devices. 29 30 Args: 31 t: the tensor to be reduced. 32 group_size: the total number of tensors to be collectively reduced. 33 Each must reside on a different device. Should be a positive integer. 34 group_key: an integer identifying the group of devices. 35 instance_key: an integer identifying the participating group of Ops. 36 merge_op: string naming the binary Op to be applied to compute each 37 partial reduction. 38 final_op: string naming the unary Op to be applied to each fully 39 reduced value. Can be 'Id' for no operation. 40 subdiv_offsets: a list of integer offsets into the tensor at which each 41 independent subdivision should begin. Use [0] if no subdivision should 42 be done. 43 communication_hint: preferred collective communication. The implementation 44 may fall back to another mechanism. Options include `auto`, `ring`, and 45 `nccl`. 46 timeout: a float. If set to a non zero, set a completion timeout to detect 47 staleness. If the timer goes off, a DeadlineExceededError is raised. The 48 timeout value in seconds. This feature is experimental. 49 50 Returns: 51 An Op implementing the distributed reduction. 52 53 Raises: 54 ValueError: if any of the input parameter constraints are not met. 55 """ 56 if group_size < 1: 57 raise ValueError('Parameter `group_size` to all_reduce must be at least 1. ' 58 f'Received: {group_size}.') 59 return gen_collective_ops.collective_reduce( 60 t, 61 group_size=group_size, 62 group_key=group_key, 63 instance_key=instance_key, 64 merge_op=merge_op, 65 final_op=final_op, 66 subdiv_offsets=subdiv_offsets, 67 communication_hint=communication_hint.lower(), 68 timeout_seconds=timeout) 69 70 71def assign_group_v2(group_assignment, device_index, base_key): 72 """Assign group key based on group_assignment. 73 74 Args: 75 group_assignment: a 2 dimensional integer Tensor that encodes which devices 76 belong to the same group. The values are indices of the devices within 0 77 to number of devices. 78 device_index: integer for the index of the current device 79 base_key: integer to offset the resulted group_key. The base key shall be 80 unique for different values of group_assignment in the same tf.function. 81 Notes: The device_index argument must be consistent with the index of the 82 device of this Op in the device assignment list. The behavior of this Op is 83 undefined if they are inconsistent. 84 85 Returns: 86 group_size, group_key: The group size and group key for the current device. 87 """ 88 group_size, group_key = gen_collective_ops.collective_assign_group_v2( 89 group_assignment=group_assignment, 90 device_index=device_index, 91 base_key=base_key) 92 return group_size, group_key 93 94 95def all_reduce_v2(t, 96 group_size, 97 group_key, 98 instance_key, 99 merge_op='Add', 100 final_op='Id', 101 communication_hint='auto', 102 timeout=0, 103 ordering_token=None, 104 max_subdivs_per_device=-1, 105 name=None): 106 """Reduces tensors collectively, across devices. 107 108 Args: 109 t: the tensor to be reduced. 110 group_size: an int32 tensor. The total number of tensors to be collectively 111 reduced. Each must reside on a different device. Should be a positive 112 integer. 113 group_key: an int32 tensor identifying the group of devices. 114 instance_key: an int32 tensor identifying the participating group of Ops. 115 merge_op: string naming the binary Op to be applied to compute each partial 116 reduction. 117 final_op: string naming the unary Op to be applied to each fully reduced 118 value. Can be 'Id' for no operation. 119 communication_hint: preferred collective communication. The implementation 120 may fall back to another mechanism. Options include `auto`, `ring`, and 121 `nccl`. 122 timeout: a float. If set to a non zero, set a completion timeout to detect 123 staleness. If the timer goes off, a DeadlineExceededError is raised. The 124 timeout value in seconds. This feature is experimental. 125 ordering_token: a resource tensor on the same device as the op to order 126 the collectives in a per-device manner by auto control dependency. 127 This argument can be omited when there is one collective Op per 128 `tf.function`, or when explicit control dependency is used instead of 129 auto control dependency. 130 max_subdivs_per_device: int specifying the maximum number of subdivisions a 131 tensor on a device can be divided into. The runtime uses this contraint to 132 parallelize processing of each per-device tensor. Setting to -1 disables 133 subdivision and reverts to previous behavior of not sub-dividing tensor. 134 Setting to 0 uses sytem defaults. 135 name: name of the Op. 136 137 Returns: 138 An Op implementing the distributed reduction. 139 """ 140 if ordering_token is not None: 141 ordering_token = [ordering_token] 142 else: 143 ordering_token = [] 144 145 return gen_collective_ops.collective_reduce_v2( 146 t, 147 group_size=group_size, 148 group_key=group_key, 149 instance_key=instance_key, 150 merge_op=merge_op, 151 final_op=final_op, 152 communication_hint=communication_hint.lower(), 153 timeout_seconds=timeout, 154 ordering_token=ordering_token, 155 max_subdivs_per_device=max_subdivs_per_device, 156 name=name) 157 158 159def all_gather(t, 160 group_size, 161 group_key, 162 instance_key, 163 communication_hint='auto', 164 timeout=0): 165 """Accumulates tensors collectively, across devices, along first dimension. 166 167 Args: 168 t: the tensor to participate in the accumulation. 169 group_size: the total number of tensors to be collectively accumulated. 170 Each must reside on a different device. Should be a positive integer. 171 group_key: an integer identifying the group of devices. 172 instance_key: an integer identifying the participating group of Ops. 173 communication_hint: preferred collective communication. The implementation 174 may fall back to another mechanism. Options include `auto`, `ring`, and 175 `nccl`. 176 timeout: a float. If set to a non zero, set a completion timeout to detect 177 staleness. If the timer goes off, a DeadlineExceededError is raised. The 178 timeout value in seconds. This feature is experimental. 179 180 Returns: 181 An Op implementing the distributed operation. 182 183 Raises: 184 ValueError: if any of the input parameter constraints are not met. 185 """ 186 if group_size < 1: 187 raise ValueError('Parameter `group_size` to all_gather must be at least 1.' 188 f' Received: {group_size}.') 189 return gen_collective_ops.collective_gather( 190 t, 191 shape=[0], 192 group_size=group_size, 193 group_key=group_key, 194 instance_key=instance_key, 195 communication_hint=communication_hint.lower(), 196 timeout_seconds=timeout) 197 198 199def all_gather_v2(t, 200 group_size, 201 group_key, 202 instance_key, 203 communication_hint='auto', 204 timeout=0, 205 ordering_token=None, 206 name=None): 207 """Accumulates tensors collectively, across devices, along first dimension. 208 209 Args: 210 t: the tensor to participate in the accumulation. 211 group_size: an int32 tensor, the total number of tensors to be collectively 212 accumulated. Each must reside on a different device. Should be a positive 213 integer. 214 group_key: an int32 tensor identifying the group of devices. 215 instance_key: an int32 tensor identifying the participating group of Ops. 216 communication_hint: preferred collective communication. The implementation 217 may fall back to another mechanism. Options include `auto`, `ring`, and 218 `nccl`. 219 timeout: a float. If set to a non zero, set a completion timeout to detect 220 staleness. If the timer goes off, a DeadlineExceededError is raised. The 221 timeout value in seconds. This feature is experimental. 222 ordering_token: a resource tensor on the same device as the op to order 223 the collectives in a per-device manner by auto control dependency. 224 This argument can be omited when there is one collective Op per 225 `tf.function`, or when explicit control dependency is used instead of 226 auto control dependency. 227 name: name of the Op. 228 229 Returns: 230 An Op implementing the distributed operation. 231 """ 232 if ordering_token is not None: 233 ordering_token = [ordering_token] 234 else: 235 ordering_token = [] 236 237 return gen_collective_ops.collective_gather_v2( 238 t, 239 group_size=group_size, 240 group_key=group_key, 241 instance_key=instance_key, 242 communication_hint=communication_hint.lower(), 243 timeout_seconds=timeout, 244 ordering_token=ordering_token, 245 name=name) 246 247 248def broadcast_send(t, 249 shape, 250 dtype, 251 group_size, 252 group_key, 253 instance_key, 254 communication_hint='auto', 255 timeout=0): 256 """Broadcasts one tensor to a group of others, across devices. 257 258 Args: 259 t: the tensor to be sent. 260 shape: the shape of the tensor being sent, which must agree with t. 261 dtype: the type of the tensor being sent, which must agree with t. 262 group_size: one plus the number of receiving tensors, i.e. the total 263 number of devices participating. Each tensor must reside on a 264 different device. 265 group_key: an integer identifying the group of devices. 266 instance_key: an integer identifying the participating group of Ops. 267 communication_hint: preferred collective communication. The implementation 268 may fall back to another mechanism. Options include `auto`, `ring`, and 269 `nccl`. 270 timeout: If set to a non zero, set a completion timeout to detect staleness. 271 If the timer goes off, a DeadlineExceededError is raised. 272 The timeout value in seconds. This feature is experimental. 273 274 Returns: 275 An Op implementing the distributed broadcast send. 276 277 Raises: 278 ValueError: if any of the input parameter constraints are not met. 279 280 Note that the shape and dtype arguments appear redundant since they 281 should be obtainable from t. The are two reasons for including 282 them. First, the shape and type of tensors passed via broadcast must 283 be known ahead of time in their most specific form so that the receive 284 side can allocate memory for the operation and shape/type inference can 285 carry forward from there. Including the same declarations on the 286 send side clarifies a commitment already made. Secondly, having nearly 287 identical use syntax for send and receive sides may simplify tool-driven 288 generation of broadcast. 289 """ 290 if group_size <= 1: 291 raise ValueError( 292 'Parameter `group_size` to broadcast_send must be at least 2. ' 293 f'Received: {group_size}.') 294 if t.shape != shape: 295 raise ValueError( 296 'Shape of broadcast_send tensor `t` not equal to declared shape. ' 297 f'Received {t.shape}, expected {shape}.') 298 if t.dtype != dtype: 299 raise ValueError( 300 'Type of broadcast_send tensor `t` not equal to declared type. ' 301 f'Received {t.dtype}, expected {dtype}.') 302 return gen_collective_ops.collective_bcast_send( 303 t, 304 shape=shape, 305 group_size=group_size, 306 group_key=group_key, 307 instance_key=instance_key, 308 communication_hint=communication_hint.lower(), 309 timeout_seconds=timeout) 310 311 312def broadcast_send_v2(t, 313 group_size, 314 group_key, 315 instance_key, 316 communication_hint='auto', 317 timeout=0): 318 """Broadcasts one tensor to a group of others, across devices. 319 320 Args: 321 t: the tensor to be sent. 322 group_size: an int32 tensor. One plus the number of receiving tensors, i.e. 323 the total number of devices participating. Each tensor must reside on a 324 different device. 325 group_key: an int32 tensor identifying the group of devices. 326 instance_key: an int32 tensor identifying the participating group of Ops. 327 communication_hint: preferred collective communication. The implementation 328 may fall back to another mechanism. Options include `auto`, `ring`, and 329 `nccl`. 330 timeout: If set to a non zero, set a completion timeout to detect staleness. 331 If the timer goes off, a DeadlineExceededError is raised. 332 The timeout value in seconds. This feature is experimental. 333 334 Returns: 335 An Op implementing the distributed broadcast send. 336 """ 337 return gen_collective_ops.collective_bcast_send_v2( 338 t, 339 group_size=group_size, 340 group_key=group_key, 341 instance_key=instance_key, 342 communication_hint=communication_hint.lower(), 343 timeout_seconds=timeout) 344 345 346def broadcast_recv(shape, 347 dtype, 348 group_size, 349 group_key, 350 instance_key, 351 communication_hint='auto', 352 timeout=0): 353 """Receives a broadcasts tensor, across devices. 354 355 Args: 356 shape: Shape of the tensor to be received. 357 dtype: Type of the tensor to be received. 358 group_size: one plus the number of receiving tensors, i.e. the total 359 number of devices participating. Each tensor must reside on a 360 different device. 361 group_key: an integer identifying the group of devices. 362 instance_key: an integer identifying the participating group of Ops. 363 communication_hint: preferred collective communication. The implementation 364 may fall back to another mechanism. Options include `auto`, `ring`, and 365 `nccl`. 366 timeout: If set to a non zero, set a completion timeout to detect staleness. 367 If the timer goes off, a DeadlineExceededError is raised. 368 The timeout value in seconds. This feature is experimental. 369 370 Returns: 371 An Op implementing the broadcast receive. 372 373 Raises: 374 ValueError: if any of the input parameter constraints are not met. 375 """ 376 if group_size <= 1: 377 raise ValueError( 378 'Parameter `group_size` to broadcast_send must be at least 2. ' 379 f'Received: {group_size}.') 380 return gen_collective_ops.collective_bcast_recv( 381 shape=shape, 382 T=dtype, 383 group_size=group_size, 384 group_key=group_key, 385 instance_key=instance_key, 386 communication_hint=communication_hint.lower(), 387 timeout_seconds=timeout) 388 389 390def broadcast_recv_v2(shape, 391 dtype, 392 group_size, 393 group_key, 394 instance_key, 395 communication_hint='auto', 396 timeout=0): 397 """Receives a broadcasts tensor, across devices. 398 399 Args: 400 shape: an int tensor. Shape of the tensor to be received. 401 dtype: Type of the tensor to be received. 402 group_size: an int32 tensor. One plus the number of receiving tensors, i.e. 403 the total number of devices participating. Each tensor must reside on a 404 different device. 405 group_key: an int32 tensor identifying the group of devices. 406 instance_key: an int32 tensor identifying the participating group of Ops. 407 communication_hint: preferred collective communication. The implementation 408 may fall back to another mechanism. Options include `auto`, `ring`, and 409 `nccl`. 410 timeout: If set to a non zero, set a completion timeout to detect staleness. 411 If the timer goes off, a DeadlineExceededError is raised. 412 The timeout value in seconds. This feature is experimental. 413 414 Returns: 415 An Op implementing the broadcast receive. 416 """ 417 return gen_collective_ops.collective_bcast_recv_v2( 418 T=dtype, 419 group_size=group_size, 420 group_key=group_key, 421 instance_key=instance_key, 422 shape=shape, 423 communication_hint=communication_hint.lower(), 424 timeout_seconds=timeout) 425 426 427def initialize_communicator(group_key, 428 rank, 429 group_size, 430 communication_hint='auto', 431 timeout_seconds=0): 432 """Initializes a collective communicator. 433 434 This creates a collective communicator, which represents membership to a 435 collective group identified by the group_key. It should be called once per 436 member of the group, and each member needs to be on a different device. 437 It blocks until all members of the group run this op. 438 439 Communicators of a group can only be initialized once. Trying to initialize 440 communicators for an existing group key will result in an error. 441 442 Args: 443 group_key: an int32 `tf.Tensor` identifying the group. 444 rank: an `tf.Tensor` specifying the rank of this device in the group. If 445 specified, the rank is required to be unique in the group. 446 group_size: an int32 `tf.Tensor`. The size of the group. 447 communication_hint: preferred collective communication. The implementation 448 may fall back to another mechanism. Options include `auto`, `ring`, and 449 `nccl`. 450 timeout_seconds: If set to a non zero, set a completion timeout to detect 451 staleness. If the timer goes off, a DeadlineExceededError is raised. The 452 timeout value in seconds. This feature is experimental. 453 454 455 Returns: 456 A resource `tf.Tensor`. 457 """ 458 return gen_collective_ops.collective_initialize_communicator( 459 group_key=group_key, 460 rank=rank, 461 group_size=group_size, 462 communication_hint=communication_hint, 463 timeout_seconds=timeout_seconds) 464 465 466def all_reduce_v3(communicator, 467 t, 468 reduction='Add', 469 group_assignment=None, 470 timeout_seconds=None): 471 """Reduces tensors mutually. 472 473 Args: 474 communicator: the resource `tf.Tensor` returned from 475 `initialize_communicator`. 476 t: the `tf.Tensor` to be reduced. 477 reduction: a string. The name of the operation to reduce the values. 478 Accpeted values are `"min"`, `"max"`, `"mul"`, `"add"`. 479 group_assignment: Optional int32 `tf.Tensor` with shape [num_groups, 480 num_ranks_per_group]. `group_assignment[i]` represents the ranks in the 481 `ith` subgroup. 482 timeout_seconds: If set to a non zero, set a completion timeout to detect 483 staleness. If the timer goes off, a DeadlineExceededError is raised. The 484 timeout value in seconds. This feature is experimental. 485 486 Returns: 487 The reduced `tf.Tensor`. 488 """ 489 if group_assignment is None: 490 group_assignment = [] 491 return gen_collective_ops.collective_reduce_v3( 492 communicator=communicator, 493 input=t, 494 group_assignment=group_assignment, 495 reduction=reduction, 496 timeout_seconds=timeout_seconds) 497 498 499def all_to_all_v3(communicator, t, group_assignment=None, timeout_seconds=None): 500 """Exchanges tensors mutually. 501 502 Args: 503 communicator: the resource `tf.Tensor` returned from 504 `initialize_communicator`. 505 t: a `tf.Tensor`. The first dimension should have the length as the size of 506 the group. `t[i]` is sent to `rank i` within the group. 507 group_assignment: Optional int32 `tf.Tensor` with shape [num_groups, 508 num_ranks_per_group]. `group_assignment[i]` represents the ranks in the 509 `ith` subgroup. 510 timeout_seconds: If set to a non zero, set a completion timeout to detect 511 staleness. If the timer goes off, a DeadlineExceededError is raised. The 512 timeout value in seconds. This feature is experimental. 513 514 Returns: 515 a `tf.Tensor`. `t[i]` is sent from `rank i` within the group. 516 """ 517 if group_assignment is None: 518 group_assignment = [] 519 return gen_collective_ops.collective_all_to_all_v3( 520 communicator=communicator, 521 input=t, 522 group_assignment=group_assignment, 523 timeout_seconds=timeout_seconds) 524