xref: /aosp_15_r20/external/tensorflow/tensorflow/python/framework/composite_tensor_gradient.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Gradient support for Composite Tensors."""
16
17import abc
18import sys
19
20from tensorflow.python.framework import composite_tensor
21from tensorflow.python.util import nest
22
23
24# pylint:disable=g-import-not-at-top
25if sys.version_info >= (3, 8):
26  from typing import Protocol
27  from typing import runtime_checkable
28else:
29  from typing_extensions import Protocol
30  from typing_extensions import runtime_checkable
31# pylint:enable=g-import-not-at-top
32
33
34# TODO(xjun): Add CompositeTensorGradient support for SparseTensor,
35# StructuredTensor, and MaskedTensor.
36class CompositeTensorGradient(object, metaclass=abc.ABCMeta):
37  """Class used to help compute gradients for CompositeTensors.
38
39  This abstract base class defines two methods: `get_gradient_components`, which
40  returns the components of a value that should be included in gradients; and
41  `replace_gradient_components`, which replaces the gradient components in a
42  value.  These methods can be used to compute the gradient of a `y` with
43  respect to `x` (`grad(y, x)`) as follows:
44
45  * If `y` is a `CompositeTensor` with `CompositeTensorGradient` `cg` =
46    `y.__composite_gradient__`, then `grad(y, x)` =
47    `grad(cg.get_gradient_components(y), x)`.
48
49  * If `x` is a `CompositeTensor` with `CompositeTensorGradient` `cg` =
50    'x.__composite_gradient__', then `grad(y, x)` =
51    `cg.replace_gradient_components(x, grad(y, cg.get_gradient_components(x))`.
52  """
53
54  @abc.abstractmethod
55  def get_gradient_components(self, value):
56    """Returns the components of `value` that should be included in gradients.
57
58    This method may not call TensorFlow ops, since any new ops added to the
59    graph would not be propertly tracked by the gradient mechanisms.
60
61    Args:
62      value: A `CompositeTensor` value.
63
64    Returns:
65      A nested structure of `Tensor` or `CompositeTensor`.
66    """
67    raise NotImplementedError(
68        f"{type(self).__name__}.get_gradient_components()")
69
70  @abc.abstractmethod
71  def replace_gradient_components(self, value, component_grads):
72    """Replaces the gradient components in `value` with `component_grads`.
73
74    This method may not call TensorFlow ops, since any new ops added to the
75    graph would not be propertly tracked by the gradient mechanisms.
76
77    Args:
78      value: A value with its gradient components compatible with
79        `component_grads`.
80      component_grads: A nested structure of `Tensor` or `CompositeTensor` or
81        `None` (for unconnected gradients).
82
83    Returns:
84      A copy of `value`, where the components that should be included in
85      gradients have been replaced by `component_grads`; or `None` (if
86      `component_grads` includes `None`).
87    """
88    raise NotImplementedError(
89        f"{type(self).__name__}.replace_gradient_components()")
90
91
92@runtime_checkable
93class CompositeTensorGradientProtocol(Protocol):
94  """Protocol for adding gradient support to CompositeTensors."""
95  __composite_gradient__: CompositeTensorGradient
96
97
98class WithValuesCompositeTensorGradient(CompositeTensorGradient):
99  """CompositeTensorGradient based on `T.values` and `T.with_values`."""
100
101  def get_gradient_components(self, value):
102    return value.values
103
104  def replace_gradient_components(self, value, component_grads):
105    return value.with_values(component_grads)
106
107
108def _get_tensors_for_gradient(x):
109  """Returns the Tensors in `x` that should be differentiated.
110
111  Args:
112    x: A `Tensor` or `CompositeTensor`.
113
114  Returns:
115    A `Tensor` or a nested structure of `Tensor`.
116  """
117  if not isinstance(x, composite_tensor.CompositeTensor):
118    return x
119
120  if not isinstance(x, CompositeTensorGradientProtocol):
121    raise ValueError(
122        f"Type {type(x).__name__} is not supported as a gradient source or "
123        "gradient target.")
124  composite_gradient = x.__composite_gradient__
125  return nest.map_structure(_get_tensors_for_gradient,
126                            composite_gradient.get_gradient_components(x))
127
128
129def _replace_tensors_for_gradient(x, grad):
130  """Replaces the tensors in `x` that should be differentiated with `grad`.
131
132  Args:
133    x: A `Tensor` or `CompositeTensor`.
134    grad: A nested structure of `Tensor`, with the same structure as the value
135      returned by `_get_tensors_for_gradient(x)`.
136
137  Returns:
138    A `Tensor` or `CompositeTensor`.
139  """
140  if not isinstance(x, composite_tensor.CompositeTensor):
141    return grad
142
143  if not isinstance(x, CompositeTensorGradientProtocol):
144    raise ValueError(
145        f"Type {type(x).__name__} is not supported as a gradient source.")
146
147  composite_gradient = x.__composite_gradient__
148  x_components = composite_gradient.get_gradient_components(x)
149  grad_components = nest.map_structure_up_to(x_components,
150                                             _replace_tensors_for_gradient,
151                                             x_components, grad)
152  if grad_components is None:
153    return None
154  return composite_gradient.replace_gradient_components(x, grad_components)
155
156
157def get_flat_tensors_for_gradients(xs):
158  """Returns a flat list of Tensors that should be differentiated for `xs`.
159
160  Args:
161    xs: A list of `Tensor`s or `CompositeTensor`s.
162
163  Returns:
164    A flat list of `Tensor`s constructed from `xs`, where `Tensor` values are
165    left as-is, and `CompositeTensor`s are replaced with
166    `_get_tensors_for_gradient(x)`.
167  """
168  return nest.flatten([_get_tensors_for_gradient(x) for x in xs])
169
170
171def replace_flat_tensors_for_gradients(xs, flat_grads):
172  """Replaces Tensors that should be differentiated in `xs` with `flat_grads`.
173
174  Args:
175    xs: A list of `Tensor`s or `CompositeTensor`s.
176    flat_grads: A list of `Tensor`.
177
178  Returns:
179    A list of `Tensor` or `CompositeTensor`.
180  """
181  xs_structure = [_get_tensors_for_gradient(x) for x in xs]
182  grads = nest.pack_sequence_as(xs_structure, flat_grads)
183  return [_replace_tensors_for_gradient(x, grad) for x, grad in zip(xs, grads)]
184