1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Unit tests for the shared functions and classes for tfdbg CLI.""" 16from collections import namedtuple 17 18from tensorflow.python.debug.cli import cli_shared 19from tensorflow.python.debug.cli import debugger_cli_common 20from tensorflow.python.framework import constant_op 21from tensorflow.python.framework import errors 22from tensorflow.python.framework import ops 23from tensorflow.python.framework import sparse_tensor 24from tensorflow.python.framework import test_util 25from tensorflow.python.ops import variables 26from tensorflow.python.platform import googletest 27 28 29class BytesToReadableStrTest(test_util.TensorFlowTestCase): 30 31 def testNoneSizeWorks(self): 32 self.assertEqual(str(None), cli_shared.bytes_to_readable_str(None)) 33 34 def testSizesBelowOneKiloByteWorks(self): 35 self.assertEqual("0", cli_shared.bytes_to_readable_str(0)) 36 self.assertEqual("500", cli_shared.bytes_to_readable_str(500)) 37 self.assertEqual("1023", cli_shared.bytes_to_readable_str(1023)) 38 39 def testSizesBetweenOneKiloByteandOneMegaByteWorks(self): 40 self.assertEqual("1.00k", cli_shared.bytes_to_readable_str(1024)) 41 self.assertEqual("2.40k", cli_shared.bytes_to_readable_str(int(1024 * 2.4))) 42 self.assertEqual("1023.00k", cli_shared.bytes_to_readable_str(1024 * 1023)) 43 44 def testSizesBetweenOneMegaByteandOneGigaByteWorks(self): 45 self.assertEqual("1.00M", cli_shared.bytes_to_readable_str(1024**2)) 46 self.assertEqual("2.40M", 47 cli_shared.bytes_to_readable_str(int(1024**2 * 2.4))) 48 self.assertEqual("1023.00M", 49 cli_shared.bytes_to_readable_str(1024**2 * 1023)) 50 51 def testSizeAboveOneGigaByteWorks(self): 52 self.assertEqual("1.00G", cli_shared.bytes_to_readable_str(1024**3)) 53 self.assertEqual("2000.00G", 54 cli_shared.bytes_to_readable_str(1024**3 * 2000)) 55 56 def testReadableStrIncludesBAtTheEndOnRequest(self): 57 self.assertEqual("0B", cli_shared.bytes_to_readable_str(0, include_b=True)) 58 self.assertEqual( 59 "1.00kB", cli_shared.bytes_to_readable_str( 60 1024, include_b=True)) 61 self.assertEqual( 62 "1.00MB", cli_shared.bytes_to_readable_str( 63 1024**2, include_b=True)) 64 self.assertEqual( 65 "1.00GB", cli_shared.bytes_to_readable_str( 66 1024**3, include_b=True)) 67 68 69class TimeToReadableStrTest(test_util.TensorFlowTestCase): 70 71 def testNoneTimeWorks(self): 72 self.assertEqual("0", cli_shared.time_to_readable_str(None)) 73 74 def testMicrosecondsTime(self): 75 self.assertEqual("40us", cli_shared.time_to_readable_str(40)) 76 77 def testMillisecondTime(self): 78 self.assertEqual("40ms", cli_shared.time_to_readable_str(40e3)) 79 80 def testSecondTime(self): 81 self.assertEqual("40s", cli_shared.time_to_readable_str(40e6)) 82 83 def testForceTimeUnit(self): 84 self.assertEqual("40s", 85 cli_shared.time_to_readable_str( 86 40e6, force_time_unit=cli_shared.TIME_UNIT_S)) 87 self.assertEqual("40000ms", 88 cli_shared.time_to_readable_str( 89 40e6, force_time_unit=cli_shared.TIME_UNIT_MS)) 90 self.assertEqual("40000000us", 91 cli_shared.time_to_readable_str( 92 40e6, force_time_unit=cli_shared.TIME_UNIT_US)) 93 self.assertEqual("4e-05s", 94 cli_shared.time_to_readable_str( 95 40, force_time_unit=cli_shared.TIME_UNIT_S)) 96 self.assertEqual("0", 97 cli_shared.time_to_readable_str( 98 0, force_time_unit=cli_shared.TIME_UNIT_S)) 99 100 with self.assertRaisesRegex(ValueError, r"Invalid time unit: ks"): 101 cli_shared.time_to_readable_str(100, force_time_unit="ks") 102 103 104@test_util.run_v1_only("tfdbg CLI is for tf.Session only") 105class GetRunStartIntroAndDescriptionTest(test_util.TensorFlowTestCase): 106 107 def setUp(self): 108 self.const_a = constant_op.constant(11.0, name="a") 109 self.const_b = constant_op.constant(22.0, name="b") 110 self.const_c = constant_op.constant(33.0, name="c") 111 112 self.sparse_d = sparse_tensor.SparseTensor( 113 indices=[[0, 0], [1, 1]], values=[1.0, 2.0], dense_shape=[3, 3]) 114 115 def tearDown(self): 116 ops.reset_default_graph() 117 118 def testSingleFetchNoFeeds(self): 119 run_start_intro = cli_shared.get_run_start_intro(12, self.const_a, None, {}) 120 121 # Verify line about run() call number. 122 self.assertTrue(run_start_intro.lines[1].endswith("run() call #12:")) 123 124 # Verify line about fetch. 125 const_a_name_line = run_start_intro.lines[4] 126 self.assertEqual(self.const_a.name, const_a_name_line.strip()) 127 128 # Verify line about feeds. 129 feeds_line = run_start_intro.lines[7] 130 self.assertEqual("(Empty)", feeds_line.strip()) 131 132 # Verify lines about possible commands and their font attributes. 133 self.assertEqual("run:", run_start_intro.lines[11][2:]) 134 annot = run_start_intro.font_attr_segs[11][0] 135 self.assertEqual(2, annot[0]) 136 self.assertEqual(5, annot[1]) 137 self.assertEqual("run", annot[2][0].content) 138 self.assertEqual("bold", annot[2][1]) 139 annot = run_start_intro.font_attr_segs[13][0] 140 self.assertEqual(2, annot[0]) 141 self.assertEqual(8, annot[1]) 142 self.assertEqual("run -n", annot[2][0].content) 143 self.assertEqual("bold", annot[2][1]) 144 self.assertEqual("run -t <T>:", run_start_intro.lines[15][2:]) 145 self.assertEqual([(2, 12, "bold")], run_start_intro.font_attr_segs[15]) 146 self.assertEqual("run -f <filter_name>:", run_start_intro.lines[17][2:]) 147 self.assertEqual([(2, 22, "bold")], run_start_intro.font_attr_segs[17]) 148 149 # Verify short description. 150 description = cli_shared.get_run_short_description(12, self.const_a, None) 151 self.assertEqual("run #12: 1 fetch (a:0); 0 feeds", description) 152 153 # Verify the main menu associated with the run_start_intro. 154 self.assertIn(debugger_cli_common.MAIN_MENU_KEY, 155 run_start_intro.annotations) 156 menu = run_start_intro.annotations[debugger_cli_common.MAIN_MENU_KEY] 157 self.assertEqual("run", menu.caption_to_item("run").content) 158 self.assertEqual("exit", menu.caption_to_item("exit").content) 159 160 def testSparseTensorAsFeedShouldHandleNoNameAttribute(self): 161 sparse_feed_val = ([[0, 0], [1, 1]], [10.0, 20.0]) 162 run_start_intro = cli_shared.get_run_start_intro( 163 1, self.sparse_d, {self.sparse_d: sparse_feed_val}, {}) 164 self.assertEqual(str(self.sparse_d), run_start_intro.lines[7].strip()) 165 166 short_description = cli_shared.get_run_short_description( 167 1, self.sparse_d, {self.sparse_d: sparse_feed_val}) 168 self.assertEqual( 169 "run #1: 1 fetch; 1 feed (%s)" % self.sparse_d, short_description) 170 171 def testSparseTensorAsFetchShouldHandleNoNameAttribute(self): 172 run_start_intro = cli_shared.get_run_start_intro(1, self.sparse_d, None, {}) 173 self.assertEqual(str(self.sparse_d), run_start_intro.lines[4].strip()) 174 175 def testTwoFetchesListNoFeeds(self): 176 fetches = [self.const_a, self.const_b] 177 run_start_intro = cli_shared.get_run_start_intro(1, fetches, None, {}) 178 179 const_a_name_line = run_start_intro.lines[4] 180 const_b_name_line = run_start_intro.lines[5] 181 self.assertEqual(self.const_a.name, const_a_name_line.strip()) 182 self.assertEqual(self.const_b.name, const_b_name_line.strip()) 183 184 feeds_line = run_start_intro.lines[8] 185 self.assertEqual("(Empty)", feeds_line.strip()) 186 187 # Verify short description. 188 description = cli_shared.get_run_short_description(1, fetches, None) 189 self.assertEqual("run #1: 2 fetches; 0 feeds", description) 190 191 def testNestedListAsFetches(self): 192 fetches = [self.const_c, [self.const_a, self.const_b]] 193 run_start_intro = cli_shared.get_run_start_intro(1, fetches, None, {}) 194 195 # Verify lines about the fetches. 196 self.assertEqual(self.const_c.name, run_start_intro.lines[4].strip()) 197 self.assertEqual(self.const_a.name, run_start_intro.lines[5].strip()) 198 self.assertEqual(self.const_b.name, run_start_intro.lines[6].strip()) 199 200 # Verify short description. 201 description = cli_shared.get_run_short_description(1, fetches, None) 202 self.assertEqual("run #1: 3 fetches; 0 feeds", description) 203 204 def testNestedDictAsFetches(self): 205 fetches = {"c": self.const_c, "ab": {"a": self.const_a, "b": self.const_b}} 206 run_start_intro = cli_shared.get_run_start_intro(1, fetches, None, {}) 207 208 # Verify lines about the fetches. The ordering of the dict keys is 209 # indeterminate. 210 fetch_names = set() 211 fetch_names.add(run_start_intro.lines[4].strip()) 212 fetch_names.add(run_start_intro.lines[5].strip()) 213 fetch_names.add(run_start_intro.lines[6].strip()) 214 215 self.assertEqual({"a:0", "b:0", "c:0"}, fetch_names) 216 217 # Verify short description. 218 description = cli_shared.get_run_short_description(1, fetches, None) 219 self.assertEqual("run #1: 3 fetches; 0 feeds", description) 220 221 def testTwoFetchesAsTupleNoFeeds(self): 222 fetches = (self.const_a, self.const_b) 223 run_start_intro = cli_shared.get_run_start_intro(1, fetches, None, {}) 224 225 const_a_name_line = run_start_intro.lines[4] 226 const_b_name_line = run_start_intro.lines[5] 227 self.assertEqual(self.const_a.name, const_a_name_line.strip()) 228 self.assertEqual(self.const_b.name, const_b_name_line.strip()) 229 230 feeds_line = run_start_intro.lines[8] 231 self.assertEqual("(Empty)", feeds_line.strip()) 232 233 # Verify short description. 234 description = cli_shared.get_run_short_description(1, fetches, None) 235 self.assertEqual("run #1: 2 fetches; 0 feeds", description) 236 237 def testTwoFetchesAsNamedTupleNoFeeds(self): 238 fetches_namedtuple = namedtuple("fetches", "x y") 239 fetches = fetches_namedtuple(self.const_b, self.const_c) 240 run_start_intro = cli_shared.get_run_start_intro(1, fetches, None, {}) 241 242 const_b_name_line = run_start_intro.lines[4] 243 const_c_name_line = run_start_intro.lines[5] 244 self.assertEqual(self.const_b.name, const_b_name_line.strip()) 245 self.assertEqual(self.const_c.name, const_c_name_line.strip()) 246 247 feeds_line = run_start_intro.lines[8] 248 self.assertEqual("(Empty)", feeds_line.strip()) 249 250 # Verify short description. 251 description = cli_shared.get_run_short_description(1, fetches, None) 252 self.assertEqual("run #1: 2 fetches; 0 feeds", description) 253 254 def testWithFeedDict(self): 255 feed_dict = { 256 self.const_a: 10.0, 257 self.const_b: 20.0, 258 } 259 260 run_start_intro = cli_shared.get_run_start_intro(1, self.const_c, feed_dict, 261 {}) 262 263 const_c_name_line = run_start_intro.lines[4] 264 self.assertEqual(self.const_c.name, const_c_name_line.strip()) 265 266 # Verify lines about the feed dict. 267 feed_a_line = run_start_intro.lines[7] 268 feed_b_line = run_start_intro.lines[8] 269 self.assertEqual(self.const_a.name, feed_a_line.strip()) 270 self.assertEqual(self.const_b.name, feed_b_line.strip()) 271 272 # Verify short description. 273 description = cli_shared.get_run_short_description(1, self.const_c, 274 feed_dict) 275 self.assertEqual("run #1: 1 fetch (c:0); 2 feeds", description) 276 277 def testTensorFilters(self): 278 feed_dict = {self.const_a: 10.0} 279 tensor_filters = { 280 "filter_a": lambda x: True, 281 "filter_b": lambda x: False, 282 } 283 284 run_start_intro = cli_shared.get_run_start_intro(1, self.const_c, feed_dict, 285 tensor_filters) 286 287 # Verify the listed names of the tensor filters. 288 filter_names = set() 289 filter_names.add(run_start_intro.lines[20].split(" ")[-1]) 290 filter_names.add(run_start_intro.lines[21].split(" ")[-1]) 291 292 self.assertEqual({"filter_a", "filter_b"}, filter_names) 293 294 # Verify short description. 295 description = cli_shared.get_run_short_description(1, self.const_c, 296 feed_dict) 297 self.assertEqual("run #1: 1 fetch (c:0); 1 feed (a:0)", description) 298 299 # Verify the command links for the two filters. 300 command_set = set() 301 annot = run_start_intro.font_attr_segs[20][0] 302 command_set.add(annot[2].content) 303 annot = run_start_intro.font_attr_segs[21][0] 304 command_set.add(annot[2].content) 305 self.assertEqual({"run -f filter_a", "run -f filter_b"}, command_set) 306 307 def testGetRunShortDescriptionWorksForTensorFeedKey(self): 308 short_description = cli_shared.get_run_short_description( 309 1, self.const_a, {self.const_a: 42.0}) 310 self.assertEqual("run #1: 1 fetch (a:0); 1 feed (a:0)", short_description) 311 312 def testGetRunShortDescriptionWorksForUnicodeFeedKey(self): 313 short_description = cli_shared.get_run_short_description( 314 1, self.const_a, {u"foo": 42.0}) 315 self.assertEqual("run #1: 1 fetch (a:0); 1 feed (foo)", short_description) 316 317 318@test_util.run_v1_only("tfdbg CLI is for tf.Session only") 319class GetErrorIntroTest(test_util.TensorFlowTestCase): 320 321 def setUp(self): 322 self.var_a = variables.Variable(42.0, name="a") 323 324 def tearDown(self): 325 ops.reset_default_graph() 326 327 def testShapeError(self): 328 tf_error = errors.OpError(None, self.var_a.initializer, "foo description", 329 None) 330 331 error_intro = cli_shared.get_error_intro(tf_error) 332 333 self.assertEqual("!!! An error occurred during the run !!!", 334 error_intro.lines[1]) 335 self.assertEqual([(0, len(error_intro.lines[1]), "blink")], 336 error_intro.font_attr_segs[1]) 337 338 self.assertEqual(2, error_intro.lines[4].index("ni -a -d -t a/Assign")) 339 self.assertEqual(2, error_intro.font_attr_segs[4][0][0]) 340 self.assertEqual(22, error_intro.font_attr_segs[4][0][1]) 341 self.assertEqual("ni -a -d -t a/Assign", 342 error_intro.font_attr_segs[4][0][2][0].content) 343 self.assertEqual("bold", error_intro.font_attr_segs[4][0][2][1]) 344 345 self.assertEqual(2, error_intro.lines[6].index("li -r a/Assign")) 346 self.assertEqual(2, error_intro.font_attr_segs[6][0][0]) 347 self.assertEqual(16, error_intro.font_attr_segs[6][0][1]) 348 self.assertEqual("li -r a/Assign", 349 error_intro.font_attr_segs[6][0][2][0].content) 350 self.assertEqual("bold", error_intro.font_attr_segs[6][0][2][1]) 351 352 self.assertEqual(2, error_intro.lines[8].index("lt")) 353 self.assertEqual(2, error_intro.font_attr_segs[8][0][0]) 354 self.assertEqual(4, error_intro.font_attr_segs[8][0][1]) 355 self.assertEqual("lt", error_intro.font_attr_segs[8][0][2][0].content) 356 self.assertEqual("bold", error_intro.font_attr_segs[8][0][2][1]) 357 358 self.assertStartsWith(error_intro.lines[11], "Op name:") 359 self.assertTrue(error_intro.lines[11].endswith("a/Assign")) 360 361 self.assertStartsWith(error_intro.lines[12], "Error type:") 362 self.assertTrue(error_intro.lines[12].endswith(str(type(tf_error)))) 363 364 self.assertEqual("Details:", error_intro.lines[14]) 365 self.assertStartsWith(error_intro.lines[15], "foo description") 366 367 def testGetErrorIntroForNoOpName(self): 368 tf_error = errors.OpError(None, None, "Fake OpError", -1) 369 error_intro = cli_shared.get_error_intro(tf_error) 370 self.assertIn("Cannot determine the name of the op", error_intro.lines[3]) 371 372 373if __name__ == "__main__": 374 googletest.main() 375