1# Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""A Variable class that is replicated to logical cores for model parallelism.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20from collections import abc 21import contextlib 22 23from tensorflow.compiler.xla.experimental.xla_sharding import xla_sharding 24from tensorflow.python.distribute import tpu_util 25from tensorflow.python.eager import context 26from tensorflow.python.framework import ops 27from tensorflow.python.ops import control_flow_ops 28from tensorflow.python.ops import gen_resource_variable_ops 29from tensorflow.python.ops import gen_tpu_partition_ops as tpu_partition_ops 30from tensorflow.python.ops import variable_scope 31from tensorflow.python.ops import variables as variables_lib 32from tensorflow.python.saved_model import save_context 33from tensorflow.python.trackable import base as trackable 34 35 36def _on_device_update(update_fn, var, value, **kwargs): 37 with ops.device(var.device): 38 return update_fn(var, value, **kwargs) 39 40 41class TPUReplicatedVariable(variables_lib.Variable): 42 """Container for replicated `Variables` that are treated as a single variable. 43 44 This class maintains a list of replicated variables that are stored on 45 separate logic TPU devices. TF2XLA bridge accesses these variables as 46 if they were a single variable. 47 """ 48 49 def __init__(self, variables, name='TPUReplicatedVariable'): 50 """Treats `variables` as a replicated list of `tf.Variable`s. 51 52 Example: 53 54 ``` 55 variables = [ 56 tf.Variable(..., shape=(10, 100), dtype=tf.float32), 57 tf.Variable(..., shape=(10, 100), dtype=tf.float32), 58 tf.Variable(..., shape=(10, 100), dtype=tf.float32), 59 tf.Variable(..., shape=(10, 100), dtype=tf.float32), 60 ] 61 replicated_variable = TPUReplicatedVariable(variables) 62 assert replicated_variable.shape.as_list() == [10, 100] 63 ``` 64 65 Args: 66 variables: A list of `ResourceVariable`s that comprise this replicated 67 variable. Variables should not be shared between different 68 `TPUReplicatedVariable` objects. 69 name: String. Name of this container. Defaults to "TPUReplicatedVariable". 70 """ 71 if not isinstance(variables, abc.Sequence) or not variables or any( 72 not isinstance(v, variables_lib.Variable) for v in variables): 73 raise TypeError('Argument `variables` should be a non-empty list of ' 74 f'`variables.Variable`s. Received {variables}') 75 76 if any(v.dtype != variables[0].dtype for v in variables): 77 raise ValueError( 78 'All elements in argument `variables` must have the same dtype. ' 79 f'Received dtypes: {[v.dtype for v in variables]}') 80 81 if any(v.shape != variables[0].shape for v in variables): 82 raise ValueError( 83 'All elements in argument `variables` must have the same shape. ' 84 f'Received shapes: {[v.shape for v in variables]}') 85 86 self._vars = variables 87 self._name = name 88 self._common_name = self._name.split(':')[0] 89 self._cached_value = None 90 91 def __iter__(self): 92 """Return an iterable for accessing the underlying sharded variables.""" 93 return iter(self._vars) 94 95 @property 96 def name(self): 97 """The name of this object. Used for checkpointing.""" 98 return self._name 99 100 @property 101 def dtype(self): 102 """The dtype of all `Variable`s in this object.""" 103 return self._vars[0].dtype 104 105 @property 106 def is_initialized(self): 107 return self._vars[0].is_initialized 108 109 @property 110 def trainable(self): 111 return self._vars[0].trainable 112 113 @property 114 def device(self): 115 """The device this variable is on.""" 116 return self._vars[0].device 117 118 @contextlib.contextmanager 119 def _handle_graph(self): 120 with self.handle.graph.as_default(): 121 yield 122 123 @contextlib.contextmanager 124 def _assign_dependencies(self): 125 if self._cached_value is not None: 126 with ops.control_dependencies([self._cached_value]): 127 yield 128 else: 129 yield 130 131 @property 132 def constraint(self): 133 return self._vars[0].constraint 134 135 @property 136 def _in_graph_mode(self): 137 return self._vars[0]._in_graph_mode # pylint: disable=protected-access 138 139 @property 140 def _unique_id(self): 141 return self._vars[0]._unique_id # pylint: disable=protected-access 142 143 @property 144 def graph(self): 145 return self._vars[0].graph 146 147 @property 148 def _shared_name(self): 149 return self._common_name 150 151 @property 152 def synchronization(self): 153 return variable_scope.VariableSynchronization.NONE 154 155 @property 156 def aggregation(self): 157 return variable_scope.VariableAggregation.NONE 158 159 @property 160 def variables(self): 161 """The list of `Variables`.""" 162 if save_context.in_save_context(): 163 return [self._vars[0]] 164 return self._vars 165 166 def _map_resources(self, save_options): 167 """For implementing `Trackable`.""" 168 first_var = self._vars[0] 169 obj_map, resource_map = first_var._map_resources(save_options) # pylint:disable=protected-access 170 for v in self._vars[1:]: 171 obj_map[v] = obj_map[first_var] 172 resource_map[v.handle] = resource_map[first_var.handle] 173 obj_map[self] = obj_map[first_var] 174 resource_map[self] = resource_map[first_var.handle] 175 return obj_map, resource_map 176 177 def _gather_saveables_for_saved_model(self): 178 return {trackable.VARIABLE_VALUE_KEY: self._vars[0]} 179 180 @property 181 def shape(self): 182 return self._vars[0].shape 183 184 @property 185 def handle(self): 186 if save_context.in_save_context() or context.executing_eagerly(): 187 return self._vars[0].handle 188 189 if tpu_util.enclosing_tpu_context() is None: 190 raise NotImplementedError('TPUReplicatedVariable.handle is not available ' 191 'outside tpu context or save context') 192 else: 193 with tpu_util.outside_or_skip_tpu_context(): 194 return xla_sharding.replicate( 195 tpu_partition_ops.tpu_partitioned_input( 196 [v.handle for v in self._vars], partition_dim=-1)) 197 198 def _read_variable_op(self): 199 return gen_resource_variable_ops.read_variable_op(self.handle, self.dtype) 200 201 def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False): 202 """Converts a variable to a tensor.""" 203 # pylint: disable=protected-access 204 if tpu_util.enclosing_tpu_context() is None: 205 return self.read_value() 206 else: 207 return self._read_variable_op() 208 209 def read_value(self): 210 return self._vars[0].read_value() 211 212 def _update(self, update_fn, value, **kwargs): 213 """Converts the value to tensor and updates the variable list.""" 214 input_tensor = ops.convert_to_tensor( 215 value, name='value_in_tensor', dtype=self.dtype) 216 217 return control_flow_ops.group( 218 *tuple( 219 _on_device_update(update_fn, v, input_tensor, **kwargs) 220 for v in self.variables)) 221 222 def assign(self, value, use_locking=False, name=None, read_value=True): 223 if tpu_util.enclosing_tpu_context() is None or context.executing_eagerly(): 224 assign_fn = lambda var, *a, **ka: var.assign(*a, **ka) 225 return self._update( 226 assign_fn, 227 value=value, 228 use_locking=use_locking, 229 name=name, 230 read_value=read_value) 231 else: 232 return tpu_util.make_raw_assign_fn( 233 gen_resource_variable_ops.assign_variable_op)( 234 self, 235 value=value, 236 use_locking=use_locking, 237 name=name, 238 read_value=read_value) 239 240 def assign_sub(self, value, use_locking=False, name=None, read_value=True): 241 if tpu_util.enclosing_tpu_context() is None or context.executing_eagerly(): 242 assign_sub_fn = lambda var, *a, **ka: var.assign_sub(*a, **ka) 243 return self._update( 244 assign_sub_fn, 245 value=value, 246 use_locking=use_locking, 247 name=name, 248 read_value=read_value) 249 else: 250 return tpu_util.make_raw_assign_fn( 251 gen_resource_variable_ops.assign_sub_variable_op)( 252 self, 253 value=value, 254 use_locking=use_locking, 255 name=name, 256 read_value=read_value) 257 258 def assign_add(self, value, use_locking=False, name=None, read_value=True): 259 if tpu_util.enclosing_tpu_context() is None or context.executing_eagerly(): 260 assign_add_fn = lambda var, *a, **ka: var.assign_add(*a, **ka) 261 return self._update( 262 assign_add_fn, 263 value=value, 264 use_locking=use_locking, 265 name=name, 266 read_value=read_value) 267 else: 268 return tpu_util.make_raw_assign_fn( 269 gen_resource_variable_ops.assign_add_variable_op)( 270 self, 271 value=value, 272 use_locking=use_locking, 273 name=name, 274 read_value=read_value) 275 276 def __str__(self): 277 debug_str = ',\n'.join( 278 ' %d: %s' % (i, v) for i, v in enumerate(self._vars)) 279 return '%s:{\n%s\n}' % (self.__class__.__name__, debug_str) 280 281 def __repr__(self): 282 debug_repr = ',\n'.join( 283 ' %d: %r' % (i, v) for i, v in enumerate(self._vars)) 284 return '%s:{\n%s\n}' % (self.__class__.__name__, debug_repr) 285 286 287# Register a conversion function which reads the value of the variable, 288# allowing instances of the class to be used as tensors. 289def _tensor_conversion_tpu_replicated_var(var, 290 dtype=None, 291 name=None, 292 as_ref=False): 293 return var._dense_var_to_tensor(dtype=dtype, name=name, as_ref=as_ref) # pylint: disable=protected-access 294 295 296ops.register_tensor_conversion_function(TPUReplicatedVariable, 297 _tensor_conversion_tpu_replicated_var) 298