xref: /aosp_15_r20/external/tensorflow/tensorflow/python/distribute/tpu_replicated_variable.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""A Variable class that is replicated to logical cores for model parallelism."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20from collections import abc
21import contextlib
22
23from tensorflow.compiler.xla.experimental.xla_sharding import xla_sharding
24from tensorflow.python.distribute import tpu_util
25from tensorflow.python.eager import context
26from tensorflow.python.framework import ops
27from tensorflow.python.ops import control_flow_ops
28from tensorflow.python.ops import gen_resource_variable_ops
29from tensorflow.python.ops import gen_tpu_partition_ops as tpu_partition_ops
30from tensorflow.python.ops import variable_scope
31from tensorflow.python.ops import variables as variables_lib
32from tensorflow.python.saved_model import save_context
33from tensorflow.python.trackable import base as trackable
34
35
36def _on_device_update(update_fn, var, value, **kwargs):
37  with ops.device(var.device):
38    return update_fn(var, value, **kwargs)
39
40
41class TPUReplicatedVariable(variables_lib.Variable):
42  """Container for replicated `Variables` that are treated as a single variable.
43
44  This class maintains a list of replicated variables that are stored on
45  separate logic TPU devices. TF2XLA bridge accesses these variables as
46  if they were a single variable.
47  """
48
49  def __init__(self, variables, name='TPUReplicatedVariable'):
50    """Treats `variables` as a replicated list of `tf.Variable`s.
51
52    Example:
53
54    ```
55    variables = [
56      tf.Variable(..., shape=(10, 100), dtype=tf.float32),
57      tf.Variable(..., shape=(10, 100), dtype=tf.float32),
58      tf.Variable(..., shape=(10, 100), dtype=tf.float32),
59      tf.Variable(..., shape=(10, 100), dtype=tf.float32),
60    ]
61    replicated_variable = TPUReplicatedVariable(variables)
62    assert replicated_variable.shape.as_list() == [10, 100]
63    ```
64
65    Args:
66      variables: A list of `ResourceVariable`s that comprise this replicated
67        variable. Variables should not be shared between different
68        `TPUReplicatedVariable` objects.
69      name: String. Name of this container. Defaults to "TPUReplicatedVariable".
70    """
71    if not isinstance(variables, abc.Sequence) or not variables or any(
72        not isinstance(v, variables_lib.Variable) for v in variables):
73      raise TypeError('Argument `variables` should be a non-empty list of '
74                      f'`variables.Variable`s. Received {variables}')
75
76    if any(v.dtype != variables[0].dtype for v in variables):
77      raise ValueError(
78          'All elements in argument `variables` must have the same dtype. '
79          f'Received dtypes: {[v.dtype for v in variables]}')
80
81    if any(v.shape != variables[0].shape for v in variables):
82      raise ValueError(
83          'All elements in argument `variables` must have the same shape. '
84          f'Received shapes: {[v.shape for v in variables]}')
85
86    self._vars = variables
87    self._name = name
88    self._common_name = self._name.split(':')[0]
89    self._cached_value = None
90
91  def __iter__(self):
92    """Return an iterable for accessing the underlying sharded variables."""
93    return iter(self._vars)
94
95  @property
96  def name(self):
97    """The name of this object. Used for checkpointing."""
98    return self._name
99
100  @property
101  def dtype(self):
102    """The dtype of all `Variable`s in this object."""
103    return self._vars[0].dtype
104
105  @property
106  def is_initialized(self):
107    return self._vars[0].is_initialized
108
109  @property
110  def trainable(self):
111    return self._vars[0].trainable
112
113  @property
114  def device(self):
115    """The device this variable is on."""
116    return self._vars[0].device
117
118  @contextlib.contextmanager
119  def _handle_graph(self):
120    with self.handle.graph.as_default():
121      yield
122
123  @contextlib.contextmanager
124  def _assign_dependencies(self):
125    if self._cached_value is not None:
126      with ops.control_dependencies([self._cached_value]):
127        yield
128    else:
129      yield
130
131  @property
132  def constraint(self):
133    return self._vars[0].constraint
134
135  @property
136  def _in_graph_mode(self):
137    return self._vars[0]._in_graph_mode  # pylint: disable=protected-access
138
139  @property
140  def _unique_id(self):
141    return self._vars[0]._unique_id  # pylint: disable=protected-access
142
143  @property
144  def graph(self):
145    return self._vars[0].graph
146
147  @property
148  def _shared_name(self):
149    return self._common_name
150
151  @property
152  def synchronization(self):
153    return variable_scope.VariableSynchronization.NONE
154
155  @property
156  def aggregation(self):
157    return variable_scope.VariableAggregation.NONE
158
159  @property
160  def variables(self):
161    """The list of `Variables`."""
162    if save_context.in_save_context():
163      return [self._vars[0]]
164    return self._vars
165
166  def _map_resources(self, save_options):
167    """For implementing `Trackable`."""
168    first_var = self._vars[0]
169    obj_map, resource_map = first_var._map_resources(save_options)  # pylint:disable=protected-access
170    for v in self._vars[1:]:
171      obj_map[v] = obj_map[first_var]
172      resource_map[v.handle] = resource_map[first_var.handle]
173    obj_map[self] = obj_map[first_var]
174    resource_map[self] = resource_map[first_var.handle]
175    return obj_map, resource_map
176
177  def _gather_saveables_for_saved_model(self):
178    return {trackable.VARIABLE_VALUE_KEY: self._vars[0]}
179
180  @property
181  def shape(self):
182    return self._vars[0].shape
183
184  @property
185  def handle(self):
186    if save_context.in_save_context() or context.executing_eagerly():
187      return self._vars[0].handle
188
189    if tpu_util.enclosing_tpu_context() is None:
190      raise NotImplementedError('TPUReplicatedVariable.handle is not available '
191                                'outside tpu context or save context')
192    else:
193      with tpu_util.outside_or_skip_tpu_context():
194        return xla_sharding.replicate(
195            tpu_partition_ops.tpu_partitioned_input(
196                [v.handle for v in self._vars], partition_dim=-1))
197
198  def _read_variable_op(self):
199    return gen_resource_variable_ops.read_variable_op(self.handle, self.dtype)
200
201  def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False):
202    """Converts a variable to a tensor."""
203    # pylint: disable=protected-access
204    if tpu_util.enclosing_tpu_context() is None:
205      return self.read_value()
206    else:
207      return self._read_variable_op()
208
209  def read_value(self):
210    return self._vars[0].read_value()
211
212  def _update(self, update_fn, value, **kwargs):
213    """Converts the value to tensor and updates the variable list."""
214    input_tensor = ops.convert_to_tensor(
215        value, name='value_in_tensor', dtype=self.dtype)
216
217    return control_flow_ops.group(
218        *tuple(
219            _on_device_update(update_fn, v, input_tensor, **kwargs)
220            for v in self.variables))
221
222  def assign(self, value, use_locking=False, name=None, read_value=True):
223    if tpu_util.enclosing_tpu_context() is None or context.executing_eagerly():
224      assign_fn = lambda var, *a, **ka: var.assign(*a, **ka)
225      return self._update(
226          assign_fn,
227          value=value,
228          use_locking=use_locking,
229          name=name,
230          read_value=read_value)
231    else:
232      return tpu_util.make_raw_assign_fn(
233          gen_resource_variable_ops.assign_variable_op)(
234              self,
235              value=value,
236              use_locking=use_locking,
237              name=name,
238              read_value=read_value)
239
240  def assign_sub(self, value, use_locking=False, name=None, read_value=True):
241    if tpu_util.enclosing_tpu_context() is None or context.executing_eagerly():
242      assign_sub_fn = lambda var, *a, **ka: var.assign_sub(*a, **ka)
243      return self._update(
244          assign_sub_fn,
245          value=value,
246          use_locking=use_locking,
247          name=name,
248          read_value=read_value)
249    else:
250      return tpu_util.make_raw_assign_fn(
251          gen_resource_variable_ops.assign_sub_variable_op)(
252              self,
253              value=value,
254              use_locking=use_locking,
255              name=name,
256              read_value=read_value)
257
258  def assign_add(self, value, use_locking=False, name=None, read_value=True):
259    if tpu_util.enclosing_tpu_context() is None or context.executing_eagerly():
260      assign_add_fn = lambda var, *a, **ka: var.assign_add(*a, **ka)
261      return self._update(
262          assign_add_fn,
263          value=value,
264          use_locking=use_locking,
265          name=name,
266          read_value=read_value)
267    else:
268      return tpu_util.make_raw_assign_fn(
269          gen_resource_variable_ops.assign_add_variable_op)(
270              self,
271              value=value,
272              use_locking=use_locking,
273              name=name,
274              read_value=read_value)
275
276  def __str__(self):
277    debug_str = ',\n'.join(
278        '  %d: %s' % (i, v) for i, v in enumerate(self._vars))
279    return '%s:{\n%s\n}' % (self.__class__.__name__, debug_str)
280
281  def __repr__(self):
282    debug_repr = ',\n'.join(
283        '  %d: %r' % (i, v) for i, v in enumerate(self._vars))
284    return '%s:{\n%s\n}' % (self.__class__.__name__, debug_repr)
285
286
287# Register a conversion function which reads the value of the variable,
288# allowing instances of the class to be used as tensors.
289def _tensor_conversion_tpu_replicated_var(var,
290                                          dtype=None,
291                                          name=None,
292                                          as_ref=False):
293  return var._dense_var_to_tensor(dtype=dtype, name=name, as_ref=as_ref)  # pylint: disable=protected-access
294
295
296ops.register_tensor_conversion_function(TPUReplicatedVariable,
297                                        _tensor_conversion_tpu_replicated_var)
298