xref: /aosp_15_r20/external/pytorch/torch/_dynamo/utils.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import atexit
3import collections
4import contextlib
5import copy
6import dataclasses
7import datetime
8import dis
9import enum
10import functools
11import gc
12import inspect
13import itertools
14import linecache
15import logging
16import math
17import operator
18import os
19import re
20import sys
21import textwrap
22import threading
23import time
24import types
25import typing
26import warnings
27import weakref
28from contextlib import contextmanager
29from functools import lru_cache, wraps
30from types import MethodWrapperType
31from typing import (
32    Any,
33    Callable,
34    cast,
35    ClassVar,
36    Counter,
37    DefaultDict,
38    Deque,
39    Dict,
40    Iterator,
41    KeysView,
42    List,
43    Optional,
44    Set,
45    Tuple,
46    Type,
47    Union,
48    ValuesView,
49)
50
51from ..utils.hooks import RemovableHandle
52
53try:
54    import numpy as np
55except ModuleNotFoundError:
56    np = None  # type: ignore[assignment]
57
58try:
59    import torch._logging
60    import torch._numpy as tnp
61    from torch._guards import detect_fake_mode  # noqa: F401n
62    from torch._logging import LazyString
63    from . import config
64
65    # NOTE: Make sure `NP_SUPPORTED_MODULES` and `NP_TO_TNP_MODULE` are in sync.
66    if np:
67        NP_SUPPORTED_MODULES: Tuple[types.ModuleType, ...] = (
68            np,
69            np.fft,
70            np.linalg,
71            np.random,
72        )
73
74        NP_TO_TNP_MODULE = {
75            np: tnp,
76            np.fft: tnp.fft,
77            np.linalg: tnp.linalg,
78            np.random: tnp.random,
79        }
80    else:
81        NP_SUPPORTED_MODULES = tuple()
82
83        NP_TO_TNP_MODULE = {}
84    from torch._subclasses.fake_tensor import FakeTensor, is_fake, maybe_get_fake_mode
85except ImportError:
86    pass
87
88import importlib
89
90import torch
91import torch._functorch.config
92import torch.fx.experimental.symbolic_shapes
93import torch.utils._pytree as pytree
94from torch import fx
95from torch._dispatch.python import enable_python_dispatcher
96from torch._guards import TracingContext
97from torch._subclasses.meta_utils import is_sparse_compressed
98from torch._utils_internal import log_compilation_event
99
100from torch.fx._utils import _format_graph_code, lazy_format_graph_code
101from torch.nn.modules.lazy import LazyModuleMixin
102from torch.utils._triton import has_triton, has_triton_package
103
104
105counters: DefaultDict[str, Counter[str]] = collections.defaultdict(collections.Counter)
106optimus_scuba_log: Dict[str, Any] = {}
107troubleshooting_url = (
108    "https://pytorch.org/docs/main/torch.compiler_troubleshooting.html"
109)
110nnmodule_doc_url = "https://pytorch.org/docs/main/torch.compiler_nn_module.html"
111nnmodule_doc_url_msg = f"See {nnmodule_doc_url} for more information and limitations."
112log = logging.getLogger(__name__)
113
114# profiling compilation time by function
115compilation_time_metrics: Dict[str, List[float]] = {}
116
117# profiling compilation time by frame phase
118frame_phase_timing: Dict[str, Dict[str, float]] = collections.defaultdict(
119    lambda: collections.defaultdict(float)
120)
121
122timer_counter = itertools.count()
123
124
125def tabulate(rows, headers):
126    try:
127        import tabulate
128
129        return tabulate.tabulate(rows, headers=headers)
130    except ImportError:
131        return "\n".join(
132            ", ".join(map(str, row)) for row in itertools.chain([headers], rows)
133        )
134
135
136curr_frame = 0
137
138
139# Note: Called for you by dynamo - you almost never ever want to invoke this yourself.
140def increment_frame():
141    global curr_frame
142    curr_frame = curr_frame + 1
143
144
145# Note: Called for you by dynamo - you almost never ever want to invoke this yourself.
146def reset_frame_count():
147    global curr_frame
148    frame_phase_timing.clear()
149    compilation_time_metrics.clear()
150    curr_frame = 0
151
152
153op_count = 0
154
155
156def increment_op_count(cnt):
157    global op_count
158    op_count += cnt
159
160
161# Calculate total time spent so far for each phase
162# For example, {'entire_frame_compile':8.574629999999999, 'backend_compile':5.26806}
163def calculate_time_spent():
164    total = 0.0
165    total_by_key = {}
166    for timings in frame_phase_timing.values():
167        for key, timing in timings.items():
168            total += timing
169            if key not in total_by_key:
170                total_by_key[key] = timing
171            else:
172                total_by_key[key] += timing
173
174    return total_by_key
175
176
177# Print a report of time spent so far
178# Ex:
179# TIMING:
180# entire_frame_compile:8.574629999999999
181# backend_compile:5.26806
182def print_time_report():
183    total_by_key = calculate_time_spent()
184
185    out = "TIMING:"
186    for key, value in total_by_key.items():
187        out = f"{out} {key}:{round(value, 5)}"
188
189    print(out)
190
191
192def _add_time_spent(key, phase_name, time_spent):
193    frame_phase_timing[key][phase_name] += time_spent
194
195
196# dynamo_timed API works as a function decorator
197# By wrapping a function in dynamo_timed, we can store a record in compilation_time_metrics
198# where the key is the functions name.
199# For example:
200#
201#  @dynamo_timed
202#  def _foo(...):
203#
204# Would show up as an entry in our timing dict:
205# OrderedDict([('bar.<locals>._foo', [0.083690, 0.23949, 3.1425e-05])])
206# This is extremely useful for granular debugging.
207#
208# For a higher-level mode, pass a phase_name into dynamo_timed
209# phase_names record an extra record into a separate compilation timing structure,
210# one keyed on frame+name rather than function.
211# The frame is incremented outside of this function, in def increment_frame() above.
212# `fwd_only` is used to identify if this phase or function is only called
213# during compiling fwd graphs, e.g, `entire_frame_compile` and `backend_compile`.
214# The other phases (`inductor_compile` and `code_gen`) are called for both fwd and bwd graphs.
215
216
217def dynamo_timed(original_function=None, phase_name=None, fwd_only=True):
218    def dynamo_timed_inner(func):
219        @wraps(func)
220        def time_wrapper(*args, **kwargs):
221            key = func.__qualname__
222            if key not in compilation_time_metrics:
223                compilation_time_metrics[key] = []
224
225            fail_type: Optional[str] = None
226            fail_reason: Optional[str] = None
227            time_spent = float("-inf")
228            try:
229                with torch.profiler.record_function(f"{key} (dynamo_timed)"):
230                    t0 = time.time()
231                    r = func(*args, **kwargs)
232                    time_spent = time.time() - t0
233                compilation_time_metrics[key].append(time_spent)
234            except Exception as e:
235                fail_type = str(type(e))
236                fail_reason = str(e)
237                raise
238            finally:
239                # Only record backward compilation metrics if phase_name is not None!
240                if phase_name:
241                    frame_key = str(curr_frame)
242                    # fwd only compilation stages: entire_frame_compile, backend_compile.
243                    # use frame_key as time aggregation key.
244                    if fwd_only and fail_type is None:
245                        _add_time_spent(frame_key, phase_name, time_spent)
246                    else:
247                        # fwd + bwd compilation stages: inductor_compile, code_gen.
248                        # use frame_key as time aggregation key for fwd graphs;
249                        # use compile_id as time aggregation key for bwd graphs.
250                        if torch._guards.TracingContext.try_get() is not None:
251                            aot_graph_name = str(
252                                torch._guards.TracingContext.get().aot_graph_name
253                            )
254                            if (
255                                "forward" in aot_graph_name
256                                or "inference" in aot_graph_name
257                            ) and fail_type is None:
258                                _add_time_spent(frame_key, phase_name, time_spent)
259                            elif "backward" in aot_graph_name:
260                                compile_id = str(
261                                    torch._guards.CompileContext.current_compile_id()
262                                )
263                                if fail_type is None:
264                                    _add_time_spent(compile_id, phase_name, time_spent)
265
266                                # log backward compilation metrics at the end of `inductor_compile` of bwd graph,
267                                # one record for one bwd graph.
268                                if phase_name == "inductor_compile":
269                                    if fail_type is None:
270                                        inductor_compile_time = frame_phase_timing[
271                                            compile_id
272                                        ].get("inductor_compile", None)
273                                        code_gen_time = frame_phase_timing[
274                                            compile_id
275                                        ].get("code_gen", None)
276                                    else:
277                                        inductor_compile_time = None
278                                        code_gen_time = None
279                                    metrics = BwdCompilationMetrics(
280                                        compile_id,
281                                        inductor_compile_time,
282                                        code_gen_time,
283                                        fail_type,
284                                        fail_reason,
285                                    )
286                                    record_compilation_metrics(metrics)
287
288            return r
289
290        return time_wrapper
291
292    if original_function:
293        return dynamo_timed_inner(original_function)
294    return dynamo_timed_inner
295
296
297def compile_times(repr="str", aggregate=False):
298    """
299    Get metrics about torchdynamo frontend/backend compilation times.
300
301    Accumulates information from functions tagged with `@dynamo_timed`.
302
303    repr='str' returns a printable string for user interaction, and 'csv'
304    returns headers, rows which can be logged for output
305
306    aggregate causes values from multiple compilations (e.g. split graphs)
307    to be accumulated into one value.  If false, expect more than one value
308    per metric.
309    """
310
311    def fmt_fn(values, item_fn=lambda x: x):
312        if aggregate:
313            return item_fn(sum(values))
314        return ", ".join(map(item_fn, values))
315
316    if repr == "str":
317        rows = [
318            (k, fmt_fn(compilation_time_metrics[k], item_fn=lambda x: f"{x:.4f}"))
319            for k in compilation_time_metrics
320        ]
321        out = "TorchDynamo compilation metrics:\n"
322        out += tabulate(rows, headers=("Function", "Runtimes (s)"))
323        return out
324    elif repr == "csv":
325        values = [
326            fmt_fn(v, item_fn=lambda x: f"{x:.6f}")
327            for v in compilation_time_metrics.values()
328        ]
329        headers = list(compilation_time_metrics.keys())
330        return headers, values
331
332
333@atexit.register
334def dump_compile_times():
335    log.info(compile_times(repr="str", aggregate=True))
336
337
338tensortype_to_dtype = {
339    torch.FloatTensor: (torch.float32, torch.float),
340    torch.DoubleTensor: (torch.float64, torch.double),
341    torch.HalfTensor: (torch.float16, torch.half),
342    torch.BFloat16Tensor: (torch.bfloat16,),
343    torch.ByteTensor: (torch.uint8,),
344    torch.CharTensor: (torch.int8,),
345    torch.LongTensor: (torch.int64, torch.long),
346    torch.IntTensor: (torch.int32, torch.int),
347    torch.ShortTensor: (torch.int16, torch.short),
348    torch.BoolTensor: (torch.bool,),
349}
350
351
352class DuplicateWarningChecker:
353    def __init__(self, maxsize=4096):
354        self.maxsize = maxsize
355        self.reset()
356
357    def reset(self):
358        self.set = collections.OrderedDict()
359
360    def add(self, key):
361        if key in self.set:
362            self.set.move_to_end(key, last=True)
363            if not config.verbose:
364                return False
365        else:
366            self.set[key] = None
367            while len(self.set) > self.maxsize:
368                self.set.popitem(last=False)
369        return True
370
371
372graph_break_dup_warning_checker = DuplicateWarningChecker()
373
374
375def setup_compile_debug():
376    compile_debug = os.environ.get("TORCH_COMPILE_DEBUG", "0") == "1"
377
378    if compile_debug:
379        return add_file_handler()
380
381    return contextlib.ExitStack()
382
383
384def reset_graph_break_dup_checker():
385    graph_break_dup_warning_checker.reset()
386
387
388def add_file_handler():
389    log_path = os.path.join(get_debug_dir(), "torchdynamo")
390    os.makedirs(log_path, exist_ok=True)
391
392    log_file_handler = logging.FileHandler(os.path.join(log_path, "debug.log"))
393    logger = logging.getLogger("torch._dynamo")
394    logger.addHandler(log_file_handler)
395
396    exitstack = contextlib.ExitStack()
397    exitstack.callback(lambda: logger.removeHandler(log_file_handler))
398    return exitstack
399
400
401def setup_log_file():
402    exitstack = contextlib.ExitStack()
403    if config.log_file_name is not None:
404        log_file_handler = logging.FileHandler(config.log_file_name)
405        for logger in torch._logging._internal.get_loggers():
406            logger.addHandler(log_file_handler)
407            exitstack.callback(lambda: logger.removeHandler(log_file_handler))
408        return exitstack
409
410    return exitstack
411
412
413def gen_record_file_name(exc, code):
414    return f"{get_debug_dir()}/error_recordings/\
415{code.co_name}_{type(exc).__name__}_{code.co_firstlineno}.rec"
416
417
418def write_record_to_file(filename, exec_record):
419    try:
420        if os.path.exists(filename):
421            log.warning(
422                "Unable to write execution record %s; file already exists.", filename
423            )
424        else:
425            os.makedirs(os.path.dirname(filename), exist_ok=True)
426            with open(filename, "wb") as f:
427                exec_record.dump(f)
428    except Exception:
429        log.exception("Unable to write execution record %s", filename)
430
431
432def count_calls(g: fx.Graph):
433    c = 0
434    for n in g.nodes:
435        if "call" in n.op:
436            c += 1
437    return c
438
439
440def identity(x):
441    return x
442
443
444def hashable(x):
445    try:
446        hash(x)
447        return True
448    except TypeError:
449        return False
450    # cannot hash writable memoryview object
451    except ValueError:
452        return False
453
454
455def nothing(*args, **kwargs):
456    pass
457
458
459class ExactWeakKeyDictionary:
460    """Similar to weakref.WeakKeyDictionary, but use `is`/`id` rather than `==` to compare equality"""
461
462    def __init__(self):
463        self.values = dict()
464        self.refs = dict()
465
466    def __getitem__(self, key):
467        return self.values[id(key)]
468
469    def get(self, key, default=None):
470        return self.values.get(id(key), default)
471
472    def __contains__(self, key):
473        return id(key) in self.values
474
475    def __setitem__(self, key, value):
476        idx = id(key)
477        if idx not in self.refs:
478            self.refs[idx] = weakref.ref(key, lambda ref: self._remove_id(idx))
479        self.values[idx] = value
480
481    def _remove_id(self, idx):
482        if idx in self.values:
483            del self.values[idx]
484        if idx in self.refs:
485            del self.refs[idx]
486
487    def clear(self):
488        self.refs.clear()
489        self.values.clear()
490
491
492def istype(obj, allowed_types):
493    """isinstance() without subclasses"""
494    if isinstance(allowed_types, (tuple, list, set)):
495        return type(obj) in allowed_types
496    return type(obj) is allowed_types
497
498
499if sys.version_info >= (3, 12):
500    # Some typing classes moved to C in 3.12,
501    # which no longer have the _Final mixin.
502    _builtin_final_typing_classes = (
503        typing.ParamSpecArgs,
504        typing.ParamSpecKwargs,
505        typing.ParamSpec,
506        typing.TypeVar,
507        typing.TypeVarTuple,
508        typing.TypeAliasType,
509    )
510
511
512def is_typing(value):
513    # _Final catches most of typing classes:
514    #   - Any
515    #   - Callable
516    #   - Union
517    #   ...
518    #
519    # NB: we intentionally ignore classes that inherit from Generic, since they
520    # can be used as both TypingVariable as well as UserDefinedClassVariable.
521    if sys.version_info >= (3, 12) and isinstance(value, _builtin_final_typing_classes):
522        return True
523    return isinstance(value, typing._Final) or value is typing.Generic  # type: ignore[attr-defined]
524
525
526def is_numpy_int_type(value):
527    if not np:
528        return False
529
530    return istype(
531        value,
532        (
533            np.int8,
534            np.int16,
535            np.int32,
536            np.int64,
537            np.uint8,
538            np.uint16,
539            np.uint32,
540            np.uint64,
541        ),
542    )
543
544
545def is_numpy_float_type(value):
546    if not np:
547        return False
548
549    return istype(
550        value,
551        (
552            np.float16,
553            np.float32,
554            np.float64,
555        ),
556    )
557
558
559def is_function_or_wrapper(value):
560    return (
561        is_function(value)
562        or isinstance(value, functools._lru_cache_wrapper)
563        and is_function(inspect.getattr_static(value, "__wrapped__"))
564        or isinstance(value, (torch._ops.OpOverloadPacket, torch._ops.OpOverload))
565    )
566
567
568def is_function(value):
569    return isinstance(
570        value,
571        (
572            types.FunctionType,
573            types.BuiltinFunctionType,
574            types.MethodDescriptorType,
575            types.WrapperDescriptorType,
576            torch.jit.ScriptFunction,
577        ),
578    )
579
580
581def unwrap_if_wrapper(fn):
582    return unwrap_with_attr_name_if_wrapper(fn)[0]
583
584
585def unwrap_with_attr_name_if_wrapper(fn):
586    # unpack @functools.lru_cache wrapped function
587    if isinstance(fn, functools._lru_cache_wrapper):
588        fn = inspect.getattr_static(fn, "__wrapped__")
589        attr_name = "__wrapped__"
590    # unpack @torch._dynamo.optimize()(fn) wrapped function
591    elif is_function(fn) and inspect.getattr_static(fn, "_torchdynamo_inline", False):
592        fn = inspect.getattr_static(fn, "_torchdynamo_inline", fn)
593        attr_name = "_torchdynamo_inline"
594    # unpack torch.jit.script_if_tracing
595    elif is_function(fn) and inspect.getattr_static(
596        fn, "__script_if_tracing_wrapper", False
597    ):
598        fn = inspect.getattr_static(fn, "__original_fn", fn)
599        attr_name = "__original_fn"
600    else:
601        attr_name = None
602    return fn, attr_name
603
604
605def is_numpy_ndarray(value):
606    if not np:
607        return False
608
609    return istype(value, np.ndarray)
610
611
612def istensor(obj):
613    """Check of obj is a tensor"""
614    tensor_list = (
615        torch.Tensor,
616        torch.nn.Parameter,
617        *config.traceable_tensor_subclasses,
618    )
619    tensor_list = tensor_list + (torch._subclasses.FakeTensor,)
620    return istype(obj, tensor_list)
621
622
623def is_lazy_module(mod):
624    return isinstance(mod, LazyModuleMixin)
625
626
627@functools.lru_cache(4096)
628def print_once(*args):
629    print(*args)
630
631
632def make_cell(val=None):
633    """Some black magic to create a cell object that usually only exists in a closure"""
634    x = val
635
636    def f():
637        return x
638
639    assert f.__closure__ is not None and len(f.__closure__) == 1
640    return f.__closure__[0]
641
642
643def proxy_args_kwargs(args, kwargs):
644    try:
645        proxy_args = tuple(arg.as_proxy() for arg in args)
646        proxy_kwargs = {key: arg.as_proxy() for key, arg in kwargs.items()}
647        return proxy_args, proxy_kwargs
648    except NotImplementedError as e:
649        from .exc import unimplemented
650        from .variables.base import typestr
651
652        unimplemented(
653            f"call_function args: {typestr(*args)} {typestr(*list(kwargs.values()))}",
654            from_exc=e,
655        )
656
657
658@dataclasses.dataclass
659class CompilationMetrics:
660    compile_id: str
661    frame_key: str
662    co_name: str
663    co_filename: str
664    co_firstlineno: int
665    cache_size: int
666    accumulated_cache_size: int
667    guard_count: Optional[int]
668    shape_env_guard_count: Optional[int]
669    graph_op_count: Optional[int]
670    graph_node_count: Optional[int]
671    graph_input_count: Optional[int]
672    start_time: float
673    entire_frame_compile_time_s: Optional[float]
674    backend_compile_time_s: Optional[float]
675    inductor_compile_time_s: Optional[float]
676    code_gen_time_s: Optional[float]
677    fail_type: Optional[str]
678    fail_reason: Optional[str]
679    fail_user_frame_filename: Optional[str]
680    fail_user_frame_lineno: Optional[int]
681    non_compliant_ops: Set[str]
682    compliant_custom_ops: Set[str]
683    restart_reasons: Set[str]
684    dynamo_time_before_restart_s: float
685    # Sometimes, we will finish analyzing a frame but conclude we don't want
686    # to install any guarded code.  True means we actually decided to install
687    # a compiled frame
688    has_guarded_code: bool
689
690
691@dataclasses.dataclass
692class BwdCompilationMetrics:
693    compile_id: str
694    inductor_compile_time_s: Optional[float]
695    code_gen_time_s: Optional[float]
696    fail_type: Optional[str]
697    fail_reason: Optional[str]
698
699
700DEFAULT_COMPILATION_METRICS_LIMIT = 64
701
702
703_compilation_metrics: Deque[
704    Union[CompilationMetrics, BwdCompilationMetrics]
705] = collections.deque(maxlen=DEFAULT_COMPILATION_METRICS_LIMIT)
706
707
708def record_compilation_metrics(
709    compilation_metrics: Union[CompilationMetrics, BwdCompilationMetrics]
710):
711    global _compilation_metrics
712    _compilation_metrics.append(compilation_metrics)
713    if isinstance(compilation_metrics, CompilationMetrics):
714        name = "compilation_metrics"
715    else:
716        name = "bwd_compilation_metrics"
717    # Currently only record fwd compilation metrics, will add bwd compilation metrics
718    # after the internal Scuba logging changes finish.
719    if isinstance(compilation_metrics, CompilationMetrics):
720        torch._logging.trace_structured(
721            name,
722            lambda: {
723                k: list(v) if isinstance(v, set) else v
724                for k, v in dataclasses.asdict(compilation_metrics).items()
725            },
726        )
727        if config.log_compilation_metrics:
728            log_compilation_event(compilation_metrics)
729
730
731def set_compilation_metrics_limit(new_size: int) -> None:
732    global _compilation_metrics
733    while len(_compilation_metrics) > new_size:
734        _compilation_metrics.popleft()
735    new_deque = collections.deque(_compilation_metrics, maxlen=new_size)
736    _compilation_metrics = new_deque
737
738
739def clear_compilation_metrics() -> None:
740    global _compilation_metrics
741    _compilation_metrics.clear()
742
743
744def get_compilation_metrics() -> List[Union[CompilationMetrics, BwdCompilationMetrics]]:
745    return list(_compilation_metrics)
746
747
748@dataclasses.dataclass
749class CleanupHook:
750    """Remove a global variable when hook is called"""
751
752    scope: Dict[str, Any]
753    name: str
754
755    def __call__(self, *args):
756        # Make sure we're not shutting down
757        if CleanupManager is not None:
758            CleanupManager.count -= 1
759        del self.scope[self.name]
760
761    @staticmethod
762    def create(scope, name, val):
763        assert name not in scope
764        CleanupManager.count += 1
765        scope[name] = val
766        return CleanupHook(scope, name)
767
768
769class CleanupManager(ExactWeakKeyDictionary):
770    count = 0
771    instance: ClassVar["CleanupManager"]
772
773    def _remove_id(self, idx):
774        for hook in self.values[idx]:
775            hook()
776        super()._remove_id(idx)
777
778
779CleanupManager.instance = CleanupManager()
780
781
782def clone_tensor(x):
783    """Clone the tensor and its gradient"""
784    y = x.clone().requires_grad_(x.requires_grad)
785    if x.is_leaf and x.grad is not None:
786        y.grad = x.grad.clone()
787    return y
788
789
790def clone_input(x, *, dtype=None):
791    """copy while preserving strides"""
792    # TODO: this is questionable
793    if is_fake(x):
794        # this func fails on fake tensors in __torch_dispatch__
795        return x
796
797    def torch_clone(x):
798        y = torch.clone(x)
799        if x.is_leaf:
800            y.requires_grad_(x.requires_grad)
801        if x.is_leaf and x.grad is not None:
802            y.grad = clone_input(x.grad, dtype=dtype)
803        if hasattr(x, "_dynamo_dynamic_indices"):
804            y._dynamo_dynamic_indices = x._dynamo_dynamic_indices.copy()  # type: ignore[attr-defined]
805        return y
806
807    with torch.no_grad():
808        if x.device.type == "xla":
809            # Access data_ptr() for a xla tensor will cause crash
810            return torch_clone(x)
811
812        # Handle sparse storage (no stride).
813        if x.layout is torch.sparse_coo:
814            return torch.sparse_coo_tensor(
815                torch_clone(x._indices()),
816                torch_clone(x._values()),
817                x.shape,
818                is_coalesced=x.is_coalesced(),
819            )
820        elif is_sparse_compressed(x):
821            if x.layout in {torch.sparse_csr, torch.sparse_bsr}:
822                compressed_indices = x.crow_indices()
823                plain_indices = x.col_indices()
824            else:
825                compressed_indices = x.ccol_indices()
826                plain_indices = x.row_indices()
827            return torch.sparse_compressed_tensor(
828                torch_clone(compressed_indices),
829                torch_clone(plain_indices),
830                torch_clone(x.values()),
831                x.shape,
832                layout=x.layout,
833            )
834
835        needed_size = sum(
836            (shape - 1) * stride for shape, stride in zip(x.size(), x.stride())
837        )
838        if x.is_quantized:
839            result = torch.empty_quantized((needed_size + 32,), x)
840        else:
841            result = torch.empty(
842                needed_size + 32, dtype=dtype or x.dtype, device=x.device
843            )
844        cache_line_offset = (
845            (x.data_ptr() - result.data_ptr()) % 32
846        ) // x.element_size()
847        result.as_strided_(x.size(), x.stride(), cache_line_offset)
848        try:
849            result.copy_(x.clone())
850            if x.is_leaf:
851                result.requires_grad_(x.requires_grad)
852            if x.is_leaf and x.grad is not None:
853                result.grad = clone_input(x.grad, dtype=dtype)
854        except RuntimeError:
855            # RuntimeError: unsupported operation: more than one element of the written-to
856            # tensor refers to a single memory location. Please clone() the tensor before
857            # performing the operation.
858            return torch_clone(x)
859        if hasattr(x, "_dynamo_dynamic_indices"):
860            result._dynamo_dynamic_indices = x._dynamo_dynamic_indices.copy()  # type: ignore[attr-defined]
861        return result
862
863
864def clone_inputs(example_inputs):
865    res: Union[Dict[Any, Any], List[Any]]
866    if type(example_inputs) is dict:
867        res = dict(example_inputs)
868        for key, value in res.items():
869            if isinstance(value, tuple):
870                res[key] = clone_inputs(value)
871            else:
872                assert isinstance(value, torch.Tensor), type(value)
873                res[key] = clone_input(value)
874        return res
875
876    res = list(example_inputs)
877    for i in range(len(res)):
878        if isinstance(res[i], torch.Tensor):
879            res[i] = clone_input(res[i])
880    return res
881
882
883def skip_frame_if_in_functorch_mode(val: torch.Tensor):
884    try:
885        val.data_ptr()  # will throw for functorch tensors
886    except RuntimeError as e:
887        from .exc import SkipFrame
888
889        # This will be GradTrackingTensor/BatchedTensor/etc
890        functorch_subclass_name = re.sub(r"\(.*", "", repr(val))
891        raise SkipFrame(
892            f"torch.compile cannot be run in context: {functorch_subclass_name}"
893        ) from e
894
895
896@contextmanager
897def preserve_rng_state():
898    disable_functorch = torch._C._DisableFuncTorch
899    disable_current_modes = torch.utils._python_dispatch._disable_current_modes
900    with disable_current_modes(), disable_functorch():
901        rng_state = torch.clone(torch.random.get_rng_state())
902        skip_frame_if_in_functorch_mode(rng_state)
903        if torch.cuda.is_available():
904            cuda_rng_state = torch.clone(torch.cuda.get_rng_state())
905    try:
906        yield
907    finally:
908        with torch.utils._python_dispatch._disable_current_modes():
909            torch.random.set_rng_state(rng_state)
910            if torch.cuda.is_available():
911                torch.cuda.set_rng_state(cuda_rng_state)  # type: ignore[possibly-undefined]
912
913
914def is_jit_model(model0):
915    return isinstance(
916        model0,
917        (
918            torch.jit._trace.TopLevelTracedModule,
919            torch.jit._script.RecursiveScriptModule,
920            torch.jit.ScriptFunction,
921            torch.jit.ScriptModule,
922        ),
923    )
924
925
926def torchscript(model, example_inputs, verbose=False):
927    if is_jit_model(model):
928        # already done?
929        return model
930
931    try:
932        return torch.jit.trace(model, example_inputs)
933    except Exception:
934        try:
935            return torch.jit.script(model)
936        except Exception:
937            if verbose:
938                log.exception("jit error")
939            else:
940                log.error("Both torch.jit.trace and torch.jit.script failed")
941    return None
942
943
944def getfile(obj):
945    try:
946        return inspect.getfile(obj)
947    except (TypeError, OSError):
948        return None
949
950
951def is_namedtuple(obj):
952    """Test if an object is a namedtuple or a torch.return_types.* quasi-namedtuple"""
953    return is_namedtuple_cls(type(obj))
954
955
956def is_namedtuple_cls(cls):
957    """Test if an object is a namedtuple or a (torch.return_types|torch.autograd.forward_ad).* quasi-namedtuple"""
958    try:
959        if issubclass(cls, tuple):
960            bases = getattr(cls, "__bases__", []) or [None]
961            module = getattr(cls, "__module__", None)
962            return module in ("torch.return_types", "torch.autograd.forward_ad") or (
963                bases[0] is tuple and hasattr(cls, "_make") and hasattr(cls, "_fields")
964            )
965    except TypeError:
966        pass
967    return False
968
969
970@functools.lru_cache(1)
971def namedtuple_fields(cls):
972    """Get the fields of a namedtuple or a torch.return_types.* quasi-namedtuple"""
973    if cls is slice:
974        return ["start", "stop", "step"]
975
976    assert issubclass(cls, tuple)
977    if hasattr(cls, "_fields"):
978        # normal namedtuples
979        return cls._fields
980
981    @dataclasses.dataclass
982    class Marker:
983        index: int
984
985    # frustrating ones e.g. torch.return_types.max
986    assert cls.__module__ == "torch.return_types"
987    obj = cls(map(Marker, range(cls.n_fields)))
988    fields: List[Optional[str]] = [None] * cls.n_fields
989    for name in dir(obj):
990        if name[0] != "_" and isinstance(getattr(obj, name), Marker):
991            fields[getattr(obj, name).index] = name
992    return fields
993
994
995def checkpoint_params(gm):
996    with torch.no_grad():
997        rng_state = torch.clone(torch.random.get_rng_state())
998        if torch.cuda.is_available():
999            cuda_rng_state = torch.clone(torch.cuda.get_rng_state())
1000        saved_state = []
1001        for param in itertools.chain(gm.parameters(), gm.buffers()):
1002            saved_state.append((param, param._version, torch.clone(param)))
1003
1004    def restore():
1005        with torch.no_grad():
1006            torch.random.set_rng_state(rng_state)
1007            if torch.cuda.is_available():
1008                torch.cuda.set_rng_state(cuda_rng_state)
1009            for param, version, original_value in saved_state:
1010                if param._version != version:
1011                    param.copy_(original_value)
1012
1013    return restore
1014
1015
1016def timed(model, example_inputs, times=1):
1017    if torch.cuda.is_available():
1018        synchronize = torch.cuda.synchronize
1019    else:
1020        synchronize = nothing
1021
1022    synchronize()
1023    gc.collect()
1024    torch.manual_seed(1337)
1025    t0 = time.perf_counter()
1026    for _ in range(times):
1027        result = model(*example_inputs)
1028        synchronize()
1029    t1 = time.perf_counter()
1030    return result, t1 - t0  # type: ignore[possibly-undefined]
1031
1032
1033def check_is_cuda(gm, example_inputs):
1034    return all(x.is_cuda for x in itertools.chain(example_inputs, gm.parameters(True)))
1035
1036
1037@lru_cache(32)
1038def rot_n_helper(n):
1039    assert n > 1
1040    vars = [f"v{i}" for i in range(n)]
1041    rotated = reversed(vars[-1:] + vars[:-1])
1042    fn = eval(f"lambda {','.join(vars)}: ({','.join(rotated)})")
1043    fn.__name__ = f"rot_{n}_helper"
1044    return fn
1045
1046
1047common_constant_types = {
1048    int,
1049    float,
1050    complex,
1051    bool,
1052    str,
1053    bytes,
1054    type(None),
1055    Ellipsis.__class__,
1056    types.CodeType,
1057    torch.device,
1058    torch.dtype,
1059    torch.memory_format,
1060    torch.layout,
1061}
1062
1063if has_triton_package():
1064    import triton
1065
1066    common_constant_types.add(triton.language.dtype)
1067
1068
1069def is_safe_constant(v):
1070    if istype(v, (tuple, frozenset)):
1071        return all(map(is_safe_constant, v))
1072    return isinstance(v, (enum.Enum, type)) or istype(
1073        v,
1074        common_constant_types | {slice},
1075    )
1076
1077
1078def specialize_symnode(arg):
1079    from .variables import ConstantVariable, SymNodeVariable
1080
1081    # Guard and specialize
1082    if isinstance(arg, SymNodeVariable):
1083        return ConstantVariable.create(arg.evaluate_expr())
1084
1085    return arg
1086
1087
1088def guard_if_dyn(arg):
1089    from .variables import ConstantVariable
1090
1091    arg = specialize_symnode(arg)
1092
1093    if isinstance(arg, ConstantVariable):
1094        return arg.as_python_constant()
1095
1096    return arg
1097
1098
1099def check_constant_args(args, kwargs):
1100    return all(x.is_python_constant() for x in itertools.chain(args, kwargs.values()))
1101
1102
1103def check_unspec_python_args(args, kwargs):
1104    from .variables.constant import ConstantVariable
1105    from .variables.tensor import UnspecializedPythonVariable
1106
1107    unspec_count = 0
1108    for x in itertools.chain(args, kwargs.values()):
1109        if isinstance(x, UnspecializedPythonVariable):
1110            unspec_count += 1
1111        elif not isinstance(x, ConstantVariable):
1112            return False
1113    return unspec_count > 0
1114
1115
1116def check_unspec_or_constant_args(args, kwargs):
1117    # A fused version of:
1118    # return check_constant_args(args, kwargs) or check_unspec_python_args(args, kwargs)
1119    from .variables.tensor import UnspecializedPythonVariable
1120
1121    for x in itertools.chain(args, kwargs.values()):
1122        if not (x.is_python_constant() or isinstance(x, UnspecializedPythonVariable)):
1123            return False
1124    return True
1125
1126
1127def check_numpy_ndarray_args(args, kwargs):
1128    from .variables.tensor import NumpyNdarrayVariable
1129
1130    return any(
1131        isinstance(x, NumpyNdarrayVariable)
1132        for x in itertools.chain(args, kwargs.values())
1133    )
1134
1135
1136dict_keys: Type[KeysView[Any]] = type(dict().keys())
1137dict_values: Type[ValuesView[Any]] = type(dict().values())
1138odict_values: Type[ValuesView[Any]] = type(collections.OrderedDict().values())
1139tuple_iterator: Type[Iterator[Any]] = type(iter(tuple()))
1140tuple_iterator_len = tuple_iterator.__length_hint__  # type: ignore[attr-defined]
1141object_new = object.__new__
1142
1143
1144def nn_module_new(cls):
1145    obj = object_new(cls)
1146    torch.nn.Module.__init__(obj)
1147    return obj
1148
1149
1150def product(it):
1151    return functools.reduce(operator.mul, it, 1)
1152
1153
1154def tuple_iterator_getitem(it, index):
1155    _, (obj,), start = it.__reduce__()
1156    return obj[start + index]
1157
1158
1159iter_next = next
1160
1161
1162def to_subclass(t, cls):
1163    return t.as_subclass(cls)
1164
1165
1166def dict_keys_getitem(d, n):
1167    return next(itertools.islice(iter(d), n, n + 1))
1168
1169
1170def enum_repr(value, local):
1171    # enum class can override __str__ method. Use __class__ and name attribute
1172    # to extract the class name and key name.
1173    name = value.__class__.__name__
1174    val = value.name
1175    scope = "L" if local else "G"
1176    local_name = f'{scope}["{name}"].{val}'
1177    return local_name
1178
1179
1180def set_example_value(node, example_value):
1181    # NB: example_value is a bit of a misnomer, because this is always a fake
1182    # tensor of some sort.  Furthermore, these example values serve as the
1183    # runtime state of Dynamo tracing, which means if metadata mutation
1184    # occurs, the example_value gets directly updated (so you can't rely on
1185    # this to accurately reflect what the state of the value was at the time
1186    # the program was traced).
1187    node.meta["example_value"] = example_value
1188    shape_env = TracingContext.get().fake_mode.shape_env
1189    if symbol_to_path := torch.fx.experimental.symbolic_shapes.compute_unbacked_bindings(
1190        shape_env, example_value
1191    ):
1192        node.meta["unbacked_bindings"] = symbol_to_path
1193
1194
1195def _get_fake_tensor(vt):
1196    fake_tensor = vt.as_proxy().node.meta.get("example_value")
1197    if not is_fake(fake_tensor):
1198        from .exc import unimplemented
1199
1200        unimplemented("Cannot check Tensor object identity without its fake value")
1201    return fake_tensor
1202
1203
1204def iter_contains(items, search, tx, check_tensor_identity=False):
1205    from .variables import (
1206        BuiltinVariable,
1207        ConstantVariable,
1208        TensorVariable,
1209        VariableTracker,
1210    )
1211
1212    if search.is_python_constant():
1213        found_const = any(
1214            x.is_python_constant()
1215            and x.as_python_constant() == search.as_python_constant()
1216            for x in items
1217        )
1218        return ConstantVariable.create(found_const)
1219
1220    must_check_tensor_id = False
1221    if check_tensor_identity and isinstance(search, TensorVariable):
1222        must_check_tensor_id = True
1223        # Match of Tensor means match of FakeTensor
1224        search = _get_fake_tensor(search)
1225
1226    found: Optional[VariableTracker] = None
1227    for x in items:
1228        if must_check_tensor_id:
1229            if isinstance(x, TensorVariable):
1230                if search is _get_fake_tensor(x):  # Object equivalence
1231                    return ConstantVariable.create(True)
1232        else:
1233            check = BuiltinVariable(operator.eq).call_function(tx, [x, search], {})
1234            if found is None:
1235                found = check
1236            else:
1237                found = BuiltinVariable(operator.or_).call_function(
1238                    tx, [check, found], {}
1239                )
1240    if found is None:
1241        found = ConstantVariable.create(False)
1242    return found
1243
1244
1245def key_is_id(k):
1246    """Returns whether it indexes dictionaries using its id"""
1247    return isinstance(k, (torch.Tensor, torch.nn.Module, MethodWrapperType))
1248
1249
1250def key_to_id(value):
1251    return [id(k) if key_is_id(k) else k for k in value.keys()]
1252
1253
1254def const_repr(x, *, local) -> str:
1255    from .trace_rules import is_builtin_callable
1256
1257    if isinstance(x, (list, tuple)):
1258        elems_repr = ",".join(const_repr(s, local=local) for s in x)
1259        if isinstance(x, list):
1260            return f"[{elems_repr}]"
1261        else:
1262            assert isinstance(x, tuple)
1263            if len(x) == 1:
1264                return f"({elems_repr},)"
1265            else:
1266                return f"({elems_repr})"
1267    elif isinstance(x, enum.Enum):
1268        # To workaround repr(Enum) returning invalid global reference before python 3.11
1269        # by calling enum_repr and removing quotes to render enum in guard code.
1270        return enum_repr(x, local=local).replace("'", "")
1271    elif is_builtin_callable(x):
1272        return x.__name__
1273    elif isinstance(x, type):
1274
1275        def fullname(o):
1276            klass = o.__class__
1277            module = klass.__module__
1278            if module == "builtins":
1279                return klass.__qualname__  # avoid outputs like 'builtins.str'
1280            return module + "." + klass.__qualname__
1281
1282        return fullname(x)
1283    else:
1284        return f"{x!r}"
1285
1286
1287def dict_keys_repr(const_keys, *, local) -> str:
1288    keys_str = ",".join(const_repr(s, local=local) for s in const_keys)
1289    return "[" + keys_str + "]"
1290
1291
1292GLOBAL_KEY_PREFIX = "__dict_key"
1293
1294
1295from torch._subclasses import UnsupportedFakeTensorException  # noqa: F401
1296
1297
1298def wrap_fake_exception(fn):
1299    try:
1300        return fn()
1301    except UnsupportedFakeTensorException as e:
1302        from .exc import unimplemented
1303
1304        msg = f"Unsupported: {e.reason} with fake tensor propagation."
1305        log.warning(msg)
1306        unimplemented(msg, from_exc=e)
1307
1308
1309def deepcopy_to_fake_tensor(obj, fake_mode):
1310    with torch._subclasses.fake_tensor.FakeCopyMode(fake_mode):
1311        return wrap_fake_exception(lambda: copy.deepcopy(obj))
1312
1313
1314def rmse(ref, res):
1315    """
1316    Calculate root mean squared error
1317    """
1318    return torch.sqrt(torch.mean(torch.square(ref - res)))
1319
1320
1321def same(
1322    ref,
1323    res,
1324    fp64_ref=None,
1325    cos_similarity=False,
1326    tol=1e-4,
1327    equal_nan=False,
1328    exact_dtype=True,
1329    relax_numpy_equality=False,
1330    ignore_non_fp=False,
1331    log_error=log.error,
1332):
1333    """Check correctness to see if ref and res match"""
1334    if fp64_ref is None:
1335        fp64_ref = ref
1336    if isinstance(ref, (list, tuple, torch.nn.ParameterList, torch.Size)):
1337        assert isinstance(res, (list, tuple)), f"type mismatch {type(ref)} {type(res)}"
1338        if len(ref) != len(res):
1339            log_error("Length mismatch")
1340            return False
1341        return len(ref) == len(res) and all(
1342            same(
1343                ai,
1344                bi,
1345                fp64_refi,
1346                cos_similarity,
1347                tol,
1348                equal_nan,
1349                exact_dtype,
1350                relax_numpy_equality,
1351                ignore_non_fp,
1352                log_error=log_error,
1353            )
1354            for ai, bi, fp64_refi in zip(ref, res, fp64_ref)
1355        )
1356    elif type(ref).__name__ == "QuestionAnsweringModelOutput":
1357        # This skips checking accuracy for start_logits/end_logits.
1358        # Tentatively, start_logits/end_logits appear to be very prone to
1359        # inaccuracies and is somewhat subsumed by checking the loss.
1360        return same(
1361            ref.loss,
1362            res.loss,
1363            fp64_ref.loss,
1364            cos_similarity,
1365            tol,
1366            equal_nan,
1367            exact_dtype,
1368            relax_numpy_equality,
1369            ignore_non_fp,
1370            log_error=log_error,
1371        )
1372    elif isinstance(ref, dict):
1373        assert isinstance(res, dict)
1374        assert set(ref.keys()) == set(
1375            res.keys()
1376        ), f"keys mismatch {set(ref.keys())} == {set(res.keys())}"
1377        for k in sorted(ref.keys()):
1378            if not (
1379                same(
1380                    ref[k],
1381                    res[k],
1382                    fp64_ref[k],
1383                    cos_similarity=cos_similarity,
1384                    tol=tol,
1385                    equal_nan=equal_nan,
1386                    exact_dtype=exact_dtype,
1387                    relax_numpy_equality=relax_numpy_equality,
1388                    ignore_non_fp=ignore_non_fp,
1389                    log_error=log_error,
1390                )
1391            ):
1392                log_error("Accuracy failed for key name %s", k)
1393                return False
1394        return True
1395    elif isinstance(ref, (torch.Tensor, float)):
1396        assert not isinstance(ref, torch._subclasses.FakeTensor)
1397        assert not isinstance(res, torch._subclasses.FakeTensor)
1398
1399        def to_tensor(t):
1400            return t if isinstance(t, torch.Tensor) else torch.tensor(t)
1401
1402        ref, res, fp64_ref = (to_tensor(val) for val in (ref, res, fp64_ref))
1403
1404        if ref.is_sparse:
1405            assert res.is_sparse
1406            ref = ref.to_dense()
1407            res = res.to_dense()
1408        assert isinstance(res, torch.Tensor), f"type mismatch {type(ref)} {type(res)}"
1409        if exact_dtype:
1410            if ref.dtype != res.dtype:
1411                log_error("dtype mismatch %s, %s", ref.dtype, res.dtype)
1412                return False
1413            if ref.dtype == torch.bool:
1414                if ignore_non_fp:
1415                    return True
1416                # triton stores bool as int8, so add this for more accurate checking
1417                r = torch.allclose(
1418                    ref.to(dtype=torch.uint8),
1419                    res.to(dtype=torch.uint8),
1420                    atol=tol,
1421                    rtol=tol,
1422                    equal_nan=equal_nan,
1423                )
1424                if not r:
1425                    log_error("Accuracy failed: uint8 tensor did not match")
1426                return r
1427
1428        if cos_similarity:
1429            ref = ref.flatten().to(torch.float32)
1430            res = res.flatten().to(torch.float32)
1431            if torch.allclose(ref, res, atol=tol, rtol=tol, equal_nan=True):
1432                # early exit that handles zero/nan better
1433                # cosine_similarity(zeros(10), zeros(10), dim=0) is 0
1434                return True
1435            score = torch.nn.functional.cosine_similarity(ref, res, dim=0, eps=1e-6)
1436            if score < 0.99:
1437                log.warning("Similarity score=%s", score.cpu().detach().item())
1438            return score >= 0.99
1439        else:
1440            if not exact_dtype:
1441                ref = ref.to(res.dtype)
1442
1443            # First try usual allclose
1444            if torch.allclose(ref, res, atol=tol, rtol=tol, equal_nan=equal_nan):
1445                return True
1446
1447            # Check error from fp64 version
1448            if fp64_ref.dtype == torch.float64:
1449                ref_error = rmse(fp64_ref, ref).item()
1450                # ref unable to produce this with stable numerics in this precision, ignore
1451                if math.isnan(ref_error):
1452                    log.warning(
1453                        "Found nan in reference. Consider running in higher precision."
1454                    )
1455
1456                res_error = rmse(fp64_ref, res).item()
1457
1458                # In the case of using AMP (Automatic Mixed Precision), certain models have
1459                # failed the benchmark's correctness check. However, the end-to-end model's
1460                # accuracy when comparing AMP with FP32 is within a difference of less than 0.1%.
1461                # Thus, it's possible that the correctness check failures for these models are
1462                # false alarms. We use multiplier of 3 instead of 2 to avoid these false alarms.
1463                multiplier = 3.0 if res.dtype == torch.bfloat16 else 2.0
1464
1465                if (
1466                    fp64_ref.numel() < 1000
1467                    or (ref.ndim == 4 and ref.shape[-1] == ref.shape[-2] == 1)
1468                    # large tol means a benchmark has been specified as REQUIRE_HIGHER_TOLERANCE
1469                    or tol >= 2 * 1e-2
1470                ):
1471                    # In the presence of noise, noise might dominate our error
1472                    # metric for smaller tensors.
1473                    # Similary, for 1x1 kernels, there seems to be high noise with amp.
1474                    multiplier = 3.0
1475
1476                passes_test = res_error <= (multiplier * ref_error + tol / 10.0)
1477                if not passes_test:
1478                    log_error(
1479                        "RMSE (res-fp64): %.5f, (ref-fp64): %.5f and shape=%s. res.dtype: %s, multiplier: %f, tol: %f",
1480                        res_error,
1481                        ref_error,
1482                        res.size(),
1483                        res.dtype,
1484                        multiplier,
1485                        tol,
1486                    )
1487                    # import pdb; pdb.set_trace()
1488                return passes_test
1489
1490            if ignore_non_fp:
1491                return True
1492
1493            log_error("Accuracy failed: allclose not within tol=%s", tol)
1494            return False
1495    elif isinstance(ref, (str, int, type(None), bool, torch.device)):
1496        if ignore_non_fp:
1497            return True
1498        r = ref == res
1499        if not r:
1500            log_error("Accuracy failed (%s): %s != %s", type(ref), ref, res)
1501        return r
1502    elif is_numpy_int_type(ref) or is_numpy_float_type(ref):
1503        if relax_numpy_equality and not (
1504            is_numpy_int_type(res) or is_numpy_float_type(res)
1505        ):
1506            ref = ref.item()
1507        r = (type(ref) is type(res)) and (ref == res)
1508        if not r:
1509            log_error("Accuracy failed (numpy): %s != %s", ref, res)
1510        return r
1511    elif is_numpy_ndarray(ref):
1512        return (type(ref) is type(res)) and same(
1513            torch.as_tensor(ref),
1514            torch.as_tensor(res),
1515            fp64_ref,
1516            cos_similarity=cos_similarity,
1517            tol=tol,
1518            equal_nan=equal_nan,
1519            exact_dtype=exact_dtype,
1520            relax_numpy_equality=relax_numpy_equality,
1521            ignore_non_fp=ignore_non_fp,
1522            log_error=log_error,
1523        )
1524    elif type(ref).__name__ in (
1525        "MaskedLMOutput",
1526        "Seq2SeqLMOutput",
1527        "CausalLMOutputWithCrossAttentions",
1528        "LongformerMaskedLMOutput",
1529        "Instances",
1530        "SquashedNormal",
1531        "Boxes",
1532        "Normal",
1533        "TanhTransform",
1534        "Foo",
1535        "Variable",
1536    ):
1537        assert type(ref) is type(res)
1538        return all(
1539            same(
1540                getattr(ref, key),
1541                getattr(res, key),
1542                getattr(fp64_ref, key),
1543                cos_similarity=cos_similarity,
1544                tol=tol,
1545                equal_nan=equal_nan,
1546                exact_dtype=exact_dtype,
1547                relax_numpy_equality=relax_numpy_equality,
1548                ignore_non_fp=ignore_non_fp,
1549                log_error=log_error,
1550            )
1551            for key in ref.__dict__.keys()
1552        )
1553    else:
1554        raise RuntimeError(f"unsupported type: {type(ref).__name__}")
1555
1556
1557def format_func_info(code):
1558    short_filename = code.co_filename.split("/")[-1]
1559    return f"'{code.co_name}' ({short_filename}:{code.co_firstlineno})"
1560
1561
1562@contextlib.contextmanager
1563def disable_cache_limit():
1564    prior = config.cache_size_limit
1565    config.cache_size_limit = sys.maxsize
1566    prior_acc_limit = config.accumulated_cache_size_limit
1567    config.accumulated_cache_size_limit = sys.maxsize
1568
1569    try:
1570        yield
1571    finally:
1572        config.cache_size_limit = prior
1573        config.accumulated_cache_size_limit = prior_acc_limit
1574
1575
1576# map from transformed code back to original user code
1577orig_code_map = ExactWeakKeyDictionary()
1578
1579# keep a record of code_obj -> list of guard failure reasons for logging
1580guard_failures: DefaultDict[Any, List[Any]] = collections.defaultdict(list)
1581
1582# Keep a record of graph break reasons for logging
1583graph_break_reasons: List["torch._dynamo.output_graph.GraphCompileReason"] = list()
1584
1585# keep record of compiled code, if we are in "error if recompile"
1586# to track code that dynamo has compiled previously
1587seen_code_map = ExactWeakKeyDictionary()
1588
1589
1590class CompileProfiler:
1591    """Utility for profiling how and what dynamo would compile.
1592
1593    Can be used for
1594     * diagnosing recompilation issues
1595     * determining an appropriate compile cache limit
1596     * (TODO)confirming which functions got compiled/skipped
1597    """
1598
1599    def __init__(self):
1600        self.frame_count = 0
1601        self.op_count = 0
1602        self.backend_ctx_ctor = disable_cache_limit
1603
1604    def __call__(self, gm: torch.fx.GraphModule, example_inputs):
1605        self.frame_count += 1
1606        for node in gm.graph.nodes:
1607            if "call" in node.op:
1608                self.op_count += 1
1609        return gm.forward
1610
1611    # no-op __enter__ and __exit__ to preserve BC
1612    def __enter__(self):
1613        return self
1614
1615    def __exit__(self, typ, val, traceback):
1616        pass
1617
1618    def get_metrics(self):
1619        return {"guard_failures": guard_failures}
1620
1621    def report(self):
1622        metrics = self.get_metrics()
1623        gf = metrics["guard_failures"]
1624
1625        def num_recompiles(code):
1626            return len(gf[code])
1627
1628        def recompile_reasons(code):
1629            return "\n".join([str(x) for x in gf[code]])
1630
1631        summarized_gf = [
1632            [format_func_info(code), num_recompiles(code), recompile_reasons(code)]
1633            for code in gf
1634        ]
1635
1636        def graph_break_report():
1637            if "graph_break" in counters:
1638                graph_breaks = counters["graph_break"]
1639                return tabulate(
1640                    [[msg, graph_breaks[msg]] for msg in graph_breaks],
1641                    headers=["Graph Break Reason", "Count"],
1642                )
1643
1644        def recompilation_report():
1645            if len(gf):
1646                max_recompiles = max(num_recompiles(code) for code in gf)
1647                recomp_table = tabulate(
1648                    summarized_gf,
1649                    headers=["Function", "Recompiles", "Recompile Reasons"],
1650                )
1651                return recomp_table + textwrap.dedent(
1652                    f"""
1653
1654                    Set torch._dynamo.config.cache_size_limit to {max_recompiles} to avoid being cache limited.
1655                """
1656                )
1657
1658        report = textwrap.dedent(
1659            """
1660            Torchdynamo Profiler Report
1661            ===========================
1662
1663            Graph Breaks
1664            ------------
1665            Graph breaks happen when torchdynamo encounters code it can't safely trace.
1666            If you want to find out why breaks are happening, check below for each break reason
1667            You may gain additional insight by passing `fullgraph=True` to torch.compile,
1668            to stop at the first break.
1669
1670        """
1671        )
1672        report += graph_break_report() or "No graph breaks detected."
1673        report += textwrap.dedent(
1674            """
1675
1676            Recompilation
1677            -------------
1678            These subgraphs were recompiled more than once due to guard failures
1679            Guard failures indicate some condition assumed to be static by the tracer changed,
1680            making it unsafe to reuse the compiled program.
1681
1682        """
1683        )
1684        report += recompilation_report() or "No recompilation detected.\n"
1685        return report
1686
1687
1688# return same dir unless user changes config between calls
1689@functools.lru_cache(None)
1690def _get_debug_dir(root_dir):
1691    dir_name = (
1692        "run_"
1693        + datetime.datetime.now().strftime("%Y_%m_%d_%H_%M_%S_%f")
1694        # use pid to avoid conflicts among ranks
1695        + "-pid_"
1696        + str(os.getpid())
1697    )
1698    return os.path.join(root_dir, dir_name)
1699
1700
1701def get_debug_dir():
1702    debug_root = config.debug_dir_root
1703    return _get_debug_dir(debug_root)
1704
1705
1706def extract_fake_example_value(node, required=True):
1707    if "example_value" in node.meta and is_fake(node.meta["example_value"]):
1708        return node.meta["example_value"]
1709    elif required:
1710        from torch._dynamo.exc import unimplemented
1711
1712        unimplemented("`FakeTensor` example value was required but not available")
1713    else:
1714        return None
1715
1716
1717def ensure_graph_fake(e, tx):
1718    assert maybe_get_fake_mode(e) is tx.fake_mode
1719    return e
1720
1721
1722def get_fake_values_from_nodes(tx, nodes, allow_non_graph_fake):
1723    def visit(n: torch.fx.Node):
1724        if n.op == "call_function" and "example_value" not in n.meta:
1725            # fake tensor validity is checked inside get_fake_value using
1726            # ensure_graph_fake
1727            return get_fake_value(n, tx, allow_non_graph_fake)
1728
1729        out = n.meta["example_value"]
1730        if not allow_non_graph_fake and isinstance(out, torch.Tensor):
1731            return ensure_graph_fake(out, tx)
1732        return out
1733
1734    return torch.fx.node.map_arg(nodes, visit)
1735
1736
1737def get_fake_value(node, tx, allow_non_graph_fake=False):
1738    """
1739    Run the computation represented by `node` using fake tensors and return the result.
1740
1741    allow_non_graph_fake: whether to allow the return result to be:
1742        1. non-fake or 2. fake that is not created by this instance of Dynamo.
1743        If `True`, you must be prepared to deal with such return values, ideally
1744        by further wrapping them as this graph's fakes.
1745    """
1746    from torch.utils._sympy.value_ranges import ValueRangeError
1747    from .exc import (
1748        TorchRuntimeError,
1749        unimplemented,
1750        Unsupported,
1751        UserError,
1752        UserErrorType,
1753    )
1754
1755    op = node.op
1756
1757    # FX Node should always return the same fake value
1758    if "example_value" in node.meta and is_fake(node.meta["example_value"]):
1759        return node.meta["example_value"]
1760
1761    args, kwargs = get_fake_values_from_nodes(
1762        tx, (node.args, node.kwargs), allow_non_graph_fake
1763    )
1764
1765    nnmodule = None
1766    if op == "call_method" and len(args) > 0 and isinstance(args[0], torch.nn.Module):
1767        # If the first argument is nn.Module, should copy to fake mode.
1768        args = (deepcopy_to_fake_tensor(args[0], tx.fake_mode),) + tuple(args[1:])
1769
1770    if op == "call_module":
1771        nnmodule = tx.output.nn_modules[node.target]
1772
1773        if is_lazy_module(nnmodule) and hasattr(nnmodule, "_initialize_hook"):
1774            # In the case of a lazy module, we want to run
1775            # the pre-hooks which initialize it.
1776            # Afterwards, lazy module deletes its pre-hooks
1777            # to avoid treating it as lazy on subsequent recompile.
1778            nnmodule._infer_parameters(nnmodule, args)
1779
1780        # no matter it's lazy module or not, we should copy to fake mode.
1781        nnmodule = deepcopy_to_fake_tensor(nnmodule, tx.fake_mode)
1782
1783    try:
1784        with tx.fake_mode, enable_python_dispatcher():
1785            ret_val = wrap_fake_exception(
1786                lambda: run_node(tx.output, node, args, kwargs, nnmodule)
1787            )
1788    except Unsupported:
1789        raise
1790    except RuntimeError as e:
1791        cause: BaseException = e
1792        if e.__cause__ is not None:
1793            cause = e.__cause__
1794
1795        if isinstance(
1796            cause, torch._subclasses.fake_tensor.DataDependentOutputException
1797        ):
1798            unimplemented(
1799                f"data dependent operator: {cause.func}; "
1800                "to enable, set torch._dynamo.config.capture_scalar_outputs = True"
1801            )
1802        elif isinstance(
1803            cause, torch._subclasses.fake_tensor.DynamicOutputShapeException
1804        ):
1805            if not torch._dynamo.config.capture_dynamic_output_shape_ops:
1806                unimplemented(
1807                    f"dynamic shape operator: {cause.func}; "
1808                    "to enable, set torch._dynamo.config.capture_dynamic_output_shape_ops = True"
1809                )
1810            else:
1811                unimplemented(
1812                    f"dynamic shape operator: {cause.func}; "
1813                    "Operator does not have a meta kernel that supports dynamic output shapes, "
1814                    "please report an issue to PyTorch"
1815                )
1816        elif isinstance(
1817            cause, torch._subclasses.fake_tensor.UnsupportedOperatorException
1818        ):
1819            op = cause.func
1820            import_suggestion = ""
1821            if isinstance(op, torch._ops.OpOverload):
1822                maybe_pystub = torch._C._dispatch_pystub(
1823                    op._schema.name, op._schema.overload_name
1824                )
1825                if maybe_pystub is not None:
1826                    module, ctx = maybe_pystub
1827                    import_suggestion = (
1828                        f"It's possible that the support was implemented in "
1829                        f"module `{module}` and you may need to `import {module}`"
1830                        f"({ctx}), otherwise "
1831                    )
1832            unimplemented(
1833                f"unsupported operator: {cause.func} ({import_suggestion}see "
1834                "https://docs.google.com/document/d/1GgvOe7C8_NVOMLOCwDaYV1mXXyHMXY7ExoewHqooxrs/edit#heading=h.64r4npvq0w0"
1835                " for how to fix)"
1836            )
1837        elif isinstance(
1838            cause, torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode
1839        ):
1840            raise UserError(  # noqa: B904
1841                UserErrorType.CONSTRAINT_VIOLATION,
1842                "Tried to use data-dependent value in the subsequent computation. "
1843                "This can happen when we encounter unbounded dynamic value that is unknown during tracing time.  "
1844                "You will need to explicitly give hint to the compiler. Please take a look at "
1845                f"torch._check OR torch._check_is_size APIs.  {cause}",
1846                case_name="constrain_as_size_example",
1847            )
1848        elif isinstance(cause, ValueRangeError):
1849            raise UserError(UserErrorType.CONSTRAINT_VIOLATION, e.args[0]) from e
1850        elif isinstance(cause, TypeError) and "argument" in str(cause):
1851            unimplemented(f"TypeError {node.target}: {cause}")
1852
1853        raise TorchRuntimeError(str(e)).with_traceback(e.__traceback__) from None
1854
1855    if not allow_non_graph_fake:
1856        _ = pytree.tree_map_only(
1857            torch.Tensor, functools.partial(ensure_graph_fake, tx=tx), ret_val
1858        )
1859    return ret_val
1860
1861
1862_current_node = threading.local()
1863
1864
1865def get_current_node():
1866    return getattr(_current_node, "value", None)
1867
1868
1869@contextmanager
1870def set_current_node(node):
1871    old = get_current_node()
1872    _current_node.value = node
1873    try:
1874        yield
1875    finally:
1876        _current_node.value = old
1877
1878
1879def run_node(tracer, node, args, kwargs, nnmodule):
1880    """
1881    Runs a given node, with the given args and kwargs.
1882
1883    Behavior is dictated by a node's op.
1884
1885    run_node is useful for extracting real values out of nodes.
1886    See get_real_value for more info on common usage.
1887
1888    Note: The tracer arg is only used for 'get_attr' ops
1889    Note: The nnmodule arg is only used for 'call_module' ops
1890
1891    Nodes that are not call_function, call_method, call_module, or get_attr will
1892    raise an AssertionError.
1893    """
1894    op = node.op
1895
1896    with set_current_node(node):
1897
1898        def make_error_message(e):
1899            return f"Failed running {op} {node.target}(*{args}, **{kwargs}):\n" + str(e)
1900
1901        try:
1902            if op == "call_function":
1903                return node.target(*args, **kwargs)
1904            elif op == "call_method":
1905                return getattr(args[0], node.target)(*args[1:], **kwargs)
1906            elif op == "call_module":
1907                assert nnmodule is not None
1908                return nnmodule(*args, **kwargs)
1909            elif op == "get_attr":
1910                return tracer.output_graph.get_submodule(node.target)
1911            elif op == "placeholder":
1912                assert "example_value" in node.meta
1913                return node.meta["example_value"]
1914
1915        except (NotImplementedError, UnsupportedFakeTensorException) as e:
1916            # NB: mimic how wrap_fake_exception does it
1917            from .exc import unimplemented
1918
1919            unimplemented(make_error_message(e), from_exc=e)
1920        except Exception as e:
1921            raise RuntimeError(make_error_message(e)).with_traceback(
1922                e.__traceback__
1923            ) from e
1924
1925    raise AssertionError(op)
1926
1927
1928def get_real_value(node, tracer):
1929    """
1930    Run the actual computation represented by `node` and return the result.
1931    This will execute any dependent nodes in the graph as well.
1932    """
1933    from .exc import TorchRuntimeError
1934
1935    cache = tracer.real_value_cache
1936    if node in cache:
1937        return cache[node]
1938
1939    op = node.op
1940    args, kwargs = torch.fx.node.map_arg(
1941        (node.args, node.kwargs),
1942        lambda n: get_real_value(n, tracer),
1943    )
1944
1945    if op == "placeholder" and "grapharg" in node.meta:
1946        return node.meta["grapharg"].example
1947
1948    if op == "call_module":
1949        nn_module = tracer.output_graph.nn_modules[node.target]
1950        if not is_lazy_module(nn_module):
1951            nn_module = copy.deepcopy(nn_module)
1952        else:
1953            # In the case of a lazy module, we want to run
1954            # the pre-hooks which initialize it
1955            nn_module(*args, **kwargs)
1956    else:
1957        nn_module = None
1958
1959    try:
1960        real_value = run_node(tracer, node, args, kwargs, nn_module)
1961        cache[node] = real_value
1962    except RuntimeError as e:
1963        raise TorchRuntimeError(str(e)).with_traceback(e.__traceback__) from None
1964    return real_value
1965
1966
1967def assert_no_fake_params_or_buffers(gm):
1968    from torch._subclasses.fake_tensor import FakeTensorConfig, is_fake
1969
1970    def stack_or_hint(t):
1971        if FakeTensorConfig.debug:
1972            import traceback
1973
1974            return f"FAKE TENSOR CREATION TRACEBACK: \n {traceback.format_list(t._debug_trace)}"
1975        else:
1976            return "Enable TORCH_FAKE_TENSOR_DEBUG=1 to get creation stack traces on fake tensors."
1977
1978    for name, buffer in gm.named_buffers():
1979        assert not is_fake(
1980            buffer
1981        ), f"Unexpected fake buffer {name} {stack_or_hint(buffer)}"
1982    for name, param in gm.named_parameters():
1983        assert not is_fake(
1984            param
1985        ), f"Unexpected fake param {name} {stack_or_hint(param)}"
1986
1987
1988def fqn(obj: Any):
1989    """
1990    Returns the fully qualified name of the object.
1991    """
1992    return f"{obj.__module__}.{obj.__qualname__}"
1993
1994
1995def ifdynstaticdefault(count1, count2):
1996    if torch._dynamo.config.assume_static_by_default:
1997        return count1
1998    else:
1999        return count2
2000
2001
2002def import_submodule(mod: types.ModuleType):
2003    """
2004    Ensure all the files in a given submodule are imported
2005    """
2006    for filename in sorted(os.listdir(os.path.dirname(cast(str, mod.__file__)))):
2007        if filename.endswith(".py") and filename[0] != "_":
2008            importlib.import_module(f"{mod.__name__}.{filename[:-3]}")
2009
2010
2011def object_has_getattribute(value: Any):
2012    try:
2013        if isinstance(
2014            inspect.getattr_static(type(value), "__getattribute__"),
2015            types.FunctionType,
2016        ):
2017            return True
2018    except AttributeError:
2019        pass
2020    return False
2021
2022
2023def get_custom_getattr(value: Any, ignore_nn_module_getattr: bool = False):
2024    try:
2025        getattr_fn = inspect.getattr_static(type(value), "__getattr__")
2026    except AttributeError:
2027        getattr_fn = None
2028    if ignore_nn_module_getattr and getattr_fn is torch.nn.Module.__getattr__:
2029        # ignore this case of getattr
2030        getattr_fn = None
2031    return getattr_fn
2032
2033
2034class TensorStaticReason(enum.Enum):
2035    PARAMETER = 2
2036    NOT_TENSOR = 4
2037    NN_MODULE_PROPERTY = 5
2038
2039
2040def tensor_static_reason_to_message(reason: TensorStaticReason):
2041    if reason == TensorStaticReason.PARAMETER:
2042        return "mark_dynamic on parameter, parameters are always static today."
2043    if reason == TensorStaticReason.NOT_TENSOR:
2044        return "mark_dynamic on a non tensor, how did this happen?"
2045    if reason == TensorStaticReason.NN_MODULE_PROPERTY:
2046        return "tensor is static because it is nn module associated."
2047    raise AssertionError(f"Illegal reason {reason}")
2048
2049
2050def tensor_always_has_static_shape(
2051    tensor: Union[torch.Tensor, Any],
2052    is_tensor: bool,
2053    guard_source: "torch._guards.GuardSource",
2054) -> Tuple[bool, Optional[TensorStaticReason]]:
2055    """
2056    Given a tensor, source, and is_tensor flag, determine if a shape should be static.
2057
2058    Args:
2059    tensor - the real tensor to evaluate, parameters force a static shape.
2060    is_tensor - internal dynamo check, essentially "is_tensor": target_cls is TensorVariable,
2061    tensors not in a TensorVariable for whatever reason are forced static.
2062
2063    Returns a tuple, where the first element is the bool of whether or not this tensor should have a static shape.
2064    The second element is a TensorStaticReason, useful for passing to tensor_static_reason_to_message if needed.
2065    """
2066    if guard_source.is_nn_module() and config.force_nn_module_property_static_shapes:
2067        return True, TensorStaticReason.NN_MODULE_PROPERTY
2068    if type(tensor) is torch.nn.Parameter and config.force_parameter_static_shapes:
2069        return True, TensorStaticReason.PARAMETER
2070    if not is_tensor:
2071        return True, TensorStaticReason.NOT_TENSOR
2072    return False, None
2073
2074
2075def lazy_format_graph_tabular(fn_name, gm):
2076    def inner():
2077        try:
2078            from tabulate import tabulate  # TODO: Check that this is installed
2079        except ImportError:
2080            return (
2081                "Tabulate module missing, please install tabulate to log the graph in tabular format, logging code instead:\n"
2082                + str(lazy_format_graph_code(fn_name, gm))
2083            )
2084
2085        node_specs = [
2086            [n.op, n.name, n.target, n.args, n.kwargs] for n in gm.graph.nodes
2087        ]
2088        graph_str = tabulate(
2089            node_specs, headers=["opcode", "name", "target", "args", "kwargs"]
2090        )
2091        return _format_graph_code(fn_name, gm.forward.__code__.co_filename, graph_str)
2092
2093    return LazyString(inner)
2094
2095
2096def format_bytecode(prefix, name, filename, line_no, code):
2097    return f"{prefix} {name} {filename} line {line_no} \n{dis.Bytecode(code).dis()}\n"
2098
2099
2100forward_hook_names = ["_forward_pre_hooks", "_forward_hooks"]
2101backward_hook_names = ["_backward_pre_hooks", "_backward_hooks"]
2102state_dict_hook_names = [
2103    "_state_dict_pre_hooks",
2104    "_state_dict_hooks",
2105    "_load_state_dict_pre_hooks",
2106    "_load_state_dict_post_hooks",
2107]
2108all_hook_names = forward_hook_names + backward_hook_names + state_dict_hook_names
2109
2110
2111def nn_module_has_global_hooks():
2112    # This is limited to backward hooks for now because NNModuleVariable
2113    # supports fwd hooks underneath.
2114    return len(torch.nn.modules.module._global_backward_hooks) or len(
2115        torch.nn.modules.module._global_backward_pre_hooks
2116    )
2117
2118
2119def nn_module_get_all_hooks(
2120    mod,
2121    check_forward_hooks=False,
2122    check_backward_hooks=False,
2123    check_state_dict_hooks=False,
2124):
2125    reset_code = torch._C._dynamo.eval_frame.reset_code
2126    """
2127    Sometimes its useful to differentiate between types of hooks such as forward/backward/pre
2128    hooks executed during module.__call__, and state_dict hooks which are executed separately.
2129    """
2130    hook_dicts_to_check = []
2131    check_all_hooks = (
2132        not check_forward_hooks
2133        and not check_backward_hooks
2134        and not check_state_dict_hooks
2135    )
2136    if check_forward_hooks or check_all_hooks:
2137        hook_dicts_to_check.extend(forward_hook_names)
2138    if check_backward_hooks or check_all_hooks:
2139        hook_dicts_to_check.extend(backward_hook_names)
2140    if check_state_dict_hooks:
2141        hook_dicts_to_check.extend(state_dict_hook_names)
2142
2143    all_hooks = []
2144    for hook_dict_name in hook_dicts_to_check:
2145        hooks = getattr(mod, hook_dict_name, [])
2146        for hook_name in hooks:
2147            hook = hooks[hook_name]
2148
2149            all_hooks.append(hook)
2150    return all_hooks
2151
2152
2153def nnmodule_has_hooks(
2154    mod,
2155    check_forward_hooks=False,
2156    check_backward_hooks=False,
2157    check_state_dict_hooks=False,
2158):
2159    """
2160    Helper function to check if a module has any hooks attached to it.
2161    """
2162    hooks = nn_module_get_all_hooks(
2163        mod,
2164        check_forward_hooks=check_forward_hooks,
2165        check_backward_hooks=check_backward_hooks,
2166        check_state_dict_hooks=check_state_dict_hooks,
2167    )
2168    return bool(hooks)
2169
2170
2171def to_numpy_helper(value):
2172    """Convert tensor and tnp.ndarray to numpy.ndarray."""
2173    if is_fake(value):
2174        return value
2175    if isinstance(value, tnp.ndarray):
2176        return to_numpy_helper(value.tensor)
2177    elif isinstance(value, torch.Tensor):
2178        return value.numpy(force=True)
2179    elif isinstance(value, (tuple, list)):
2180        return type(value)(to_numpy_helper(obj) for obj in value)
2181    else:
2182        return value
2183
2184
2185def numpy_to_tensor(value):
2186    """Convert tnp.ndarray to tensor, leave other types intact. If a list/tuple, loop through it to convert."""
2187    assert np is not None
2188    if isinstance(value, np.ndarray):
2189        return torch.as_tensor(value)
2190    if isinstance(value, tnp.ndarray):
2191        return value.tensor
2192    elif isinstance(value, (tuple, list)):
2193        return type(value)(numpy_to_tensor(obj) for obj in value)
2194    else:
2195        return value
2196
2197
2198class numpy_to_tensor_wrapper:
2199    def __init__(self, f):
2200        self.f = f
2201        self.__name__ = "wrapped_" + self.f.__name__
2202
2203    def __repr__(self):
2204        return f"<Wrapped function <original {self.f.__name__}>>"
2205
2206    def __call__(self, *args, **kwargs):
2207        out = self.f(*args, **kwargs)
2208        return numpy_to_tensor(out)
2209
2210
2211def numpy_attr_wrapper(obj, name):
2212    if isinstance(obj, tnp.ndarray):
2213        out = getattr(obj, name)
2214        return numpy_to_tensor(out)
2215    elif isinstance(obj, torch.Tensor):
2216        out = getattr(tnp.ndarray(obj), name)
2217        return numpy_to_tensor(out)
2218
2219
2220class numpy_method_wrapper:
2221    """Convert obj from torch.Tensor to tnp.ndarray and call method. Then convert result back to torch.Tensor."""
2222
2223    def __init__(self, method: str):
2224        self.method = method
2225        self.__name__ = "wrapped_" + self.method
2226
2227    def __repr__(self):
2228        return f"<Wrapped method <original {self.method}>>"
2229
2230    def __call__(self, *args, **kwargs):
2231        obj = args[0]
2232        if isinstance(obj, torch.Tensor):
2233            obj = tnp.ndarray(obj)
2234        method_callable = getattr(obj, self.method)
2235        out = method_callable(*args[1:], **kwargs)
2236        return numpy_to_tensor(out)
2237
2238
2239class numpy_operator_wrapper:
2240    """Implements dunder methods for tnp.ndarray via functions from the operator library"""
2241
2242    def __init__(self, op: Callable[..., Any]):
2243        self.op = op
2244        self.__name__ = f"wrapped_{op.__name__}"
2245
2246    def __repr__(self):
2247        return f"<Wrapped operator <original {self.__name__}>>"
2248
2249    def __call__(self, *args, **kwargs):
2250        assert not kwargs
2251
2252        args = (
2253            tnp.ndarray(arg) if isinstance(arg, torch.Tensor) else arg for arg in args
2254        )
2255        out = self.op(*args)
2256        return numpy_to_tensor(out)
2257
2258
2259def defake(x):
2260    if not isinstance(x, FakeTensor):
2261        return x
2262    size: torch._prims_common.ShapeType
2263    stride: torch._prims_common.StrideType
2264    if x._has_symbolic_sizes_strides:
2265        size = []
2266        for s in x.size():
2267            if isinstance(s, torch.SymInt):
2268                size.append(s.node.shape_env.size_hint(s.node.expr))
2269            else:
2270                size.append(s)
2271        stride = []
2272        for s in x.stride():
2273            if isinstance(s, torch.SymInt):
2274                stride.append(s.node.shape_env.size_hint(s.node.expr))
2275            else:
2276                stride.append(s)
2277    else:
2278        size = x.size()
2279        stride = x.stride()
2280    y = torch.empty_strided(
2281        size,
2282        stride,
2283        dtype=x.dtype,
2284        device=x.device,
2285        requires_grad=x.requires_grad,
2286    )
2287    y.zero_()
2288    return y
2289
2290
2291def is_utils_checkpoint(obj):
2292    # Lazy import to avoid circular dependencies
2293    import torch.utils.checkpoint
2294
2295    return obj is torch.utils.checkpoint.checkpoint
2296
2297
2298def build_checkpoint_variable(**options):
2299    import torch._higher_order_ops.wrap as higher_order_ops
2300    from .variables.higher_order_ops import TorchHigherOrderOperatorVariable
2301
2302    # TODO - This is a temporary situation where we have two versions of
2303    # checkpointing implementation. We will converge on one and remove the other.
2304    activation_checkpoint_op: torch._ops.HigherOrderOperator = (
2305        higher_order_ops.tag_activation_checkpoint
2306    )
2307    if torch._functorch.config.functionalize_rng_ops:
2308        activation_checkpoint_op = higher_order_ops.wrap_activation_checkpoint
2309
2310    return TorchHigherOrderOperatorVariable.make(
2311        activation_checkpoint_op,
2312        **options,
2313    )
2314
2315
2316def is_compile_supported(device_type):
2317    from .eval_frame import is_dynamo_supported
2318
2319    compile_supported = is_dynamo_supported()
2320    if device_type == "cpu":
2321        pass
2322    elif device_type == "cuda" and compile_supported:
2323        compile_supported = has_triton()
2324    else:
2325        compile_supported = False
2326    return compile_supported
2327
2328
2329# The following 3.11 source code functions are adapted from
2330# https://github.com/python/cpython/blob/v3.11.4/Lib/traceback.py
2331# in order to output source code corresponding to bytecode in 3.11+.
2332# We need our own versions since we want to support multiline expressions.
2333def _fix_offset(str: str, offset: int) -> int:
2334    """
2335    Convert byte offset `offset` of `str` into character offset.
2336    Byte offset is used for 3.11+ instruction column data.
2337    Takes things like unicode characters into consideration.
2338
2339    Unchanged from CPython implementation.
2340    """
2341    as_utf8 = str.encode("utf-8")
2342    return len(as_utf8[:offset].decode("utf-8", errors="replace"))
2343
2344
2345@dataclasses.dataclass
2346class _Anchors:
2347    # inclusive
2348    left_end_lineno: int
2349    left_end_offset: int
2350    right_start_lineno: int
2351    # exclusive
2352    right_start_offset: int
2353
2354
2355def _extract_anchors_from_expr(segment: str) -> Optional[_Anchors]:
2356    """
2357    Given source code `segment` corresponding to a bytecode
2358    instruction, determine:
2359        - for binary ops, the location of the binary op
2360        - for indexing, the location of the brackets.
2361    `segment` is expected to be a valid Python expression
2362    """
2363    assert sys.version_info >= (3, 11)
2364
2365    import ast
2366
2367    try:
2368        # Without brackets, `segment` is parsed as a statement.
2369        # We expect an expression, so wrap `segment` in
2370        # brackets to handle multi-line expressions.
2371        tree = ast.parse("(\n" + segment + "\n)")
2372    except SyntaxError:
2373        return None
2374
2375    if len(tree.body) != 1:
2376        return None
2377
2378    lines = segment.split("\n")
2379
2380    # get character index given byte offset
2381    def normalize(lineno, offset):
2382        return _fix_offset(lines[lineno], offset)
2383
2384    # Gets the next valid character index in `lines`, if
2385    # the current location is not valid. Handles empty lines.
2386    def next_valid_char(lineno, col):
2387        while lineno < len(lines) and col >= len(lines[lineno]):
2388            col = 0
2389            lineno += 1
2390        assert lineno < len(lines) and col < len(lines[lineno])
2391        return lineno, col
2392
2393    # Get the next valid character index in `lines`.
2394    def increment(lineno, col):
2395        col += 1
2396        lineno, col = next_valid_char(lineno, col)
2397        assert lineno < len(lines) and col < len(lines[lineno])
2398        return lineno, col
2399
2400    # Get the next valid character at least on the next line
2401    def nextline(lineno, col):
2402        col = 0
2403        lineno += 1
2404        lineno, col = next_valid_char(lineno, col)
2405        assert lineno < len(lines) and col < len(lines[lineno])
2406        return lineno, col
2407
2408    statement = tree.body[0]
2409    if isinstance(statement, ast.Expr):
2410        expr = statement.value
2411        if isinstance(expr, ast.BinOp):
2412            # ast gives locations for BinOp subexpressions, e.g.
2413            # ( left_expr ) + ( right_expr )
2414            #   left^^^^^       right^^^^^
2415            # -2 since end_lineno is 1-indexed and because we added an extra
2416            # bracket to `segment` when calling ast.parse
2417            cur_lineno = cast(int, expr.left.end_lineno) - 2
2418            cur_col = normalize(cur_lineno, expr.left.end_col_offset)
2419            cur_lineno, cur_col = next_valid_char(cur_lineno, cur_col)
2420
2421            # Heuristic to find the operator character.
2422            # The original CPython implementation did not look for ), \, or #,
2423            # leading to incorrect anchor location, e.g.
2424            # (x) + (y)
2425            # ~~^~~~~~~
2426            while (ch := lines[cur_lineno][cur_col]).isspace() or ch in ")\\#":
2427                if ch in "\\#":
2428                    cur_lineno, cur_col = nextline(cur_lineno, cur_col)
2429                else:
2430                    cur_lineno, cur_col = increment(cur_lineno, cur_col)
2431
2432            # binary op is 1 or 2 characters long, on the same line
2433            right_col = cur_col + 1
2434            if (
2435                right_col < len(lines[cur_lineno])
2436                and not (ch := lines[cur_lineno][right_col]).isspace()
2437                and ch not in "\\#"
2438            ):
2439                right_col += 1
2440            # right_col can be invalid since it is exclusive
2441
2442            return _Anchors(cur_lineno, cur_col, cur_lineno, right_col)
2443        elif isinstance(expr, ast.Subscript):
2444            # ast gives locations for value and slice subexpressions, e.g.
2445            # ( value_expr ) [ slice_expr ]
2446            #   value^^^^^     slice^^^^^
2447            # subscript^^^^^^^^^^^^^^^^^^^^
2448            # find left bracket (first '[' after value)
2449            left_lineno = cast(int, expr.value.end_lineno) - 2
2450            left_col = normalize(left_lineno, expr.value.end_col_offset)
2451            left_lineno, left_col = next_valid_char(left_lineno, left_col)
2452            while lines[left_lineno][left_col] != "[":
2453                left_lineno, left_col = increment(left_lineno, left_col)
2454            # find right bracket (final character of expression)
2455            right_lineno = cast(int, expr.end_lineno) - 2
2456            right_col = normalize(right_lineno, expr.end_col_offset)
2457            return _Anchors(left_lineno, left_col, right_lineno, right_col)
2458        elif isinstance(expr, ast.Call):
2459            # ( func_expr ) (args, kwargs)
2460            #   func^^^^^
2461            # call^^^^^^^^^^^^^^^^^^^^^^^^
2462            # find left bracket (first '(' after func)
2463            left_lineno = cast(int, expr.func.end_lineno) - 2
2464            left_col = normalize(left_lineno, expr.func.end_col_offset)
2465            left_lineno, left_col = next_valid_char(left_lineno, left_col)
2466            while lines[left_lineno][left_col] != "(":
2467                left_lineno, left_col = increment(left_lineno, left_col)
2468            # find right bracket (final character of expression)
2469            right_lineno = cast(int, expr.end_lineno) - 2
2470            right_col = normalize(right_lineno, expr.end_col_offset)
2471            return _Anchors(left_lineno, left_col, right_lineno, right_col)
2472
2473    return None
2474
2475
2476def get_instruction_source_311(code: types.CodeType, inst: dis.Instruction) -> str:
2477    """
2478    Python 3.11+ only. Returns lines of source code (from code object `code`)
2479    corresponding to `inst`'s location data, and underlines relevant code to `inst`.
2480
2481    Example: CALL on `g`:
2482    f(g(
2483      ^^
2484        h(x)))
2485        ^^^^^
2486
2487    We need our own implementation since `format_frame_summary` in
2488    Python's `traceback` module doesn't handle multi-line expressions
2489    (and their anchor extraction code is not completely correct).
2490    """
2491    assert inst.positions is not None
2492    if inst.positions.lineno is None:
2493        return ""
2494    # The rstrip + "\n" pattern is used throughout this function to handle
2495    # linecache.getline errors. Error lines are treated as empty strings "", but we want
2496    # to treat them as blank lines "\n".
2497    first_line = linecache.getline(code.co_filename, inst.positions.lineno).rstrip()
2498    if inst.positions.end_lineno is None:
2499        return first_line
2500    if inst.positions.col_offset is None or inst.positions.end_col_offset is None:
2501        return first_line
2502
2503    # character index of the start of the instruction
2504    start_offset = _fix_offset(first_line, inst.positions.col_offset)
2505    # character index of the end of the instruction
2506    # compute later since end may be a different line
2507    end_offset = None
2508    # expression corresponding to the instruction so we can get anchors
2509    segment = ""
2510    # underline markers to be printed - start with `~` marker and replace with `^` later
2511    markers = []
2512
2513    # Compute segment and initial markers
2514    if inst.positions.end_lineno == inst.positions.lineno:
2515        end_offset = _fix_offset(first_line, inst.positions.end_col_offset)
2516        segment = first_line[start_offset:end_offset]
2517        markers.append(" " * start_offset + "~" * (end_offset - start_offset))
2518    else:
2519        segment = first_line[start_offset:] + "\n"
2520        markers.append(" " * start_offset + "~" * (len(first_line) - start_offset))
2521        last_line = linecache.getline(
2522            code.co_filename, inst.positions.end_lineno
2523        ).rstrip()
2524        end_offset = _fix_offset(last_line, inst.positions.end_col_offset)
2525        for lineno in range(inst.positions.lineno + 1, inst.positions.end_lineno):
2526            line = linecache.getline(code.co_filename, lineno).rstrip()
2527            segment += line + "\n"
2528            # don't underline leading spaces
2529            num_spaces = len(line) - len(line.lstrip())
2530            markers.append(" " * num_spaces + "~" * (len(line) - num_spaces))
2531        segment += last_line[:end_offset]
2532        num_spaces = len(last_line) - len(last_line.lstrip())
2533        markers.append(" " * num_spaces + "~" * (end_offset - num_spaces))
2534
2535    anchors: Optional[_Anchors] = None
2536    try:
2537        anchors = _extract_anchors_from_expr(segment)
2538    except AssertionError:
2539        pass
2540
2541    # replace `~` markers with `^` where necessary
2542    if anchors is None:
2543        markers = [marker.replace("~", "^") for marker in markers]
2544    else:
2545        # make markers mutable
2546        mutable_markers: List[List[str]] = [list(marker) for marker in markers]
2547
2548        # anchor positions do not take start_offset into account
2549        if anchors.left_end_lineno == 0:
2550            anchors.left_end_offset += start_offset
2551        if anchors.right_start_lineno == 0:
2552            anchors.right_start_offset += start_offset
2553
2554        # Turn `~`` markers between anchors to `^`
2555        for lineno in range(len(markers)):
2556            for col in range(len(mutable_markers[lineno])):
2557                if lineno < anchors.left_end_lineno:
2558                    continue
2559                if lineno == anchors.left_end_lineno and col < anchors.left_end_offset:
2560                    continue
2561                if (
2562                    lineno == anchors.right_start_lineno
2563                    and col >= anchors.right_start_offset
2564                ):
2565                    continue
2566                if lineno > anchors.right_start_lineno:
2567                    continue
2568                if mutable_markers[lineno][col] == "~":
2569                    mutable_markers[lineno][col] = "^"
2570
2571        # make markers into strings again
2572        markers = ["".join(marker) for marker in mutable_markers]
2573
2574    result = ""
2575    for i in range(len(markers)):
2576        result += (
2577            linecache.getline(code.co_filename, inst.positions.lineno + i).rstrip()
2578            + "\n"
2579        )
2580        result += markers[i] + "\n"
2581    return result
2582
2583
2584def get_static_address_type(t):
2585    if isinstance(t, torch.Tensor):
2586        return getattr(t, "_dynamo_static_input_type", None)
2587
2588    return None
2589
2590
2591def is_rng_state_getter_or_setter(value):
2592    getters = (
2593        # The following two functions are not identical, so don't remove anyone!
2594        torch._C.Generator.get_state,
2595        torch.default_generator.get_state,
2596        torch.get_rng_state,
2597        torch.cuda.get_rng_state,
2598    )
2599    setters = (
2600        torch._C.Generator.set_state,
2601        torch.default_generator.set_state,
2602        torch.set_rng_state,
2603        torch.cuda.set_rng_state,
2604    )
2605    return value in (*setters, *getters)
2606
2607
2608def is_tensor_base_attr_getter(value):
2609    return (
2610        isinstance(value, types.MethodWrapperType)
2611        and value.__name__ == "__get__"
2612        and value.__self__.__objclass__ is torch._C._TensorBase  # type: ignore[attr-defined]
2613    )
2614
2615
2616def is_torch_function_object(value):
2617    return hasattr(value, "__torch_function__")
2618
2619
2620def has_torch_function(vt: "torch._dynamo.variables.base.VariableTracker") -> bool:
2621    from torch._dynamo.variables import LazyVariableTracker, UserDefinedObjectVariable
2622    from torch._dynamo.variables.torch_function import TensorWithTFOverrideVariable
2623
2624    if isinstance(vt, TensorWithTFOverrideVariable):
2625        return True
2626
2627    if isinstance(vt, LazyVariableTracker):
2628        LazyVariableTracker.realize(vt)
2629
2630    return isinstance(vt, UserDefinedObjectVariable) and hasattr(
2631        vt.value, "__torch_function__"
2632    )
2633
2634
2635# see note [Tensor Fakification and Symbol Caching]
2636def to_fake_tensor(t, fake_mode):
2637    symbolic_context = None
2638    source = None
2639    if tracing_context := torch._guards.TracingContext.try_get():
2640        if t in tracing_context.tensor_to_context:
2641            symbolic_context = tracing_context.tensor_to_context[t]
2642            source = symbolic_context.tensor_source
2643
2644    return fake_mode.from_tensor(
2645        t, static_shapes=False, symbolic_context=symbolic_context, source=source
2646    )
2647
2648
2649def get_first_attr(obj, *attrs):
2650    """
2651    Return the first available attribute or throw an exception if none is present.
2652    """
2653    for attr in attrs:
2654        if hasattr(obj, attr):
2655            return getattr(obj, attr)
2656
2657    raise AssertionError(f"{obj} does not has any of the attributes: {attrs}")
2658
2659
2660@contextlib.contextmanager
2661def maybe_enable_compiled_autograd(should_enable):
2662    def compiler_fn(gm):
2663        def inner_compiler(gm_, example_inputs_):
2664            torch._dynamo.utils.counters["compiled_autograd"]["compiles"] += 1
2665            return torch._inductor.compile(gm_, example_inputs_)
2666
2667        return torch.compile(gm, backend=inner_compiler, fullgraph=True, dynamic=True)
2668
2669    if should_enable:
2670        with torch._dynamo.compiled_autograd.enable(compiler_fn) as ctx:
2671            yield ctx
2672    else:
2673        yield
2674
2675
2676def invalid_removeable_handle():
2677    # need a subclass so weakref works
2678    class Invalid(dict):  # type: ignore[type-arg]
2679        pass
2680
2681    return RemovableHandle(Invalid())
2682
2683
2684# Returns a "proxy" (new object with the same class and dict) for (non-GraphModule) nn.Module's.
2685# Attribute changes to the original object/proxy will be reflected in the other.
2686# This is useful for cases where we want a keep-alive reference to a module without increasing
2687# its reference count.
2688def nn_module_proxy(mod):
2689    if not isinstance(mod, torch.nn.Module):
2690        return mod
2691    if isinstance(mod, torch.fx.GraphModule):
2692        # Dynamo-generated GM's shouldn't contain user-created GM's
2693        return mod
2694    proxy = mod.__class__.__new__(mod.__class__)
2695    proxy.__dict__ = mod.__dict__
2696    return proxy
2697
2698
2699class GmWrapper(torch.nn.Module):
2700    def __init__(self, gm, spec):
2701        super().__init__()
2702        self.gm = gm
2703        self.spec = spec
2704
2705    def forward(self, *args):
2706        args: List[Any] = list(args)
2707        return self.gm(*pytree.tree_unflatten(args, self.spec))
2708
2709
2710def flatten_graph_inputs(gm: torch.fx.GraphModule, inputs, compile_gm):
2711    """
2712    Mutate inputs so that they are flat and wrap gm such that it
2713    accepts those inputs.  This is needed for graphs that take
2714    bumpy inputs.
2715    """
2716    inputs, spec = pytree.tree_flatten(inputs)
2717    compiled_fn = compile_gm(GmWrapper(gm, spec), inputs)
2718
2719    idx_to_steal = [
2720        i
2721        for i, node in enumerate(gm.graph.nodes)
2722        if node.op == "placeholder" and node.meta.get("steal_arg", False)
2723    ]
2724
2725    def wrapper(*args):
2726        # note this doesn't check the spec, assuming it is the same
2727        flat_args = pytree.arg_tree_leaves(*args)
2728
2729        # flat_args is a new list, so we need to clear references from the old list
2730        for i in idx_to_steal:
2731            args[i].clear()
2732
2733        # this call is boxed to avoid increasing refcount until we reach aot_module_simplified forward
2734        return compiled_fn(flat_args)
2735
2736    return wrapper
2737
2738
2739def get_locals_to_steal(maybe_gm):
2740    if not isinstance(maybe_gm, torch.fx.GraphModule) or not hasattr(maybe_gm, "meta"):
2741        return []
2742    return maybe_gm.meta.get("locals_to_steal", [])
2743
2744
2745def set_locals_to_steal(gm, locals_to_steal):
2746    gm.meta["locals_to_steal"] = locals_to_steal
2747
2748
2749class Lit:
2750    def __init__(self, s):
2751        self.s = s
2752
2753    def __repr__(self):
2754        return self.s
2755
2756
2757warn_once_cache: Set[str] = set()
2758
2759
2760def warn_once(msg, stacklevel=1):
2761    # Dynamo causes all warnings.warn (in user code and in Dynamo code) to print all the time.
2762    # https://github.com/pytorch/pytorch/issues/128427.
2763    # warn_once is a workaround: if the msg has been warned on before, then we will not
2764    # warn again.
2765    # NB: it's totally ok to store a cache of all the strings: this is what warnings.warn does as well.
2766    if msg in warn_once_cache:
2767        return
2768    warn_once_cache.add(msg)
2769    warnings.warn(msg, stacklevel=stacklevel + 1)
2770