xref: /aosp_15_r20/external/tensorflow/tensorflow/python/autograph/core/function_wrappers.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Support for wrapping converted functions bodies with auxiliary logic."""
16
17from tensorflow.python.autograph.core import ag_ctx
18from tensorflow.python.autograph.core import converter
19from tensorflow.python.autograph.operators import variables
20from tensorflow.python.framework import auto_control_deps
21from tensorflow.python.framework import ops
22from tensorflow.python.framework import tensor_util
23from tensorflow.python.util import nest
24
25
26# TODO(mdan): Move this into operators - it represents a function definition.
27
28
29class FunctionScope(object):
30  """Context manager that wraps the body of a converted function.
31
32  This context manager handles various operations related to the scope of a
33  function:
34    * optional TF name scopes - these name scopes match the name of the
35        function, for easy visualization in tensorBoard;
36    * optional automatic control dependencies - this adds the same mechanism
37        for control dependencies that is used by `@tf.function`; it can be
38        optionally enabled when using `tf.autograph.to_graph`;
39    * tracking of autograph conversion state (whether it's enabled by the user,
40        conversion options;
41  """
42
43  def __init__(self, function_name, scope_name, options):
44    self.name = scope_name
45    self.options = options
46
47    if options.user_requested:
48      self.autograph_ctx = ag_ctx.ControlStatusCtx(ag_ctx.Status.ENABLED,
49                                                   options)
50    self.callopts = options.call_options()
51
52    use_name_scope = options.uses(converter.Feature.NAME_SCOPES)
53    self.use_name_scope = use_name_scope
54    if use_name_scope:
55      self.name_scope = ops.name_scope(self._sanitize(function_name))
56
57    use_auto_deps = self.options.uses(converter.Feature.AUTO_CONTROL_DEPS)
58    self.use_auto_deps = use_auto_deps
59    if use_auto_deps:
60      self.autodeps_scope = auto_control_deps.AutomaticControlDependencies()
61      self._return_value_marked = False
62
63  def _sanitize(self, name):
64    """See https://www.tensorflow.org/api_docs/python/tf/Graph#name_scope."""
65    # TensorFlow doesn't like leading underscores at the top level.
66    if name and name.startswith('_'):
67      name = 'fn' + name
68    return name
69
70  def __enter__(self):
71    if self.options.user_requested:
72      self.autograph_ctx.__enter__()
73    if self.use_name_scope:
74      self.name_scope.__enter__()
75    if self.use_auto_deps:
76      self.autodeps_scope.__enter__()
77    return self
78
79  def __exit__(self, exc_type, exc_val, exc_tb):
80    if self.options.user_requested:
81      self.autograph_ctx.__exit__(exc_type, exc_val, exc_tb)
82    if self.use_name_scope:
83      self.name_scope.__exit__(exc_type, exc_val, exc_tb)
84    if self.use_auto_deps:
85      self.autodeps_scope.__exit__(exc_type, exc_val, exc_tb)
86
87  def ret(self, value, did_return):
88    """Marks a value as returned from the function guarded by the scope."""
89    del did_return
90
91    if isinstance(value, variables.UndefinedReturnValue):
92      return None
93
94    if self.use_auto_deps:
95      self._return_value_marked = True
96      if value is None:
97        # We don't create dummy returns, to preserve Python semantics. The user
98        # is responsible for adding a return value to the top-level function.
99        return None
100
101      def _mark_return_if_tensor(t):
102        if tensor_util.is_tf_type(t):
103          return self.autodeps_scope.mark_as_return(t)
104        return t
105
106      value = nest.map_structure(_mark_return_if_tensor, value)
107    return value
108
109
110def with_function_scope(thunk, scope_name, options):
111  """Inline version of the FunctionScope context manager."""
112  with FunctionScope('lambda_', scope_name, options) as scope:
113    return thunk(scope)
114