1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for cli_config.""" 16import json 17import os 18import tempfile 19 20from tensorflow.python.debug.cli import cli_config 21from tensorflow.python.framework import test_util 22from tensorflow.python.lib.io import file_io 23from tensorflow.python.platform import gfile 24from tensorflow.python.platform import googletest 25 26 27class CLIConfigTest(test_util.TensorFlowTestCase): 28 29 def setUp(self): 30 self._tmp_dir = tempfile.mkdtemp() 31 self._tmp_config_path = os.path.join(self._tmp_dir, ".tfdbg_config") 32 self.assertFalse(gfile.Exists(self._tmp_config_path)) 33 super(CLIConfigTest, self).setUp() 34 35 def tearDown(self): 36 file_io.delete_recursively(self._tmp_dir) 37 super(CLIConfigTest, self).tearDown() 38 39 def testConstructCLIConfigWithoutFile(self): 40 config = cli_config.CLIConfig(config_file_path=self._tmp_config_path) 41 self.assertEqual(20, config.get("graph_recursion_depth")) 42 self.assertEqual(True, config.get("mouse_mode")) 43 with self.assertRaises(KeyError): 44 config.get("property_that_should_not_exist") 45 self.assertTrue(gfile.Exists(self._tmp_config_path)) 46 47 def testCLIConfigForwardCompatibilityTest(self): 48 config = cli_config.CLIConfig(config_file_path=self._tmp_config_path) 49 with open(self._tmp_config_path, "rt") as f: 50 config_json = json.load(f) 51 # Remove a field to simulate forward compatibility test. 52 del config_json["graph_recursion_depth"] 53 with open(self._tmp_config_path, "wt") as f: 54 json.dump(config_json, f) 55 56 config = cli_config.CLIConfig(config_file_path=self._tmp_config_path) 57 self.assertEqual(20, config.get("graph_recursion_depth")) 58 59 def testModifyConfigValue(self): 60 config = cli_config.CLIConfig(config_file_path=self._tmp_config_path) 61 config.set("graph_recursion_depth", 9) 62 config.set("mouse_mode", False) 63 self.assertEqual(9, config.get("graph_recursion_depth")) 64 self.assertEqual(False, config.get("mouse_mode")) 65 66 def testModifyConfigValueWithTypeCasting(self): 67 config = cli_config.CLIConfig(config_file_path=self._tmp_config_path) 68 config.set("graph_recursion_depth", "18") 69 config.set("mouse_mode", "false") 70 self.assertEqual(18, config.get("graph_recursion_depth")) 71 self.assertEqual(False, config.get("mouse_mode")) 72 73 def testModifyConfigValueWithTypeCastingFailure(self): 74 config = cli_config.CLIConfig(config_file_path=self._tmp_config_path) 75 with self.assertRaises(ValueError): 76 config.set("mouse_mode", "maybe") 77 78 def testLoadFromModifiedConfigFile(self): 79 config = cli_config.CLIConfig(config_file_path=self._tmp_config_path) 80 config.set("graph_recursion_depth", 9) 81 config.set("mouse_mode", False) 82 config2 = cli_config.CLIConfig(config_file_path=self._tmp_config_path) 83 self.assertEqual(9, config2.get("graph_recursion_depth")) 84 self.assertEqual(False, config2.get("mouse_mode")) 85 86 def testSummarizeFromConfig(self): 87 config = cli_config.CLIConfig(config_file_path=self._tmp_config_path) 88 output = config.summarize() 89 self.assertEqual( 90 ["Command-line configuration:", 91 "", 92 " graph_recursion_depth: %d" % config.get("graph_recursion_depth"), 93 " mouse_mode: %s" % config.get("mouse_mode")], output.lines) 94 95 def testSummarizeFromConfigWithHighlight(self): 96 config = cli_config.CLIConfig(config_file_path=self._tmp_config_path) 97 output = config.summarize(highlight="mouse_mode") 98 self.assertEqual( 99 ["Command-line configuration:", 100 "", 101 " graph_recursion_depth: %d" % config.get("graph_recursion_depth"), 102 " mouse_mode: %s" % config.get("mouse_mode")], output.lines) 103 self.assertEqual((2, 12, ["underline", "bold"]), 104 output.font_attr_segs[3][0]) 105 self.assertEqual((14, 18, "bold"), output.font_attr_segs[3][1]) 106 107 def testSetCallback(self): 108 config = cli_config.CLIConfig(config_file_path=self._tmp_config_path) 109 110 test_value = {"graph_recursion_depth": -1} 111 def callback(config): 112 test_value["graph_recursion_depth"] = config.get("graph_recursion_depth") 113 config.set_callback("graph_recursion_depth", callback) 114 115 config.set("graph_recursion_depth", config.get("graph_recursion_depth") - 1) 116 self.assertEqual(test_value["graph_recursion_depth"], 117 config.get("graph_recursion_depth")) 118 119 def testSetCallbackInvalidPropertyName(self): 120 config = cli_config.CLIConfig(config_file_path=self._tmp_config_path) 121 122 with self.assertRaises(KeyError): 123 config.set_callback("nonexistent_property_name", print) 124 125 def testSetCallbackNotCallable(self): 126 config = cli_config.CLIConfig(config_file_path=self._tmp_config_path) 127 128 with self.assertRaises(TypeError): 129 config.set_callback("graph_recursion_depth", 1) 130 131 132if __name__ == "__main__": 133 googletest.main() 134