1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ======================================================================== 15"""Utilities to handle tensor tracer parameters.""" 16 17 18import os 19import os.path 20import re 21 22from tensorflow.python.ops import linalg_ops 23from tensorflow.python.ops import math_ops 24from tensorflow.python.platform import tf_logging as logging 25 26TRACE_MODE_PART_TENSOR = 'part-tensor' 27TRACE_MODE_FULL_TENSOR = 'full-tensor' 28TRACE_MODE_FULL_TENSOR_SUMMARY = 'full_tensor_summary' 29 30TRACE_MODE_NAN_INF = 'nan-inf' 31TRACE_MODE_NORM = 'norm' 32TRACE_MODE_MAX_ABS = 'max-abs' 33TRACE_MODE_SUMMARY = 'summary' 34# summary mode to collects a finite set of signatures for each traced tensor, 35# (such as norm, max, min, mean) and dumps it using tb summaries. 36 37# Full tensor mode dumps the whole tensor values for the traced tensors without 38# any processing on them; using tb summaries. 39 40_SUBMODE_BRIEF = 'brief' 41_SUBMODE_DETAILED = 'detailed' 42 43_FLAG_SINGLE_QUOTE_PAT = re.compile(r"\s*--([^=]+)='([^']*)'") 44_FLAG_DOUBLE_QUOTE_PAT = re.compile(r'\s*--([^=]+)="([^"]*)"') 45_FLAG_NO_QUOTE_PAT = re.compile(r'\s*--([^=]+)=(\S*)') 46_FLAG_NO_EQUAL_PAT = re.compile(r'\s*--([^=]+)\s*') 47 48FLAGS_ENV_VAR = 'TENSOR_TRACER_FLAGS' 49FLAG_NAME_ENABLE = 'enable' 50FLAG_NAME_TRACE_MODE = 'trace_mode' 51FLAG_NAME_TRACE_SCALAR_OPS = 'trace_scalar' 52FLAG_NAME_SUBMODE = 'submode' 53FLAG_NAME_EXCLUDED_OPNAMES = 'excluded_opnames' 54FLAG_NAME_EXCLUDED_OPTYPES = 'excluded_optypes' 55FLAG_NAME_INCLUDED_OPNAMES = 'included_opnames' 56FLAG_NAME_INCLUDED_OPTYPES = 'included_optypes' 57FLAG_NAME_TRACE_LEVEL = 'trace_level' 58FLAG_NAME_TRACE_DIR = 'trace_dir' 59FLAG_NAME_REPORT_FILE = 'report_file' 60FLAG_NAME_USE_TEST_UNDECLARED_OUTPUTS_DIR = 'use_test_undeclared_outputs_dir' 61FLAG_NAME_OP_RANGE = 'op_range' 62# Folder to dump the pre (before tensor tracer updates) and post graphs (after 63# tensor tracer updates). 64FLAG_NAME_DUMP_BEFORE_AFTER_GRAPHS = 'dump_graphs' 65FLAG_NAME_SUMMARY_SIGNATURES = 'signatures' 66FLAG_NAME_SUMMARY_PER_CORE = 'collect_summary_per_core' 67FLAG_NAME_TEMP_CACHE_VAR = 'use_temp_cache' 68FLAG_NAME_INSPECT_TRACE = 'inspect_trace' 69FLAG_NAME_FINGERPRINT_DIR = 'use_fingerprint_subdirectory' 70FLAG_FLUSH_SUMMARY = 'flush_summaries' 71 72 73VALID_FLAG_NAMES = [ 74 FLAG_NAME_ENABLE, FLAG_NAME_TRACE_MODE, 75 FLAG_NAME_TRACE_SCALAR_OPS, 76 FLAG_NAME_SUBMODE, FLAG_NAME_EXCLUDED_OPNAMES, 77 FLAG_NAME_EXCLUDED_OPTYPES, FLAG_NAME_INCLUDED_OPNAMES, 78 FLAG_NAME_INCLUDED_OPTYPES, FLAG_NAME_TRACE_DIR, 79 FLAG_NAME_REPORT_FILE, 80 FLAG_NAME_USE_TEST_UNDECLARED_OUTPUTS_DIR, 81 FLAG_NAME_OP_RANGE, 82 FLAG_NAME_DUMP_BEFORE_AFTER_GRAPHS, FLAG_NAME_TRACE_LEVEL, 83 FLAG_NAME_SUMMARY_SIGNATURES, FLAG_NAME_SUMMARY_PER_CORE, 84 FLAG_NAME_TEMP_CACHE_VAR, FLAG_NAME_FINGERPRINT_DIR, 85 FLAG_NAME_INSPECT_TRACE, FLAG_FLUSH_SUMMARY, 86] 87 88_OP_RANGE_PAT = re.compile(r'(\d+):(\d+)') 89_TEST_UNDECLARED_OUTPUTS_DIR_ENV_VAR = 'TEST_UNDECLARED_OUTPUTS_DIR' 90 91_TT_DEFAULT_TRACE_LEVEL = 3 92_TT_PREFIX = 'tensor_tracer' 93 94_TT_NORM = 'norm' 95_TT_MAX = 'max' 96_TT_MAX_ABS = 'max-abs' 97_TT_MIN = 'min' 98_TT_SPARSITY = 'sparsity' 99_TT_MEAN = 'mean' 100_TT_VAR = 'var' 101_TT_SIZE = 'size' 102 103TT_SUMMARY_NORM = '%s_%s' % (_TT_PREFIX, _TT_NORM) 104TT_SUMMARY_MAX = '%s_%s' % (_TT_PREFIX, _TT_MAX) 105TT_SUMMARY_MAX_ABS = '%s_%s' % (_TT_PREFIX, _TT_MAX_ABS) 106TT_SUMMARY_MIN = '%s_%s' % (_TT_PREFIX, _TT_MIN) 107TT_SUMMARY_SPARSITY = '%s_%s' % (_TT_PREFIX, _TT_SPARSITY) 108TT_SUMMARY_MEAN = '%s_%s' % (_TT_PREFIX, _TT_MEAN) 109TT_SUMMARY_VAR = '%s_%s' % (_TT_PREFIX, _TT_VAR) 110TT_SUMMARY_SIZE = '%s_%s' % (_TT_PREFIX, _TT_SIZE) 111 112TT_SUMMARY_SIGNATURES = (TT_SUMMARY_NORM, TT_SUMMARY_MAX, TT_SUMMARY_MIN, 113 TT_SUMMARY_SPARSITY, TT_SUMMARY_MEAN, TT_SUMMARY_VAR, 114 TT_SUMMARY_SIZE, TT_SUMMARY_MAX_ABS) 115 116 117class TTParameters(object): 118 """A class that handles the parameters of Tensor Tracer.""" 119 120 def __init__(self, env=None): 121 if env: 122 self._env = env 123 else: 124 self._env = os.environ 125 self._validate_flag_names() 126 self.trace_mode = self._get_trace_mode() 127 self.submode = self._get_submode() 128 self.trace_dir = self._get_trace_dir() 129 self.report_file_path = self._get_report_filepath() 130 self.op_range = self._get_op_range() 131 self.excluded_opname_re_list = self._flag_value_to_re_list( 132 FLAG_NAME_EXCLUDED_OPNAMES) 133 self.excluded_optype_re_list = self._flag_value_to_re_list( 134 FLAG_NAME_EXCLUDED_OPTYPES) 135 136 self.included_opname_re_list = self._flag_value_to_re_list( 137 FLAG_NAME_INCLUDED_OPNAMES) 138 self.included_optype_re_list = self._flag_value_to_re_list( 139 FLAG_NAME_INCLUDED_OPTYPES) 140 141 self.trace_scalar_ops = self.is_flag_on(FLAG_NAME_TRACE_SCALAR_OPS) 142 self.use_compact_trace = self.trace_mode in (TRACE_MODE_NAN_INF, 143 TRACE_MODE_NORM, 144 TRACE_MODE_MAX_ABS, 145 TRACE_MODE_SUMMARY) 146 self.use_temp_cache_var = self.is_flag_on(FLAG_NAME_TEMP_CACHE_VAR) 147 self.inspect_trace = self.is_flag_on(FLAG_NAME_INSPECT_TRACE) 148 self.use_fingerprint_subdir = self.is_flag_on(FLAG_NAME_FINGERPRINT_DIR) 149 150 _, self.graph_dump_path = self.get_flag_value( 151 FLAG_NAME_DUMP_BEFORE_AFTER_GRAPHS) 152 self.trace_level = self._get_flag_int_value(FLAG_NAME_TRACE_LEVEL, 153 _TT_DEFAULT_TRACE_LEVEL) 154 self.summary_signatures = self._get_summary_signatures() 155 self.collect_summary_per_core = self.is_flag_on(FLAG_NAME_SUMMARY_PER_CORE) 156 # TODO(b/199284834): Will be resolved with referenced bug. 157 if self.collect_summary_per_core: 158 logging.warning('Aggregate signatures are approximate for mean, variance' 159 ' and sparsity.') 160 self.flush_summaries_with_outside_compile = self.is_flag_on( 161 FLAG_FLUSH_SUMMARY) 162 # Do not produce errors or warnings if Tensor Tracer is not enabled. 163 if self.is_enabled(): 164 self._check_flag_errors() 165 166 def _check_flag_errors(self): 167 if self.trace_mode in (TRACE_MODE_SUMMARY, TRACE_MODE_FULL_TENSOR_SUMMARY): 168 if not self.trace_dir: 169 raise ValueError('trace_dir must be explicitly provided in ' 170 'TENSOR_TRACER_FLAGS when summary mode is used.') 171 172 def _get_report_filepath(self): 173 """Sets the path of the output report file.""" 174 175 found, report_file_path = self.get_flag_value(FLAG_NAME_REPORT_FILE) 176 if found and report_file_path and self.use_test_undeclared_outputs_dir(): 177 if os.path.isabs(report_file_path): 178 raise ValueError('If use_test_undeclared_outputs_dir is set,' 179 'report_file_path cannot be an absolute path (%s)' 180 %report_file_path) 181 outputs_dir = self._env.get(_TEST_UNDECLARED_OUTPUTS_DIR_ENV_VAR) 182 report_file_path = os.path.join(outputs_dir, report_file_path) 183 return report_file_path 184 185 def _get_op_range(self): 186 """Sets the index range of the Ops that we will consider tracing.""" 187 found, op_range = self.get_flag_value(FLAG_NAME_OP_RANGE) 188 if not found or not op_range: 189 op_range = (-1, -1) # this means including all ops. 190 return op_range 191 match = _OP_RANGE_PAT.match(op_range) 192 if not match: 193 op_range = (-1, -1) # this means including all ops. 194 return op_range 195 op_range = (int(match.group(1)), int(match.group(2))) 196 return op_range 197 198 def _get_trace_dir(self): 199 found, trace_dir = self.get_flag_value(FLAG_NAME_TRACE_DIR) 200 if found and trace_dir and self.use_test_undeclared_outputs_dir(): 201 raise ValueError( 202 'Cannot not use --%s and --%s at the same time' % 203 (FLAG_NAME_TRACE_DIR, FLAG_NAME_USE_TEST_UNDECLARED_OUTPUTS_DIR)) 204 if self.use_test_undeclared_outputs_dir(): 205 trace_dir = self._env.get(_TEST_UNDECLARED_OUTPUTS_DIR_ENV_VAR) 206 return trace_dir 207 208 def _get_trace_mode(self): 209 """Checks if the given trace mode is valid.""" 210 211 found, trace_mode = self.get_flag_value(FLAG_NAME_TRACE_MODE) 212 if not found or not trace_mode: 213 trace_mode = TRACE_MODE_NORM 214 valid_trace_modes = [ 215 TRACE_MODE_NAN_INF, TRACE_MODE_PART_TENSOR, TRACE_MODE_FULL_TENSOR, 216 TRACE_MODE_NORM, TRACE_MODE_MAX_ABS, 217 TRACE_MODE_SUMMARY, TRACE_MODE_FULL_TENSOR_SUMMARY 218 ] 219 if trace_mode not in valid_trace_modes: 220 raise ValueError('Invalid trace mode "%s" given to the Tensor_Tracer.' 221 'Valid trace modes are: %s'%(trace_mode, 222 valid_trace_modes)) 223 return trace_mode 224 225 def is_brief_mode(self): 226 return self.submode == _SUBMODE_BRIEF 227 228 def _get_submode(self): 229 """Checks if the given submode is valid.""" 230 231 found, submode = self.get_flag_value(FLAG_NAME_SUBMODE) 232 if not found or not submode: 233 submode = _SUBMODE_DETAILED 234 if not submode: 235 return 236 valid_submodes = [_SUBMODE_DETAILED, _SUBMODE_BRIEF] 237 if submode not in valid_submodes: 238 raise ValueError('Invalid submode "%s" given to the Tensor_Tracer.' 239 'Valid submodes are: %s'%(submode, 240 valid_submodes)) 241 return submode 242 243 @staticmethod 244 def match_next_flag(flags, pos): 245 """Returns the match for the next TensorTracer flag. 246 247 Args: 248 flags: a string that contains the flags. 249 pos: where in flags to start the search. 250 251 Returns: 252 A pair where the first element is the regular-expression 253 match found and the second element indicates if the match 254 has a value. 255 """ 256 257 match = _FLAG_DOUBLE_QUOTE_PAT.match(flags, pos) 258 if match: 259 return match, True 260 match = _FLAG_SINGLE_QUOTE_PAT.match(flags, pos) 261 if match: 262 return match, True 263 match = _FLAG_NO_QUOTE_PAT.match(flags, pos) 264 if match: 265 return match, True 266 match = _FLAG_NO_EQUAL_PAT.match(flags, pos) 267 if match: 268 # The flag is found but is not given a value. 269 return match, False 270 # The flag is not found. 271 return None, False 272 273 def _validate_flag_names(self): 274 """Validates if the TensorTrace flags passed are valid.""" 275 tensor_tracer_flags = self._env.get(FLAGS_ENV_VAR) 276 if not tensor_tracer_flags: 277 return 278 pos = 0 279 while True: 280 match, _ = TTParameters.match_next_flag(tensor_tracer_flags, pos) 281 if not match: 282 break 283 flag_name = match.group(1) 284 if flag_name not in VALID_FLAG_NAMES: 285 raise ValueError( 286 'The flag name "%s" passed via the environment variable "%s" ' 287 'is invalid. Valid flag names are:' 288 '\n%s' % (flag_name, FLAGS_ENV_VAR, VALID_FLAG_NAMES)) 289 pos = match.end() 290 291 def _supported_signatures(self): 292 """Returns a tuple of supported signatures.""" 293 return TT_SUMMARY_SIGNATURES 294 295 def _get_summary_signatures(self): 296 """Verifies and returns the summary signatures. 297 298 Returns: 299 A dictionary of the signature identifiers {signature: index} that will be 300 computed when trace_mode is summary. 301 """ 302 signatures = self._flag_value_as_list(FLAG_NAME_SUMMARY_SIGNATURES) 303 supported_signatures = self._supported_signatures() 304 305 tt_signatures = [] 306 for signature in signatures: 307 signature_with_prefix = '%s_%s' % (_TT_PREFIX, signature) 308 if signature in supported_signatures: 309 tt_signatures.append(signature) 310 elif signature_with_prefix in supported_signatures: 311 tt_signatures.append(signature_with_prefix) 312 else: 313 logging.warning('Unknown signature:%s. Supported signatures: %s' % 314 (signature, supported_signatures)) 315 if not tt_signatures: 316 # Default case collects norm and max only. 317 return {TT_SUMMARY_MAX_ABS: 0, TT_SUMMARY_NORM: 1} 318 else: 319 return {signature: idx for idx, signature in enumerate(tt_signatures)} 320 321 def get_signature_to_agg_fn_map(self): 322 """Returns a map that contains the aggregate function for each signature.""" 323 # TODO(b/199284834): Aggregations are not accurate for mean and sparsity if 324 # cores have a different number of elements. Variance uses the maximal core 325 # variance. 326 return {TRACE_MODE_NORM: linalg_ops.norm, 327 TRACE_MODE_MAX_ABS: math_ops.reduce_max, 328 TRACE_MODE_NAN_INF: math_ops.reduce_max, 329 TT_SUMMARY_NORM: linalg_ops.norm, 330 TT_SUMMARY_MAX: math_ops.reduce_max, 331 TT_SUMMARY_MAX_ABS: 332 lambda t, axis=0: math_ops.reduce_max(math_ops.abs(t), # pylint: disable=g-long-lambda 333 axis=axis), 334 TT_SUMMARY_MIN: math_ops.reduce_min, 335 # Exact if each part has the same number of values. 336 TT_SUMMARY_SPARSITY: math_ops.reduce_mean, 337 TT_SUMMARY_MEAN: math_ops.reduce_mean, 338 TT_SUMMARY_VAR: math_ops.reduce_max, # Simply reduce max variance. 339 TT_SUMMARY_SIZE: math_ops.reduce_sum} 340 341 def _flag_value_as_list(self, wanted_flag_name): 342 """Returns the string list of a TensorTracer flag. 343 344 Args: 345 wanted_flag_name: the name of the flag we are looking for. 346 347 Returns: 348 The list value of the flag. 349 """ 350 string_value_list = [] 351 found, flag_value = self.get_flag_value(wanted_flag_name) 352 353 if found: 354 string_value_list = flag_value.split(',') 355 return string_value_list 356 357 def _flag_value_as_int_list(self, wanted_flag_name): 358 """Returns the integer list of a TensorTracer flag. 359 360 Args: 361 wanted_flag_name: the name of the flag we are looking for. 362 363 Returns: 364 the value of the flag. 365 Raises: 366 RuntimeError: If supposedly deadcode is reached. 367 """ 368 int_list = [] 369 found, flag_value = self.get_flag_value(wanted_flag_name) 370 371 if found and flag_value: 372 try: 373 integer_values = flag_value.split(',') 374 int_list = [int(int_val) for int_val in integer_values] 375 except ValueError: 376 logging.warning('Cannot convert %s to int for flag %s', int_list, 377 wanted_flag_name) 378 return int_list 379 380 def _get_flag_int_value(self, wanted_flag_name, default_value): 381 """Returns the int value of a TensorTracer flag. 382 383 Args: 384 wanted_flag_name: the name of the flag we are looking for. 385 default_value: the default value for the flag, if not provided. 386 Returns: 387 the value of the flag. 388 Raises: 389 RuntimeError: If supposedly deadcode is reached. 390 """ 391 flag_int_value = default_value 392 found, flag_value = self.get_flag_value(wanted_flag_name) 393 394 if found: 395 try: 396 flag_int_value = int(flag_value) 397 except ValueError: 398 logging.warning('Cannot convert %s to int for flag %s' % ( 399 flag_int_value, wanted_flag_name)) 400 return flag_int_value 401 402 def get_flag_value(self, wanted_flag_name): 403 """Returns the value of a TensorTracer flags. 404 405 Args: 406 wanted_flag_name: the name of the flag we are looking for. 407 408 Returns: 409 A pair where the first element indicates if the flag is 410 found and the second element is the value of the flag. 411 412 Raises: 413 RuntimeError: If supposedly deadcode is reached. 414 """ 415 416 tensor_tracer_flags = self._env.get(FLAGS_ENV_VAR) 417 if not tensor_tracer_flags: 418 return False, None 419 pos = 0 420 while True: 421 match, has_value = TTParameters.match_next_flag( 422 tensor_tracer_flags, pos) 423 if not match: 424 return False, None 425 flag_name = match.group(1) 426 if has_value: 427 flag_value = match.group(2) 428 else: 429 flag_value = None 430 if flag_name == wanted_flag_name: 431 return True, flag_value 432 pos = match.end() 433 raise RuntimeError('Invalid tensor tracer flag. Could not recognize %s.' % 434 flag_name) 435 436 def _flag_value_to_re_list(self, flag_name): 437 """Converts list of strings to compiled RE.""" 438 439 re_list = [] 440 found, flag_value = self.get_flag_value(flag_name) 441 if not found or not flag_value: 442 return re_list 443 list_of_values = flag_value.split(',') 444 for v in list_of_values: 445 r = re.compile(v) 446 re_list.append(r) 447 return re_list 448 449 def is_flag_on(self, flag_name): 450 """Returns True if the given flag is on.""" 451 452 found, flag_value = self.get_flag_value(flag_name) 453 if not found: 454 return False 455 if flag_value is None: 456 return True 457 # Depends on the flag value. 458 flag_value = flag_value.lower() 459 enabled = flag_value in ['1', 't', 'true', 'y', 'yes'] 460 return enabled 461 462 def is_enabled(self): 463 """Returns True if TensorTracer is enabled.""" 464 465 if self.is_flag_on(FLAG_NAME_ENABLE): 466 logging.debug('Tensor Tracer is enabled with flags %s.', 467 self._env.get(FLAGS_ENV_VAR)) 468 return True 469 else: 470 return False 471 472 def use_test_undeclared_outputs_dir(self): 473 """Decides the output directory of the report and trace files. 474 475 Args: 476 None. 477 478 Returns: 479 True if the output files should be written to the 480 test-undeclared-outputs-directory defined via an 481 env variable. 482 """ 483 484 return self.is_flag_on(FLAG_NAME_USE_TEST_UNDECLARED_OUTPUTS_DIR) 485