xref: /aosp_15_r20/external/pytorch/torch/_dynamo/repro/after_dynamo.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import argparse
3import copy
4import functools
5import logging
6import os
7import shutil
8import sys
9import textwrap
10from importlib import import_module
11from typing import Union
12
13import torch
14import torch.fx as fx
15from torch._dynamo.debug_utils import (
16    AccuracyError,
17    backend_accuracy_fails,
18    BUCK_CMD_PREFIX,
19    BuckTargetWriter,
20    extra_imports,
21    generate_config_string,
22    helper_for_dump_minify,
23    InputReader,
24    InputWriter,
25    minifier_dir,
26    NNModuleToString,
27    NopInputReader,
28    run_fwd_maybe_bwd,
29    same_two_models,
30)
31from torch.fx.experimental.symbolic_shapes import fx_placeholder_targets
32from torch.hub import tqdm
33
34from .. import config
35from ..backends.registry import lookup_backend, register_debug_backend
36from ..debug_utils import clone_inputs_retaining_gradness
37
38
39log = logging.getLogger(__name__)
40
41
42inductor_config = import_module("torch._inductor.config")
43use_buck = inductor_config.is_fbcode()
44
45# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
46#                           MAIN ENTRY POINT
47# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
48
49
50def _accuracy_fails(gm, example_inputs, compiler_fn):
51    return backend_accuracy_fails(
52        gm,
53        example_inputs,
54        compiler_fn,
55        only_fwd=config.repro_forward_only,
56        ignore_non_fp=config.repro_ignore_non_fp,
57    )
58
59
60class WrapBackendDebug:
61    def __init__(self, unconfigured_compiler_fn, compiler_name: str) -> None:
62        functools.wraps(unconfigured_compiler_fn)(self)
63        self._torchdynamo_orig_callable = unconfigured_compiler_fn  # type: ignore[attr-defined]
64        self._compiler_name = compiler_name
65        if hasattr(unconfigured_compiler_fn, "__name__"):
66            self.__name__ = unconfigured_compiler_fn.__name__
67        if hasattr(unconfigured_compiler_fn, "compiler_name"):
68            self.__name__ = unconfigured_compiler_fn.compiler_name
69        if hasattr(unconfigured_compiler_fn, "get_compiler_config"):
70            self.get_compiler_config = unconfigured_compiler_fn.get_compiler_config  # type: ignore[attr-defined]
71
72    def __call__(self, gm, example_inputs, **kwargs):
73        compiler_fn = functools.partial(self._torchdynamo_orig_callable, **kwargs)
74        assert config.repro_after in ("dynamo", "aot", None)
75
76        if config.repro_after == "dynamo":
77
78            def add_paths(exc):
79                exc.minifier_path = os.path.join(minifier_dir(), "minifier_launcher.py")
80                if use_buck:
81                    exc.buck_command = " ".join(
82                        BUCK_CMD_PREFIX
83                        + [BuckTargetWriter(exc.minifier_path).cmd_line_path]
84                    )
85
86            if config.repro_level == 3:
87                dump_to_minify_after_dynamo(gm, example_inputs, self._compiler_name)
88
89            # Check for either accuracy (level 4) or other type of failures.
90            if config.repro_level == 4:
91                # Check Accuracy
92                compiled_gm = compiler_fn(copy.deepcopy(gm), example_inputs)
93                if _accuracy_fails(gm, example_inputs, compiler_fn):
94                    log.warning(
95                        "Accuracy failed for the TorchDynamo produced graph. Creating script to minify the error."
96                    )
97                    dump_to_minify_after_dynamo(
98                        fx.GraphModule(gm, copy.deepcopy(gm.graph)),
99                        example_inputs,
100                        self._compiler_name,
101                    )
102                    exc = AccuracyError("Bad accuracy detected.")
103                    add_paths(exc)
104                    raise exc
105            else:
106                try:
107                    compiled_gm = compiler_fn(copy.deepcopy(gm), example_inputs)
108                    run_fwd_maybe_bwd(compiled_gm, example_inputs)
109                except Exception as exc:
110                    log.warning(
111                        "Compiled Fx GraphModule failed. Creating script to minify the error."
112                    )
113                    if config.repro_level == 1:
114                        dump_state_fn = functools.partial(
115                            dump_backend_state, compiler_name=self._compiler_name
116                        )
117                        dump_state_fn(
118                            fx.GraphModule(gm, copy.deepcopy(gm.graph)), example_inputs
119                        )
120                    elif config.repro_level == 2:
121                        dump_to_minify_after_dynamo(
122                            fx.GraphModule(gm, copy.deepcopy(gm.graph)),
123                            example_inputs,
124                            self._compiler_name,
125                        )
126                    add_paths(exc)
127                    raise
128        else:
129            compiled_gm = compiler_fn(gm, example_inputs)
130
131        return compiled_gm
132
133
134def wrap_backend_debug(unconfigured_compiler_fn, compiler_name: str):
135    """
136    A minifier decorator that wraps the TorchDynamo produced Fx graph modules.
137    As opposed to wrap_compiler_debug, this wrapper intercepts at the
138    TorchDynamo produced Fx Graph Module. This makes it backend-agnostic to some
139    level, e.g., it is useful for minifying issues related to Aot Autograd
140    tracing.  If an error is found, we minify and save the minified repro in
141    repro.tar.gz.
142    """
143    return WrapBackendDebug(unconfigured_compiler_fn, compiler_name)
144
145
146# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
147#                           REPRO DUMPERS
148# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
149
150
151def generate_dynamo_fx_repro_string(
152    gm,
153    args,
154    compiler_name,
155    check_accuracy=False,
156    *,
157    stable_output=False,
158    save_dir=None,
159    command="run",
160):
161    """
162    Generate a repro string for backend-agnostic minified version.
163    """
164
165    model_str = NNModuleToString.convert(gm)
166
167    # TODO: Figure out why torch.compile'd hash isn't work on this codepath
168    writer = InputWriter(save_dir, stable_hash=True)
169    for placeholder, arg in zip(fx_placeholder_targets(gm), args):
170        if isinstance(arg, (int, torch.SymInt)):
171            writer.symint(placeholder, arg)
172        elif isinstance(arg, torch.Tensor):
173            # TODO: improve these names with FQN
174            writer.tensor(placeholder, arg)
175        else:
176            raise TypeError(f"arg is neither SymInt/int nor torch.Tensor, {arg}")
177    load_args = "\n".join(writer.lines())
178
179    return textwrap.dedent(
180        f"""
181from math import inf
182import torch
183from torch import tensor, device
184import torch.fx as fx
185import torch._dynamo
186from torch._dynamo.testing import rand_strided
187from torch._dynamo.debug_utils import run_fwd_maybe_bwd
188
189{generate_config_string(stable_output=stable_output)}
190
191{extra_imports}
192
193{model_str}
194mod = Repro()
195
196{load_args}
197
198if __name__ == '__main__':
199    from torch._dynamo.repro.after_dynamo import run_repro
200    run_repro(mod, load_args, accuracy={check_accuracy!r}, command={command!r},
201        save_dir={save_dir!r}, autocast={torch.is_autocast_enabled()!r}, backend={compiler_name!r})
202"""
203    )
204
205
206def dump_backend_repro_as_file(gm, args, compiler_name, check_accuracy=False):
207    """
208    Saves the repro to a repro.py file
209    """
210    curdir = os.getcwd()
211    subdir = os.path.join(os.getcwd(), "checkpoints")
212    if not os.path.exists(subdir):
213        os.makedirs(subdir, exist_ok=True)
214    file_name = os.path.join(subdir, f"minified_{len(gm.graph.nodes)}_nodes.py")
215    log.warning(
216        "Writing checkpoint with %s nodes to %s", len(gm.graph.nodes), file_name
217    )
218
219    with open(file_name, "w") as fd:
220        fd.write(
221            generate_dynamo_fx_repro_string(
222                gm, args, compiler_name, check_accuracy, save_dir=subdir
223            )
224        )
225    latest_repro = os.path.join(curdir, "repro.py")
226    log.warning("Copying %s to %s for convenience", file_name, latest_repro)
227
228    if use_buck:
229        BuckTargetWriter(latest_repro).write()
230
231    shutil.copyfile(file_name, latest_repro)
232
233
234def dump_backend_state(gm, args, compiler_name, check_accuracy=False):
235    """
236    Dumps the dynamo graph to repro the issue.
237    1) It tries to convert Fx GraphModule to a string. If we can, it writes to a
238    repro.py file.
239    2) If we can't convert Fx GraphModule to a string, we use to_folder to save
240    the module and save a tar file.
241    """
242    assert NNModuleToString.can_convert_to_string(gm)
243    return dump_backend_repro_as_file(gm, args, compiler_name, check_accuracy)
244    # return dump_backend_repro_as_tarfile(gm, args, compiler_name)
245
246
247# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
248#                       MINIFIER DUMPER
249# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
250
251
252def dump_to_minify_after_dynamo(gm, args, compiler_name):
253    # TODO: factor this out
254    subdir = os.path.join(minifier_dir(), "checkpoints")
255    if not os.path.exists(subdir):
256        os.makedirs(subdir, exist_ok=True)
257    helper_for_dump_minify(
258        generate_dynamo_fx_repro_string(
259            gm,
260            args,
261            compiler_name,
262            check_accuracy=config.repro_level == 4,
263            save_dir=subdir,
264            command="minify",
265        )
266    )
267
268
269# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
270#                       MINIFIER BACKENDS
271# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
272
273
274@register_debug_backend
275def dynamo_minifier_backend(gm, example_inputs, compiler_name):
276    from functorch.compile import minifier
277
278    compiler_fn = lookup_backend(compiler_name)
279
280    # TODO: It's inconsistent to pass SymInt inputs but REAL tensors.
281    # We should pass ints and look at the GraphModule placeholders
282    # to resolve them to SymInt (if necessary)
283    example_inputs = [
284        i.node.hint if isinstance(i, torch.SymInt) else i for i in example_inputs
285    ]
286
287    try:
288        compiled_gm = compiler_fn(gm, example_inputs)
289        run_fwd_maybe_bwd(compiled_gm, example_inputs)
290        raise ValueError("No issue was detected")
291    except Exception as exc:
292        orig_failure = str(exc)
293        log.warning(
294            "Compiled Fx GraphModule failed. Creating script to minify the error."
295        )
296        dump_state_fn = functools.partial(
297            dump_backend_state, compiler_name=compiler_name
298        )
299        dump_state_fn(fx.GraphModule(gm, copy.deepcopy(gm.graph)), example_inputs)
300        fails_fn = functools.partial(
301            backend_fails,
302            compiler_fn=compiler_fn,
303            orig_failure=orig_failure,
304        )
305        minifier(
306            gm,
307            example_inputs,
308            module_fails=fails_fn,
309            dump_state=dump_state_fn,
310        )
311    return gm
312
313
314@register_debug_backend
315def dynamo_accuracy_minifier_backend(gm, example_inputs, compiler_name):
316    from functorch.compile import minifier
317
318    compiler_fn = lookup_backend(compiler_name)
319
320    # Set the eval mode to remove randomness.
321    gm.eval()
322
323    # Check Accuracy
324    if _accuracy_fails(gm, example_inputs, compiler_fn):
325        log.warning("Accuracy failed for the TorchDynamo produced graph")
326        dump_state_fn = functools.partial(
327            dump_backend_state, compiler_name=compiler_name, check_accuracy=True
328        )
329        fails_fn = functools.partial(
330            _accuracy_fails,
331            compiler_fn=compiler_fn,
332        )
333        dump_state_fn(fx.GraphModule(gm, copy.deepcopy(gm.graph)), example_inputs)
334        minifier(
335            gm,
336            example_inputs,
337            module_fails=fails_fn,
338            dump_state=dump_state_fn,
339        )
340    else:
341        log.error("Input graph does not fail accuracy testing")
342    return gm
343
344
345def backend_fails(gm, example_inputs, compiler_fn, orig_failure):
346    """
347    Minifier uses this function to identify if the minified graph module fails
348    with the same error.
349
350    One caveat is that minifier can potentially go into a wrong direction when
351    the resulting graph module fails for a different reason. To avoid this, we
352    save the string for the original exception and check similarity between new
353    and old exception. They can be somewhat different in some cases, when the
354    exception string depends on the failing node information. So, we have a
355    loose similarity metric to guide the minifier path.
356    """
357    from difflib import SequenceMatcher
358
359    try:
360        # Run the original gm to check eager validity
361        run_fwd_maybe_bwd(gm, clone_inputs_retaining_gradness(example_inputs))
362        compiled_gm = compiler_fn(gm, example_inputs)
363        run_fwd_maybe_bwd(compiled_gm, clone_inputs_retaining_gradness(example_inputs))
364    except Exception as e:
365        new_failure = str(e)
366        if SequenceMatcher(None, orig_failure, new_failure).ratio() > 0.5:
367            return True
368    return False
369
370
371# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
372#                           REPRO MAIN
373# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
374
375
376def run_load_args(options, mod, load_args):
377    if not hasattr(load_args, "_version"):
378        log.warning(
379            "load_args does not have a _version attribute, please file a bug to PyTorch "
380            "and describe how you generate this repro script"
381        )
382    else:
383        if load_args._version > 0:
384            log.warning(
385                "load_args is version %s, but this version of PyTorch only supports "
386                "version 0.  We will try to run it anyway but there may be an incompatibility; "
387                "if so, try upgrading your version of PyTorch.",
388                load_args._version,
389            )
390
391    nop_reader = NopInputReader()
392    load_args(nop_reader)
393
394    with tqdm(desc="Loading inputs", total=nop_reader.total) as pbar:
395        input_reader = InputReader(save_dir=options.save_dir, pbar=pbar)
396        load_args(input_reader)
397        args = input_reader.args
398
399    return args
400
401
402def repro_minify(options, mod, load_args):
403    args = run_load_args(options, mod, load_args)
404
405    # Setup debug minifier compiler
406    if not options.accuracy:
407        compiler_fn = lookup_backend("dynamo_minifier_backend")
408    else:
409        compiler_fn = lookup_backend("dynamo_accuracy_minifier_backend")
410
411    if options.backend is None:
412        raise RuntimeError(
413            "Compiler name is None - this likely means that a custom compiler "
414            "was called by torchdynamo. Please remove this error, import your "
415            "custom compiler function, and replace the backend=None "
416            "line in run_repro to backend=<my_imported_custom_function>"
417        )
418
419    dynamo_minifier_backend = functools.partial(
420        compiler_fn,
421        compiler_name=options.backend,
422    )
423    opt_mod = torch._dynamo.optimize(dynamo_minifier_backend)(mod)
424
425    with torch.amp.autocast("cuda", enabled=options.autocast):
426        opt_mod(*args)
427
428
429def repro_run(options, mod, load_args):
430    opt_mod = torch._dynamo.optimize(options.backend)(mod)
431
432    if options.accuracy != "":
433        mod.eval()
434        opt_mod.eval()
435
436        with torch.amp.autocast("cuda", enabled=options.autocast):
437            # TODO: disable clone
438            args = run_load_args(options, mod, load_args)
439            assert same_two_models(mod, mod, args), "Eager itself failed"
440            if not same_two_models(
441                mod,
442                opt_mod,
443                args,
444                only_fwd=config.repro_forward_only,
445                ignore_non_fp=config.repro_ignore_non_fp,
446            ):
447                raise AccuracyError("Dynamo failed")
448    else:
449        with torch.amp.autocast("cuda", enabled=options.autocast):
450            args = run_load_args(options, mod, load_args)
451            ref = run_fwd_maybe_bwd(
452                mod, args, only_fwd=options.only_fwd, disable_clone=True
453            )
454            del args
455
456            args = run_load_args(options, mod, load_args)
457            res = run_fwd_maybe_bwd(
458                opt_mod, args, only_fwd=options.only_fwd, disable_clone=True
459            )
460
461
462def run_repro(
463    mod,
464    load_args,
465    *,
466    command="run",
467    accuracy: Union[bool, str] = "",
468    save_dir=None,
469    autocast=False,
470    backend="inductor",
471    **kwargs,
472):
473    for k in kwargs:
474        log.warning(
475            "Unrecognized kwarg %s; perhaps this repro was made on a newer version of PyTorch",
476            k,
477        )
478
479    if accuracy is True:
480        accuracy = "accuracy"
481    elif accuracy is False:
482        accuracy = ""
483
484    parser = argparse.ArgumentParser(
485        description=f"""\
486An after_dynamo repro script, typically triggering a bug in Dynamo or
487AOTAutograd.  When run with no arguments, this script defaults to running
488'{command}'.  Extra flags may be available; to find out more, try '{command}
489--help'.  There are also alternate subcommands available, see below.
490
491default settings on this script:
492  {accuracy=}
493  {save_dir=}
494""",
495        formatter_class=argparse.RawTextHelpFormatter,
496    )
497
498    def common_flags(parser):
499        accuracy_group = parser.add_mutually_exclusive_group()
500        accuracy_group.add_argument(
501            "--no-accuracy",
502            dest="accuracy",
503            action="store_const",
504            const="",
505            default=accuracy,
506            help="do not test accuracy, just run the module and see if it errors",
507        )
508        accuracy_group.add_argument(
509            "--accuracy",
510            action="store_const",
511            const="accuracy",
512            default=accuracy,
513            help="test accuracy",
514        )
515        parser.add_argument(
516            "--save-dir",
517            type=str,
518            default=save_dir,
519            metavar="DIR",
520            help="directory where saved inputs live",
521        )
522        parser.add_argument(
523            "--no-save-dir",
524            dest="save_dir",
525            action="store_const",
526            const=None,
527            help="don't use any directory for saved inputs",
528        )
529        parser.add_argument(
530            "--no-isolate",
531            dest="isolate",
532            action="store_false",
533            default=False,
534            help="no isolate (doesn't do anything for after_dynamo)",
535        )
536        parser.add_argument(
537            "--autocast",
538            default=autocast,
539            action="store_true",
540            help="use torch.cuda.amp.autocast",
541        )
542        parser.add_argument(
543            "--no-autocast",
544            dest="autocast",
545            action="store_false",
546            help="don't use torch.cuda.amp.autocast",
547        )
548        parser.add_argument(
549            "--backend",
550            type=str,
551            default=backend,
552            metavar="BACKEND",
553            help="torch.compile backend to use",
554        )
555
556    subparsers = parser.add_subparsers(
557        dest="command", metavar="{run,minify}", required=True
558    )
559
560    parser_run = subparsers.add_parser(
561        "run",
562        help="just run the repro",
563    )
564    common_flags(parser_run)
565    parser_run.add_argument(
566        "--only-fwd",
567        action="store_true",
568        help="don't run backwards compilation for testing",
569    )
570
571    parser_minify = subparsers.add_parser(
572        "minify", help="run the minifier on the repro"
573    )
574    common_flags(parser_minify)
575
576    args = None
577    if len(sys.argv) <= 1:
578        args = [command, *sys.argv[1:]]
579
580    options = parser.parse_args(args)
581    COMMAND_FNS = {
582        "minify": repro_minify,
583        "run": repro_run,
584    }
585    COMMAND_FNS[options.command](options, mod, load_args)
586