xref: /aosp_15_r20/external/pytorch/torch/_dynamo/debug_utils.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2# mypy: disable-error-code="method-assign"
3import atexit
4import copy
5import cProfile
6import functools
7import getpass
8import inspect
9import itertools
10import logging
11import os
12import re
13import subprocess
14import sys
15import tempfile
16import textwrap
17from collections import Counter
18from importlib import import_module
19from typing import Any, Callable, Dict, List, Optional, TypeVar
20
21import torch
22import torch._prims_common as utils
23import torch._subclasses.meta_utils
24from torch import Tensor
25from torch._dynamo.testing import rand_strided
26from torch._prims_common import is_float_dtype
27from torch.multiprocessing.reductions import StorageWeakRef
28from torch.utils._content_store import ContentStoreReader, ContentStoreWriter
29
30from . import config
31from .utils import clone_inputs, get_debug_dir
32
33
34log = logging.getLogger(__name__)
35
36T = TypeVar("T")
37
38
39inductor_config = import_module("torch._inductor.config")
40use_buck = inductor_config.is_fbcode()
41
42if use_buck:
43    import libfb.py.build_info
44
45
46extra_deps = []
47extra_imports = ""
48if use_buck:
49    extra_deps = [
50        "//caffe2/torch/fb/sparsenn:sparsenn_operators_gpu",
51        "//caffe2/torch/fb/sparsenn:sparsenn_operators",
52        "//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu",
53        "//deeplearning/fbgemm/fbgemm_gpu:sparse_ops",
54    ]
55    cur_target = libfb.py.build_info.BuildInfo.get_build_rule().replace("fbcode:", "//")  # type: ignore[possibly-undefined]
56    extra_imports = "\n".join([f'torch.ops.load_library("{x}")' for x in extra_deps])
57
58
59BUCK_CMD_PREFIX = ["buck2", "run", "@mode/dev-nosan"]
60
61
62class BuckTargetWriter:
63    def __init__(self, filename):
64        self.subdir, self.py_file = os.path.split(os.path.abspath(filename))
65        self.target = self.py_file.replace(".py", "")
66
67        # Get main_module path from fbcode
68        self.path = f'{self.subdir.replace("/", ".")}.{self.target}'
69        self.path = self.path[self.path.find("fbcode.") :]
70        self.path = self.path[7:]
71
72        # Get cmd line path
73        tmp = self.subdir
74        tmp = tmp[tmp.find("fbcode/") :][7:]
75        self.cmd_line_path = f"//{tmp}:{self.target}"
76
77    def build(self):
78        extra_cpp_deps = "\n".join([f'        "{x}",' for x in extra_deps])
79        return textwrap.dedent(
80            f"""
81load("@fbcode_macros//build_defs:python_binary.bzl", "python_binary")
82
83python_binary(
84    name="{self.target}",
85    srcs = ["{self.py_file}"],
86    compile = False,
87    deps = [
88        "//caffe2:torch",
89        "//caffe2/functorch:functorch",
90        "//triton:triton",
91        "{cur_target}",
92    ],
93    cpp_deps = [
94{extra_cpp_deps}
95    ],
96    main_module = "{self.path}",
97    par_style = "xar",
98)
99"""
100        )
101
102    def write(self, print_msg=True):
103        target_file = os.path.join(self.subdir, "TARGETS")
104        with open(target_file, "w") as fd:
105            fd.write(self.build())
106        # log.warning("Wrote isolation TARGETS file at %s", target_file)
107        cmd_split = BUCK_CMD_PREFIX + [self.cmd_line_path]
108        if print_msg:
109            log.warning(
110                "Found an example that reproduces the error. Run this cmd to repro - %s",
111                " ".join(cmd_split),
112            )
113        return cmd_split
114
115
116def minifier_dir():
117    path = os.path.join(get_debug_dir(), "minifier")
118    if path is None:
119        path = f"{tempfile.gettempdir()}/minifier_{getpass.getuser()}"
120    if not os.path.exists(path):
121        os.makedirs(path, exist_ok=True)
122    return path
123
124
125MAX_CONSTANT_NUMEL_INLINE = 4
126
127
128class NNModuleToString:
129    safe_reprs = [
130        torch.nn.Linear,
131        torch.nn.Conv1d,
132        torch.nn.Conv2d,
133        torch.nn.Conv3d,
134        torch.nn.BatchNorm1d,
135        torch.nn.BatchNorm2d,
136        torch.nn.BatchNorm3d,
137        torch.nn.LayerNorm,
138        torch.nn.Dropout,
139        torch.nn.Softmax,
140        torch.nn.ReLU,
141        torch.nn.GELU,
142        torch.nn.Identity,
143        torch.nn.MaxPool2d,
144        torch.nn.Embedding,
145        torch.nn.Tanh,
146        torch.nn.ConvTranspose1d,
147        torch.nn.GLU,
148        torch.nn.LSTM,
149        torch.nn.Flatten,
150        torch.nn.AdaptiveAvgPool2d,
151    ]
152
153    @staticmethod
154    def can_convert_to_string(gm):
155        cant_convert = set()
156        for _, module in gm.named_children():
157            if type(module) not in NNModuleToString.safe_reprs:
158                cant_convert.add(module)
159
160        if len(cant_convert) > 0:
161            log.warning("We have not tested reprs of some modules - %s", cant_convert)
162        # TODO - Assuming that all modules can be safely repr'd. Check if that assumption is correct.
163        return True
164
165    @staticmethod
166    def convert(gm):
167        from torch.nn.modules.module import _addindent
168
169        tab = " " * 4
170
171        model_str = textwrap.dedent(
172            """
173            from torch.nn import *
174            class Repro(torch.nn.Module):
175                def __init__(self) -> None:
176                    super().__init__()
177            """
178        )
179
180        for module_name, module in gm.named_children():
181            module_str = f"{module.__repr__()}"
182            # module should be a core torch.nn.Module, so all parameters
183            # should be on the same device.
184            example_param = next(module.parameters(), None)
185            if example_param is not None and example_param.is_cuda:
186                module_str = f"{module_str}.cuda()"
187            model_str += f"{tab*2}self.{module_name} = {module_str}\n"
188
189        for buffer_name, buffer in gm._buffers.items():
190            if buffer is None:
191                continue
192            # Serialize full data for small buffers
193            if buffer.numel() <= MAX_CONSTANT_NUMEL_INLINE:
194                from torch._tensor_str import PRINT_OPTS
195
196                assert PRINT_OPTS.threshold >= MAX_CONSTANT_NUMEL_INLINE
197                tensor_str = repr(buffer)
198            elif torch.is_floating_point(buffer):
199                tensor_str = f"torch.randn({list(buffer.shape)}, dtype={buffer.dtype})"
200            else:
201                tensor_str = (
202                    f"torch.randint(1, size={list(buffer.shape)}, dtype={buffer.dtype})"
203                )
204            if buffer.is_cuda:
205                tensor_str = f"{tensor_str}.cuda()"
206            model_str += f"{tab*2}self.register_buffer('{buffer_name}', {tensor_str})\n"
207
208        for param_name, param in gm._parameters.items():
209            if param is None:
210                continue
211            maybe_device = ""
212            if param.is_cuda:
213                maybe_device = ', device="cuda"'
214            tensor_str = f"torch.nn.Parameter(torch.randn({list(param.shape)}, dtype={param.dtype}{maybe_device}))"
215            model_str += f"{tab*2}self.{param_name} = {tensor_str}\n"
216
217        # TODO - Keep this code for now. But, I don't think we will need this.
218        # attrs = dir(gm)
219        # for attr in attrs:
220        #     if "_tensor_constant" in attr:
221        #         val = getattr(gm, attr)
222        #         model_str += f"    {attr} = {val!r}\n"
223
224        model_str += f"{_addindent(gm.code, 4)}\n"
225        return model_str
226
227
228@functools.lru_cache(None)  # subprocess is expensive
229def _cuda_system_info_comment():
230    if not torch.cuda.is_available():
231        return "# torch.cuda.is_available()==False, no GPU info collected\n"
232
233    model_str = "# CUDA Info: \n"
234    try:
235        cuda_version_out = subprocess.check_output(["nvcc", "--version"])
236        cuda_version_lines = cuda_version_out.decode().split("\n")
237        comment = "".join([f"# {s} \n" for s in cuda_version_lines if s not in [""]])
238        model_str += f"{comment}\n"
239    except (FileNotFoundError, subprocess.CalledProcessError):
240        model_str += "# nvcc not found\n"
241
242    gpu_names = Counter(
243        torch.cuda.get_device_name(i) for i in range(torch.cuda.device_count())
244    )
245
246    model_str += "# GPU Hardware Info: \n"
247    for name, count in gpu_names.items():
248        model_str += f"# {name} : {count} \n"
249    model_str += "\n"
250    return model_str
251
252
253def generate_config_string(*, stable_output=False):
254    import torch._functorch.config
255    import torch._inductor.config
256
257    if stable_output:
258        return "# config omitted due to stable_output=True"
259
260    experimental_config = torch.fx.experimental._config.codegen_config()  # type: ignore[attr-defined]
261    return f"""\
262import torch._dynamo.config
263import torch._inductor.config
264import torch._functorch.config
265import torch.fx.experimental._config
266{torch._dynamo.config.codegen_config()}
267{torch._inductor.config.codegen_config()}
268{torch._functorch.config.codegen_config()}
269{experimental_config}
270"""
271
272
273def get_minifier_repro_path():
274    return os.path.join(minifier_dir(), "minifier_launcher.py")
275
276
277def helper_for_dump_minify(contents):
278    minified_repro_path = get_minifier_repro_path()
279    log.warning("Writing minified repro to:\n%s", minified_repro_path)
280
281    if use_buck:
282        BuckTargetWriter(minified_repro_path).write()
283    try:
284        with open(minified_repro_path, "w") as fd:
285            fd.write(contents)
286
287    except OSError as e:
288        log.exception("")
289        raise NotImplementedError("Could not write to {minified_repro_path}") from e
290
291
292class AccuracyError(Exception):
293    pass
294
295
296def clone_inputs_retaining_gradness(example_inputs):
297    """
298    This clone inputs is different from utils clone_input. In case of minifier,
299    all the tensors are leaf tensors while creating a new graph. So, we set the
300    requires_grad field w/o checking the leafness of the tensor.
301    """
302    cloned_inputs = clone_inputs(example_inputs)
303    for idx in range(len(example_inputs)):
304        if isinstance(cloned_inputs[idx], torch.Tensor):
305            cloned_inputs[idx].requires_grad_(example_inputs[idx].requires_grad)
306    return cloned_inputs
307
308
309def run_fwd_maybe_bwd(gm, args, only_fwd=False, disable_clone=False):
310    """
311    Runs a forward and possibly backward iteration for a given mod and args.
312
313    When disable_clone is True, we will use args as-is without cloning.
314    This is higher fidelity but we may destroy the args in the process.
315    """
316    from .testing import collect_results, reduce_to_scalar_loss, requires_bwd_pass
317
318    gm = copy.deepcopy(gm)
319    if not disable_clone:
320        args = clone_inputs_retaining_gradness(args)
321
322    if hasattr(gm, "zero_grad"):
323        gm.zero_grad(True)
324
325    # TorchInductor returned callable expects lists. So, may need a boxed calling convention.
326    out = gm(args) if hasattr(gm, "_boxed_call") else gm(*args)
327
328    if only_fwd:
329        return out
330    if requires_bwd_pass(out):
331        loss = reduce_to_scalar_loss(out)
332        loss.backward()
333    return collect_results(gm, out, None, args)
334
335
336def same_two_models(
337    gm,
338    opt_gm,
339    example_inputs,
340    only_fwd=False,
341    *,
342    require_fp64=False,
343    ignore_non_fp=False,
344):
345    """
346    Check two models have same accuracy.
347
348    require_fp64: if True, raise an error if we unable to calculate the fp64 reference
349    ignore_non_fp: if True, do not compare outputs which are not floating point.  This
350        is mostly useful for the minifier (which wants to avoid quantizing floating point
351        error into integer/boolean error)
352    """
353    from .utils import same
354
355    ref = run_fwd_maybe_bwd(gm, example_inputs, only_fwd)
356
357    fp64_ref = None
358    if config.same_two_models_use_fp64:
359        try:
360            fp64_model, fp64_examples = cast_to_fp64(
361                copy.deepcopy(gm), clone_inputs_retaining_gradness(example_inputs)
362            )
363            fp64_ref = run_fwd_maybe_bwd(fp64_model, fp64_examples, only_fwd)
364        except Exception:
365            if require_fp64:
366                raise RuntimeError(  # noqa: B904
367                    "Could not generate fp64 outputs, workaround with torch._dynamo.config.same_two_models_use_fp64 = False"
368                )
369            log.warning("Could not generate fp64 outputs")
370
371    try:
372        res = run_fwd_maybe_bwd(opt_gm, example_inputs, only_fwd)
373    except Exception as e:
374        # This means that the minified graph is bad/exposes a different problem.
375        # As we are checking accuracy here, lets log the exception and return True.
376        log.exception(
377            "While minifying the program in accuracy minification mode, "
378            "ran into a runtime exception which is likely an unrelated issue."
379            " Skipping this graph."
380        )
381        return True
382
383    passing = same(
384        ref,
385        res,
386        fp64_ref,
387        tol=config.repro_tolerance,
388        equal_nan=True,
389        ignore_non_fp=ignore_non_fp,
390    )
391    return passing
392
393
394def cast_dtype_args_to_fp64(model):
395    for node in model.graph.nodes:
396        if (
397            node.op == "call_function"
398            and node.target == torch.ops.prims.convert_element_type.default
399        ):
400            assert len(node.args) == 2
401            if is_float_dtype(node.args[1]) and node.args[1] != torch.float64:
402                node.args = (node.args[0], torch.float64)
403        if node.op == "call_function":
404            dtype = node.kwargs.get("dtype")
405            if dtype is not None and is_float_dtype(dtype):
406                new_kwargs = dict(node.kwargs)
407                new_kwargs["dtype"] = torch.float64
408                node.kwargs = new_kwargs
409
410    model.graph.lint()
411    model.recompile()
412    return model
413
414
415def cast_to(dtype, model, inputs):
416    from torch.utils._pytree import tree_map
417
418    model = model.to(dtype)
419    if dtype == torch.float64:
420        # If casting to fp64 for accuracy comparison, we need to
421        # replace dtype arguments embedded in the graph with fp64
422        model = cast_dtype_args_to_fp64(model)
423
424    inputs = tree_map(
425        lambda x: x.to(dtype)
426        if isinstance(x, torch.Tensor) and x.is_floating_point()
427        else x,
428        inputs,
429    )
430    return model, inputs
431
432
433def cast_to_fp64(model, inputs):
434    return cast_to(torch.float64, model, inputs)
435
436
437def backend_accuracy_fails(
438    gm,
439    example_inputs,
440    compiler_fn,
441    only_fwd=False,
442    *,
443    require_fp64=False,
444    ignore_non_fp=False,
445):
446    try:
447        compiled_gm = compiler_fn(
448            copy.deepcopy(gm), clone_inputs_retaining_gradness(example_inputs)
449        )
450        return not same_two_models(
451            gm,
452            compiled_gm,
453            example_inputs,
454            only_fwd,
455            require_fp64=require_fp64,
456            ignore_non_fp=ignore_non_fp,
457        )
458    except Exception as e:
459        # This means that the minified graph is bad/exposes a different problem.
460        # As we are checking accuracy here, lets log the exception and return False.
461        log.exception(
462            "While minifying the program in accuracy minification mode, "
463            "ran into a runtime exception which is likely an unrelated issue."
464            " Skipping this graph"
465        )
466        return False
467
468
469# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
470#                       REPRO SUPPORT CODE
471# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
472
473
474# Helper functions for computing what the default values of tensor
475# values should be.  These all coincide with factory functions, e.g., torch.empty
476
477
478def _stride_or_default(
479    stride: Optional["torch._prims_common.StrideType"],
480    *,
481    shape: "torch._prims_common.ShapeType",
482) -> "torch._prims_common.StrideType":
483    return stride if stride is not None else utils.make_contiguous_strides_for(shape)
484
485
486def _mk_defaulter(d: T) -> Callable[[Optional[T]], T]:
487    return lambda x: x if x is not None else d
488
489
490_dtype_or_default = _mk_defaulter(torch.float32)
491_device_or_default = _mk_defaulter(torch.device("cpu"))
492_storage_offset_or_default = _mk_defaulter(0)
493_requires_grad_or_default = _mk_defaulter(False)
494_is_leaf_or_default = _mk_defaulter(False)
495
496
497class NopInputReader:
498    def __init__(self) -> None:
499        self.total = 0
500
501    def storage(self, storage_hash, nbytes, *, device=None, dtype_hint=None):
502        self.total += 1
503
504    def tensor(self, *args, **kwargs):
505        pass
506
507    def symint(self, *args, **kwargs):
508        pass
509
510
511# TODO: Support bundling the entire repro into a zip file for ease of
512# transferring around
513class InputReader:
514    def __init__(self, save_dir=None, *, pbar=None):
515        # If None, we will generate random data instead.  It's important
516        # to natively support this use case as it will allow people to
517        # share repros without including the real data, if the problem
518        # reproduces even on random data.
519        if save_dir is None:
520            log.warning("no save_dir specified, will generate random data")
521        self.store = ContentStoreReader(save_dir) if save_dir is not None else None
522        self.args = []
523        self.pbar = pbar
524
525    def storage(self, storage_hash, nbytes, *, device=None, dtype_hint=None):
526        if self.pbar is not None:
527            self.pbar.update(1)
528        device = _device_or_default(device)
529        dtype_hint = _dtype_or_default(dtype_hint)
530        if self.store is not None and storage_hash is not None:
531            try:
532                storage = self.store.read_storage(storage_hash)
533            except FileNotFoundError:
534                pass
535            else:
536                if device != storage.device:
537                    log.warning("device mismatch: %s != %s", device, storage.device)
538                    # TODO: transfer it to the right device?  But failing this
539                    # way would be very mysterious!  Would have been better
540                    # not to store device in the serialized format...
541                return storage
542        log.warning("could not load %s, generating random data instead", storage_hash)
543        shape = (nbytes // dtype_hint.itemsize,)
544        stride = _stride_or_default(None, shape=shape)
545        return rand_strided(shape, stride, dtype_hint, device).untyped_storage()
546
547    def tensor(
548        self,
549        storage,
550        shape,
551        stride=None,
552        *,
553        storage_offset=None,
554        dtype=None,
555        requires_grad=None,
556        is_leaf=None,
557        **metadata,
558    ):
559        stride = _stride_or_default(stride, shape=shape)
560        storage_offset = _storage_offset_or_default(storage_offset)
561        dtype = _dtype_or_default(dtype)
562        is_leaf = _is_leaf_or_default(is_leaf)
563        requires_grad = _requires_grad_or_default(requires_grad)
564        t = torch.tensor(
565            [], dtype=dtype, device=storage.device, requires_grad=requires_grad
566        )
567        with torch.no_grad():
568            t.set_(storage, storage_offset, shape, stride)
569        if not is_leaf:
570            # Fake up some autograd history in a very naughty way
571            with torch.enable_grad():
572                t = t.clone(memory_format=torch.preserve_format)
573            with torch.no_grad():
574                t.set_(storage, storage_offset, shape, stride)
575        assert torch._subclasses.meta_utils.safe_is_leaf(t) == is_leaf
576        torch._utils.set_tensor_metadata(t, metadata)
577        self.args.append(t)
578        return t  # for BC
579
580    def symint(self, val):
581        self.args.append(val)
582        return val  # for BC
583
584
585# Here is our writer strategy:
586#  1. We will stream all of the inputs to disk
587#  2. You can now deterministically randomize the inputs, or reload
588#     the inputs from disk
589#  3. You can YOLO run the script without the inputs, in which case
590#     we'll fill the inputs with random data and pray.  This is the
591#     legacy behavior, but it's also useful if you want to find out
592#     if we're so broken even random inputs trigger it
593#  4. We could offer an in process "check if the randomized thing
594#     works too" but this is delicate so we don't do it
595
596
597class InputWriter:
598    def __init__(self, save_dir, *, stable_hash=False):
599        self._lines = []
600        # TODO: consider ensuring tensor and storage counters line up?
601        self.storage_counter = itertools.count()
602        self.save_dir = save_dir
603        self.store = (
604            ContentStoreWriter(save_dir, stable_hash=stable_hash)
605            if save_dir is not None
606            else None
607        )
608        self.seen_storages = {}
609
610    def lines(self):
611        r = [
612            "def load_args(reader):",
613        ]
614        r.extend(f"    {l}" for l in self._lines)
615        # In case we need to change the internal format of load_args
616        # in an FC-breaking way
617        r.append("load_args._version = 0")
618        return r
619
620    # Storages are untyped, but we need to initialize them with data if
621    # we don't have the real data, so we give a hint saying what kind
622    # of initialization may be appropriate
623    #
624    # If we had a FakeTensor, device_hint tells us what device should be
625    def storage(self, untyped_storage, *, dtype_hint=None, device_hint=None) -> str:
626        ws = StorageWeakRef(untyped_storage)
627        v = self.seen_storages.get(ws)
628        if v is not None:
629            return v
630        v = f"buf{next(self.storage_counter)}"
631        maybe_dtype_hint = ""
632        if _dtype_or_default(None) != _dtype_or_default(dtype_hint):
633            maybe_dtype_hint = f", dtype_hint={dtype_hint!r}"
634        # TODO: being optional on device is kind of pointless as the default
635        # is CPU but most repros we care about are CUDA
636        maybe_device = ""
637        device = untyped_storage.device
638        if device.type == "meta":
639            assert device_hint is not None
640            device = device_hint
641        if _device_or_default(None) != device:
642            maybe_device = f", device={device!r}"
643        nbytes = untyped_storage.nbytes()
644        storage_hash = None
645        if self.store is not None and untyped_storage.device.type != "meta":
646            storage_hash = self.store.write_storage(untyped_storage)
647        self._lines.append(
648            f"{v} = reader.storage({storage_hash!r}, {nbytes!r}{maybe_device}{maybe_dtype_hint})"
649        )
650        self.seen_storages[ws] = v
651        return v
652
653    def tensor(self, name, t) -> None:
654        from torch.fx.experimental.symbolic_shapes import statically_known_true
655
656        storage = self.storage(
657            t.untyped_storage(), dtype_hint=t.dtype, device_hint=t.device
658        )
659        args = []
660        # NB: this is positional, must come first
661        if _stride_or_default(None, shape=t.shape) != t.stride():
662            args.append(str(tuple(t.stride())))
663        if _dtype_or_default(None) != t.dtype:
664            args.append(f"dtype={t.dtype!r}")
665        if not statically_known_true(
666            _storage_offset_or_default(None) == t.storage_offset()
667        ):
668            args.append(f"storage_offset={t.storage_offset()!r}")
669        tensor_metadata = torch._utils.get_tensor_metadata(t)
670        if tensor_metadata:
671            args.extend(f"{k}={v!r}" for k, v in tensor_metadata.items())
672        if _requires_grad_or_default(None) != t.requires_grad:
673            args.append(f"requires_grad={t.requires_grad!r}")
674        is_leaf = torch._subclasses.meta_utils.safe_is_leaf(t)
675        if _is_leaf_or_default(None) != is_leaf:
676            args.append(f"is_leaf={is_leaf!r}")
677        self._lines.append(
678            "reader.tensor("
679            + ", ".join([storage, str(tuple(t.shape)), *args])
680            + f")  # {name}"
681        )
682
683    # TODO: this doesn't actually symint atm
684    def symint(self, name, val) -> None:
685        if isinstance(val, torch.SymInt):
686            val = val.node.hint
687        self._lines.append(f"reader.symint({val!r})  # {name}")
688
689
690def aot_graph_input_parser(
691    func: Callable[[List[Tensor]], List[Tensor]],
692    device: str = "cuda",
693    sym_shapes: Optional[Dict[str, int]] = None,
694    default_sym_shape: Optional[int] = None,
695) -> Dict[str, Any]:
696    """
697    Takes in a function which has been printed with print_readable() and constructs kwargs to run it.
698
699    Handles Tensor inputs, Symints, and a graph module which might have tensor constants.
700
701    Consider a function `forward` defined as follows:
702
703    def forward(self, primals_1: "f32[1001, 6]", primals_2: "f32[s0]", primals_3: "Sym(s0)",):
704        _tensor_constant0: "i64[4190]" = self._tensor_constant0
705        # Further implementation
706
707    kwargs = aot_graph_input_parser(forward)
708    forward(**kwargs)
709    """
710
711    from torch.fx.graph import dtype_abbrs
712
713    dtype_map = {value: key for key, value in dtype_abbrs.items()}
714    dtype_pattern = "|".join(dtype_abbrs.values())
715
716    # Extracting the source code from the function
717    source = inspect.getsource(func)
718
719    # Regular expressions
720    tensor_assignment_regex = rf"(_tensor_constant\d+): \"({dtype_pattern})\[\s*(.*?)\s*\]\" = self\.(_tensor_constant\d+)"
721    tensor_regex = rf"({dtype_pattern})\[\s*(.*?)\s*\]"
722    sym_shape_regex = r"Sym\((s\d+)\)"
723
724    class TensorContainer:
725        "Container for tensors as attributes"
726
727    # Dictionary for tensors from annotations
728    kwargs: Dict[str, Any] = {}
729
730    sym_shapes = sym_shapes or {}
731
732    def get_sym_int(symint):
733        torch._check(
734            symint in sym_shapes or default_sym_shape is not None,
735            lambda: f"{symint} not in symbolic_shapes and default sym shape not passed in",
736        )
737        return sym_shapes.get(symint, default_sym_shape)
738
739    def gen_tensor(shape, dtype) -> Tensor:
740        # Resolve symbolic shapes to concrete values
741        resolved_shape = []
742        dynamic_dims = []
743        for i, dim in enumerate(shape):
744            dim = dim.strip()
745            if "s" in dim:
746                s = get_sym_int(dim)
747                resolved_shape.append(s)
748                dynamic_dims.append(i)
749            else:
750                if dim:
751                    resolved_shape.append(int(dim))
752
753        constructor = torch.randn if dtype.is_floating_point else torch.zeros
754        out = constructor(resolved_shape, dtype=dtype, device=device)  # type: ignore[call-arg]
755        for d in dynamic_dims:
756            torch._dynamo.mark_dynamic(out, d)
757        return out
758
759    # Parse function annotations for tensor generation
760    annotations = func.__annotations__
761    for param, annotation in annotations.items():
762        # Skip 'return' annotation
763        if param == "return":
764            continue
765
766        match = re.search(tensor_regex, annotation)
767        if match:
768            data_type, shape_str = match.groups()
769            shape = tuple(shape_str.split(","))
770            dtype = dtype_map[data_type]
771            kwargs[param] = gen_tensor(shape, dtype)
772
773        match = re.search(sym_shape_regex, annotation)
774        if match:
775            kwargs[param] = get_sym_int(match.group(1))
776
777    if "self" in inspect.signature(func).parameters:
778        container = TensorContainer()
779        kwargs["self"] = container
780        for match in re.finditer(tensor_assignment_regex, source):
781            attr_name, data_type, shape_str, _ = match.groups()
782            shape = tuple(shape_str.split(","))
783            dtype = dtype_map[data_type]
784            setattr(container, attr_name, gen_tensor(shape, dtype))
785
786    return kwargs
787
788
789def profile_to_file(filename: str) -> Callable[[T], T]:
790    """
791    Decorator to cProfile a given function and save the result to disk on process exit.
792
793    Args:
794        filename: filename to save profile to
795    """
796    prof = cProfile.Profile()
797    filename = os.path.abspath(os.path.expanduser(filename))
798
799    def decorator(fn):
800        @functools.wraps(fn)
801        def wrapper(*args, **kwargs):
802            prof.enable()
803            try:
804                return fn(*args, **kwargs)
805            finally:
806                prof.disable()
807
808        return wrapper
809
810    def save_it():
811        prof.dump_stats(filename)
812        sys.stderr.write(
813            textwrap.dedent(
814                f"""\
815                Wrote profile to {filename}, view with:
816
817                    snakeviz {filename}
818
819                """
820            )
821        )
822
823    atexit.register(save_it)
824    return decorator
825