xref: /aosp_15_r20/external/tensorflow/tensorflow/python/tpu/tensor_tracer_flags.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ========================================================================
15"""Utilities to handle tensor tracer parameters."""
16
17
18import os
19import os.path
20import re
21
22from tensorflow.python.ops import linalg_ops
23from tensorflow.python.ops import math_ops
24from tensorflow.python.platform import tf_logging as logging
25
26TRACE_MODE_PART_TENSOR = 'part-tensor'
27TRACE_MODE_FULL_TENSOR = 'full-tensor'
28TRACE_MODE_FULL_TENSOR_SUMMARY = 'full_tensor_summary'
29
30TRACE_MODE_NAN_INF = 'nan-inf'
31TRACE_MODE_NORM = 'norm'
32TRACE_MODE_MAX_ABS = 'max-abs'
33TRACE_MODE_SUMMARY = 'summary'
34# summary mode to collects a finite set of signatures for each traced tensor,
35# (such as norm, max, min, mean) and dumps it using tb summaries.
36
37# Full tensor mode dumps the whole tensor values for the traced tensors without
38# any processing on them; using tb summaries.
39
40_SUBMODE_BRIEF = 'brief'
41_SUBMODE_DETAILED = 'detailed'
42
43_FLAG_SINGLE_QUOTE_PAT = re.compile(r"\s*--([^=]+)='([^']*)'")
44_FLAG_DOUBLE_QUOTE_PAT = re.compile(r'\s*--([^=]+)="([^"]*)"')
45_FLAG_NO_QUOTE_PAT = re.compile(r'\s*--([^=]+)=(\S*)')
46_FLAG_NO_EQUAL_PAT = re.compile(r'\s*--([^=]+)\s*')
47
48FLAGS_ENV_VAR = 'TENSOR_TRACER_FLAGS'
49FLAG_NAME_ENABLE = 'enable'
50FLAG_NAME_TRACE_MODE = 'trace_mode'
51FLAG_NAME_TRACE_SCALAR_OPS = 'trace_scalar'
52FLAG_NAME_SUBMODE = 'submode'
53FLAG_NAME_EXCLUDED_OPNAMES = 'excluded_opnames'
54FLAG_NAME_EXCLUDED_OPTYPES = 'excluded_optypes'
55FLAG_NAME_INCLUDED_OPNAMES = 'included_opnames'
56FLAG_NAME_INCLUDED_OPTYPES = 'included_optypes'
57FLAG_NAME_TRACE_LEVEL = 'trace_level'
58FLAG_NAME_TRACE_DIR = 'trace_dir'
59FLAG_NAME_REPORT_FILE = 'report_file'
60FLAG_NAME_USE_TEST_UNDECLARED_OUTPUTS_DIR = 'use_test_undeclared_outputs_dir'
61FLAG_NAME_OP_RANGE = 'op_range'
62# Folder to dump the pre (before tensor tracer updates) and post graphs (after
63# tensor tracer updates).
64FLAG_NAME_DUMP_BEFORE_AFTER_GRAPHS = 'dump_graphs'
65FLAG_NAME_SUMMARY_SIGNATURES = 'signatures'
66FLAG_NAME_SUMMARY_PER_CORE = 'collect_summary_per_core'
67FLAG_NAME_TEMP_CACHE_VAR = 'use_temp_cache'
68FLAG_NAME_INSPECT_TRACE = 'inspect_trace'
69FLAG_NAME_FINGERPRINT_DIR = 'use_fingerprint_subdirectory'
70FLAG_FLUSH_SUMMARY = 'flush_summaries'
71
72
73VALID_FLAG_NAMES = [
74    FLAG_NAME_ENABLE, FLAG_NAME_TRACE_MODE,
75    FLAG_NAME_TRACE_SCALAR_OPS,
76    FLAG_NAME_SUBMODE, FLAG_NAME_EXCLUDED_OPNAMES,
77    FLAG_NAME_EXCLUDED_OPTYPES, FLAG_NAME_INCLUDED_OPNAMES,
78    FLAG_NAME_INCLUDED_OPTYPES, FLAG_NAME_TRACE_DIR,
79    FLAG_NAME_REPORT_FILE,
80    FLAG_NAME_USE_TEST_UNDECLARED_OUTPUTS_DIR,
81    FLAG_NAME_OP_RANGE,
82    FLAG_NAME_DUMP_BEFORE_AFTER_GRAPHS, FLAG_NAME_TRACE_LEVEL,
83    FLAG_NAME_SUMMARY_SIGNATURES, FLAG_NAME_SUMMARY_PER_CORE,
84    FLAG_NAME_TEMP_CACHE_VAR, FLAG_NAME_FINGERPRINT_DIR,
85    FLAG_NAME_INSPECT_TRACE, FLAG_FLUSH_SUMMARY,
86]
87
88_OP_RANGE_PAT = re.compile(r'(\d+):(\d+)')
89_TEST_UNDECLARED_OUTPUTS_DIR_ENV_VAR = 'TEST_UNDECLARED_OUTPUTS_DIR'
90
91_TT_DEFAULT_TRACE_LEVEL = 3
92_TT_PREFIX = 'tensor_tracer'
93
94_TT_NORM = 'norm'
95_TT_MAX = 'max'
96_TT_MAX_ABS = 'max-abs'
97_TT_MIN = 'min'
98_TT_SPARSITY = 'sparsity'
99_TT_MEAN = 'mean'
100_TT_VAR = 'var'
101_TT_SIZE = 'size'
102
103TT_SUMMARY_NORM = '%s_%s' % (_TT_PREFIX, _TT_NORM)
104TT_SUMMARY_MAX = '%s_%s' % (_TT_PREFIX, _TT_MAX)
105TT_SUMMARY_MAX_ABS = '%s_%s' % (_TT_PREFIX, _TT_MAX_ABS)
106TT_SUMMARY_MIN = '%s_%s' % (_TT_PREFIX, _TT_MIN)
107TT_SUMMARY_SPARSITY = '%s_%s' % (_TT_PREFIX, _TT_SPARSITY)
108TT_SUMMARY_MEAN = '%s_%s' % (_TT_PREFIX, _TT_MEAN)
109TT_SUMMARY_VAR = '%s_%s' % (_TT_PREFIX, _TT_VAR)
110TT_SUMMARY_SIZE = '%s_%s' % (_TT_PREFIX, _TT_SIZE)
111
112TT_SUMMARY_SIGNATURES = (TT_SUMMARY_NORM, TT_SUMMARY_MAX, TT_SUMMARY_MIN,
113                         TT_SUMMARY_SPARSITY, TT_SUMMARY_MEAN, TT_SUMMARY_VAR,
114                         TT_SUMMARY_SIZE, TT_SUMMARY_MAX_ABS)
115
116
117class TTParameters(object):
118  """A class that handles the parameters of Tensor Tracer."""
119
120  def __init__(self, env=None):
121    if env:
122      self._env = env
123    else:
124      self._env = os.environ
125    self._validate_flag_names()
126    self.trace_mode = self._get_trace_mode()
127    self.submode = self._get_submode()
128    self.trace_dir = self._get_trace_dir()
129    self.report_file_path = self._get_report_filepath()
130    self.op_range = self._get_op_range()
131    self.excluded_opname_re_list = self._flag_value_to_re_list(
132        FLAG_NAME_EXCLUDED_OPNAMES)
133    self.excluded_optype_re_list = self._flag_value_to_re_list(
134        FLAG_NAME_EXCLUDED_OPTYPES)
135
136    self.included_opname_re_list = self._flag_value_to_re_list(
137        FLAG_NAME_INCLUDED_OPNAMES)
138    self.included_optype_re_list = self._flag_value_to_re_list(
139        FLAG_NAME_INCLUDED_OPTYPES)
140
141    self.trace_scalar_ops = self.is_flag_on(FLAG_NAME_TRACE_SCALAR_OPS)
142    self.use_compact_trace = self.trace_mode in (TRACE_MODE_NAN_INF,
143                                                 TRACE_MODE_NORM,
144                                                 TRACE_MODE_MAX_ABS,
145                                                 TRACE_MODE_SUMMARY)
146    self.use_temp_cache_var = self.is_flag_on(FLAG_NAME_TEMP_CACHE_VAR)
147    self.inspect_trace = self.is_flag_on(FLAG_NAME_INSPECT_TRACE)
148    self.use_fingerprint_subdir = self.is_flag_on(FLAG_NAME_FINGERPRINT_DIR)
149
150    _, self.graph_dump_path = self.get_flag_value(
151        FLAG_NAME_DUMP_BEFORE_AFTER_GRAPHS)
152    self.trace_level = self._get_flag_int_value(FLAG_NAME_TRACE_LEVEL,
153                                                _TT_DEFAULT_TRACE_LEVEL)
154    self.summary_signatures = self._get_summary_signatures()
155    self.collect_summary_per_core = self.is_flag_on(FLAG_NAME_SUMMARY_PER_CORE)
156    # TODO(b/199284834): Will be resolved with referenced bug.
157    if self.collect_summary_per_core:
158      logging.warning('Aggregate signatures are approximate for mean, variance'
159                      ' and sparsity.')
160    self.flush_summaries_with_outside_compile = self.is_flag_on(
161        FLAG_FLUSH_SUMMARY)
162    # Do not produce errors or warnings if Tensor Tracer is not enabled.
163    if self.is_enabled():
164      self._check_flag_errors()
165
166  def _check_flag_errors(self):
167    if self.trace_mode in (TRACE_MODE_SUMMARY, TRACE_MODE_FULL_TENSOR_SUMMARY):
168      if not self.trace_dir:
169        raise ValueError('trace_dir must be explicitly provided in '
170                         'TENSOR_TRACER_FLAGS when summary mode is used.')
171
172  def _get_report_filepath(self):
173    """Sets the path of the output report file."""
174
175    found, report_file_path = self.get_flag_value(FLAG_NAME_REPORT_FILE)
176    if found and report_file_path and self.use_test_undeclared_outputs_dir():
177      if os.path.isabs(report_file_path):
178        raise ValueError('If use_test_undeclared_outputs_dir is set,'
179                         'report_file_path cannot be an absolute path (%s)'
180                         %report_file_path)
181      outputs_dir = self._env.get(_TEST_UNDECLARED_OUTPUTS_DIR_ENV_VAR)
182      report_file_path = os.path.join(outputs_dir, report_file_path)
183    return report_file_path
184
185  def _get_op_range(self):
186    """Sets the index range of the Ops that we will consider tracing."""
187    found, op_range = self.get_flag_value(FLAG_NAME_OP_RANGE)
188    if not found or not op_range:
189      op_range = (-1, -1)  # this means including all ops.
190      return op_range
191    match = _OP_RANGE_PAT.match(op_range)
192    if not match:
193      op_range = (-1, -1)  # this means including all ops.
194      return op_range
195    op_range = (int(match.group(1)), int(match.group(2)))
196    return op_range
197
198  def _get_trace_dir(self):
199    found, trace_dir = self.get_flag_value(FLAG_NAME_TRACE_DIR)
200    if found and trace_dir and self.use_test_undeclared_outputs_dir():
201      raise ValueError(
202          'Cannot not use --%s and --%s at the same time' %
203          (FLAG_NAME_TRACE_DIR, FLAG_NAME_USE_TEST_UNDECLARED_OUTPUTS_DIR))
204    if self.use_test_undeclared_outputs_dir():
205      trace_dir = self._env.get(_TEST_UNDECLARED_OUTPUTS_DIR_ENV_VAR)
206    return trace_dir
207
208  def _get_trace_mode(self):
209    """Checks if the given trace mode is valid."""
210
211    found, trace_mode = self.get_flag_value(FLAG_NAME_TRACE_MODE)
212    if not found or not trace_mode:
213      trace_mode = TRACE_MODE_NORM
214    valid_trace_modes = [
215        TRACE_MODE_NAN_INF, TRACE_MODE_PART_TENSOR, TRACE_MODE_FULL_TENSOR,
216        TRACE_MODE_NORM, TRACE_MODE_MAX_ABS,
217        TRACE_MODE_SUMMARY, TRACE_MODE_FULL_TENSOR_SUMMARY
218    ]
219    if trace_mode not in valid_trace_modes:
220      raise ValueError('Invalid trace mode "%s" given to the Tensor_Tracer.'
221                       'Valid trace modes are: %s'%(trace_mode,
222                                                    valid_trace_modes))
223    return trace_mode
224
225  def is_brief_mode(self):
226    return self.submode == _SUBMODE_BRIEF
227
228  def _get_submode(self):
229    """Checks if the given submode is valid."""
230
231    found, submode = self.get_flag_value(FLAG_NAME_SUBMODE)
232    if not found or not submode:
233      submode = _SUBMODE_DETAILED
234    if not submode:
235      return
236    valid_submodes = [_SUBMODE_DETAILED, _SUBMODE_BRIEF]
237    if submode not in valid_submodes:
238      raise ValueError('Invalid submode "%s" given to the Tensor_Tracer.'
239                       'Valid submodes are: %s'%(submode,
240                                                 valid_submodes))
241    return submode
242
243  @staticmethod
244  def match_next_flag(flags, pos):
245    """Returns the match for the next TensorTracer flag.
246
247    Args:
248       flags: a string that contains the flags.
249       pos: where in flags to start the search.
250
251    Returns:
252       A pair where the first element is the regular-expression
253       match found and the second element indicates if the match
254       has a value.
255    """
256
257    match = _FLAG_DOUBLE_QUOTE_PAT.match(flags, pos)
258    if match:
259      return match, True
260    match = _FLAG_SINGLE_QUOTE_PAT.match(flags, pos)
261    if match:
262      return match, True
263    match = _FLAG_NO_QUOTE_PAT.match(flags, pos)
264    if match:
265      return match, True
266    match = _FLAG_NO_EQUAL_PAT.match(flags, pos)
267    if match:
268      # The flag is found but is not given a value.
269      return match, False
270    # The flag is not found.
271    return None, False
272
273  def _validate_flag_names(self):
274    """Validates if the TensorTrace flags passed are valid."""
275    tensor_tracer_flags = self._env.get(FLAGS_ENV_VAR)
276    if not tensor_tracer_flags:
277      return
278    pos = 0
279    while True:
280      match, _ = TTParameters.match_next_flag(tensor_tracer_flags, pos)
281      if not match:
282        break
283      flag_name = match.group(1)
284      if flag_name not in VALID_FLAG_NAMES:
285        raise ValueError(
286            'The flag name "%s" passed via the environment variable "%s" '
287            'is invalid. Valid flag names are:'
288            '\n%s' % (flag_name, FLAGS_ENV_VAR, VALID_FLAG_NAMES))
289      pos = match.end()
290
291  def _supported_signatures(self):
292    """Returns a tuple of supported signatures."""
293    return TT_SUMMARY_SIGNATURES
294
295  def _get_summary_signatures(self):
296    """Verifies and returns the summary signatures.
297
298    Returns:
299      A dictionary of the signature identifiers {signature: index} that will be
300      computed when trace_mode is summary.
301    """
302    signatures = self._flag_value_as_list(FLAG_NAME_SUMMARY_SIGNATURES)
303    supported_signatures = self._supported_signatures()
304
305    tt_signatures = []
306    for signature in signatures:
307      signature_with_prefix = '%s_%s' % (_TT_PREFIX, signature)
308      if signature in supported_signatures:
309        tt_signatures.append(signature)
310      elif signature_with_prefix in supported_signatures:
311        tt_signatures.append(signature_with_prefix)
312      else:
313        logging.warning('Unknown signature:%s. Supported signatures: %s' %
314                        (signature, supported_signatures))
315    if not tt_signatures:
316      # Default case collects norm and max only.
317      return {TT_SUMMARY_MAX_ABS: 0, TT_SUMMARY_NORM: 1}
318    else:
319      return {signature: idx for idx, signature in enumerate(tt_signatures)}
320
321  def get_signature_to_agg_fn_map(self):
322    """Returns a map that contains the aggregate function for each signature."""
323    # TODO(b/199284834): Aggregations are not accurate for mean and sparsity if
324    # cores have a different number of elements. Variance uses the maximal core
325    # variance.
326    return {TRACE_MODE_NORM: linalg_ops.norm,
327            TRACE_MODE_MAX_ABS: math_ops.reduce_max,
328            TRACE_MODE_NAN_INF: math_ops.reduce_max,
329            TT_SUMMARY_NORM: linalg_ops.norm,
330            TT_SUMMARY_MAX: math_ops.reduce_max,
331            TT_SUMMARY_MAX_ABS:
332                lambda t, axis=0: math_ops.reduce_max(math_ops.abs(t),  # pylint: disable=g-long-lambda
333                                                      axis=axis),
334            TT_SUMMARY_MIN: math_ops.reduce_min,
335            # Exact if each part has the same number of values.
336            TT_SUMMARY_SPARSITY: math_ops.reduce_mean,
337            TT_SUMMARY_MEAN: math_ops.reduce_mean,
338            TT_SUMMARY_VAR: math_ops.reduce_max,  # Simply reduce max variance.
339            TT_SUMMARY_SIZE: math_ops.reduce_sum}
340
341  def _flag_value_as_list(self, wanted_flag_name):
342    """Returns the string list of a TensorTracer flag.
343
344    Args:
345      wanted_flag_name: the name of the flag we are looking for.
346
347    Returns:
348      The list value of the flag.
349    """
350    string_value_list = []
351    found, flag_value = self.get_flag_value(wanted_flag_name)
352
353    if found:
354      string_value_list = flag_value.split(',')
355    return string_value_list
356
357  def _flag_value_as_int_list(self, wanted_flag_name):
358    """Returns the integer list of a TensorTracer flag.
359
360    Args:
361      wanted_flag_name: the name of the flag we are looking for.
362
363    Returns:
364      the value of the flag.
365    Raises:
366      RuntimeError: If supposedly deadcode is reached.
367    """
368    int_list = []
369    found, flag_value = self.get_flag_value(wanted_flag_name)
370
371    if found and flag_value:
372      try:
373        integer_values = flag_value.split(',')
374        int_list = [int(int_val) for int_val in integer_values]
375      except ValueError:
376        logging.warning('Cannot convert %s to int for flag %s', int_list,
377                        wanted_flag_name)
378    return int_list
379
380  def _get_flag_int_value(self, wanted_flag_name, default_value):
381    """Returns the int value of a TensorTracer flag.
382
383    Args:
384      wanted_flag_name: the name of the flag we are looking for.
385      default_value: the default value for the flag, if not provided.
386    Returns:
387      the value of the flag.
388    Raises:
389      RuntimeError: If supposedly deadcode is reached.
390    """
391    flag_int_value = default_value
392    found, flag_value = self.get_flag_value(wanted_flag_name)
393
394    if found:
395      try:
396        flag_int_value = int(flag_value)
397      except ValueError:
398        logging.warning('Cannot convert %s to int for flag %s' % (
399            flag_int_value, wanted_flag_name))
400    return flag_int_value
401
402  def get_flag_value(self, wanted_flag_name):
403    """Returns the value of a TensorTracer flags.
404
405    Args:
406      wanted_flag_name: the name of the flag we are looking for.
407
408    Returns:
409      A pair where the first element indicates if the flag is
410      found and the second element is the value of the flag.
411
412    Raises:
413      RuntimeError: If supposedly deadcode is reached.
414    """
415
416    tensor_tracer_flags = self._env.get(FLAGS_ENV_VAR)
417    if not tensor_tracer_flags:
418      return False, None
419    pos = 0
420    while True:
421      match, has_value = TTParameters.match_next_flag(
422          tensor_tracer_flags, pos)
423      if not match:
424        return False, None
425      flag_name = match.group(1)
426      if has_value:
427        flag_value = match.group(2)
428      else:
429        flag_value = None
430      if flag_name == wanted_flag_name:
431        return True, flag_value
432      pos = match.end()
433    raise RuntimeError('Invalid tensor tracer flag. Could not recognize %s.' %
434                       flag_name)
435
436  def _flag_value_to_re_list(self, flag_name):
437    """Converts list of strings to compiled RE."""
438
439    re_list = []
440    found, flag_value = self.get_flag_value(flag_name)
441    if not found or not flag_value:
442      return re_list
443    list_of_values = flag_value.split(',')
444    for v in list_of_values:
445      r = re.compile(v)
446      re_list.append(r)
447    return re_list
448
449  def is_flag_on(self, flag_name):
450    """Returns True if the given flag is on."""
451
452    found, flag_value = self.get_flag_value(flag_name)
453    if not found:
454      return False
455    if flag_value is None:
456      return True
457    # Depends on the flag value.
458    flag_value = flag_value.lower()
459    enabled = flag_value in ['1', 't', 'true', 'y', 'yes']
460    return enabled
461
462  def is_enabled(self):
463    """Returns True if TensorTracer is enabled."""
464
465    if self.is_flag_on(FLAG_NAME_ENABLE):
466      logging.debug('Tensor Tracer is enabled with flags %s.',
467                    self._env.get(FLAGS_ENV_VAR))
468      return True
469    else:
470      return False
471
472  def use_test_undeclared_outputs_dir(self):
473    """Decides the output directory of the report and trace files.
474
475    Args:
476       None.
477
478    Returns:
479       True if the output files should be written to the
480       test-undeclared-outputs-directory defined via an
481       env variable.
482    """
483
484    return self.is_flag_on(FLAG_NAME_USE_TEST_UNDECLARED_OUTPUTS_DIR)
485