xref: /aosp_15_r20/external/tensorflow/tensorflow/python/ops/default_gradient.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utilities for computing default gradients."""
16from tensorflow.python.framework import dtypes
17from tensorflow.python.framework import tensor_shape
18from tensorflow.python.ops import array_ops
19from tensorflow.python.ops import resource_variable_ops
20
21
22def get_zeros_dtype(t):
23  """Return the dtype for the default gradient for a Tensor."""
24  if t.dtype == dtypes.resource:
25    handle_data = resource_variable_ops.get_eager_safe_handle_data(t)
26    if (handle_data is None or not handle_data.is_set or
27        len(handle_data.shape_and_type) != 1):
28      raise ValueError("Internal error: Tried to take gradients (or similar) "
29                       "of a variable without handle data:\n%s" % str(t))
30    return handle_data.shape_and_type[0].dtype
31  return t.dtype
32
33
34def shape_and_dtype(t):
35  """Return the shape and dtype for the default gradient for a Tensor."""
36  if t.dtype == dtypes.resource:
37    handle_data = resource_variable_ops.get_eager_safe_handle_data(t)
38    if (handle_data is None or not handle_data.is_set or
39        len(handle_data.shape_and_type) != 1):
40      raise ValueError("Internal error: Tried to take gradients (or similar) "
41                       "of a variable without handle data:\n%s" % str(t))
42    shape_and_type = handle_data.shape_and_type[0]
43    return (tensor_shape.TensorShape(shape_and_type.shape),
44            dtypes.as_dtype(shape_and_type.dtype))
45  return t.shape, t.dtype
46
47
48def zeros_like(t):
49  """Like array_ops.zeros_like, but respects resource handles."""
50  if t.dtype == dtypes.resource:
51    return array_ops.zeros(*shape_and_dtype(t))
52  else:
53    return array_ops.zeros_like(t)
54
55
56def ones_like(t):
57  """Like array_ops.ones_like, but respects resource handles."""
58  if t.dtype == dtypes.resource:
59    return array_ops.ones(*shape_and_dtype(t))
60  else:
61    return array_ops.ones_like(t)
62
63
64def supports_default_grad(t):
65  """Whether tensor `t` supports creating a default gradient.
66
67  This function assumes that `t` is of a trainable type.
68
69  Args:
70    t: Tensor
71
72  Returns:
73    Bool
74  """
75  if t.dtype == dtypes.resource:
76    handle_data = resource_variable_ops.get_eager_safe_handle_data(t)
77    if (handle_data is None or not handle_data.is_set or
78        len(handle_data.shape_and_type) != 1):
79      return False
80  return True
81