1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Contains AutoCastVariable, a variable which automatically casts itself.""" 16 17import threading 18from tensorflow.python.eager import context 19from tensorflow.python.framework import ops 20from tensorflow.python.keras.distribute import distributed_training_utils 21from tensorflow.python.ops import math_ops 22from tensorflow.python.ops import resource_variable_ops 23from tensorflow.python.ops import variables 24from tensorflow.python.types import core 25 26 27# _autocast_dtype.dtype is the dtype AutoCastVariables should be cast to, or 28# None if AutoCastVariables should not be cast. 29_autocast_dtype = threading.local() 30 31 32def numpy_text(tensor, is_repr=False): 33 """Human readable representation of a tensor's numpy value.""" 34 if tensor.dtype.is_numpy_compatible: 35 # pylint: disable=protected-access 36 text = repr(tensor._numpy()) if is_repr else str(tensor._numpy()) 37 # pylint: enable=protected-access 38 else: 39 text = '<unprintable>' 40 if '\n' in text: 41 text = '\n' + text 42 return text 43 44 45class AutoCastVariable(variables.Variable, core.Tensor): 46 """Variable that will cast itself to a different dtype in applicable contexts. 47 48 This class wraps a floating-point `tf.Variable`. It emulates the variable 49 interface and delegates to the wrapped variable, but it additionally will cast 50 the wrapped variable under an `enable_auto_cast_variables(dtype)` context 51 manager. 52 53 For example: 54 55 >>> v = tf.Variable(1.0, dtype=tf.float32) 56 >>> v = AutoCastVariable(v) 57 >>> tf.identity(v).dtype 58 tf.float32 59 >>> with enable_auto_cast_variables(tf.float16): 60 ... tf.identity(v).dtype 61 tf.float16 62 63 The purpose of this class is to allow Keras layers to create variables in 64 float32, and automatically cast them to float16 or bfloat16 when the layer is 65 called. 66 """ 67 68 def __init__(self, variable): 69 """Creates an AutoCastVariable instance. 70 71 Args: 72 variable: A floating-point resource variable to wrap. 73 74 Raises: 75 ValueError: If `variable` is not a floating-point resource variable 76 """ 77 if not isinstance(variable, variables.Variable): 78 raise ValueError('variable must be of type tf.ResourceVariable, but got: ' 79 '%s' % variable) 80 if not variable.dtype.is_floating: 81 raise ValueError('variable must be a floating point variable but has ' 82 'type: %s' % variable.dtype.name) 83 self._variable = variable 84 # 'delegate' means AutoCastVariable.op return self._variable.op, which will 85 # raise an AttributeError in Eager (as intended). If set to any other value, 86 # AutoCastVariable.op returns that value instead, which is used to set the 87 # op attribute in AutoCastVariable.assign(). 88 self._op = 'delegate' 89 90 def _should_cast(self): 91 """Returns True if this variable should be casted when accessed.""" 92 autocast_dtype = getattr(_autocast_dtype, 'dtype', None) 93 return autocast_dtype is not None and self.dtype != autocast_dtype 94 95 @property 96 def dtype(self): 97 """The dtype of the underlying variable, before any casts are done.""" 98 return self._variable.dtype 99 100 @property 101 def true_dtype(self): 102 """Deprecated alias of `dtype`.""" 103 return self._variable.dtype 104 105 @property 106 def _cast_dtype(self): 107 dtype = getattr(_autocast_dtype, 'dtype', None) 108 return dtype or self._variable.dtype 109 110 def value(self): 111 val = self._variable.value() 112 if not self._should_cast(): 113 return val 114 return math_ops.cast(val, self._cast_dtype) 115 116 def read_value(self): 117 val = self._variable.read_value() 118 return math_ops.cast(val, self._cast_dtype) 119 120 def sparse_read(self, indices, name=None): 121 """Reads the value of this variable sparsely, using `gather`.""" 122 val = self._variable.sparse_read(indices, name=name) 123 return math_ops.cast(val, self._cast_dtype) 124 125 def gather_nd(self, indices, name=None): 126 """Gather slices of the variable into a Tensor.""" 127 val = self._variable.gather_nd(indices, name=name) 128 return math_ops.cast(val, self._cast_dtype) 129 130 def __getattr__(self, name): 131 return getattr(self._variable, name) 132 133 def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False): 134 """Converts this variable to a tensor.""" 135 if as_ref: 136 # This ValueError should not occur in practice since it is impossible to 137 # pass as_ref=True using public APIs. 138 raise ValueError('Cannot convert AutoCastVariable to a tensor if ' 139 'as_ref=True is passed to convert_to_tensor') 140 if not self._should_cast(): 141 return ops.convert_to_tensor_v2_with_dispatch(self._variable, dtype=dtype, 142 name=name) 143 if dtype is not None and not dtype.is_compatible_with(self._cast_dtype): 144 raise ValueError( 145 'Incompatible type conversion requested to type {!r} for ' 146 'AutoCastVariable which is casted to type {!r}'.format( 147 dtype.name, self._cast_dtype.name)) 148 val = ops.convert_to_tensor_v2_with_dispatch( 149 self._variable, dtype=self._variable.dtype, name=name) 150 return math_ops.cast(val, self._cast_dtype) 151 152 def _should_act_as_resource_variable(self): 153 """Pass resource_variable_ops.is_resource_variable check.""" 154 pass 155 156 def __repr__(self): 157 if context.executing_eagerly() and not self._in_graph_mode: 158 repr_str = ("<AutoCastVariable '{v.name}' shape={v.shape} " 159 'dtype={v.dtype.name} dtype_to_cast_to={v._cast_dtype.name}, ' 160 'numpy={np_repr}>') 161 return repr_str.format( 162 v=self, np_repr=numpy_text(self.read_value(), is_repr=True)) 163 else: 164 repr_str = ("<AutoCastVariable '{v.name}' shape={v.shape} " 165 'dtype={v.dtype.name} dtype_to_cast_to={v._cast_dtype.name}>') 166 return repr_str.format(v=self) 167 168 # Method delegations: We delegate the following methods to self._variable. 169 # Each of these methods simply calls the same method on self._variable. The 170 # base Variable raises NotImplementedError for most of these, so we must 171 # override them. 172 # 173 # We do not define the following methods from Variable for the following 174 # reasons: 175 # * 'count_up_to': This method only applies to int variables, which cannot 176 # be wrapped with an AutoCastVariable. 177 # * 'ref': Instead we inherit the definition from Variable. 178 # If we defined and delegated to Variable, the ref of an AutoCastVariable 179 # would be the same as the ref of the underlying variable, which would be 180 # strange as they are different Python objects. 181 182 def set_shape(self, shape): 183 return self._variable.set_shape(self, shape) 184 185 @property 186 def trainable(self): 187 return self._variable.trainable 188 189 @property 190 def synchronization(self): 191 return self._variable.synchronization 192 193 @property 194 def aggregation(self): 195 return self._variable.aggregation 196 197 def eval(self, session=None): 198 return self._variable.eval(session) 199 200 def initialized_value(self): 201 return self._variable.initialized_value() 202 203 @property 204 def initial_value(self): 205 return self._variable.initial_value 206 207 @property 208 def constraint(self): 209 return self._variable.constraint 210 211 def _apply_assign_update(self, 212 update_fn, 213 value, 214 use_locking=None, 215 name=None, 216 read_value=True): 217 # TODO(b/146181571): This logic can be simplified once 218 # DistributedVariable.assign returns a DistributedVariable. Currently for 219 # MirroredStrategy, it returns a Mirrored value. 220 if ops.executing_eagerly_outside_functions(): 221 assign_op = update_fn(value, use_locking, name, False) 222 if read_value: 223 # We create a new AutoCastVariable with the same underlying tf.Variable. 224 # The new AutoCastVariable is identical except the 'op' attribute is 225 # defined. This matches the behavior of tf.Variable.assign. 226 var = create_autocast_variable(self._variable) 227 var._op = assign_op # pylint:disable=protected-access 228 return var 229 return assign_op 230 231 # Fallback to wrapping the returned variable in graph mode if possible 232 assign_var = update_fn(value, use_locking, name, read_value) 233 if read_value and resource_variable_ops.is_resource_variable(assign_var): 234 return create_autocast_variable(assign_var) 235 return assign_var 236 237 def _apply_update(self, update_fn, *args, **kwargs): 238 update_var = update_fn(*args, **kwargs) 239 if ops.executing_eagerly_outside_functions(): 240 return self 241 242 # Fallback to wrapping the returned variable in graph mode if possible 243 if resource_variable_ops.is_resource_variable(update_var): 244 return create_autocast_variable(update_var) 245 return update_var 246 247 def assign(self, value, use_locking=None, name=None, read_value=True): 248 return self._apply_assign_update(self._variable.assign, value, use_locking, 249 name, read_value) 250 251 def assign_add(self, delta, use_locking=None, name=None, read_value=True): 252 return self._apply_assign_update(self._variable.assign_add, delta, 253 use_locking, name, read_value) 254 255 def assign_sub(self, delta, use_locking=None, name=None, read_value=True): 256 return self._apply_assign_update(self._variable.assign_sub, delta, 257 use_locking, name, read_value) 258 259 def scatter_sub(self, sparse_delta, use_locking=False, name=None): 260 return self._apply_update(self._variable.scatter_sub, sparse_delta, 261 use_locking, name) 262 263 def scatter_add(self, sparse_delta, use_locking=False, name=None): 264 return self._apply_update(self._variable.scatter_add, sparse_delta, 265 use_locking, name) 266 267 def scatter_max(self, sparse_delta, use_locking=False, name=None): 268 return self._apply_update(self._variable.scatter_max, sparse_delta, 269 use_locking, name) 270 271 def scatter_min(self, sparse_delta, use_locking=False, name=None): 272 return self._apply_update(self._variable.scatter_min, sparse_delta, 273 use_locking, name) 274 275 def scatter_mul(self, sparse_delta, use_locking=False, name=None): 276 return self._apply_update(self._variable.scatter_mul, sparse_delta, 277 use_locking, name) 278 279 def scatter_div(self, sparse_delta, use_locking=False, name=None): 280 return self._apply_update(self._variable.scatter_div, sparse_delta, 281 use_locking, name) 282 283 def scatter_update(self, sparse_delta, use_locking=False, name=None): 284 return self._apply_update(self._variable.scatter_update, sparse_delta, 285 use_locking, name) 286 287 def batch_scatter_update(self, sparse_delta, use_locking=False, name=None): 288 return self._apply_update(self._variable.batch_scatter_update, sparse_delta, 289 use_locking, name) 290 291 def scatter_nd_sub(self, indices, updates, name=None): 292 return self._apply_update(self._variable.scatter_nd_sub, indices, updates, 293 name) 294 295 def scatter_nd_add(self, indices, updates, name=None): 296 return self._apply_update(self._variable.scatter_nd_add, indices, updates, 297 name) 298 299 def scatter_nd_update(self, indices, updates, name=None): 300 return self._apply_update(self._variable.scatter_nd_update, indices, 301 updates, name) 302 303 def load(self, value, session=None): 304 return self._variable.load(value, session) 305 306 @property 307 def name(self): 308 return self._variable.name 309 310 @property 311 def _shared_name(self): 312 return self._variable._shared_name # pylint:disable=protected-access 313 314 @property 315 def initializer(self): 316 return self._variable.initializer 317 318 @property 319 def device(self): 320 return self._variable.device 321 322 @property 323 def op(self): 324 if self._op == 'delegate': 325 return self._variable.op 326 return self._op 327 328 def _as_graph_element(self): 329 graph_element = self._variable._as_graph_element() # pylint:disable=protected-access 330 if graph_element is None: 331 return self._op 332 return graph_element 333 334 @property 335 def graph(self): 336 return self._variable.graph 337 338 @property 339 def shape(self): 340 return self._variable.shape 341 342 def get_shape(self): 343 return self._variable.get_shape() 344 345 def _gather_saveables_for_checkpoint(self): 346 # By delegating this method to the wrapped variable, checkpoints with 347 # AutoCastVariables are identical to checkpoints with normal variables. 348 # Therefore models checkpointed with AutoCastVariables can be restored on 349 # models with normal variables, and vice versa. 350 return self._variable._gather_saveables_for_checkpoint() # pylint:disable=protected-access 351 352 def _map_resources(self, save_options): 353 # By delegating this method to the wrapped variable, SavedModel with 354 # AutoCastVariables are identical to SavedModel with normal variables. 355 obj_map, resource_map = self._variable._map_resources(save_options) # pylint:disable=protected-access 356 obj_map[self] = obj_map[self._variable] 357 return obj_map, resource_map 358 359 # TODO(reedwm): Maybe encode the fact the variable is an AutoCastVariable in 360 # to_proto(). 361 def to_proto(self, export_scope=None): 362 return self._variable.to_proto(export_scope) 363 364 def from_proto(self, variable_def, import_scope=None): 365 return self._variable.from_proto(variable_def, import_scope) 366 367 # Delegate the private attributes _handle_name and _initializer_op to 368 # self._variable. SavedModel sets these attributes when loading a model. For 369 # example, it sets _handle_name here: 370 # https://github.com/tensorflow/tensorflow/blob/db26bd574fa95b5bdd53c08463dd19407cc0297e/tensorflow/python/keras/saving/saved_model/load.py#L211 371 # We need to expose these attributes on AutoCastVariable as well for 372 # SavedModel to work properly. 373 # TODO(reedwm/kathywu): Find a better way to support SavedModel. Exposing 374 # private attributes is hacky and difficult to maintain. 375 @property 376 def _handle_name(self): 377 return self._variable._handle_name # pylint: disable=protected-access 378 379 @_handle_name.setter 380 def _handle_name(self, handle_name): 381 self._variable._handle_name = handle_name # pylint: disable=protected-access 382 383 @property 384 def _initializer_op(self): 385 return self._variable._initializer_op # pylint: disable=protected-access 386 387 @_initializer_op.setter 388 def _initializer_op(self, initializer_op): 389 self._variable._initializer_op = initializer_op # pylint: disable=protected-access 390 391 # Operator overloads: 392 # Note we only overload operators that support floating-point types, as 393 # non-float variables cannot be wrapped with an AutoCastVariable. 394 # Also note: We call read_value() instead of value(), because value() causes 395 # gradients not to work properly when TPUStrategy is used: b/143380936 396 397 def __add__(self, o): 398 return self.read_value() + o 399 400 def __radd__(self, o): 401 return o + self.read_value() 402 403 def __sub__(self, o): 404 return self.read_value() - o 405 406 def __rsub__(self, o): 407 return o - self.read_value() 408 409 def __mul__(self, o): 410 return self.read_value() * o 411 412 def __rmul__(self, o): 413 return o * self.read_value() 414 415 def __truediv__(self, o): 416 return self.read_value() / o 417 418 def __rtruediv__(self, o): 419 return o / self.read_value() 420 421 def __floordiv__(self, o): 422 return self.read_value() // o 423 424 def __rfloordiv__(self, o): 425 return o // self.read_value() 426 427 def __mod__(self, o): 428 return self.read_value() % o 429 430 def __rmod__(self, o): 431 return o % self.read_value() 432 433 def __lt__(self, o): 434 return self.read_value() < o 435 436 def __le__(self, o): 437 return self.read_value() <= o 438 439 def __gt__(self, o): 440 return self.read_value() > o 441 442 def __ge__(self, o): 443 return self.read_value() >= o 444 445 def __getitem__(self, o): 446 return self.read_value()[o] 447 448 def __pow__(self, o, modulo=None): 449 return pow(self.read_value(), o, modulo) 450 451 def __rpow__(self, o): 452 return pow(o, self.read_value()) 453 454 def __neg__(self): 455 return -self.read_value() # pylint: disable=invalid-unary-operand-type 456 457 def __abs__(self): 458 return abs(self.read_value()) 459 460 def __div__(self, o): 461 try: 462 return self.read_value().__div__(o) 463 except AttributeError: 464 # See https://docs.python.org/3/library/constants.html#NotImplemented 465 return NotImplemented 466 467 def __rdiv__(self, o): 468 try: 469 return self.read_value().__rdiv__(o) 470 except AttributeError: 471 # See https://docs.python.org/3/library/constants.html#NotImplemented 472 return NotImplemented 473 474 def __matmul__(self, o): 475 try: 476 return self.read_value().__matmul__(o) 477 except AttributeError: 478 # See https://docs.python.org/3/library/constants.html#NotImplemented 479 return NotImplemented 480 481 def __rmatmul__(self, o): 482 try: 483 return self.read_value().__rmatmul__(o) 484 except AttributeError: 485 # See https://docs.python.org/3/library/constants.html#NotImplemented 486 return NotImplemented 487 488 # pylint: enable=multiple-statements 489 490 491ops.register_tensor_conversion_function(AutoCastVariable, 492 AutoCastVariable._dense_var_to_tensor) # pylint:disable=protected-access 493 494 495def create_autocast_variable(variable): 496 """Creates an AutoCastVariable that wraps another variable. 497 498 This typically just returns `AutoCastVariable(variable)`. But, if the variable 499 is a DistributedVariable or one of its subclasses, we instead dynamically 500 create a class that subclasses from both AutoCastVariable and 501 variable.__class__. This is so the returned variable will still pass 502 `isinstance(variable, variable.__class__)`, which is required for 503 DistributedVariables and its subclasses to work properly. 504 505 Args: 506 variable: A floating-point resource variable to wrap. 507 508 Returns: 509 An AutoCastVariable that wraps the variable. 510 """ 511 if not distributed_training_utils.is_distributed_variable(variable): 512 return AutoCastVariable(variable) 513 514 class AutoCastDistributedVariable(AutoCastVariable, variable.__class__): 515 """An AutoCastVariable that also subclasses from variable.__class__. 516 517 variable.__class__ is either a DistributedVariable or an 518 AggregatingVariable. 519 """ 520 521 def __repr__(self): 522 523 # pylint: disable=missing-format-attribute 524 return ('<AutoCastDistributedVariable dtype={v.dtype.name} ' 525 'dtype_to_cast_to={v._cast_dtype.name} ' 526 'inner_variable={v._variable}>' 527 ).format(v=self) 528 # pylint: enable=missing-format-attribute 529 530 return AutoCastDistributedVariable(variable) 531 532 533class enable_auto_cast_variables(object): # pylint:disable=invalid-name 534 """Context manager which enables the autocasting of `AutoCastVariable`s. 535 536 Under this context manager, `AutoCastVariable`s will be cast to `dtype` if 537 `dtype` is floating-point. Otherwise, `AutoCastVariable`s will not be cast. 538 """ 539 540 __slots__ = ['_dtype', '_prev_dtype'] 541 542 def __init__(self, dtype): 543 if dtype and not dtype.is_floating: 544 dtype = None 545 self._dtype = dtype 546 547 def __enter__(self): 548 self._prev_dtype = getattr(_autocast_dtype, 'dtype', None) 549 _autocast_dtype.dtype = self._dtype 550 551 def __exit__(self, type_arg, value_arg, traceback_arg): 552 _autocast_dtype.dtype = self._prev_dtype 553