1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""ShardedVariable class.""" 16import copy 17import math 18from typing import Sequence 19import weakref 20 21import numpy as np 22 23from tensorflow.python.framework import composite_tensor 24from tensorflow.python.framework import constant_op 25from tensorflow.python.framework import dtypes 26from tensorflow.python.framework import indexed_slices as indexed_slices_lib 27from tensorflow.python.framework import ops 28from tensorflow.python.framework import tensor_shape 29from tensorflow.python.framework import type_spec 30from tensorflow.python.ops import array_ops 31from tensorflow.python.ops import data_flow_ops 32from tensorflow.python.ops import embedding_ops 33from tensorflow.python.ops import math_ops 34from tensorflow.python.ops import partitioned_variables 35from tensorflow.python.ops import resource_variable_ops 36from tensorflow.python.ops import variables as variables_lib 37from tensorflow.python.saved_model import save_context 38from tensorflow.python.trackable import base as trackable 39from tensorflow.python.training.saving import saveable_object_util 40from tensorflow.python.util import dispatch 41from tensorflow.python.util.tf_export import tf_export 42 43 44@tf_export('distribute.experimental.partitioners.Partitioner', v1=[]) 45class Partitioner(object): 46 """Partitioner base class: all partitiners inherit from this class. 47 48 Partitioners should implement a `__call__` method with the following 49 signature: 50 51 ```python 52 def __call__(self, shape, dtype, axis=0): 53 # Partitions the given `shape` and returns the partition results. 54 # See docstring of `__call__` method for the format of partition results. 55 ``` 56 """ 57 58 def __call__(self, shape, dtype, axis=0): 59 """Partitions the given `shape` and returns the partition results. 60 61 Examples of a partitioner that allocates a fixed number of shards: 62 63 ```python 64 partitioner = FixedShardsPartitioner(num_shards=2) 65 partitions = partitioner(tf.TensorShape([10, 3], tf.float32), axis=0) 66 print(partitions) # [2, 0] 67 ``` 68 69 Args: 70 shape: a `tf.TensorShape`, the shape to partition. 71 dtype: a `tf.dtypes.Dtype` indicating the type of the partition value. 72 axis: The axis to partition along. Default: outermost axis. 73 74 Returns: 75 A list of integers representing the number of partitions on each axis, 76 where i-th value correponds to i-th axis. 77 """ 78 raise NotImplementedError 79 80 81@tf_export('distribute.experimental.partitioners.FixedShardsPartitioner', v1=[]) 82class FixedShardsPartitioner(Partitioner): 83 """Partitioner that allocates a fixed number of shards. 84 85 Examples: 86 87 >>> # standalone usage: 88 >>> partitioner = FixedShardsPartitioner(num_shards=2) 89 >>> partitions = partitioner(tf.TensorShape([10, 3]), tf.float32) 90 >>> [2, 1] 91 >>> 92 >>> # use in ParameterServerStrategy 93 >>> # strategy = tf.distribute.experimental.ParameterServerStrategy( 94 >>> # cluster_resolver=cluster_resolver, variable_partitioner=partitioner) 95 96 """ 97 98 def __init__(self, num_shards): 99 """Creates a new `FixedShardsPartitioner`. 100 101 Args: 102 num_shards: `int`, number of shards to partition. 103 """ 104 self._num_shards = num_shards 105 106 def __call__(self, shape, dtype, axis=0): 107 del dtype 108 result = [1] * len(shape) 109 result[axis] = min(self._num_shards, shape.dims[axis].value) 110 return result 111 112 113@tf_export('distribute.experimental.partitioners.MinSizePartitioner', v1=[]) 114class MinSizePartitioner(Partitioner): 115 """Partitioner that allocates a minimum size per shard. 116 117 This partitioner ensures each shard has at least `min_shard_bytes`, and tries 118 to allocate as many shards as possible, i.e., keeping shard size as small as 119 possible. The maximum number of such shards (upper bound) is given by 120 `max_shards`. 121 122 Examples: 123 124 >>> partitioner = MinSizePartitioner(min_shard_bytes=4, max_shards=2) 125 >>> partitions = partitioner(tf.TensorShape([6, 1]), tf.float32) 126 >>> [2, 1] 127 >>> partitioner = MinSizePartitioner(min_shard_bytes=4, max_shards=10) 128 >>> partitions = partitioner(tf.TensorShape([6, 1]), tf.float32) 129 >>> [6, 1] 130 >>> 131 >>> # use in ParameterServerStrategy 132 >>> # strategy = tf.distribute.experimental.ParameterServerStrategy( 133 >>> # cluster_resolver=cluster_resolver, variable_partitioner=partitioner) 134 """ 135 136 def __init__(self, 137 min_shard_bytes=256 << 10, 138 max_shards=1, 139 bytes_per_string=16): 140 """Creates a new `MinSizePartitioner`. 141 142 Args: 143 min_shard_bytes: Minimum bytes of each shard. Defaults to 256K. 144 max_shards: Upper bound on the number of shards. Defaults to 1. 145 bytes_per_string: If the partition value is of type string, this provides 146 an estimate of how large each string is. 147 """ 148 if min_shard_bytes < 1: 149 raise ValueError('Argument `min_shard_bytes` must be positive. ' 150 f'Received: {min_shard_bytes}') 151 if max_shards < 1: 152 raise ValueError('Argument `max_shards` must be positive. ' 153 f'Received: {max_shards}') 154 if bytes_per_string < 1: 155 raise ValueError('Argument `bytes_per_string` must be positive. ' 156 f'Received: {bytes_per_string}') 157 self._min_shard_bytes = min_shard_bytes 158 self._max_shards = max_shards 159 self._bytes_per_string = bytes_per_string 160 161 def __call__(self, shape, dtype, axis=0): 162 return partitioned_variables.min_max_variable_partitioner( 163 max_partitions=self._max_shards, 164 axis=axis, 165 min_slice_size=self._min_shard_bytes, 166 bytes_per_string_element=self._bytes_per_string)(shape, dtype) 167 168 169@tf_export('distribute.experimental.partitioners.MaxSizePartitioner', v1=[]) 170class MaxSizePartitioner(Partitioner): 171 """Partitioner that keeps shards below `max_shard_bytes`. 172 173 This partitioner ensures each shard has at most `max_shard_bytes`, and tries 174 to allocate as few shards as possible, i.e., keeping shard size as large 175 as possible. 176 177 If the partitioner hits the `max_shards` limit, then each shard may end up 178 larger than `max_shard_bytes`. By default `max_shards` equals `None` and no 179 limit on the number of shards is enforced. 180 181 Examples: 182 183 >>> partitioner = MaxSizePartitioner(max_shard_bytes=4) 184 >>> partitions = partitioner(tf.TensorShape([6, 1]), tf.float32) 185 >>> [6, 1] 186 >>> partitioner = MaxSizePartitioner(max_shard_bytes=4, max_shards=2) 187 >>> partitions = partitioner(tf.TensorShape([6, 1]), tf.float32) 188 >>> [2, 1] 189 >>> partitioner = MaxSizePartitioner(max_shard_bytes=1024) 190 >>> partitions = partitioner(tf.TensorShape([6, 1]), tf.float32) 191 >>> [1, 1] 192 >>> 193 >>> # use in ParameterServerStrategy 194 >>> # strategy = tf.distribute.experimental.ParameterServerStrategy( 195 >>> # cluster_resolver=cluster_resolver, variable_partitioner=partitioner) 196 """ 197 198 def __init__(self, max_shard_bytes, max_shards=None, bytes_per_string=16): 199 """Creates a new `MaxSizePartitioner`. 200 201 Args: 202 max_shard_bytes: The maximum size any given shard is allowed to be. 203 max_shards: The maximum number of shards in `int` created taking 204 precedence over `max_shard_bytes`. 205 bytes_per_string: If the partition value is of type string, this provides 206 an estimate of how large each string is. 207 """ 208 if max_shard_bytes < 1: 209 raise ValueError('Argument `max_shard_bytes` must be positive. ' 210 f'Received {max_shard_bytes}') 211 if max_shards and max_shards < 1: 212 raise ValueError('Argument `max_shards` must be positive. ' 213 f'Received {max_shards}') 214 if bytes_per_string < 1: 215 raise ValueError('Argument `bytes_per_string` must be positive. ' 216 f'Received: {bytes_per_string}') 217 218 self._max_shard_bytes = max_shard_bytes 219 self._max_shards = max_shards 220 self._bytes_per_string = bytes_per_string 221 222 def __call__(self, shape, dtype, axis=0): 223 return partitioned_variables.variable_axis_size_partitioner( 224 max_shard_bytes=self._max_shard_bytes, 225 max_shards=self._max_shards, 226 bytes_per_string_element=self._bytes_per_string, 227 axis=axis)(shape, dtype) 228 229 230class ShardedVariableSpec(type_spec.TypeSpec): 231 """Type specification for a `ShardedVariable`.""" 232 233 __slots__ = ['_variable_specs'] 234 235 value_type = property(lambda self: ShardedVariable) 236 237 def __init__(self, *variable_specs): 238 self._variable_specs = tuple(variable_specs) 239 240 def _serialize(self): 241 return self._variable_specs 242 243 @property 244 def _component_specs(self): 245 return self._variable_specs 246 247 def _to_components(self, value): 248 return value.variables 249 250 def _from_components(self, variables): 251 return ShardedVariable(variables) 252 253 254class ShardedVariableMixin(trackable.Trackable): 255 """Mixin for ShardedVariable.""" 256 257 # TODO(b/170877138): Remove this mixin once fixed. This mixin is required 258 # since TPUEmbeddingVariable can't be a CompositeTensor. 259 260 def __init__(self, variables, name='ShardedVariable'): 261 """Treats `variables` as shards of a larger Variable. 262 263 264 Example: 265 266 ``` 267 variables = [ 268 tf.Variable(..., shape=(10, 100), dtype=tf.float32), 269 tf.Variable(..., shape=(15, 100), dtype=tf.float32), 270 tf.Variable(..., shape=(5, 100), dtype=tf.float32) 271 ] 272 sharded_variable = ShardedVariableMixin(variables) 273 assert sharded_variable.shape.as_list() == [30, 100] 274 ``` 275 276 Args: 277 variables: A list of `ResourceVariable`s that comprise this sharded 278 variable. Variables should not be shared between different 279 `ShardedVariableMixin` objects. 280 name: String. Name of this container. Defaults to "ShardedVariable". 281 """ 282 super(ShardedVariableMixin, self).__init__() 283 self._variables = variables 284 self._name = name 285 286 if not isinstance(variables, Sequence) or not variables or any( 287 not isinstance(v, variables_lib.Variable) for v in variables): 288 raise TypeError('Argument `variables` should be a non-empty list of ' 289 f'`variables.Variable`s. Received {variables}') 290 291 var_dtypes = {v.dtype for v in variables} 292 if len(var_dtypes) > 1: 293 raise ValueError( 294 'All elements in argument `variables` must have the same dtype. ' 295 f'Received dtypes: {[v.dtype for v in variables]}') 296 297 first_var = variables[0] 298 self._dtype = first_var.dtype 299 300 # All variables must have the same shape for axes > 0. 301 higher_dim_shapes = {tuple(v.shape.as_list()[1:]) for v in variables} 302 if len(higher_dim_shapes) > 1: 303 raise ValueError( 304 'All elements in argument `variables` must have the same shapes ' 305 'except for the first axis. ' 306 f'Received shapes: {[v.shape for v in variables]}') 307 first_dim = sum(int(v.shape.as_list()[0]) for v in variables) 308 self._shape = tensor_shape.TensorShape([first_dim] + 309 first_var.shape.as_list()[1:]) 310 311 for v in variables: 312 v._sharded_container = weakref.ref(self) 313 314 self._var_offsets = [ 315 [0 for _ in range(len(first_var.shape))] for _ in range(len(variables)) 316 ] 317 for i in range(1, len(variables)): 318 # Always partition on the first axis. Offsets on other axes are 0. 319 self._var_offsets[i][0] += ( 320 self._var_offsets[i - 1][0] + variables[i - 1].shape.as_list()[0]) 321 322 save_slice_info = [v._get_save_slice_info() for v in variables] # pylint: disable=protected-access 323 if any(slice_info is not None for slice_info in save_slice_info): 324 raise ValueError( 325 '`SaveSliceInfo` should not be set for all elements in argument ' 326 '`variables`. `ShardedVariable` will infer `SaveSliceInfo` according ' 327 'to the order of the elements `variables`. ' 328 f'Received save slice info {save_slice_info}') 329 330 # We create an uninitialized saving_variable with the full shape, which can 331 # be later captured in signatures so that the signatures can treat this 332 # ShardedVariable as one single variable. 333 self._saving_variable = resource_variable_ops.UninitializedVariable( 334 shape=self._shape, dtype=self._dtype, name=self._name, 335 trainable=self._variables[0].trainable, 336 synchronization=variables_lib.VariableSynchronization.NONE, 337 aggregation=variables_lib.VariableAggregation.NONE) 338 339 def __iter__(self): 340 """Return an iterable for accessing the underlying sharded variables.""" 341 return iter(self._variables) 342 343 def __getitem__(self, slice_spec): 344 """Extracts the specified region as a Tensor from the sharded variable. 345 346 The API contract is identical to `Tensor.__getitem__`. Assignment to the 347 sliced range is not yet supported. 348 349 Args: 350 slice_spec: The arguments to __getitem__, specifying the global slicing of 351 the sharded variable. 352 353 Returns: 354 The appropriate slice of tensor based on `slice_spec`. 355 356 Raises: 357 IndexError: If a slice index is out of bound. 358 TypeError: If `spec_spec` contains Tensor. 359 """ 360 361 # TODO(b/177482728): Support tensor input. 362 # TODO(b/177482728): Support slice assign, similar to variable slice assign. 363 364 if (isinstance(slice_spec, bool) or (isinstance(slice_spec, ops.Tensor) and 365 slice_spec.dtype == dtypes.bool) or 366 (isinstance(slice_spec, np.ndarray) and slice_spec.dtype == bool)): 367 tensor = _var_to_tensor(self) 368 return array_ops.boolean_mask(tensor=tensor, mask=slice_spec) 369 370 if not isinstance(slice_spec, (list, tuple)): 371 slice_spec = (slice_spec,) 372 373 s = slice_spec[0] 374 if isinstance(s, slice): 375 first_dim_slice_specs = self._decompose_slice_spec(s) 376 values = [] 377 for i, var in enumerate(self._variables): 378 if first_dim_slice_specs[i] is not None: 379 all_dim_slice_spec = (first_dim_slice_specs[i],) + slice_spec[1:] 380 values.append(var[all_dim_slice_spec]) 381 if s.step is not None and s.step < 0: 382 values.reverse() 383 if not values: 384 return constant_op.constant([], 385 dtype=self._dtype, 386 shape=((0,) + self._shape[1:])) 387 return array_ops.concat(values, axis=0) 388 elif s is Ellipsis: 389 return array_ops.concat([var[slice_spec] for var in self._variables], 390 axis=0) 391 elif s is array_ops.newaxis: 392 return array_ops.concat([var[slice_spec[1:]] for var in self._variables], 393 axis=0)[array_ops.newaxis] 394 else: 395 if isinstance(s, ops.Tensor): 396 raise TypeError( 397 'ShardedVariable: using Tensor for indexing is not allowed.') 398 if s < 0: 399 s += self._shape[0] 400 if s < 0 or s >= self._shape[0]: 401 raise IndexError( 402 f'ShardedVariable: slice index {s} of dimension 0 out of bounds.') 403 for i in range(len(self._variables)): 404 if i == len(self._variables) - 1 or (s > self._var_offsets[i][0] and 405 s < self._var_offsets[i + 1][0]): 406 return self._variables[i][(s - self._var_offsets[i][0],) + 407 slice_spec[1:]] 408 409 def _decompose_slice_spec(self, slice_spec): 410 """Decompose a global slice_spec into a list of per-variable slice_spec. 411 412 `ShardedVariable` only supports first dimension partitioning, thus 413 `slice_spec` must be for first dimension. 414 415 Args: 416 slice_spec: A python `slice` object that specifies the global slicing. 417 418 Returns: 419 A list of python `slice` objects or None specifying the local slicing for 420 each component variable. None means no slicing. 421 422 For example, given component variables: 423 v0 = [0, 1, 2] 424 v1 = [3, 4, 5] 425 v2 = [6, 7, 8, 9] 426 427 If `slice_spec` is slice(start=None, stop=None, step=None), we will have: 428 v0[returned[0]] = [0, 1, 2] 429 v1[returned[1]] = [3, 4, 5] 430 v2[returned[2]] = [6, 7, 8, 9] 431 If `slice_spec` is slice(start=2, stop=8, step=3), we will have: 432 v0[returned[0]] = [2] 433 v1[returned[1]] = [5] 434 returned[2] == None 435 If `slice_spec` is slice(start=9, stop=3, step=-2), we will have: 436 returned[0] == None 437 v1[returned[1]] = [5] 438 v2[returned[2]] = [9, 7] 439 """ 440 if isinstance(slice_spec.start, ops.Tensor) or isinstance( 441 slice_spec.stop, ops.Tensor) or isinstance(slice_spec.step, ops.Tensor): 442 raise TypeError( 443 'ShardedVariable: using Tensor in slice_spec is not allowed. Please ' 444 'file a feature request with the TensorFlow team.') 445 446 result = [] 447 # Normalize start, end and stop. 448 slice_step = slice_spec.step if slice_spec.step is not None else 1 449 if slice_step == 0: 450 raise ValueError('slice step cannot be zero') 451 slice_start = slice_spec.start 452 if slice_start is None: 453 slice_start = 0 if slice_step > 0 else self._shape[0] - 1 454 elif slice_start < 0: 455 slice_start += self._shape[0] 456 slice_end = slice_spec.stop 457 if slice_end is None: 458 # After the normalization, we no longer interpret negative index, thus 459 # "-1" conceptually refers to the element before the first one, which 460 # doesn't exist. This is to ease the decomposition code. 461 slice_end = self._shape[0] if slice_step > 0 else -1 462 elif slice_end < 0: 463 slice_end += self._shape[0] 464 465 # To find the local slice_spec of each component variable, we start from 466 # the start of the global slice, and iterate through each variable. 467 # When iterating on a variable, we move the cursor (`cur`) to the first 468 # index that falls into the variable's range, which becomes the start of 469 # the variable's local slice_spec. The end of the local_spec is determined 470 # by using whatever is smaller between global slice end and variable range 471 # end. 472 cur = slice_start 473 if slice_step > 0: 474 for i in range(len(self._var_offsets)): 475 var_start = self._var_offsets[i][0] 476 var_end = ( 477 self._var_offsets[i + 1][0] 478 if i < len(self._var_offsets) - 1 else self._shape[0]) 479 if cur < var_start: 480 cur += slice_step * int(math.ceil((var_start - cur) / slice_step)) 481 if cur >= var_end or cur >= slice_end: 482 result.append(None) 483 else: 484 start = cur - var_start 485 end = min(slice_end, var_end) - var_start 486 result.append(slice(start, end, slice_step)) 487 else: # slice_step < 0 488 for i in range(len(self._var_offsets) - 1, -1, -1): 489 var_start = self._var_offsets[i][0] 490 var_end = ( 491 self._var_offsets[i + 1][0] 492 if i < len(self._var_offsets) - 1 else self._shape[0]) 493 if cur >= var_end: 494 cur += slice_step * int(math.ceil((var_end - cur - 1) / slice_step)) 495 if cur < var_start or cur <= slice_end: 496 result.append(None) 497 else: 498 start = cur - var_start 499 if slice_end >= var_start: 500 end = slice_end - var_start 501 else: 502 end = None # no explicit end: slice until hitting the boundary. 503 result.append(slice(start, end, slice_step)) 504 505 result.reverse() 506 507 return result 508 509 @property 510 def _type_spec(self): 511 return ShardedVariableSpec( 512 *(resource_variable_ops.VariableSpec(v.shape, v.dtype) 513 for v in self._variables)) 514 515 @property 516 def variables(self): 517 """The list of `Variable`s that make up the shards of this object.""" 518 if save_context.in_save_context(): 519 return [self._saving_variable] 520 return self._variables 521 522 @property 523 def name(self): 524 """The name of this object. Used for checkpointing.""" 525 return self._name 526 527 @property 528 def dtype(self): 529 """The dtype of all `Variable`s in this object.""" 530 return self._dtype 531 532 @property 533 def shape(self): 534 """The overall shape, combining all shards along axis `0`.""" 535 return self._shape 536 537 def assign(self, value, use_locking=None, name=None, read_value=True): 538 for i, v in enumerate(self._variables): 539 v.assign(array_ops.slice(value, self._var_offsets[i], v.shape.as_list())) 540 return self 541 542 def assign_add(self, delta, use_locking=False, name=None, read_value=True): 543 for i, v in enumerate(self._variables): 544 v.assign_add( 545 array_ops.slice(delta, self._var_offsets[i], v.shape.as_list())) 546 return self 547 548 def assign_sub(self, delta, use_locking=False, name=None, read_value=True): 549 for i, v in enumerate(self._variables): 550 v.assign_sub( 551 array_ops.slice(delta, self._var_offsets[i], v.shape.as_list())) 552 return self 553 554 def _decompose_indices(self, indices): 555 """Decompose a global 1D indices into a list of per-variable indices.""" 556 if indices.shape.rank != 1: 557 raise ValueError( 558 'ShardedVariable: indices must be 1D Tensor for sparse operations. ' 559 f'Received shape: {indices.shape}') 560 561 base = self._shape[0] // len(self._variables) 562 extra = self._shape[0] % len(self._variables) 563 564 # Assert that sharding conforms to "div" sharding 565 expect_first_dim = [base] * len(self._variables) 566 for i in range(extra): 567 expect_first_dim[i] = expect_first_dim[i] + 1 568 actual_first_dim = [v.shape.as_list()[0] for v in self._variables] 569 if expect_first_dim != actual_first_dim: 570 raise NotImplementedError( 571 'scater_xxx ops are not supported in ShardedVariale that does not ' 572 'conform to "div" sharding') 573 574 # For index that falls into the partition that has extra 1, assignment is 575 # `index // (base + 1)` (no less than `(indices - extra) // base`) 576 # For index that falls into the partition that doesn't has extra 1, 577 # assignment is `(indices - extra) // base` (no less than 578 # `indices // (base + 1)`) 579 # 580 # Example: 581 # base = 10, extra = 2, partitions: [0, 11), [11, 22), [22, 32) 582 # index = 10 -> partition_assigment = 0 583 # index = 22 -> partition_assiment = 2 584 partition_assignments = math_ops.maximum(indices // (base + 1), 585 (indices - extra) // base) 586 local_indices = array_ops.where(partition_assignments < extra, 587 indices % (base + 1), 588 (indices - extra) % base) 589 # For whatever reason `dynamic_partition` only supports int32 590 partition_assignments = math_ops.cast(partition_assignments, dtypes.int32) 591 per_var_indices = data_flow_ops.dynamic_partition(local_indices, 592 partition_assignments, 593 len(self._variables)) 594 595 return per_var_indices, partition_assignments 596 597 def _decompose_indexed_slices(self, indexed_slices): 598 """Decompose a global `IndexedSlices` into a list of per-variable ones.""" 599 per_var_indices, partition_assignments = self._decompose_indices( 600 indexed_slices.indices) 601 per_var_values = data_flow_ops.dynamic_partition(indexed_slices.values, 602 partition_assignments, 603 len(self._variables)) 604 605 return [ 606 indexed_slices_lib.IndexedSlices( 607 values=per_var_values[i], indices=per_var_indices[i]) 608 for i in range(len(self._variables)) 609 ] 610 611 # ==================== scatter ops implementations ======================== # 612 613 def scatter_add(self, sparse_delta, use_locking=False, name=None): 614 """Implements tf.Variable.scatter_add.""" 615 per_var_sparse_delta = self._decompose_indexed_slices(sparse_delta) 616 for i, v in enumerate(self._variables): 617 new_name = None 618 if name is not None: 619 new_name = '{}/part_{}'.format(name, i) 620 v.scatter_add(per_var_sparse_delta[i], name=new_name) 621 return self 622 623 def scatter_div(self, sparse_delta, use_locking=False, name=None): 624 """Implements tf.Variable.scatter_div.""" 625 per_var_sparse_delta = self._decompose_indexed_slices(sparse_delta) 626 for i, v in enumerate(self._variables): 627 new_name = None 628 if name is not None: 629 new_name = '{}/part_{}'.format(name, i) 630 v.scatter_div(per_var_sparse_delta[i], name=new_name) 631 return self 632 633 def scatter_max(self, sparse_delta, use_locking=False, name=None): 634 """Implements tf.Variable.scatter_max.""" 635 per_var_sparse_delta = self._decompose_indexed_slices(sparse_delta) 636 for i, v in enumerate(self._variables): 637 new_name = None 638 if name is not None: 639 new_name = '{}/part_{}'.format(name, i) 640 v.scatter_max(per_var_sparse_delta[i], name=new_name) 641 return self 642 643 def scatter_min(self, sparse_delta, use_locking=False, name=None): 644 """Implements tf.Variable.scatter_min.""" 645 per_var_sparse_delta = self._decompose_indexed_slices(sparse_delta) 646 for i, v in enumerate(self._variables): 647 new_name = None 648 if name is not None: 649 new_name = '{}/part_{}'.format(name, i) 650 v.scatter_min(per_var_sparse_delta[i], name=new_name) 651 return self 652 653 def scatter_mul(self, sparse_delta, use_locking=False, name=None): 654 """Implements tf.Variable.scatter_mul.""" 655 per_var_sparse_delta = self._decompose_indexed_slices(sparse_delta) 656 for i, v in enumerate(self._variables): 657 new_name = None 658 if name is not None: 659 new_name = '{}/part_{}'.format(name, i) 660 v.scatter_mul(per_var_sparse_delta[i], name=new_name) 661 return self 662 663 def scatter_sub(self, sparse_delta, use_locking=False, name=None): 664 """Implements tf.Variable.scatter_sub.""" 665 per_var_sparse_delta = self._decompose_indexed_slices(sparse_delta) 666 for i, v in enumerate(self._variables): 667 new_name = None 668 if name is not None: 669 new_name = '{}/part_{}'.format(name, i) 670 v.scatter_sub(per_var_sparse_delta[i], name=new_name) 671 return self 672 673 def scatter_update(self, sparse_delta, use_locking=False, name=None): 674 """Implements tf.Variable.scatter_update.""" 675 per_var_sparse_delta = self._decompose_indexed_slices(sparse_delta) 676 for i, v in enumerate(self._variables): 677 new_name = None 678 if name is not None: 679 new_name = '{}/part_{}'.format(name, i) 680 v.scatter_update(per_var_sparse_delta[i], name=new_name) 681 return self 682 683 def batch_scatter_update(self, sparse_delta, use_locking=False, name=None): 684 """Implements tf.Variable.batch_scatter_update.""" 685 per_var_sparse_delta = self._decompose_indexed_slices(sparse_delta) 686 for i, v in enumerate(self._variables): 687 new_name = None 688 if name is not None: 689 new_name = '{}/part_{}'.format(name, i) 690 v.batch_scatter_update(per_var_sparse_delta[i], name=new_name) 691 return self 692 693 # ================== scatter ops implementations END ====================== # 694 695 def sparse_read(self, indices, name=None): 696 """Implements tf.Variable.sparse_read.""" 697 per_var_indices, _ = self._decompose_indices(indices) 698 result = [] 699 for i, v in enumerate(self._variables): 700 new_name = None 701 if name is not None: 702 new_name = '{}/part_{}'.format(name, i) 703 result.append(v.sparse_read(per_var_indices[i], name=new_name)) 704 return array_ops.concat(result, axis=0) 705 706 def _gather_saveables_for_checkpoint(self): 707 """Return a `Saveable` for each shard. See `Trackable`.""" 708 709 def _saveable_factory(name=self.name): 710 """Creates `SaveableObject`s for this `ShardedVariable`.""" 711 saveables = [] 712 dims = len(self._variables[0].shape) 713 var_offset = [0 for _ in range(dims)] 714 for v in self._variables: 715 save_slice_info = variables_lib.Variable.SaveSliceInfo( 716 full_name=self.name, 717 full_shape=self.shape.as_list(), 718 var_offset=copy.copy(var_offset), 719 var_shape=v.shape.as_list()) 720 saveables.append( 721 saveable_object_util.ResourceVariableSaveable( 722 v, save_slice_info.spec, name)) 723 var_offset[0] += int(v.shape[0]) 724 return saveables 725 726 return {trackable.VARIABLE_VALUE_KEY: _saveable_factory} 727 728 def _map_resources(self, save_options): 729 """For implementing `Trackable`.""" 730 obj_map, resource_map = {}, {} 731 for v in self._variables + [self._saving_variable]: 732 v_obj_map, v_resource_map = v._map_resources(save_options) # pylint:disable=protected-access 733 obj_map.update(v_obj_map) 734 resource_map.update(v_resource_map) 735 obj_map[self] = ShardedVariable([obj_map[self._saving_variable]], 736 name=self.name) 737 return obj_map, resource_map 738 739 @property 740 def _unique_id(self): 741 # String-replace to ensure uniqueness for checkpoint tracking 742 return self.variables[0]._unique_id.replace('part_0', 'sharded') # pylint: disable=protected-access 743 744 @property 745 def _distribute_strategy(self): 746 return self.variables[0]._distribute_strategy # pylint: disable=protected-access 747 748 @property 749 def _shared_name(self): 750 return self._name 751 752 @property 753 def is_sharded_variable(self): 754 return True 755 756 def numpy(self): 757 """Copies the values in this ShardedVariable to a NumPy array. 758 759 First converts to a single Tensor using the registered conversion function, 760 which concatenates the shards, then uses Tensor.numpy() to convert to 761 a NumPy array. 762 763 Returns: 764 A NumPy array of the same shape and dtype. 765 """ 766 return _var_to_tensor(self).numpy() 767 768 769@tf_export('__internal__.distribute.ShardedVariable', v1=[]) 770class ShardedVariable(ShardedVariableMixin, composite_tensor.CompositeTensor): 771 """A container for `Variables` that should be treated as shards. 772 773 Variables that are too large to fit on a single device (e.g., large 774 embeddings) 775 may need to be sharded over multiple devices. This class maintains a list of 776 smaller variables that can be independently stored on separate devices (eg, 777 multiple parameter servers), and saves and restores those variables as if they 778 were a single larger variable. 779 780 Objects of this class can be saved with a given number of shards and then 781 restored from a checkpoint into a different number of shards. 782 783 Objects of this class can be saved to SavedModel format using 784 `tf.saved_model.save`. The SavedModel can be used by programs like TF serving 785 APIs. It is not yet supported to load the SavedModel with 786 `tf.saved_model.load`. 787 788 Since `ShardedVariable` can be saved and then restored to different number of 789 shards depending on the restore environments, for example, TF serving APIs 790 would restore to one shard for serving efficiency, when using 791 `ShardedVariable` in a tf.function, one should generally not assume it has the 792 same number of shards across save and load. 793 794 Sharding is only supported along the first dimension. 795 796 >>> class Model(tf.Module): 797 ... def __init__(self): 798 ... self.sharded_variable = ShardedVariable([ 799 ... tf.Variable([3.0], dtype=tf.float32), 800 ... tf.Variable([2.0], dtype=tf.float32) 801 ... ]) 802 ... 803 ... @tf.function(input_signature=[tf.TensorSpec([], dtype=tf.int32)]) 804 ... def fn(self, x): 805 ... return tf.nn.embedding_lookup(self.sharded_variable.variables, x) 806 ... 807 ... @tf.function(input_signature=[tf.TensorSpec([], dtype=tf.int32)]) 808 ... def serve_fn(self, x): 809 ... return tf.nn.embedding_lookup(self.sharded_variable.variables, x) 810 >>> 811 >>> model = Model() 812 >>> model.fn(1).numpy() 813 2.0 814 >>> tf.saved_model.save(model, export_dir='/tmp/saved_model', 815 ... signatures=model.serve_fn) 816 """ 817 818 @property 819 def _type_spec(self): 820 return ShardedVariableSpec( 821 *(resource_variable_ops.VariableSpec(v.shape, v.dtype) 822 for v in self._variables)) 823 824 @classmethod 825 def _overload_all_operators(cls): 826 """Register overloads for all operators.""" 827 for operator in ops.Tensor.OVERLOADABLE_OPERATORS: 828 if operator == '__getitem__': 829 continue 830 831 cls._overload_operator(operator) 832 833 @classmethod 834 def _overload_operator(cls, operator): 835 """Delegate an operator overload to `ops.Tensor`.""" 836 tensor_operator = getattr(ops.Tensor, operator) 837 838 def _operator(v, *args, **kwargs): 839 return tensor_operator(_var_to_tensor(v), *args, **kwargs) 840 841 setattr(cls, operator, _operator) 842 843 def __tf_experimental_restore_capture__(self, concrete_function, 844 internal_capture): 845 # Avoid restoring captures for functions that use ShardedVariable - the 846 # layer will be recreated during Keras model loading 847 # TODO(jmullenbach): support loading models with ShardedVariables using 848 # tf.saved_model.load 849 return None 850 851 def _should_act_as_resource_variable(self): 852 """Pass resource_variable_ops.is_resource_variable check.""" 853 return True 854 855 def _write_object_proto(self, proto, options): 856 resource_variable_ops.write_object_proto_for_resource_variable( 857 self._saving_variable, proto, options, enforce_naming=False) 858 859 860def _var_to_tensor(var, dtype=None, name=None, as_ref=False): 861 """Converts a `ShardedVariable` to a `Tensor`.""" 862 del name 863 if dtype is not None and not dtype.is_compatible_with(var.dtype): 864 raise ValueError( 865 'Incompatible type conversion requested to type {!r} for variable ' 866 'of type {!r}'.format(dtype.name, var.dtype.name)) 867 if as_ref: 868 raise NotImplementedError( 869 "ShardedVariable doesn't support being used as a reference.") 870 # We use op dispatch mechanism to override embedding_lookup ops when called 871 # with ShardedVariable. This requires embedding_lookup ops to raise TypeError 872 # when called with ShardedVariable. However since ShardedVariable can be 873 # converted to a tensor via concat, embedding_lookup ops would silently 874 # do the convertion and never raise a TypeError. To be able to properly 875 # raise a TypeError, namescope is used to detect if this method is called 876 # within a embedding_lookup op. 877 # NOTE: This doesn't work in eager mode since op namescope is always cleared 878 # in eager. This also breaks if user sets the name of embedding_lookup op 879 # with something that doesn't contain str "embedding_lookup". 880 # 881 # TODO(chenkai): Find a more robust way to do this, which should not rely 882 # on namescope. 883 if 'embedding_lookup' in ops.get_name_scope(): 884 raise TypeError('Converting ShardedVariable to tensor in embedding lookup' 885 ' ops is disallowed.') 886 return array_ops.concat(var.variables, axis=0) 887 888 889# Register a conversion function which reads the value of the variable, 890# allowing instances of the class to be used as tensors. 891ops.register_tensor_conversion_function(ShardedVariable, _var_to_tensor) 892 893ShardedVariable._overload_all_operators() # pylint: disable=protected-access 894 895 896# Override the behavior of embedding_lookup(sharded_variable, ...) 897@dispatch.dispatch_for_types(embedding_ops.embedding_lookup, ShardedVariable) 898def embedding_lookup(params, 899 ids, 900 partition_strategy='mod', 901 name=None, 902 validate_indices=True, 903 max_norm=None): 904 if isinstance(params, list): 905 params = params[0] 906 return embedding_ops.embedding_lookup(params.variables, ids, 907 partition_strategy, name, 908 validate_indices, max_norm) 909