xref: /aosp_15_r20/external/tensorflow/tensorflow/python/keras/mixed_precision/autocast_variable.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Contains AutoCastVariable, a variable which automatically casts itself."""
16
17import threading
18from tensorflow.python.eager import context
19from tensorflow.python.framework import ops
20from tensorflow.python.keras.distribute import distributed_training_utils
21from tensorflow.python.ops import math_ops
22from tensorflow.python.ops import resource_variable_ops
23from tensorflow.python.ops import variables
24from tensorflow.python.types import core
25
26
27# _autocast_dtype.dtype is the dtype AutoCastVariables should be cast to, or
28# None if AutoCastVariables should not be cast.
29_autocast_dtype = threading.local()
30
31
32def numpy_text(tensor, is_repr=False):
33  """Human readable representation of a tensor's numpy value."""
34  if tensor.dtype.is_numpy_compatible:
35    # pylint: disable=protected-access
36    text = repr(tensor._numpy()) if is_repr else str(tensor._numpy())
37    # pylint: enable=protected-access
38  else:
39    text = '<unprintable>'
40  if '\n' in text:
41    text = '\n' + text
42  return text
43
44
45class AutoCastVariable(variables.Variable, core.Tensor):
46  """Variable that will cast itself to a different dtype in applicable contexts.
47
48  This class wraps a floating-point `tf.Variable`. It emulates the variable
49  interface and delegates to the wrapped variable, but it additionally will cast
50  the wrapped variable under an `enable_auto_cast_variables(dtype)` context
51  manager.
52
53  For example:
54
55  >>> v = tf.Variable(1.0, dtype=tf.float32)
56  >>> v = AutoCastVariable(v)
57  >>> tf.identity(v).dtype
58  tf.float32
59  >>> with enable_auto_cast_variables(tf.float16):
60  ...   tf.identity(v).dtype
61  tf.float16
62
63  The purpose of this class is to allow Keras layers to create variables in
64  float32, and automatically cast them to float16 or bfloat16 when the layer is
65  called.
66  """
67
68  def __init__(self, variable):
69    """Creates an AutoCastVariable instance.
70
71    Args:
72      variable: A floating-point resource variable to wrap.
73
74    Raises:
75      ValueError: If `variable` is not a floating-point resource variable
76    """
77    if not isinstance(variable, variables.Variable):
78      raise ValueError('variable must be of type tf.ResourceVariable, but got: '
79                       '%s' % variable)
80    if not variable.dtype.is_floating:
81      raise ValueError('variable must be a floating point variable but has '
82                       'type: %s' % variable.dtype.name)
83    self._variable = variable
84    # 'delegate' means AutoCastVariable.op return self._variable.op, which will
85    # raise an AttributeError in Eager (as intended). If set to any other value,
86    # AutoCastVariable.op returns that value instead, which is used to set the
87    # op attribute in AutoCastVariable.assign().
88    self._op = 'delegate'
89
90  def _should_cast(self):
91    """Returns True if this variable should be casted when accessed."""
92    autocast_dtype = getattr(_autocast_dtype, 'dtype', None)
93    return autocast_dtype is not None and self.dtype != autocast_dtype
94
95  @property
96  def dtype(self):
97    """The dtype of the underlying variable, before any casts are done."""
98    return self._variable.dtype
99
100  @property
101  def true_dtype(self):
102    """Deprecated alias of `dtype`."""
103    return self._variable.dtype
104
105  @property
106  def _cast_dtype(self):
107    dtype = getattr(_autocast_dtype, 'dtype', None)
108    return dtype or self._variable.dtype
109
110  def value(self):
111    val = self._variable.value()
112    if not self._should_cast():
113      return val
114    return math_ops.cast(val, self._cast_dtype)
115
116  def read_value(self):
117    val = self._variable.read_value()
118    return math_ops.cast(val, self._cast_dtype)
119
120  def sparse_read(self, indices, name=None):
121    """Reads the value of this variable sparsely, using `gather`."""
122    val = self._variable.sparse_read(indices, name=name)
123    return math_ops.cast(val, self._cast_dtype)
124
125  def gather_nd(self, indices, name=None):
126    """Gather slices of the variable into a Tensor."""
127    val = self._variable.gather_nd(indices, name=name)
128    return math_ops.cast(val, self._cast_dtype)
129
130  def __getattr__(self, name):
131    return getattr(self._variable, name)
132
133  def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False):
134    """Converts this variable to a tensor."""
135    if as_ref:
136      # This ValueError should not occur in practice since it is impossible to
137      # pass as_ref=True using public APIs.
138      raise ValueError('Cannot convert AutoCastVariable to a tensor if '
139                       'as_ref=True is passed to convert_to_tensor')
140    if not self._should_cast():
141      return ops.convert_to_tensor_v2_with_dispatch(self._variable, dtype=dtype,
142                                                    name=name)
143    if dtype is not None and not dtype.is_compatible_with(self._cast_dtype):
144      raise ValueError(
145          'Incompatible type conversion requested to type {!r} for '
146          'AutoCastVariable which is casted to type {!r}'.format(
147              dtype.name, self._cast_dtype.name))
148    val = ops.convert_to_tensor_v2_with_dispatch(
149        self._variable, dtype=self._variable.dtype, name=name)
150    return math_ops.cast(val, self._cast_dtype)
151
152  def _should_act_as_resource_variable(self):
153    """Pass resource_variable_ops.is_resource_variable check."""
154    pass
155
156  def __repr__(self):
157    if context.executing_eagerly() and not self._in_graph_mode:
158      repr_str = ("<AutoCastVariable '{v.name}' shape={v.shape} "
159                  'dtype={v.dtype.name} dtype_to_cast_to={v._cast_dtype.name}, '
160                  'numpy={np_repr}>')
161      return repr_str.format(
162          v=self, np_repr=numpy_text(self.read_value(), is_repr=True))
163    else:
164      repr_str = ("<AutoCastVariable '{v.name}' shape={v.shape} "
165                  'dtype={v.dtype.name} dtype_to_cast_to={v._cast_dtype.name}>')
166      return repr_str.format(v=self)
167
168  # Method delegations: We delegate the following methods to self._variable.
169  # Each of these methods simply calls the same method on self._variable. The
170  # base Variable raises NotImplementedError for most of these, so we must
171  # override them.
172  #
173  # We do not define the following methods from Variable for the following
174  # reasons:
175  #   * 'count_up_to': This method only applies to int variables, which cannot
176  #     be wrapped with an AutoCastVariable.
177  #   * 'ref': Instead we inherit the definition from Variable.
178  #     If we defined and delegated to Variable, the ref of an AutoCastVariable
179  #     would be the same as the ref of the underlying variable, which would be
180  #     strange as they are different Python objects.
181
182  def set_shape(self, shape):
183    return self._variable.set_shape(self, shape)
184
185  @property
186  def trainable(self):
187    return self._variable.trainable
188
189  @property
190  def synchronization(self):
191    return self._variable.synchronization
192
193  @property
194  def aggregation(self):
195    return self._variable.aggregation
196
197  def eval(self, session=None):
198    return self._variable.eval(session)
199
200  def initialized_value(self):
201    return self._variable.initialized_value()
202
203  @property
204  def initial_value(self):
205    return self._variable.initial_value
206
207  @property
208  def constraint(self):
209    return self._variable.constraint
210
211  def _apply_assign_update(self,
212                           update_fn,
213                           value,
214                           use_locking=None,
215                           name=None,
216                           read_value=True):
217    # TODO(b/146181571): This logic can be simplified once
218    # DistributedVariable.assign returns a DistributedVariable. Currently for
219    # MirroredStrategy, it returns a Mirrored value.
220    if ops.executing_eagerly_outside_functions():
221      assign_op = update_fn(value, use_locking, name, False)
222      if read_value:
223        # We create a new AutoCastVariable with the same underlying tf.Variable.
224        # The new AutoCastVariable is identical except the 'op' attribute is
225        # defined. This matches the behavior of tf.Variable.assign.
226        var = create_autocast_variable(self._variable)
227        var._op = assign_op  # pylint:disable=protected-access
228        return var
229      return assign_op
230
231    # Fallback to wrapping the returned variable in graph mode if possible
232    assign_var = update_fn(value, use_locking, name, read_value)
233    if read_value and resource_variable_ops.is_resource_variable(assign_var):
234      return create_autocast_variable(assign_var)
235    return assign_var
236
237  def _apply_update(self, update_fn, *args, **kwargs):
238    update_var = update_fn(*args, **kwargs)
239    if ops.executing_eagerly_outside_functions():
240      return self
241
242    # Fallback to wrapping the returned variable in graph mode if possible
243    if resource_variable_ops.is_resource_variable(update_var):
244      return create_autocast_variable(update_var)
245    return update_var
246
247  def assign(self, value, use_locking=None, name=None, read_value=True):
248    return self._apply_assign_update(self._variable.assign, value, use_locking,
249                                     name, read_value)
250
251  def assign_add(self, delta, use_locking=None, name=None, read_value=True):
252    return self._apply_assign_update(self._variable.assign_add, delta,
253                                     use_locking, name, read_value)
254
255  def assign_sub(self, delta, use_locking=None, name=None, read_value=True):
256    return self._apply_assign_update(self._variable.assign_sub, delta,
257                                     use_locking, name, read_value)
258
259  def scatter_sub(self, sparse_delta, use_locking=False, name=None):
260    return self._apply_update(self._variable.scatter_sub, sparse_delta,
261                              use_locking, name)
262
263  def scatter_add(self, sparse_delta, use_locking=False, name=None):
264    return self._apply_update(self._variable.scatter_add, sparse_delta,
265                              use_locking, name)
266
267  def scatter_max(self, sparse_delta, use_locking=False, name=None):
268    return self._apply_update(self._variable.scatter_max, sparse_delta,
269                              use_locking, name)
270
271  def scatter_min(self, sparse_delta, use_locking=False, name=None):
272    return self._apply_update(self._variable.scatter_min, sparse_delta,
273                              use_locking, name)
274
275  def scatter_mul(self, sparse_delta, use_locking=False, name=None):
276    return self._apply_update(self._variable.scatter_mul, sparse_delta,
277                              use_locking, name)
278
279  def scatter_div(self, sparse_delta, use_locking=False, name=None):
280    return self._apply_update(self._variable.scatter_div, sparse_delta,
281                              use_locking, name)
282
283  def scatter_update(self, sparse_delta, use_locking=False, name=None):
284    return self._apply_update(self._variable.scatter_update, sparse_delta,
285                              use_locking, name)
286
287  def batch_scatter_update(self, sparse_delta, use_locking=False, name=None):
288    return self._apply_update(self._variable.batch_scatter_update, sparse_delta,
289                              use_locking, name)
290
291  def scatter_nd_sub(self, indices, updates, name=None):
292    return self._apply_update(self._variable.scatter_nd_sub, indices, updates,
293                              name)
294
295  def scatter_nd_add(self, indices, updates, name=None):
296    return self._apply_update(self._variable.scatter_nd_add, indices, updates,
297                              name)
298
299  def scatter_nd_update(self, indices, updates, name=None):
300    return self._apply_update(self._variable.scatter_nd_update, indices,
301                              updates, name)
302
303  def load(self, value, session=None):
304    return self._variable.load(value, session)
305
306  @property
307  def name(self):
308    return self._variable.name
309
310  @property
311  def _shared_name(self):
312    return self._variable._shared_name  # pylint:disable=protected-access
313
314  @property
315  def initializer(self):
316    return self._variable.initializer
317
318  @property
319  def device(self):
320    return self._variable.device
321
322  @property
323  def op(self):
324    if self._op == 'delegate':
325      return self._variable.op
326    return self._op
327
328  def _as_graph_element(self):
329    graph_element = self._variable._as_graph_element()  # pylint:disable=protected-access
330    if graph_element is None:
331      return self._op
332    return graph_element
333
334  @property
335  def graph(self):
336    return self._variable.graph
337
338  @property
339  def shape(self):
340    return self._variable.shape
341
342  def get_shape(self):
343    return self._variable.get_shape()
344
345  def _gather_saveables_for_checkpoint(self):
346    # By delegating this method to the wrapped variable, checkpoints with
347    # AutoCastVariables are identical to checkpoints with normal variables.
348    # Therefore models checkpointed with AutoCastVariables can be restored on
349    # models with normal variables, and vice versa.
350    return self._variable._gather_saveables_for_checkpoint()  # pylint:disable=protected-access
351
352  def _map_resources(self, save_options):
353    # By delegating this method to the wrapped variable, SavedModel with
354    # AutoCastVariables are identical to SavedModel with normal variables.
355    obj_map, resource_map = self._variable._map_resources(save_options)  # pylint:disable=protected-access
356    obj_map[self] = obj_map[self._variable]
357    return obj_map, resource_map
358
359  # TODO(reedwm): Maybe encode the fact the variable is an AutoCastVariable in
360  # to_proto().
361  def to_proto(self, export_scope=None):
362    return self._variable.to_proto(export_scope)
363
364  def from_proto(self, variable_def, import_scope=None):
365    return self._variable.from_proto(variable_def, import_scope)
366
367  # Delegate the private attributes _handle_name and _initializer_op to
368  # self._variable. SavedModel sets these attributes when loading a model. For
369  # example, it sets _handle_name here:
370  # https://github.com/tensorflow/tensorflow/blob/db26bd574fa95b5bdd53c08463dd19407cc0297e/tensorflow/python/keras/saving/saved_model/load.py#L211
371  # We need to expose these attributes on AutoCastVariable as well for
372  # SavedModel to work properly.
373  # TODO(reedwm/kathywu): Find a better way to support SavedModel. Exposing
374  # private attributes is hacky and difficult to maintain.
375  @property
376  def _handle_name(self):
377    return self._variable._handle_name  # pylint: disable=protected-access
378
379  @_handle_name.setter
380  def _handle_name(self, handle_name):
381    self._variable._handle_name = handle_name  # pylint: disable=protected-access
382
383  @property
384  def _initializer_op(self):
385    return self._variable._initializer_op  # pylint: disable=protected-access
386
387  @_initializer_op.setter
388  def _initializer_op(self, initializer_op):
389    self._variable._initializer_op = initializer_op  # pylint: disable=protected-access
390
391  # Operator overloads:
392  # Note we only overload operators that support floating-point types, as
393  # non-float variables cannot be wrapped with an AutoCastVariable.
394  # Also note: We call read_value() instead of value(), because value() causes
395  # gradients not to work properly when TPUStrategy is used: b/143380936
396
397  def __add__(self, o):
398    return self.read_value() + o
399
400  def __radd__(self, o):
401    return o + self.read_value()
402
403  def __sub__(self, o):
404    return self.read_value() - o
405
406  def __rsub__(self, o):
407    return o - self.read_value()
408
409  def __mul__(self, o):
410    return self.read_value() * o
411
412  def __rmul__(self, o):
413    return o * self.read_value()
414
415  def __truediv__(self, o):
416    return self.read_value() / o
417
418  def __rtruediv__(self, o):
419    return o / self.read_value()
420
421  def __floordiv__(self, o):
422    return self.read_value() // o
423
424  def __rfloordiv__(self, o):
425    return o // self.read_value()
426
427  def __mod__(self, o):
428    return self.read_value() % o
429
430  def __rmod__(self, o):
431    return o % self.read_value()
432
433  def __lt__(self, o):
434    return self.read_value() < o
435
436  def __le__(self, o):
437    return self.read_value() <= o
438
439  def __gt__(self, o):
440    return self.read_value() > o
441
442  def __ge__(self, o):
443    return self.read_value() >= o
444
445  def __getitem__(self, o):
446    return self.read_value()[o]
447
448  def __pow__(self, o, modulo=None):
449    return pow(self.read_value(), o, modulo)
450
451  def __rpow__(self, o):
452    return pow(o, self.read_value())
453
454  def __neg__(self):
455    return -self.read_value()  # pylint: disable=invalid-unary-operand-type
456
457  def __abs__(self):
458    return abs(self.read_value())
459
460  def __div__(self, o):
461    try:
462      return self.read_value().__div__(o)
463    except AttributeError:
464      # See https://docs.python.org/3/library/constants.html#NotImplemented
465      return NotImplemented
466
467  def __rdiv__(self, o):
468    try:
469      return self.read_value().__rdiv__(o)
470    except AttributeError:
471      # See https://docs.python.org/3/library/constants.html#NotImplemented
472      return NotImplemented
473
474  def __matmul__(self, o):
475    try:
476      return self.read_value().__matmul__(o)
477    except AttributeError:
478      # See https://docs.python.org/3/library/constants.html#NotImplemented
479      return NotImplemented
480
481  def __rmatmul__(self, o):
482    try:
483      return self.read_value().__rmatmul__(o)
484    except AttributeError:
485      # See https://docs.python.org/3/library/constants.html#NotImplemented
486      return NotImplemented
487
488  # pylint: enable=multiple-statements
489
490
491ops.register_tensor_conversion_function(AutoCastVariable,
492                                        AutoCastVariable._dense_var_to_tensor)  # pylint:disable=protected-access
493
494
495def create_autocast_variable(variable):
496  """Creates an AutoCastVariable that wraps another variable.
497
498  This typically just returns `AutoCastVariable(variable)`. But, if the variable
499  is a DistributedVariable or one of its subclasses, we instead dynamically
500  create a class that subclasses from both AutoCastVariable and
501  variable.__class__. This is so the returned variable will still pass
502  `isinstance(variable, variable.__class__)`, which is required for
503  DistributedVariables and its subclasses to work properly.
504
505  Args:
506    variable: A floating-point resource variable to wrap.
507
508  Returns:
509    An AutoCastVariable that wraps the variable.
510  """
511  if not distributed_training_utils.is_distributed_variable(variable):
512    return AutoCastVariable(variable)
513
514  class AutoCastDistributedVariable(AutoCastVariable, variable.__class__):
515    """An AutoCastVariable that also subclasses from variable.__class__.
516
517    variable.__class__ is either a DistributedVariable or an
518    AggregatingVariable.
519    """
520
521    def __repr__(self):
522
523      # pylint: disable=missing-format-attribute
524      return ('<AutoCastDistributedVariable dtype={v.dtype.name} '
525              'dtype_to_cast_to={v._cast_dtype.name} '
526              'inner_variable={v._variable}>'
527             ).format(v=self)
528      # pylint: enable=missing-format-attribute
529
530  return AutoCastDistributedVariable(variable)
531
532
533class enable_auto_cast_variables(object):  # pylint:disable=invalid-name
534  """Context manager which enables the autocasting of `AutoCastVariable`s.
535
536  Under this context manager, `AutoCastVariable`s will be cast to `dtype` if
537  `dtype` is floating-point. Otherwise, `AutoCastVariable`s will not be cast.
538  """
539
540  __slots__ = ['_dtype', '_prev_dtype']
541
542  def __init__(self, dtype):
543    if dtype and not dtype.is_floating:
544      dtype = None
545    self._dtype = dtype
546
547  def __enter__(self):
548    self._prev_dtype = getattr(_autocast_dtype, 'dtype', None)
549    _autocast_dtype.dtype = self._dtype
550
551  def __exit__(self, type_arg, value_arg, traceback_arg):
552    _autocast_dtype.dtype = self._prev_dtype
553