1# mypy: allow-untyped-defs 2import argparse 3import copy 4import functools 5import logging 6import os 7import shutil 8import sys 9import textwrap 10from importlib import import_module 11from typing import Union 12 13import torch 14import torch.fx as fx 15from torch._dynamo.debug_utils import ( 16 AccuracyError, 17 backend_accuracy_fails, 18 BUCK_CMD_PREFIX, 19 BuckTargetWriter, 20 extra_imports, 21 generate_config_string, 22 helper_for_dump_minify, 23 InputReader, 24 InputWriter, 25 minifier_dir, 26 NNModuleToString, 27 NopInputReader, 28 run_fwd_maybe_bwd, 29 same_two_models, 30) 31from torch.fx.experimental.symbolic_shapes import fx_placeholder_targets 32from torch.hub import tqdm 33 34from .. import config 35from ..backends.registry import lookup_backend, register_debug_backend 36from ..debug_utils import clone_inputs_retaining_gradness 37 38 39log = logging.getLogger(__name__) 40 41 42inductor_config = import_module("torch._inductor.config") 43use_buck = inductor_config.is_fbcode() 44 45# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 46# MAIN ENTRY POINT 47# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 48 49 50def _accuracy_fails(gm, example_inputs, compiler_fn): 51 return backend_accuracy_fails( 52 gm, 53 example_inputs, 54 compiler_fn, 55 only_fwd=config.repro_forward_only, 56 ignore_non_fp=config.repro_ignore_non_fp, 57 ) 58 59 60class WrapBackendDebug: 61 def __init__(self, unconfigured_compiler_fn, compiler_name: str) -> None: 62 functools.wraps(unconfigured_compiler_fn)(self) 63 self._torchdynamo_orig_callable = unconfigured_compiler_fn # type: ignore[attr-defined] 64 self._compiler_name = compiler_name 65 if hasattr(unconfigured_compiler_fn, "__name__"): 66 self.__name__ = unconfigured_compiler_fn.__name__ 67 if hasattr(unconfigured_compiler_fn, "compiler_name"): 68 self.__name__ = unconfigured_compiler_fn.compiler_name 69 if hasattr(unconfigured_compiler_fn, "get_compiler_config"): 70 self.get_compiler_config = unconfigured_compiler_fn.get_compiler_config # type: ignore[attr-defined] 71 72 def __call__(self, gm, example_inputs, **kwargs): 73 compiler_fn = functools.partial(self._torchdynamo_orig_callable, **kwargs) 74 assert config.repro_after in ("dynamo", "aot", None) 75 76 if config.repro_after == "dynamo": 77 78 def add_paths(exc): 79 exc.minifier_path = os.path.join(minifier_dir(), "minifier_launcher.py") 80 if use_buck: 81 exc.buck_command = " ".join( 82 BUCK_CMD_PREFIX 83 + [BuckTargetWriter(exc.minifier_path).cmd_line_path] 84 ) 85 86 if config.repro_level == 3: 87 dump_to_minify_after_dynamo(gm, example_inputs, self._compiler_name) 88 89 # Check for either accuracy (level 4) or other type of failures. 90 if config.repro_level == 4: 91 # Check Accuracy 92 compiled_gm = compiler_fn(copy.deepcopy(gm), example_inputs) 93 if _accuracy_fails(gm, example_inputs, compiler_fn): 94 log.warning( 95 "Accuracy failed for the TorchDynamo produced graph. Creating script to minify the error." 96 ) 97 dump_to_minify_after_dynamo( 98 fx.GraphModule(gm, copy.deepcopy(gm.graph)), 99 example_inputs, 100 self._compiler_name, 101 ) 102 exc = AccuracyError("Bad accuracy detected.") 103 add_paths(exc) 104 raise exc 105 else: 106 try: 107 compiled_gm = compiler_fn(copy.deepcopy(gm), example_inputs) 108 run_fwd_maybe_bwd(compiled_gm, example_inputs) 109 except Exception as exc: 110 log.warning( 111 "Compiled Fx GraphModule failed. Creating script to minify the error." 112 ) 113 if config.repro_level == 1: 114 dump_state_fn = functools.partial( 115 dump_backend_state, compiler_name=self._compiler_name 116 ) 117 dump_state_fn( 118 fx.GraphModule(gm, copy.deepcopy(gm.graph)), example_inputs 119 ) 120 elif config.repro_level == 2: 121 dump_to_minify_after_dynamo( 122 fx.GraphModule(gm, copy.deepcopy(gm.graph)), 123 example_inputs, 124 self._compiler_name, 125 ) 126 add_paths(exc) 127 raise 128 else: 129 compiled_gm = compiler_fn(gm, example_inputs) 130 131 return compiled_gm 132 133 134def wrap_backend_debug(unconfigured_compiler_fn, compiler_name: str): 135 """ 136 A minifier decorator that wraps the TorchDynamo produced Fx graph modules. 137 As opposed to wrap_compiler_debug, this wrapper intercepts at the 138 TorchDynamo produced Fx Graph Module. This makes it backend-agnostic to some 139 level, e.g., it is useful for minifying issues related to Aot Autograd 140 tracing. If an error is found, we minify and save the minified repro in 141 repro.tar.gz. 142 """ 143 return WrapBackendDebug(unconfigured_compiler_fn, compiler_name) 144 145 146# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 147# REPRO DUMPERS 148# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 149 150 151def generate_dynamo_fx_repro_string( 152 gm, 153 args, 154 compiler_name, 155 check_accuracy=False, 156 *, 157 stable_output=False, 158 save_dir=None, 159 command="run", 160): 161 """ 162 Generate a repro string for backend-agnostic minified version. 163 """ 164 165 model_str = NNModuleToString.convert(gm) 166 167 # TODO: Figure out why torch.compile'd hash isn't work on this codepath 168 writer = InputWriter(save_dir, stable_hash=True) 169 for placeholder, arg in zip(fx_placeholder_targets(gm), args): 170 if isinstance(arg, (int, torch.SymInt)): 171 writer.symint(placeholder, arg) 172 elif isinstance(arg, torch.Tensor): 173 # TODO: improve these names with FQN 174 writer.tensor(placeholder, arg) 175 else: 176 raise TypeError(f"arg is neither SymInt/int nor torch.Tensor, {arg}") 177 load_args = "\n".join(writer.lines()) 178 179 return textwrap.dedent( 180 f""" 181from math import inf 182import torch 183from torch import tensor, device 184import torch.fx as fx 185import torch._dynamo 186from torch._dynamo.testing import rand_strided 187from torch._dynamo.debug_utils import run_fwd_maybe_bwd 188 189{generate_config_string(stable_output=stable_output)} 190 191{extra_imports} 192 193{model_str} 194mod = Repro() 195 196{load_args} 197 198if __name__ == '__main__': 199 from torch._dynamo.repro.after_dynamo import run_repro 200 run_repro(mod, load_args, accuracy={check_accuracy!r}, command={command!r}, 201 save_dir={save_dir!r}, autocast={torch.is_autocast_enabled()!r}, backend={compiler_name!r}) 202""" 203 ) 204 205 206def dump_backend_repro_as_file(gm, args, compiler_name, check_accuracy=False): 207 """ 208 Saves the repro to a repro.py file 209 """ 210 curdir = os.getcwd() 211 subdir = os.path.join(os.getcwd(), "checkpoints") 212 if not os.path.exists(subdir): 213 os.makedirs(subdir, exist_ok=True) 214 file_name = os.path.join(subdir, f"minified_{len(gm.graph.nodes)}_nodes.py") 215 log.warning( 216 "Writing checkpoint with %s nodes to %s", len(gm.graph.nodes), file_name 217 ) 218 219 with open(file_name, "w") as fd: 220 fd.write( 221 generate_dynamo_fx_repro_string( 222 gm, args, compiler_name, check_accuracy, save_dir=subdir 223 ) 224 ) 225 latest_repro = os.path.join(curdir, "repro.py") 226 log.warning("Copying %s to %s for convenience", file_name, latest_repro) 227 228 if use_buck: 229 BuckTargetWriter(latest_repro).write() 230 231 shutil.copyfile(file_name, latest_repro) 232 233 234def dump_backend_state(gm, args, compiler_name, check_accuracy=False): 235 """ 236 Dumps the dynamo graph to repro the issue. 237 1) It tries to convert Fx GraphModule to a string. If we can, it writes to a 238 repro.py file. 239 2) If we can't convert Fx GraphModule to a string, we use to_folder to save 240 the module and save a tar file. 241 """ 242 assert NNModuleToString.can_convert_to_string(gm) 243 return dump_backend_repro_as_file(gm, args, compiler_name, check_accuracy) 244 # return dump_backend_repro_as_tarfile(gm, args, compiler_name) 245 246 247# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 248# MINIFIER DUMPER 249# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 250 251 252def dump_to_minify_after_dynamo(gm, args, compiler_name): 253 # TODO: factor this out 254 subdir = os.path.join(minifier_dir(), "checkpoints") 255 if not os.path.exists(subdir): 256 os.makedirs(subdir, exist_ok=True) 257 helper_for_dump_minify( 258 generate_dynamo_fx_repro_string( 259 gm, 260 args, 261 compiler_name, 262 check_accuracy=config.repro_level == 4, 263 save_dir=subdir, 264 command="minify", 265 ) 266 ) 267 268 269# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 270# MINIFIER BACKENDS 271# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 272 273 274@register_debug_backend 275def dynamo_minifier_backend(gm, example_inputs, compiler_name): 276 from functorch.compile import minifier 277 278 compiler_fn = lookup_backend(compiler_name) 279 280 # TODO: It's inconsistent to pass SymInt inputs but REAL tensors. 281 # We should pass ints and look at the GraphModule placeholders 282 # to resolve them to SymInt (if necessary) 283 example_inputs = [ 284 i.node.hint if isinstance(i, torch.SymInt) else i for i in example_inputs 285 ] 286 287 try: 288 compiled_gm = compiler_fn(gm, example_inputs) 289 run_fwd_maybe_bwd(compiled_gm, example_inputs) 290 raise ValueError("No issue was detected") 291 except Exception as exc: 292 orig_failure = str(exc) 293 log.warning( 294 "Compiled Fx GraphModule failed. Creating script to minify the error." 295 ) 296 dump_state_fn = functools.partial( 297 dump_backend_state, compiler_name=compiler_name 298 ) 299 dump_state_fn(fx.GraphModule(gm, copy.deepcopy(gm.graph)), example_inputs) 300 fails_fn = functools.partial( 301 backend_fails, 302 compiler_fn=compiler_fn, 303 orig_failure=orig_failure, 304 ) 305 minifier( 306 gm, 307 example_inputs, 308 module_fails=fails_fn, 309 dump_state=dump_state_fn, 310 ) 311 return gm 312 313 314@register_debug_backend 315def dynamo_accuracy_minifier_backend(gm, example_inputs, compiler_name): 316 from functorch.compile import minifier 317 318 compiler_fn = lookup_backend(compiler_name) 319 320 # Set the eval mode to remove randomness. 321 gm.eval() 322 323 # Check Accuracy 324 if _accuracy_fails(gm, example_inputs, compiler_fn): 325 log.warning("Accuracy failed for the TorchDynamo produced graph") 326 dump_state_fn = functools.partial( 327 dump_backend_state, compiler_name=compiler_name, check_accuracy=True 328 ) 329 fails_fn = functools.partial( 330 _accuracy_fails, 331 compiler_fn=compiler_fn, 332 ) 333 dump_state_fn(fx.GraphModule(gm, copy.deepcopy(gm.graph)), example_inputs) 334 minifier( 335 gm, 336 example_inputs, 337 module_fails=fails_fn, 338 dump_state=dump_state_fn, 339 ) 340 else: 341 log.error("Input graph does not fail accuracy testing") 342 return gm 343 344 345def backend_fails(gm, example_inputs, compiler_fn, orig_failure): 346 """ 347 Minifier uses this function to identify if the minified graph module fails 348 with the same error. 349 350 One caveat is that minifier can potentially go into a wrong direction when 351 the resulting graph module fails for a different reason. To avoid this, we 352 save the string for the original exception and check similarity between new 353 and old exception. They can be somewhat different in some cases, when the 354 exception string depends on the failing node information. So, we have a 355 loose similarity metric to guide the minifier path. 356 """ 357 from difflib import SequenceMatcher 358 359 try: 360 # Run the original gm to check eager validity 361 run_fwd_maybe_bwd(gm, clone_inputs_retaining_gradness(example_inputs)) 362 compiled_gm = compiler_fn(gm, example_inputs) 363 run_fwd_maybe_bwd(compiled_gm, clone_inputs_retaining_gradness(example_inputs)) 364 except Exception as e: 365 new_failure = str(e) 366 if SequenceMatcher(None, orig_failure, new_failure).ratio() > 0.5: 367 return True 368 return False 369 370 371# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 372# REPRO MAIN 373# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 374 375 376def run_load_args(options, mod, load_args): 377 if not hasattr(load_args, "_version"): 378 log.warning( 379 "load_args does not have a _version attribute, please file a bug to PyTorch " 380 "and describe how you generate this repro script" 381 ) 382 else: 383 if load_args._version > 0: 384 log.warning( 385 "load_args is version %s, but this version of PyTorch only supports " 386 "version 0. We will try to run it anyway but there may be an incompatibility; " 387 "if so, try upgrading your version of PyTorch.", 388 load_args._version, 389 ) 390 391 nop_reader = NopInputReader() 392 load_args(nop_reader) 393 394 with tqdm(desc="Loading inputs", total=nop_reader.total) as pbar: 395 input_reader = InputReader(save_dir=options.save_dir, pbar=pbar) 396 load_args(input_reader) 397 args = input_reader.args 398 399 return args 400 401 402def repro_minify(options, mod, load_args): 403 args = run_load_args(options, mod, load_args) 404 405 # Setup debug minifier compiler 406 if not options.accuracy: 407 compiler_fn = lookup_backend("dynamo_minifier_backend") 408 else: 409 compiler_fn = lookup_backend("dynamo_accuracy_minifier_backend") 410 411 if options.backend is None: 412 raise RuntimeError( 413 "Compiler name is None - this likely means that a custom compiler " 414 "was called by torchdynamo. Please remove this error, import your " 415 "custom compiler function, and replace the backend=None " 416 "line in run_repro to backend=<my_imported_custom_function>" 417 ) 418 419 dynamo_minifier_backend = functools.partial( 420 compiler_fn, 421 compiler_name=options.backend, 422 ) 423 opt_mod = torch._dynamo.optimize(dynamo_minifier_backend)(mod) 424 425 with torch.amp.autocast("cuda", enabled=options.autocast): 426 opt_mod(*args) 427 428 429def repro_run(options, mod, load_args): 430 opt_mod = torch._dynamo.optimize(options.backend)(mod) 431 432 if options.accuracy != "": 433 mod.eval() 434 opt_mod.eval() 435 436 with torch.amp.autocast("cuda", enabled=options.autocast): 437 # TODO: disable clone 438 args = run_load_args(options, mod, load_args) 439 assert same_two_models(mod, mod, args), "Eager itself failed" 440 if not same_two_models( 441 mod, 442 opt_mod, 443 args, 444 only_fwd=config.repro_forward_only, 445 ignore_non_fp=config.repro_ignore_non_fp, 446 ): 447 raise AccuracyError("Dynamo failed") 448 else: 449 with torch.amp.autocast("cuda", enabled=options.autocast): 450 args = run_load_args(options, mod, load_args) 451 ref = run_fwd_maybe_bwd( 452 mod, args, only_fwd=options.only_fwd, disable_clone=True 453 ) 454 del args 455 456 args = run_load_args(options, mod, load_args) 457 res = run_fwd_maybe_bwd( 458 opt_mod, args, only_fwd=options.only_fwd, disable_clone=True 459 ) 460 461 462def run_repro( 463 mod, 464 load_args, 465 *, 466 command="run", 467 accuracy: Union[bool, str] = "", 468 save_dir=None, 469 autocast=False, 470 backend="inductor", 471 **kwargs, 472): 473 for k in kwargs: 474 log.warning( 475 "Unrecognized kwarg %s; perhaps this repro was made on a newer version of PyTorch", 476 k, 477 ) 478 479 if accuracy is True: 480 accuracy = "accuracy" 481 elif accuracy is False: 482 accuracy = "" 483 484 parser = argparse.ArgumentParser( 485 description=f"""\ 486An after_dynamo repro script, typically triggering a bug in Dynamo or 487AOTAutograd. When run with no arguments, this script defaults to running 488'{command}'. Extra flags may be available; to find out more, try '{command} 489--help'. There are also alternate subcommands available, see below. 490 491default settings on this script: 492 {accuracy=} 493 {save_dir=} 494""", 495 formatter_class=argparse.RawTextHelpFormatter, 496 ) 497 498 def common_flags(parser): 499 accuracy_group = parser.add_mutually_exclusive_group() 500 accuracy_group.add_argument( 501 "--no-accuracy", 502 dest="accuracy", 503 action="store_const", 504 const="", 505 default=accuracy, 506 help="do not test accuracy, just run the module and see if it errors", 507 ) 508 accuracy_group.add_argument( 509 "--accuracy", 510 action="store_const", 511 const="accuracy", 512 default=accuracy, 513 help="test accuracy", 514 ) 515 parser.add_argument( 516 "--save-dir", 517 type=str, 518 default=save_dir, 519 metavar="DIR", 520 help="directory where saved inputs live", 521 ) 522 parser.add_argument( 523 "--no-save-dir", 524 dest="save_dir", 525 action="store_const", 526 const=None, 527 help="don't use any directory for saved inputs", 528 ) 529 parser.add_argument( 530 "--no-isolate", 531 dest="isolate", 532 action="store_false", 533 default=False, 534 help="no isolate (doesn't do anything for after_dynamo)", 535 ) 536 parser.add_argument( 537 "--autocast", 538 default=autocast, 539 action="store_true", 540 help="use torch.cuda.amp.autocast", 541 ) 542 parser.add_argument( 543 "--no-autocast", 544 dest="autocast", 545 action="store_false", 546 help="don't use torch.cuda.amp.autocast", 547 ) 548 parser.add_argument( 549 "--backend", 550 type=str, 551 default=backend, 552 metavar="BACKEND", 553 help="torch.compile backend to use", 554 ) 555 556 subparsers = parser.add_subparsers( 557 dest="command", metavar="{run,minify}", required=True 558 ) 559 560 parser_run = subparsers.add_parser( 561 "run", 562 help="just run the repro", 563 ) 564 common_flags(parser_run) 565 parser_run.add_argument( 566 "--only-fwd", 567 action="store_true", 568 help="don't run backwards compilation for testing", 569 ) 570 571 parser_minify = subparsers.add_parser( 572 "minify", help="run the minifier on the repro" 573 ) 574 common_flags(parser_minify) 575 576 args = None 577 if len(sys.argv) <= 1: 578 args = [command, *sys.argv[1:]] 579 580 options = parser.parse_args(args) 581 COMMAND_FNS = { 582 "minify": repro_minify, 583 "run": repro_run, 584 } 585 COMMAND_FNS[options.command](options, mod, load_args) 586