1# Copyright 2022 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Gradient support for Composite Tensors.""" 16 17import abc 18import sys 19 20from tensorflow.python.framework import composite_tensor 21from tensorflow.python.util import nest 22 23 24# pylint:disable=g-import-not-at-top 25if sys.version_info >= (3, 8): 26 from typing import Protocol 27 from typing import runtime_checkable 28else: 29 from typing_extensions import Protocol 30 from typing_extensions import runtime_checkable 31# pylint:enable=g-import-not-at-top 32 33 34# TODO(xjun): Add CompositeTensorGradient support for SparseTensor, 35# StructuredTensor, and MaskedTensor. 36class CompositeTensorGradient(object, metaclass=abc.ABCMeta): 37 """Class used to help compute gradients for CompositeTensors. 38 39 This abstract base class defines two methods: `get_gradient_components`, which 40 returns the components of a value that should be included in gradients; and 41 `replace_gradient_components`, which replaces the gradient components in a 42 value. These methods can be used to compute the gradient of a `y` with 43 respect to `x` (`grad(y, x)`) as follows: 44 45 * If `y` is a `CompositeTensor` with `CompositeTensorGradient` `cg` = 46 `y.__composite_gradient__`, then `grad(y, x)` = 47 `grad(cg.get_gradient_components(y), x)`. 48 49 * If `x` is a `CompositeTensor` with `CompositeTensorGradient` `cg` = 50 'x.__composite_gradient__', then `grad(y, x)` = 51 `cg.replace_gradient_components(x, grad(y, cg.get_gradient_components(x))`. 52 """ 53 54 @abc.abstractmethod 55 def get_gradient_components(self, value): 56 """Returns the components of `value` that should be included in gradients. 57 58 This method may not call TensorFlow ops, since any new ops added to the 59 graph would not be propertly tracked by the gradient mechanisms. 60 61 Args: 62 value: A `CompositeTensor` value. 63 64 Returns: 65 A nested structure of `Tensor` or `CompositeTensor`. 66 """ 67 raise NotImplementedError( 68 f"{type(self).__name__}.get_gradient_components()") 69 70 @abc.abstractmethod 71 def replace_gradient_components(self, value, component_grads): 72 """Replaces the gradient components in `value` with `component_grads`. 73 74 This method may not call TensorFlow ops, since any new ops added to the 75 graph would not be propertly tracked by the gradient mechanisms. 76 77 Args: 78 value: A value with its gradient components compatible with 79 `component_grads`. 80 component_grads: A nested structure of `Tensor` or `CompositeTensor` or 81 `None` (for unconnected gradients). 82 83 Returns: 84 A copy of `value`, where the components that should be included in 85 gradients have been replaced by `component_grads`; or `None` (if 86 `component_grads` includes `None`). 87 """ 88 raise NotImplementedError( 89 f"{type(self).__name__}.replace_gradient_components()") 90 91 92@runtime_checkable 93class CompositeTensorGradientProtocol(Protocol): 94 """Protocol for adding gradient support to CompositeTensors.""" 95 __composite_gradient__: CompositeTensorGradient 96 97 98class WithValuesCompositeTensorGradient(CompositeTensorGradient): 99 """CompositeTensorGradient based on `T.values` and `T.with_values`.""" 100 101 def get_gradient_components(self, value): 102 return value.values 103 104 def replace_gradient_components(self, value, component_grads): 105 return value.with_values(component_grads) 106 107 108def _get_tensors_for_gradient(x): 109 """Returns the Tensors in `x` that should be differentiated. 110 111 Args: 112 x: A `Tensor` or `CompositeTensor`. 113 114 Returns: 115 A `Tensor` or a nested structure of `Tensor`. 116 """ 117 if not isinstance(x, composite_tensor.CompositeTensor): 118 return x 119 120 if not isinstance(x, CompositeTensorGradientProtocol): 121 raise ValueError( 122 f"Type {type(x).__name__} is not supported as a gradient source or " 123 "gradient target.") 124 composite_gradient = x.__composite_gradient__ 125 return nest.map_structure(_get_tensors_for_gradient, 126 composite_gradient.get_gradient_components(x)) 127 128 129def _replace_tensors_for_gradient(x, grad): 130 """Replaces the tensors in `x` that should be differentiated with `grad`. 131 132 Args: 133 x: A `Tensor` or `CompositeTensor`. 134 grad: A nested structure of `Tensor`, with the same structure as the value 135 returned by `_get_tensors_for_gradient(x)`. 136 137 Returns: 138 A `Tensor` or `CompositeTensor`. 139 """ 140 if not isinstance(x, composite_tensor.CompositeTensor): 141 return grad 142 143 if not isinstance(x, CompositeTensorGradientProtocol): 144 raise ValueError( 145 f"Type {type(x).__name__} is not supported as a gradient source.") 146 147 composite_gradient = x.__composite_gradient__ 148 x_components = composite_gradient.get_gradient_components(x) 149 grad_components = nest.map_structure_up_to(x_components, 150 _replace_tensors_for_gradient, 151 x_components, grad) 152 if grad_components is None: 153 return None 154 return composite_gradient.replace_gradient_components(x, grad_components) 155 156 157def get_flat_tensors_for_gradients(xs): 158 """Returns a flat list of Tensors that should be differentiated for `xs`. 159 160 Args: 161 xs: A list of `Tensor`s or `CompositeTensor`s. 162 163 Returns: 164 A flat list of `Tensor`s constructed from `xs`, where `Tensor` values are 165 left as-is, and `CompositeTensor`s are replaced with 166 `_get_tensors_for_gradient(x)`. 167 """ 168 return nest.flatten([_get_tensors_for_gradient(x) for x in xs]) 169 170 171def replace_flat_tensors_for_gradients(xs, flat_grads): 172 """Replaces Tensors that should be differentiated in `xs` with `flat_grads`. 173 174 Args: 175 xs: A list of `Tensor`s or `CompositeTensor`s. 176 flat_grads: A list of `Tensor`. 177 178 Returns: 179 A list of `Tensor` or `CompositeTensor`. 180 """ 181 xs_structure = [_get_tensors_for_gradient(x) for x in xs] 182 grads = nest.pack_sequence_as(xs_structure, flat_grads) 183 return [_replace_tensors_for_gradient(x, grad) for x, grad in zip(xs, grads)] 184