1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Support for wrapping converted functions bodies with auxiliary logic.""" 16 17from tensorflow.python.autograph.core import ag_ctx 18from tensorflow.python.autograph.core import converter 19from tensorflow.python.autograph.operators import variables 20from tensorflow.python.framework import auto_control_deps 21from tensorflow.python.framework import ops 22from tensorflow.python.framework import tensor_util 23from tensorflow.python.util import nest 24 25 26# TODO(mdan): Move this into operators - it represents a function definition. 27 28 29class FunctionScope(object): 30 """Context manager that wraps the body of a converted function. 31 32 This context manager handles various operations related to the scope of a 33 function: 34 * optional TF name scopes - these name scopes match the name of the 35 function, for easy visualization in tensorBoard; 36 * optional automatic control dependencies - this adds the same mechanism 37 for control dependencies that is used by `@tf.function`; it can be 38 optionally enabled when using `tf.autograph.to_graph`; 39 * tracking of autograph conversion state (whether it's enabled by the user, 40 conversion options; 41 """ 42 43 def __init__(self, function_name, scope_name, options): 44 self.name = scope_name 45 self.options = options 46 47 if options.user_requested: 48 self.autograph_ctx = ag_ctx.ControlStatusCtx(ag_ctx.Status.ENABLED, 49 options) 50 self.callopts = options.call_options() 51 52 use_name_scope = options.uses(converter.Feature.NAME_SCOPES) 53 self.use_name_scope = use_name_scope 54 if use_name_scope: 55 self.name_scope = ops.name_scope(self._sanitize(function_name)) 56 57 use_auto_deps = self.options.uses(converter.Feature.AUTO_CONTROL_DEPS) 58 self.use_auto_deps = use_auto_deps 59 if use_auto_deps: 60 self.autodeps_scope = auto_control_deps.AutomaticControlDependencies() 61 self._return_value_marked = False 62 63 def _sanitize(self, name): 64 """See https://www.tensorflow.org/api_docs/python/tf/Graph#name_scope.""" 65 # TensorFlow doesn't like leading underscores at the top level. 66 if name and name.startswith('_'): 67 name = 'fn' + name 68 return name 69 70 def __enter__(self): 71 if self.options.user_requested: 72 self.autograph_ctx.__enter__() 73 if self.use_name_scope: 74 self.name_scope.__enter__() 75 if self.use_auto_deps: 76 self.autodeps_scope.__enter__() 77 return self 78 79 def __exit__(self, exc_type, exc_val, exc_tb): 80 if self.options.user_requested: 81 self.autograph_ctx.__exit__(exc_type, exc_val, exc_tb) 82 if self.use_name_scope: 83 self.name_scope.__exit__(exc_type, exc_val, exc_tb) 84 if self.use_auto_deps: 85 self.autodeps_scope.__exit__(exc_type, exc_val, exc_tb) 86 87 def ret(self, value, did_return): 88 """Marks a value as returned from the function guarded by the scope.""" 89 del did_return 90 91 if isinstance(value, variables.UndefinedReturnValue): 92 return None 93 94 if self.use_auto_deps: 95 self._return_value_marked = True 96 if value is None: 97 # We don't create dummy returns, to preserve Python semantics. The user 98 # is responsible for adding a return value to the top-level function. 99 return None 100 101 def _mark_return_if_tensor(t): 102 if tensor_util.is_tf_type(t): 103 return self.autodeps_scope.mark_as_return(t) 104 return t 105 106 value = nest.map_structure(_mark_return_if_tensor, value) 107 return value 108 109 110def with_function_scope(thunk, scope_name, options): 111 """Inline version of the FunctionScope context manager.""" 112 with FunctionScope('lambda_', scope_name, options) as scope: 113 return thunk(scope) 114