xref: /aosp_15_r20/external/tensorflow/tensorflow/python/autograph/core/converter_testing.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Base class for tests in this module."""
16
17import contextlib
18import imp
19import inspect
20import io
21import sys
22
23from tensorflow.python.autograph.core import config
24from tensorflow.python.autograph.core import converter
25from tensorflow.python.autograph.impl import api
26from tensorflow.python.framework import ops
27from tensorflow.python.platform import test
28
29
30def allowlist(f):
31  """Helper that marks a callable as whtelitisted."""
32  if 'allowlisted_module_for_testing' not in sys.modules:
33    allowlisted_mod = imp.new_module('allowlisted_module_for_testing')
34    sys.modules['allowlisted_module_for_testing'] = allowlisted_mod
35    config.CONVERSION_RULES = (
36        (config.DoNotConvert('allowlisted_module_for_testing'),) +
37        config.CONVERSION_RULES)
38
39  f.__module__ = 'allowlisted_module_for_testing'
40
41
42def is_inside_generated_code():
43  """Tests whether the caller is generated code. Implementation-specific."""
44  frame = inspect.currentframe()
45  try:
46    frame = frame.f_back
47
48    internal_stack_functions = ('converted_call', '_call_unconverted')
49    # Walk up the stack until we're out of the internal functions.
50    while (frame is not None and
51           frame.f_code.co_name in internal_stack_functions):
52      frame = frame.f_back
53    if frame is None:
54      return False
55
56    return 'ag__' in frame.f_locals
57  finally:
58    del frame
59
60
61class TestingTranspiler(api.PyToTF):
62  """Testing version that only applies given transformations."""
63
64  def __init__(self, converters, ag_overrides):
65    super(TestingTranspiler, self).__init__()
66    if isinstance(converters, (list, tuple)):
67      self._converters = converters
68    else:
69      self._converters = (converters,)
70    self.transformed_ast = None
71    self._ag_overrides = ag_overrides
72
73  def get_extra_locals(self):
74    retval = super(TestingTranspiler, self).get_extra_locals()
75    if self._ag_overrides:
76      modified_ag = imp.new_module('fake_autograph')
77      modified_ag.__dict__.update(retval['ag__'].__dict__)
78      modified_ag.__dict__.update(self._ag_overrides)
79      retval['ag__'] = modified_ag
80    return retval
81
82  def transform_ast(self, node, ctx):
83    node = self.initial_analysis(node, ctx)
84
85    for c in self._converters:
86      node = c.transform(node, ctx)
87
88    self.transformed_ast = node
89    self.transform_ctx = ctx
90    return node
91
92
93class TestCase(test.TestCase):
94  """Base class for unit tests in this module. Contains relevant utilities."""
95
96  def setUp(self):
97    # AutoGraph tests must run in graph mode to properly test control flow.
98    self.graph = ops.Graph().as_default()
99    self.graph.__enter__()
100
101  def tearDown(self):
102    self.graph.__exit__(None, None, None)
103
104  @contextlib.contextmanager
105  def assertPrints(self, expected_result):
106    try:
107      out_capturer = io.StringIO()
108      sys.stdout = out_capturer
109      yield
110      self.assertEqual(out_capturer.getvalue(), expected_result)
111    finally:
112      sys.stdout = sys.__stdout__
113
114  def transform(
115      self, f, converter_module, include_ast=False, ag_overrides=None):
116    program_ctx = converter.ProgramContext(
117        options=converter.ConversionOptions(recursive=True),
118        autograph_module=api)
119
120    tr = TestingTranspiler(converter_module, ag_overrides)
121    transformed, _, _ = tr.transform_function(f, program_ctx)
122
123    if include_ast:
124      return transformed, tr.transformed_ast, tr.transform_ctx
125
126    return transformed
127