xref: /aosp_15_r20/external/tensorflow/tensorflow/python/debug/cli/cli_config.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Configurations for TensorFlow Debugger (TFDBG) command-line interfaces."""
16import collections
17import json
18import os
19
20from tensorflow.python.debug.cli import debugger_cli_common
21from tensorflow.python.platform import gfile
22
23RL = debugger_cli_common.RichLine
24
25
26class CLIConfig(object):
27  """Client-facing configurations for TFDBG command-line interfaces."""
28
29  _CONFIG_FILE_NAME = ".tfdbg_config"
30
31  _DEFAULT_CONFIG = [
32      ("graph_recursion_depth", 20),
33      ("mouse_mode", True),
34  ]
35
36  def __init__(self, config_file_path=None):
37    self._config_file_path = (config_file_path or
38                              self._default_config_file_path())
39    self._config = collections.OrderedDict(self._DEFAULT_CONFIG)
40    if gfile.Exists(self._config_file_path):
41      config = self._load_from_file()
42      for key, value in config.items():
43        self._config[key] = value
44    self._save_to_file()
45
46    self._set_callbacks = {}
47
48  def get(self, property_name):
49    if property_name not in self._config:
50      raise KeyError("%s is not a valid property name." % property_name)
51    return self._config[property_name]
52
53  def set(self, property_name, property_val):
54    """Set the value of a property.
55
56    Supports limitd property value types: `bool`, `int` and `str`.
57
58    Args:
59      property_name: Name of the property.
60      property_val: Value of the property. If the property has `bool` type and
61        this argument has `str` type, the `str` value will be parsed as a `bool`
62
63    Raises:
64      ValueError: if a `str` property_value fails to be parsed as a `bool`.
65      KeyError: if `property_name` is an invalid property name.
66    """
67    if property_name not in self._config:
68      raise KeyError("%s is not a valid property name." % property_name)
69
70    orig_val = self._config[property_name]
71    if isinstance(orig_val, bool):
72      if isinstance(property_val, str):
73        if property_val.lower() in ("1", "true", "t", "yes", "y", "on"):
74          property_val = True
75        elif property_val.lower() in ("0", "false", "f", "no", "n", "off"):
76          property_val = False
77        else:
78          raise ValueError(
79              "Invalid string value for bool type: %s" % property_val)
80      else:
81        property_val = bool(property_val)
82    elif isinstance(orig_val, int):
83      property_val = int(property_val)
84    elif isinstance(orig_val, str):
85      property_val = str(property_val)
86    else:
87      raise TypeError("Unsupported property type: %s" % type(orig_val))
88    self._config[property_name] = property_val
89    self._save_to_file()
90
91    # Invoke set-callback.
92    if property_name in self._set_callbacks:
93      self._set_callbacks[property_name](self._config)
94
95  def set_callback(self, property_name, callback):
96    """Set a set-callback for given property.
97
98    Args:
99      property_name: Name of the property.
100      callback: The callback as a `callable` of signature:
101          def cbk(config):
102        where config is the config after it is set to the new value.
103        The callback is invoked each time the set() method is called with the
104        matching property_name.
105
106    Raises:
107      KeyError: If property_name does not exist.
108      TypeError: If `callback` is not callable.
109    """
110    if property_name not in self._config:
111      raise KeyError("%s is not a valid property name." % property_name)
112    if not callable(callback):
113      raise TypeError("The callback object provided is not callable.")
114    self._set_callbacks[property_name] = callback
115
116  def _default_config_file_path(self):
117    return os.path.join(os.path.expanduser("~"), self._CONFIG_FILE_NAME)
118
119  def _save_to_file(self):
120    try:
121      with gfile.Open(self._config_file_path, "w") as config_file:
122        json.dump(self._config, config_file)
123    except IOError:
124      pass
125
126  def summarize(self, highlight=None):
127    """Get a text summary of the config.
128
129    Args:
130      highlight: A property name to highlight in the output.
131
132    Returns:
133      A `RichTextLines` output.
134    """
135    lines = [RL("Command-line configuration:", "bold"), RL("")]
136    for name, val in self._config.items():
137      highlight_attr = "bold" if name == highlight else None
138      line = RL("  ")
139      line += RL(name, ["underline", highlight_attr])
140      line += RL(": ")
141      line += RL(str(val), font_attr=highlight_attr)
142      lines.append(line)
143    return debugger_cli_common.rich_text_lines_from_rich_line_list(lines)
144
145  def _load_from_file(self):
146    try:
147      with gfile.Open(self._config_file_path, "r") as config_file:
148        config_dict = json.load(config_file)
149        config = collections.OrderedDict()
150        for key in sorted(config_dict.keys()):
151          config[key] = config_dict[key]
152        return config
153    except (IOError, ValueError):
154      # The reading of the config file may fail due to IO issues or file
155      # corruption. We do not want tfdbg to error out just because of that.
156      return dict()
157