xref: /aosp_15_r20/external/tensorflow/tensorflow/python/debug/cli/cli_shared_test.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Unit tests for the shared functions and classes for tfdbg CLI."""
16from collections import namedtuple
17
18from tensorflow.python.debug.cli import cli_shared
19from tensorflow.python.debug.cli import debugger_cli_common
20from tensorflow.python.framework import constant_op
21from tensorflow.python.framework import errors
22from tensorflow.python.framework import ops
23from tensorflow.python.framework import sparse_tensor
24from tensorflow.python.framework import test_util
25from tensorflow.python.ops import variables
26from tensorflow.python.platform import googletest
27
28
29class BytesToReadableStrTest(test_util.TensorFlowTestCase):
30
31  def testNoneSizeWorks(self):
32    self.assertEqual(str(None), cli_shared.bytes_to_readable_str(None))
33
34  def testSizesBelowOneKiloByteWorks(self):
35    self.assertEqual("0", cli_shared.bytes_to_readable_str(0))
36    self.assertEqual("500", cli_shared.bytes_to_readable_str(500))
37    self.assertEqual("1023", cli_shared.bytes_to_readable_str(1023))
38
39  def testSizesBetweenOneKiloByteandOneMegaByteWorks(self):
40    self.assertEqual("1.00k", cli_shared.bytes_to_readable_str(1024))
41    self.assertEqual("2.40k", cli_shared.bytes_to_readable_str(int(1024 * 2.4)))
42    self.assertEqual("1023.00k", cli_shared.bytes_to_readable_str(1024 * 1023))
43
44  def testSizesBetweenOneMegaByteandOneGigaByteWorks(self):
45    self.assertEqual("1.00M", cli_shared.bytes_to_readable_str(1024**2))
46    self.assertEqual("2.40M",
47                     cli_shared.bytes_to_readable_str(int(1024**2 * 2.4)))
48    self.assertEqual("1023.00M",
49                     cli_shared.bytes_to_readable_str(1024**2 * 1023))
50
51  def testSizeAboveOneGigaByteWorks(self):
52    self.assertEqual("1.00G", cli_shared.bytes_to_readable_str(1024**3))
53    self.assertEqual("2000.00G",
54                     cli_shared.bytes_to_readable_str(1024**3 * 2000))
55
56  def testReadableStrIncludesBAtTheEndOnRequest(self):
57    self.assertEqual("0B", cli_shared.bytes_to_readable_str(0, include_b=True))
58    self.assertEqual(
59        "1.00kB", cli_shared.bytes_to_readable_str(
60            1024, include_b=True))
61    self.assertEqual(
62        "1.00MB", cli_shared.bytes_to_readable_str(
63            1024**2, include_b=True))
64    self.assertEqual(
65        "1.00GB", cli_shared.bytes_to_readable_str(
66            1024**3, include_b=True))
67
68
69class TimeToReadableStrTest(test_util.TensorFlowTestCase):
70
71  def testNoneTimeWorks(self):
72    self.assertEqual("0", cli_shared.time_to_readable_str(None))
73
74  def testMicrosecondsTime(self):
75    self.assertEqual("40us", cli_shared.time_to_readable_str(40))
76
77  def testMillisecondTime(self):
78    self.assertEqual("40ms", cli_shared.time_to_readable_str(40e3))
79
80  def testSecondTime(self):
81    self.assertEqual("40s", cli_shared.time_to_readable_str(40e6))
82
83  def testForceTimeUnit(self):
84    self.assertEqual("40s",
85                     cli_shared.time_to_readable_str(
86                         40e6, force_time_unit=cli_shared.TIME_UNIT_S))
87    self.assertEqual("40000ms",
88                     cli_shared.time_to_readable_str(
89                         40e6, force_time_unit=cli_shared.TIME_UNIT_MS))
90    self.assertEqual("40000000us",
91                     cli_shared.time_to_readable_str(
92                         40e6, force_time_unit=cli_shared.TIME_UNIT_US))
93    self.assertEqual("4e-05s",
94                     cli_shared.time_to_readable_str(
95                         40, force_time_unit=cli_shared.TIME_UNIT_S))
96    self.assertEqual("0",
97                     cli_shared.time_to_readable_str(
98                         0, force_time_unit=cli_shared.TIME_UNIT_S))
99
100    with self.assertRaisesRegex(ValueError, r"Invalid time unit: ks"):
101      cli_shared.time_to_readable_str(100, force_time_unit="ks")
102
103
104@test_util.run_v1_only("tfdbg CLI is for tf.Session only")
105class GetRunStartIntroAndDescriptionTest(test_util.TensorFlowTestCase):
106
107  def setUp(self):
108    self.const_a = constant_op.constant(11.0, name="a")
109    self.const_b = constant_op.constant(22.0, name="b")
110    self.const_c = constant_op.constant(33.0, name="c")
111
112    self.sparse_d = sparse_tensor.SparseTensor(
113        indices=[[0, 0], [1, 1]], values=[1.0, 2.0], dense_shape=[3, 3])
114
115  def tearDown(self):
116    ops.reset_default_graph()
117
118  def testSingleFetchNoFeeds(self):
119    run_start_intro = cli_shared.get_run_start_intro(12, self.const_a, None, {})
120
121    # Verify line about run() call number.
122    self.assertTrue(run_start_intro.lines[1].endswith("run() call #12:"))
123
124    # Verify line about fetch.
125    const_a_name_line = run_start_intro.lines[4]
126    self.assertEqual(self.const_a.name, const_a_name_line.strip())
127
128    # Verify line about feeds.
129    feeds_line = run_start_intro.lines[7]
130    self.assertEqual("(Empty)", feeds_line.strip())
131
132    # Verify lines about possible commands and their font attributes.
133    self.assertEqual("run:", run_start_intro.lines[11][2:])
134    annot = run_start_intro.font_attr_segs[11][0]
135    self.assertEqual(2, annot[0])
136    self.assertEqual(5, annot[1])
137    self.assertEqual("run", annot[2][0].content)
138    self.assertEqual("bold", annot[2][1])
139    annot = run_start_intro.font_attr_segs[13][0]
140    self.assertEqual(2, annot[0])
141    self.assertEqual(8, annot[1])
142    self.assertEqual("run -n", annot[2][0].content)
143    self.assertEqual("bold", annot[2][1])
144    self.assertEqual("run -t <T>:", run_start_intro.lines[15][2:])
145    self.assertEqual([(2, 12, "bold")], run_start_intro.font_attr_segs[15])
146    self.assertEqual("run -f <filter_name>:", run_start_intro.lines[17][2:])
147    self.assertEqual([(2, 22, "bold")], run_start_intro.font_attr_segs[17])
148
149    # Verify short description.
150    description = cli_shared.get_run_short_description(12, self.const_a, None)
151    self.assertEqual("run #12: 1 fetch (a:0); 0 feeds", description)
152
153    # Verify the main menu associated with the run_start_intro.
154    self.assertIn(debugger_cli_common.MAIN_MENU_KEY,
155                  run_start_intro.annotations)
156    menu = run_start_intro.annotations[debugger_cli_common.MAIN_MENU_KEY]
157    self.assertEqual("run", menu.caption_to_item("run").content)
158    self.assertEqual("exit", menu.caption_to_item("exit").content)
159
160  def testSparseTensorAsFeedShouldHandleNoNameAttribute(self):
161    sparse_feed_val = ([[0, 0], [1, 1]], [10.0, 20.0])
162    run_start_intro = cli_shared.get_run_start_intro(
163        1, self.sparse_d, {self.sparse_d: sparse_feed_val}, {})
164    self.assertEqual(str(self.sparse_d), run_start_intro.lines[7].strip())
165
166    short_description = cli_shared.get_run_short_description(
167        1, self.sparse_d, {self.sparse_d: sparse_feed_val})
168    self.assertEqual(
169        "run #1: 1 fetch; 1 feed (%s)" % self.sparse_d, short_description)
170
171  def testSparseTensorAsFetchShouldHandleNoNameAttribute(self):
172    run_start_intro = cli_shared.get_run_start_intro(1, self.sparse_d, None, {})
173    self.assertEqual(str(self.sparse_d), run_start_intro.lines[4].strip())
174
175  def testTwoFetchesListNoFeeds(self):
176    fetches = [self.const_a, self.const_b]
177    run_start_intro = cli_shared.get_run_start_intro(1, fetches, None, {})
178
179    const_a_name_line = run_start_intro.lines[4]
180    const_b_name_line = run_start_intro.lines[5]
181    self.assertEqual(self.const_a.name, const_a_name_line.strip())
182    self.assertEqual(self.const_b.name, const_b_name_line.strip())
183
184    feeds_line = run_start_intro.lines[8]
185    self.assertEqual("(Empty)", feeds_line.strip())
186
187    # Verify short description.
188    description = cli_shared.get_run_short_description(1, fetches, None)
189    self.assertEqual("run #1: 2 fetches; 0 feeds", description)
190
191  def testNestedListAsFetches(self):
192    fetches = [self.const_c, [self.const_a, self.const_b]]
193    run_start_intro = cli_shared.get_run_start_intro(1, fetches, None, {})
194
195    # Verify lines about the fetches.
196    self.assertEqual(self.const_c.name, run_start_intro.lines[4].strip())
197    self.assertEqual(self.const_a.name, run_start_intro.lines[5].strip())
198    self.assertEqual(self.const_b.name, run_start_intro.lines[6].strip())
199
200    # Verify short description.
201    description = cli_shared.get_run_short_description(1, fetches, None)
202    self.assertEqual("run #1: 3 fetches; 0 feeds", description)
203
204  def testNestedDictAsFetches(self):
205    fetches = {"c": self.const_c, "ab": {"a": self.const_a, "b": self.const_b}}
206    run_start_intro = cli_shared.get_run_start_intro(1, fetches, None, {})
207
208    # Verify lines about the fetches. The ordering of the dict keys is
209    # indeterminate.
210    fetch_names = set()
211    fetch_names.add(run_start_intro.lines[4].strip())
212    fetch_names.add(run_start_intro.lines[5].strip())
213    fetch_names.add(run_start_intro.lines[6].strip())
214
215    self.assertEqual({"a:0", "b:0", "c:0"}, fetch_names)
216
217    # Verify short description.
218    description = cli_shared.get_run_short_description(1, fetches, None)
219    self.assertEqual("run #1: 3 fetches; 0 feeds", description)
220
221  def testTwoFetchesAsTupleNoFeeds(self):
222    fetches = (self.const_a, self.const_b)
223    run_start_intro = cli_shared.get_run_start_intro(1, fetches, None, {})
224
225    const_a_name_line = run_start_intro.lines[4]
226    const_b_name_line = run_start_intro.lines[5]
227    self.assertEqual(self.const_a.name, const_a_name_line.strip())
228    self.assertEqual(self.const_b.name, const_b_name_line.strip())
229
230    feeds_line = run_start_intro.lines[8]
231    self.assertEqual("(Empty)", feeds_line.strip())
232
233    # Verify short description.
234    description = cli_shared.get_run_short_description(1, fetches, None)
235    self.assertEqual("run #1: 2 fetches; 0 feeds", description)
236
237  def testTwoFetchesAsNamedTupleNoFeeds(self):
238    fetches_namedtuple = namedtuple("fetches", "x y")
239    fetches = fetches_namedtuple(self.const_b, self.const_c)
240    run_start_intro = cli_shared.get_run_start_intro(1, fetches, None, {})
241
242    const_b_name_line = run_start_intro.lines[4]
243    const_c_name_line = run_start_intro.lines[5]
244    self.assertEqual(self.const_b.name, const_b_name_line.strip())
245    self.assertEqual(self.const_c.name, const_c_name_line.strip())
246
247    feeds_line = run_start_intro.lines[8]
248    self.assertEqual("(Empty)", feeds_line.strip())
249
250    # Verify short description.
251    description = cli_shared.get_run_short_description(1, fetches, None)
252    self.assertEqual("run #1: 2 fetches; 0 feeds", description)
253
254  def testWithFeedDict(self):
255    feed_dict = {
256        self.const_a: 10.0,
257        self.const_b: 20.0,
258    }
259
260    run_start_intro = cli_shared.get_run_start_intro(1, self.const_c, feed_dict,
261                                                     {})
262
263    const_c_name_line = run_start_intro.lines[4]
264    self.assertEqual(self.const_c.name, const_c_name_line.strip())
265
266    # Verify lines about the feed dict.
267    feed_a_line = run_start_intro.lines[7]
268    feed_b_line = run_start_intro.lines[8]
269    self.assertEqual(self.const_a.name, feed_a_line.strip())
270    self.assertEqual(self.const_b.name, feed_b_line.strip())
271
272    # Verify short description.
273    description = cli_shared.get_run_short_description(1, self.const_c,
274                                                       feed_dict)
275    self.assertEqual("run #1: 1 fetch (c:0); 2 feeds", description)
276
277  def testTensorFilters(self):
278    feed_dict = {self.const_a: 10.0}
279    tensor_filters = {
280        "filter_a": lambda x: True,
281        "filter_b": lambda x: False,
282    }
283
284    run_start_intro = cli_shared.get_run_start_intro(1, self.const_c, feed_dict,
285                                                     tensor_filters)
286
287    # Verify the listed names of the tensor filters.
288    filter_names = set()
289    filter_names.add(run_start_intro.lines[20].split(" ")[-1])
290    filter_names.add(run_start_intro.lines[21].split(" ")[-1])
291
292    self.assertEqual({"filter_a", "filter_b"}, filter_names)
293
294    # Verify short description.
295    description = cli_shared.get_run_short_description(1, self.const_c,
296                                                       feed_dict)
297    self.assertEqual("run #1: 1 fetch (c:0); 1 feed (a:0)", description)
298
299    # Verify the command links for the two filters.
300    command_set = set()
301    annot = run_start_intro.font_attr_segs[20][0]
302    command_set.add(annot[2].content)
303    annot = run_start_intro.font_attr_segs[21][0]
304    command_set.add(annot[2].content)
305    self.assertEqual({"run -f filter_a", "run -f filter_b"}, command_set)
306
307  def testGetRunShortDescriptionWorksForTensorFeedKey(self):
308    short_description = cli_shared.get_run_short_description(
309        1, self.const_a, {self.const_a: 42.0})
310    self.assertEqual("run #1: 1 fetch (a:0); 1 feed (a:0)", short_description)
311
312  def testGetRunShortDescriptionWorksForUnicodeFeedKey(self):
313    short_description = cli_shared.get_run_short_description(
314        1, self.const_a, {u"foo": 42.0})
315    self.assertEqual("run #1: 1 fetch (a:0); 1 feed (foo)", short_description)
316
317
318@test_util.run_v1_only("tfdbg CLI is for tf.Session only")
319class GetErrorIntroTest(test_util.TensorFlowTestCase):
320
321  def setUp(self):
322    self.var_a = variables.Variable(42.0, name="a")
323
324  def tearDown(self):
325    ops.reset_default_graph()
326
327  def testShapeError(self):
328    tf_error = errors.OpError(None, self.var_a.initializer, "foo description",
329                              None)
330
331    error_intro = cli_shared.get_error_intro(tf_error)
332
333    self.assertEqual("!!! An error occurred during the run !!!",
334                     error_intro.lines[1])
335    self.assertEqual([(0, len(error_intro.lines[1]), "blink")],
336                     error_intro.font_attr_segs[1])
337
338    self.assertEqual(2, error_intro.lines[4].index("ni -a -d -t a/Assign"))
339    self.assertEqual(2, error_intro.font_attr_segs[4][0][0])
340    self.assertEqual(22, error_intro.font_attr_segs[4][0][1])
341    self.assertEqual("ni -a -d -t a/Assign",
342                     error_intro.font_attr_segs[4][0][2][0].content)
343    self.assertEqual("bold", error_intro.font_attr_segs[4][0][2][1])
344
345    self.assertEqual(2, error_intro.lines[6].index("li -r a/Assign"))
346    self.assertEqual(2, error_intro.font_attr_segs[6][0][0])
347    self.assertEqual(16, error_intro.font_attr_segs[6][0][1])
348    self.assertEqual("li -r a/Assign",
349                     error_intro.font_attr_segs[6][0][2][0].content)
350    self.assertEqual("bold", error_intro.font_attr_segs[6][0][2][1])
351
352    self.assertEqual(2, error_intro.lines[8].index("lt"))
353    self.assertEqual(2, error_intro.font_attr_segs[8][0][0])
354    self.assertEqual(4, error_intro.font_attr_segs[8][0][1])
355    self.assertEqual("lt", error_intro.font_attr_segs[8][0][2][0].content)
356    self.assertEqual("bold", error_intro.font_attr_segs[8][0][2][1])
357
358    self.assertStartsWith(error_intro.lines[11], "Op name:")
359    self.assertTrue(error_intro.lines[11].endswith("a/Assign"))
360
361    self.assertStartsWith(error_intro.lines[12], "Error type:")
362    self.assertTrue(error_intro.lines[12].endswith(str(type(tf_error))))
363
364    self.assertEqual("Details:", error_intro.lines[14])
365    self.assertStartsWith(error_intro.lines[15], "foo description")
366
367  def testGetErrorIntroForNoOpName(self):
368    tf_error = errors.OpError(None, None, "Fake OpError", -1)
369    error_intro = cli_shared.get_error_intro(tf_error)
370    self.assertIn("Cannot determine the name of the op", error_intro.lines[3])
371
372
373if __name__ == "__main__":
374  googletest.main()
375