xref: /aosp_15_r20/external/tensorflow/tensorflow/python/autograph/tests/call_to_builtin_function_test.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Simple call to a builtin function."""
16
17import unittest
18
19import tensorflow as tf
20
21from tensorflow.python.autograph.tests import reference_test_base
22
23
24# TODO(mdan): Add tests for all builtins.
25
26
27def dict_call(x):
28  return dict(foo=x)
29
30
31def dict_call_aliased(x):
32  def fake_dict(x):
33    return x
34
35  dict = fake_dict  # pylint:disable=redefined-builtin
36  return dict(x)
37
38
39def dict_call_dynamic(x):
40  def gen_dict():
41    return dict
42
43  d = gen_dict()
44  return d(foo=x)
45
46
47def len_call(x):
48  return len(x)
49
50
51def nested_call(x):
52  return list(range(len(x)))
53
54
55def nested_cast(x):
56  return float(int(x))
57
58
59def len_call_aliased(x):
60
61  def fake_len(x):
62    return x
63
64  len = fake_len  # pylint:disable=redefined-builtin
65  return len(x)
66
67
68def len_call_dynamic(x):
69
70  def gen_len():
71    return len
72
73  l = gen_len()
74  return l(x)
75
76
77def len_call_on_mock():
78  x = unittest.mock.MagicMock()
79  return len(x)
80
81
82class ReferenceTest(reference_test_base.TestCase):
83
84  def test_basic(self):
85    self.assertFunctionMatchesEager(dict_call, 1)
86    self.assertFunctionMatchesEager(len_call, [1, 2])
87    self.assertFunctionMatchesEager(dict_call_aliased, 1)
88    self.assertFunctionMatchesEager(len_call_aliased, [1, 2])
89    self.assertFunctionMatchesEager(dict_call_dynamic, 1)
90    self.assertFunctionMatchesEager(len_call_dynamic, [1, 2])
91    self.assertFunctionMatchesEager(nested_call, [])
92    self.assertFunctionMatchesEager(nested_call, [1, 2, 3])
93
94  def test_basic_tensor(self):
95    self.all_inputs_tensors = True
96    self.assertFunctionMatchesEager(dict_call, 1)
97    self.assertFunctionMatchesEager(len_call, [1, 2])
98    self.assertFunctionMatchesEager(dict_call_aliased, 1)
99    self.assertFunctionMatchesEager(len_call_aliased, [1, 2])
100    self.assertFunctionMatchesEager(dict_call_dynamic, 1)
101    self.assertFunctionMatchesEager(len_call_dynamic, [1, 2])
102    self.assertFunctionMatchesEager(nested_call, [])
103    self.assertFunctionMatchesEager(nested_call, [1, 2, 3])
104
105
106if __name__ == '__main__':
107  tf.test.main()
108