1# mypy: allow-untyped-defs 2import atexit 3import collections 4import contextlib 5import copy 6import dataclasses 7import datetime 8import dis 9import enum 10import functools 11import gc 12import inspect 13import itertools 14import linecache 15import logging 16import math 17import operator 18import os 19import re 20import sys 21import textwrap 22import threading 23import time 24import types 25import typing 26import warnings 27import weakref 28from contextlib import contextmanager 29from functools import lru_cache, wraps 30from types import MethodWrapperType 31from typing import ( 32 Any, 33 Callable, 34 cast, 35 ClassVar, 36 Counter, 37 DefaultDict, 38 Deque, 39 Dict, 40 Iterator, 41 KeysView, 42 List, 43 Optional, 44 Set, 45 Tuple, 46 Type, 47 Union, 48 ValuesView, 49) 50 51from ..utils.hooks import RemovableHandle 52 53try: 54 import numpy as np 55except ModuleNotFoundError: 56 np = None # type: ignore[assignment] 57 58try: 59 import torch._logging 60 import torch._numpy as tnp 61 from torch._guards import detect_fake_mode # noqa: F401n 62 from torch._logging import LazyString 63 from . import config 64 65 # NOTE: Make sure `NP_SUPPORTED_MODULES` and `NP_TO_TNP_MODULE` are in sync. 66 if np: 67 NP_SUPPORTED_MODULES: Tuple[types.ModuleType, ...] = ( 68 np, 69 np.fft, 70 np.linalg, 71 np.random, 72 ) 73 74 NP_TO_TNP_MODULE = { 75 np: tnp, 76 np.fft: tnp.fft, 77 np.linalg: tnp.linalg, 78 np.random: tnp.random, 79 } 80 else: 81 NP_SUPPORTED_MODULES = tuple() 82 83 NP_TO_TNP_MODULE = {} 84 from torch._subclasses.fake_tensor import FakeTensor, is_fake, maybe_get_fake_mode 85except ImportError: 86 pass 87 88import importlib 89 90import torch 91import torch._functorch.config 92import torch.fx.experimental.symbolic_shapes 93import torch.utils._pytree as pytree 94from torch import fx 95from torch._dispatch.python import enable_python_dispatcher 96from torch._guards import TracingContext 97from torch._subclasses.meta_utils import is_sparse_compressed 98from torch._utils_internal import log_compilation_event 99 100from torch.fx._utils import _format_graph_code, lazy_format_graph_code 101from torch.nn.modules.lazy import LazyModuleMixin 102from torch.utils._triton import has_triton, has_triton_package 103 104 105counters: DefaultDict[str, Counter[str]] = collections.defaultdict(collections.Counter) 106optimus_scuba_log: Dict[str, Any] = {} 107troubleshooting_url = ( 108 "https://pytorch.org/docs/main/torch.compiler_troubleshooting.html" 109) 110nnmodule_doc_url = "https://pytorch.org/docs/main/torch.compiler_nn_module.html" 111nnmodule_doc_url_msg = f"See {nnmodule_doc_url} for more information and limitations." 112log = logging.getLogger(__name__) 113 114# profiling compilation time by function 115compilation_time_metrics: Dict[str, List[float]] = {} 116 117# profiling compilation time by frame phase 118frame_phase_timing: Dict[str, Dict[str, float]] = collections.defaultdict( 119 lambda: collections.defaultdict(float) 120) 121 122timer_counter = itertools.count() 123 124 125def tabulate(rows, headers): 126 try: 127 import tabulate 128 129 return tabulate.tabulate(rows, headers=headers) 130 except ImportError: 131 return "\n".join( 132 ", ".join(map(str, row)) for row in itertools.chain([headers], rows) 133 ) 134 135 136curr_frame = 0 137 138 139# Note: Called for you by dynamo - you almost never ever want to invoke this yourself. 140def increment_frame(): 141 global curr_frame 142 curr_frame = curr_frame + 1 143 144 145# Note: Called for you by dynamo - you almost never ever want to invoke this yourself. 146def reset_frame_count(): 147 global curr_frame 148 frame_phase_timing.clear() 149 compilation_time_metrics.clear() 150 curr_frame = 0 151 152 153op_count = 0 154 155 156def increment_op_count(cnt): 157 global op_count 158 op_count += cnt 159 160 161# Calculate total time spent so far for each phase 162# For example, {'entire_frame_compile':8.574629999999999, 'backend_compile':5.26806} 163def calculate_time_spent(): 164 total = 0.0 165 total_by_key = {} 166 for timings in frame_phase_timing.values(): 167 for key, timing in timings.items(): 168 total += timing 169 if key not in total_by_key: 170 total_by_key[key] = timing 171 else: 172 total_by_key[key] += timing 173 174 return total_by_key 175 176 177# Print a report of time spent so far 178# Ex: 179# TIMING: 180# entire_frame_compile:8.574629999999999 181# backend_compile:5.26806 182def print_time_report(): 183 total_by_key = calculate_time_spent() 184 185 out = "TIMING:" 186 for key, value in total_by_key.items(): 187 out = f"{out} {key}:{round(value, 5)}" 188 189 print(out) 190 191 192def _add_time_spent(key, phase_name, time_spent): 193 frame_phase_timing[key][phase_name] += time_spent 194 195 196# dynamo_timed API works as a function decorator 197# By wrapping a function in dynamo_timed, we can store a record in compilation_time_metrics 198# where the key is the functions name. 199# For example: 200# 201# @dynamo_timed 202# def _foo(...): 203# 204# Would show up as an entry in our timing dict: 205# OrderedDict([('bar.<locals>._foo', [0.083690, 0.23949, 3.1425e-05])]) 206# This is extremely useful for granular debugging. 207# 208# For a higher-level mode, pass a phase_name into dynamo_timed 209# phase_names record an extra record into a separate compilation timing structure, 210# one keyed on frame+name rather than function. 211# The frame is incremented outside of this function, in def increment_frame() above. 212# `fwd_only` is used to identify if this phase or function is only called 213# during compiling fwd graphs, e.g, `entire_frame_compile` and `backend_compile`. 214# The other phases (`inductor_compile` and `code_gen`) are called for both fwd and bwd graphs. 215 216 217def dynamo_timed(original_function=None, phase_name=None, fwd_only=True): 218 def dynamo_timed_inner(func): 219 @wraps(func) 220 def time_wrapper(*args, **kwargs): 221 key = func.__qualname__ 222 if key not in compilation_time_metrics: 223 compilation_time_metrics[key] = [] 224 225 fail_type: Optional[str] = None 226 fail_reason: Optional[str] = None 227 time_spent = float("-inf") 228 try: 229 with torch.profiler.record_function(f"{key} (dynamo_timed)"): 230 t0 = time.time() 231 r = func(*args, **kwargs) 232 time_spent = time.time() - t0 233 compilation_time_metrics[key].append(time_spent) 234 except Exception as e: 235 fail_type = str(type(e)) 236 fail_reason = str(e) 237 raise 238 finally: 239 # Only record backward compilation metrics if phase_name is not None! 240 if phase_name: 241 frame_key = str(curr_frame) 242 # fwd only compilation stages: entire_frame_compile, backend_compile. 243 # use frame_key as time aggregation key. 244 if fwd_only and fail_type is None: 245 _add_time_spent(frame_key, phase_name, time_spent) 246 else: 247 # fwd + bwd compilation stages: inductor_compile, code_gen. 248 # use frame_key as time aggregation key for fwd graphs; 249 # use compile_id as time aggregation key for bwd graphs. 250 if torch._guards.TracingContext.try_get() is not None: 251 aot_graph_name = str( 252 torch._guards.TracingContext.get().aot_graph_name 253 ) 254 if ( 255 "forward" in aot_graph_name 256 or "inference" in aot_graph_name 257 ) and fail_type is None: 258 _add_time_spent(frame_key, phase_name, time_spent) 259 elif "backward" in aot_graph_name: 260 compile_id = str( 261 torch._guards.CompileContext.current_compile_id() 262 ) 263 if fail_type is None: 264 _add_time_spent(compile_id, phase_name, time_spent) 265 266 # log backward compilation metrics at the end of `inductor_compile` of bwd graph, 267 # one record for one bwd graph. 268 if phase_name == "inductor_compile": 269 if fail_type is None: 270 inductor_compile_time = frame_phase_timing[ 271 compile_id 272 ].get("inductor_compile", None) 273 code_gen_time = frame_phase_timing[ 274 compile_id 275 ].get("code_gen", None) 276 else: 277 inductor_compile_time = None 278 code_gen_time = None 279 metrics = BwdCompilationMetrics( 280 compile_id, 281 inductor_compile_time, 282 code_gen_time, 283 fail_type, 284 fail_reason, 285 ) 286 record_compilation_metrics(metrics) 287 288 return r 289 290 return time_wrapper 291 292 if original_function: 293 return dynamo_timed_inner(original_function) 294 return dynamo_timed_inner 295 296 297def compile_times(repr="str", aggregate=False): 298 """ 299 Get metrics about torchdynamo frontend/backend compilation times. 300 301 Accumulates information from functions tagged with `@dynamo_timed`. 302 303 repr='str' returns a printable string for user interaction, and 'csv' 304 returns headers, rows which can be logged for output 305 306 aggregate causes values from multiple compilations (e.g. split graphs) 307 to be accumulated into one value. If false, expect more than one value 308 per metric. 309 """ 310 311 def fmt_fn(values, item_fn=lambda x: x): 312 if aggregate: 313 return item_fn(sum(values)) 314 return ", ".join(map(item_fn, values)) 315 316 if repr == "str": 317 rows = [ 318 (k, fmt_fn(compilation_time_metrics[k], item_fn=lambda x: f"{x:.4f}")) 319 for k in compilation_time_metrics 320 ] 321 out = "TorchDynamo compilation metrics:\n" 322 out += tabulate(rows, headers=("Function", "Runtimes (s)")) 323 return out 324 elif repr == "csv": 325 values = [ 326 fmt_fn(v, item_fn=lambda x: f"{x:.6f}") 327 for v in compilation_time_metrics.values() 328 ] 329 headers = list(compilation_time_metrics.keys()) 330 return headers, values 331 332 333@atexit.register 334def dump_compile_times(): 335 log.info(compile_times(repr="str", aggregate=True)) 336 337 338tensortype_to_dtype = { 339 torch.FloatTensor: (torch.float32, torch.float), 340 torch.DoubleTensor: (torch.float64, torch.double), 341 torch.HalfTensor: (torch.float16, torch.half), 342 torch.BFloat16Tensor: (torch.bfloat16,), 343 torch.ByteTensor: (torch.uint8,), 344 torch.CharTensor: (torch.int8,), 345 torch.LongTensor: (torch.int64, torch.long), 346 torch.IntTensor: (torch.int32, torch.int), 347 torch.ShortTensor: (torch.int16, torch.short), 348 torch.BoolTensor: (torch.bool,), 349} 350 351 352class DuplicateWarningChecker: 353 def __init__(self, maxsize=4096): 354 self.maxsize = maxsize 355 self.reset() 356 357 def reset(self): 358 self.set = collections.OrderedDict() 359 360 def add(self, key): 361 if key in self.set: 362 self.set.move_to_end(key, last=True) 363 if not config.verbose: 364 return False 365 else: 366 self.set[key] = None 367 while len(self.set) > self.maxsize: 368 self.set.popitem(last=False) 369 return True 370 371 372graph_break_dup_warning_checker = DuplicateWarningChecker() 373 374 375def setup_compile_debug(): 376 compile_debug = os.environ.get("TORCH_COMPILE_DEBUG", "0") == "1" 377 378 if compile_debug: 379 return add_file_handler() 380 381 return contextlib.ExitStack() 382 383 384def reset_graph_break_dup_checker(): 385 graph_break_dup_warning_checker.reset() 386 387 388def add_file_handler(): 389 log_path = os.path.join(get_debug_dir(), "torchdynamo") 390 os.makedirs(log_path, exist_ok=True) 391 392 log_file_handler = logging.FileHandler(os.path.join(log_path, "debug.log")) 393 logger = logging.getLogger("torch._dynamo") 394 logger.addHandler(log_file_handler) 395 396 exitstack = contextlib.ExitStack() 397 exitstack.callback(lambda: logger.removeHandler(log_file_handler)) 398 return exitstack 399 400 401def setup_log_file(): 402 exitstack = contextlib.ExitStack() 403 if config.log_file_name is not None: 404 log_file_handler = logging.FileHandler(config.log_file_name) 405 for logger in torch._logging._internal.get_loggers(): 406 logger.addHandler(log_file_handler) 407 exitstack.callback(lambda: logger.removeHandler(log_file_handler)) 408 return exitstack 409 410 return exitstack 411 412 413def gen_record_file_name(exc, code): 414 return f"{get_debug_dir()}/error_recordings/\ 415{code.co_name}_{type(exc).__name__}_{code.co_firstlineno}.rec" 416 417 418def write_record_to_file(filename, exec_record): 419 try: 420 if os.path.exists(filename): 421 log.warning( 422 "Unable to write execution record %s; file already exists.", filename 423 ) 424 else: 425 os.makedirs(os.path.dirname(filename), exist_ok=True) 426 with open(filename, "wb") as f: 427 exec_record.dump(f) 428 except Exception: 429 log.exception("Unable to write execution record %s", filename) 430 431 432def count_calls(g: fx.Graph): 433 c = 0 434 for n in g.nodes: 435 if "call" in n.op: 436 c += 1 437 return c 438 439 440def identity(x): 441 return x 442 443 444def hashable(x): 445 try: 446 hash(x) 447 return True 448 except TypeError: 449 return False 450 # cannot hash writable memoryview object 451 except ValueError: 452 return False 453 454 455def nothing(*args, **kwargs): 456 pass 457 458 459class ExactWeakKeyDictionary: 460 """Similar to weakref.WeakKeyDictionary, but use `is`/`id` rather than `==` to compare equality""" 461 462 def __init__(self): 463 self.values = dict() 464 self.refs = dict() 465 466 def __getitem__(self, key): 467 return self.values[id(key)] 468 469 def get(self, key, default=None): 470 return self.values.get(id(key), default) 471 472 def __contains__(self, key): 473 return id(key) in self.values 474 475 def __setitem__(self, key, value): 476 idx = id(key) 477 if idx not in self.refs: 478 self.refs[idx] = weakref.ref(key, lambda ref: self._remove_id(idx)) 479 self.values[idx] = value 480 481 def _remove_id(self, idx): 482 if idx in self.values: 483 del self.values[idx] 484 if idx in self.refs: 485 del self.refs[idx] 486 487 def clear(self): 488 self.refs.clear() 489 self.values.clear() 490 491 492def istype(obj, allowed_types): 493 """isinstance() without subclasses""" 494 if isinstance(allowed_types, (tuple, list, set)): 495 return type(obj) in allowed_types 496 return type(obj) is allowed_types 497 498 499if sys.version_info >= (3, 12): 500 # Some typing classes moved to C in 3.12, 501 # which no longer have the _Final mixin. 502 _builtin_final_typing_classes = ( 503 typing.ParamSpecArgs, 504 typing.ParamSpecKwargs, 505 typing.ParamSpec, 506 typing.TypeVar, 507 typing.TypeVarTuple, 508 typing.TypeAliasType, 509 ) 510 511 512def is_typing(value): 513 # _Final catches most of typing classes: 514 # - Any 515 # - Callable 516 # - Union 517 # ... 518 # 519 # NB: we intentionally ignore classes that inherit from Generic, since they 520 # can be used as both TypingVariable as well as UserDefinedClassVariable. 521 if sys.version_info >= (3, 12) and isinstance(value, _builtin_final_typing_classes): 522 return True 523 return isinstance(value, typing._Final) or value is typing.Generic # type: ignore[attr-defined] 524 525 526def is_numpy_int_type(value): 527 if not np: 528 return False 529 530 return istype( 531 value, 532 ( 533 np.int8, 534 np.int16, 535 np.int32, 536 np.int64, 537 np.uint8, 538 np.uint16, 539 np.uint32, 540 np.uint64, 541 ), 542 ) 543 544 545def is_numpy_float_type(value): 546 if not np: 547 return False 548 549 return istype( 550 value, 551 ( 552 np.float16, 553 np.float32, 554 np.float64, 555 ), 556 ) 557 558 559def is_function_or_wrapper(value): 560 return ( 561 is_function(value) 562 or isinstance(value, functools._lru_cache_wrapper) 563 and is_function(inspect.getattr_static(value, "__wrapped__")) 564 or isinstance(value, (torch._ops.OpOverloadPacket, torch._ops.OpOverload)) 565 ) 566 567 568def is_function(value): 569 return isinstance( 570 value, 571 ( 572 types.FunctionType, 573 types.BuiltinFunctionType, 574 types.MethodDescriptorType, 575 types.WrapperDescriptorType, 576 torch.jit.ScriptFunction, 577 ), 578 ) 579 580 581def unwrap_if_wrapper(fn): 582 return unwrap_with_attr_name_if_wrapper(fn)[0] 583 584 585def unwrap_with_attr_name_if_wrapper(fn): 586 # unpack @functools.lru_cache wrapped function 587 if isinstance(fn, functools._lru_cache_wrapper): 588 fn = inspect.getattr_static(fn, "__wrapped__") 589 attr_name = "__wrapped__" 590 # unpack @torch._dynamo.optimize()(fn) wrapped function 591 elif is_function(fn) and inspect.getattr_static(fn, "_torchdynamo_inline", False): 592 fn = inspect.getattr_static(fn, "_torchdynamo_inline", fn) 593 attr_name = "_torchdynamo_inline" 594 # unpack torch.jit.script_if_tracing 595 elif is_function(fn) and inspect.getattr_static( 596 fn, "__script_if_tracing_wrapper", False 597 ): 598 fn = inspect.getattr_static(fn, "__original_fn", fn) 599 attr_name = "__original_fn" 600 else: 601 attr_name = None 602 return fn, attr_name 603 604 605def is_numpy_ndarray(value): 606 if not np: 607 return False 608 609 return istype(value, np.ndarray) 610 611 612def istensor(obj): 613 """Check of obj is a tensor""" 614 tensor_list = ( 615 torch.Tensor, 616 torch.nn.Parameter, 617 *config.traceable_tensor_subclasses, 618 ) 619 tensor_list = tensor_list + (torch._subclasses.FakeTensor,) 620 return istype(obj, tensor_list) 621 622 623def is_lazy_module(mod): 624 return isinstance(mod, LazyModuleMixin) 625 626 627@functools.lru_cache(4096) 628def print_once(*args): 629 print(*args) 630 631 632def make_cell(val=None): 633 """Some black magic to create a cell object that usually only exists in a closure""" 634 x = val 635 636 def f(): 637 return x 638 639 assert f.__closure__ is not None and len(f.__closure__) == 1 640 return f.__closure__[0] 641 642 643def proxy_args_kwargs(args, kwargs): 644 try: 645 proxy_args = tuple(arg.as_proxy() for arg in args) 646 proxy_kwargs = {key: arg.as_proxy() for key, arg in kwargs.items()} 647 return proxy_args, proxy_kwargs 648 except NotImplementedError as e: 649 from .exc import unimplemented 650 from .variables.base import typestr 651 652 unimplemented( 653 f"call_function args: {typestr(*args)} {typestr(*list(kwargs.values()))}", 654 from_exc=e, 655 ) 656 657 658@dataclasses.dataclass 659class CompilationMetrics: 660 compile_id: str 661 frame_key: str 662 co_name: str 663 co_filename: str 664 co_firstlineno: int 665 cache_size: int 666 accumulated_cache_size: int 667 guard_count: Optional[int] 668 shape_env_guard_count: Optional[int] 669 graph_op_count: Optional[int] 670 graph_node_count: Optional[int] 671 graph_input_count: Optional[int] 672 start_time: float 673 entire_frame_compile_time_s: Optional[float] 674 backend_compile_time_s: Optional[float] 675 inductor_compile_time_s: Optional[float] 676 code_gen_time_s: Optional[float] 677 fail_type: Optional[str] 678 fail_reason: Optional[str] 679 fail_user_frame_filename: Optional[str] 680 fail_user_frame_lineno: Optional[int] 681 non_compliant_ops: Set[str] 682 compliant_custom_ops: Set[str] 683 restart_reasons: Set[str] 684 dynamo_time_before_restart_s: float 685 # Sometimes, we will finish analyzing a frame but conclude we don't want 686 # to install any guarded code. True means we actually decided to install 687 # a compiled frame 688 has_guarded_code: bool 689 690 691@dataclasses.dataclass 692class BwdCompilationMetrics: 693 compile_id: str 694 inductor_compile_time_s: Optional[float] 695 code_gen_time_s: Optional[float] 696 fail_type: Optional[str] 697 fail_reason: Optional[str] 698 699 700DEFAULT_COMPILATION_METRICS_LIMIT = 64 701 702 703_compilation_metrics: Deque[ 704 Union[CompilationMetrics, BwdCompilationMetrics] 705] = collections.deque(maxlen=DEFAULT_COMPILATION_METRICS_LIMIT) 706 707 708def record_compilation_metrics( 709 compilation_metrics: Union[CompilationMetrics, BwdCompilationMetrics] 710): 711 global _compilation_metrics 712 _compilation_metrics.append(compilation_metrics) 713 if isinstance(compilation_metrics, CompilationMetrics): 714 name = "compilation_metrics" 715 else: 716 name = "bwd_compilation_metrics" 717 # Currently only record fwd compilation metrics, will add bwd compilation metrics 718 # after the internal Scuba logging changes finish. 719 if isinstance(compilation_metrics, CompilationMetrics): 720 torch._logging.trace_structured( 721 name, 722 lambda: { 723 k: list(v) if isinstance(v, set) else v 724 for k, v in dataclasses.asdict(compilation_metrics).items() 725 }, 726 ) 727 if config.log_compilation_metrics: 728 log_compilation_event(compilation_metrics) 729 730 731def set_compilation_metrics_limit(new_size: int) -> None: 732 global _compilation_metrics 733 while len(_compilation_metrics) > new_size: 734 _compilation_metrics.popleft() 735 new_deque = collections.deque(_compilation_metrics, maxlen=new_size) 736 _compilation_metrics = new_deque 737 738 739def clear_compilation_metrics() -> None: 740 global _compilation_metrics 741 _compilation_metrics.clear() 742 743 744def get_compilation_metrics() -> List[Union[CompilationMetrics, BwdCompilationMetrics]]: 745 return list(_compilation_metrics) 746 747 748@dataclasses.dataclass 749class CleanupHook: 750 """Remove a global variable when hook is called""" 751 752 scope: Dict[str, Any] 753 name: str 754 755 def __call__(self, *args): 756 # Make sure we're not shutting down 757 if CleanupManager is not None: 758 CleanupManager.count -= 1 759 del self.scope[self.name] 760 761 @staticmethod 762 def create(scope, name, val): 763 assert name not in scope 764 CleanupManager.count += 1 765 scope[name] = val 766 return CleanupHook(scope, name) 767 768 769class CleanupManager(ExactWeakKeyDictionary): 770 count = 0 771 instance: ClassVar["CleanupManager"] 772 773 def _remove_id(self, idx): 774 for hook in self.values[idx]: 775 hook() 776 super()._remove_id(idx) 777 778 779CleanupManager.instance = CleanupManager() 780 781 782def clone_tensor(x): 783 """Clone the tensor and its gradient""" 784 y = x.clone().requires_grad_(x.requires_grad) 785 if x.is_leaf and x.grad is not None: 786 y.grad = x.grad.clone() 787 return y 788 789 790def clone_input(x, *, dtype=None): 791 """copy while preserving strides""" 792 # TODO: this is questionable 793 if is_fake(x): 794 # this func fails on fake tensors in __torch_dispatch__ 795 return x 796 797 def torch_clone(x): 798 y = torch.clone(x) 799 if x.is_leaf: 800 y.requires_grad_(x.requires_grad) 801 if x.is_leaf and x.grad is not None: 802 y.grad = clone_input(x.grad, dtype=dtype) 803 if hasattr(x, "_dynamo_dynamic_indices"): 804 y._dynamo_dynamic_indices = x._dynamo_dynamic_indices.copy() # type: ignore[attr-defined] 805 return y 806 807 with torch.no_grad(): 808 if x.device.type == "xla": 809 # Access data_ptr() for a xla tensor will cause crash 810 return torch_clone(x) 811 812 # Handle sparse storage (no stride). 813 if x.layout is torch.sparse_coo: 814 return torch.sparse_coo_tensor( 815 torch_clone(x._indices()), 816 torch_clone(x._values()), 817 x.shape, 818 is_coalesced=x.is_coalesced(), 819 ) 820 elif is_sparse_compressed(x): 821 if x.layout in {torch.sparse_csr, torch.sparse_bsr}: 822 compressed_indices = x.crow_indices() 823 plain_indices = x.col_indices() 824 else: 825 compressed_indices = x.ccol_indices() 826 plain_indices = x.row_indices() 827 return torch.sparse_compressed_tensor( 828 torch_clone(compressed_indices), 829 torch_clone(plain_indices), 830 torch_clone(x.values()), 831 x.shape, 832 layout=x.layout, 833 ) 834 835 needed_size = sum( 836 (shape - 1) * stride for shape, stride in zip(x.size(), x.stride()) 837 ) 838 if x.is_quantized: 839 result = torch.empty_quantized((needed_size + 32,), x) 840 else: 841 result = torch.empty( 842 needed_size + 32, dtype=dtype or x.dtype, device=x.device 843 ) 844 cache_line_offset = ( 845 (x.data_ptr() - result.data_ptr()) % 32 846 ) // x.element_size() 847 result.as_strided_(x.size(), x.stride(), cache_line_offset) 848 try: 849 result.copy_(x.clone()) 850 if x.is_leaf: 851 result.requires_grad_(x.requires_grad) 852 if x.is_leaf and x.grad is not None: 853 result.grad = clone_input(x.grad, dtype=dtype) 854 except RuntimeError: 855 # RuntimeError: unsupported operation: more than one element of the written-to 856 # tensor refers to a single memory location. Please clone() the tensor before 857 # performing the operation. 858 return torch_clone(x) 859 if hasattr(x, "_dynamo_dynamic_indices"): 860 result._dynamo_dynamic_indices = x._dynamo_dynamic_indices.copy() # type: ignore[attr-defined] 861 return result 862 863 864def clone_inputs(example_inputs): 865 res: Union[Dict[Any, Any], List[Any]] 866 if type(example_inputs) is dict: 867 res = dict(example_inputs) 868 for key, value in res.items(): 869 if isinstance(value, tuple): 870 res[key] = clone_inputs(value) 871 else: 872 assert isinstance(value, torch.Tensor), type(value) 873 res[key] = clone_input(value) 874 return res 875 876 res = list(example_inputs) 877 for i in range(len(res)): 878 if isinstance(res[i], torch.Tensor): 879 res[i] = clone_input(res[i]) 880 return res 881 882 883def skip_frame_if_in_functorch_mode(val: torch.Tensor): 884 try: 885 val.data_ptr() # will throw for functorch tensors 886 except RuntimeError as e: 887 from .exc import SkipFrame 888 889 # This will be GradTrackingTensor/BatchedTensor/etc 890 functorch_subclass_name = re.sub(r"\(.*", "", repr(val)) 891 raise SkipFrame( 892 f"torch.compile cannot be run in context: {functorch_subclass_name}" 893 ) from e 894 895 896@contextmanager 897def preserve_rng_state(): 898 disable_functorch = torch._C._DisableFuncTorch 899 disable_current_modes = torch.utils._python_dispatch._disable_current_modes 900 with disable_current_modes(), disable_functorch(): 901 rng_state = torch.clone(torch.random.get_rng_state()) 902 skip_frame_if_in_functorch_mode(rng_state) 903 if torch.cuda.is_available(): 904 cuda_rng_state = torch.clone(torch.cuda.get_rng_state()) 905 try: 906 yield 907 finally: 908 with torch.utils._python_dispatch._disable_current_modes(): 909 torch.random.set_rng_state(rng_state) 910 if torch.cuda.is_available(): 911 torch.cuda.set_rng_state(cuda_rng_state) # type: ignore[possibly-undefined] 912 913 914def is_jit_model(model0): 915 return isinstance( 916 model0, 917 ( 918 torch.jit._trace.TopLevelTracedModule, 919 torch.jit._script.RecursiveScriptModule, 920 torch.jit.ScriptFunction, 921 torch.jit.ScriptModule, 922 ), 923 ) 924 925 926def torchscript(model, example_inputs, verbose=False): 927 if is_jit_model(model): 928 # already done? 929 return model 930 931 try: 932 return torch.jit.trace(model, example_inputs) 933 except Exception: 934 try: 935 return torch.jit.script(model) 936 except Exception: 937 if verbose: 938 log.exception("jit error") 939 else: 940 log.error("Both torch.jit.trace and torch.jit.script failed") 941 return None 942 943 944def getfile(obj): 945 try: 946 return inspect.getfile(obj) 947 except (TypeError, OSError): 948 return None 949 950 951def is_namedtuple(obj): 952 """Test if an object is a namedtuple or a torch.return_types.* quasi-namedtuple""" 953 return is_namedtuple_cls(type(obj)) 954 955 956def is_namedtuple_cls(cls): 957 """Test if an object is a namedtuple or a (torch.return_types|torch.autograd.forward_ad).* quasi-namedtuple""" 958 try: 959 if issubclass(cls, tuple): 960 bases = getattr(cls, "__bases__", []) or [None] 961 module = getattr(cls, "__module__", None) 962 return module in ("torch.return_types", "torch.autograd.forward_ad") or ( 963 bases[0] is tuple and hasattr(cls, "_make") and hasattr(cls, "_fields") 964 ) 965 except TypeError: 966 pass 967 return False 968 969 970@functools.lru_cache(1) 971def namedtuple_fields(cls): 972 """Get the fields of a namedtuple or a torch.return_types.* quasi-namedtuple""" 973 if cls is slice: 974 return ["start", "stop", "step"] 975 976 assert issubclass(cls, tuple) 977 if hasattr(cls, "_fields"): 978 # normal namedtuples 979 return cls._fields 980 981 @dataclasses.dataclass 982 class Marker: 983 index: int 984 985 # frustrating ones e.g. torch.return_types.max 986 assert cls.__module__ == "torch.return_types" 987 obj = cls(map(Marker, range(cls.n_fields))) 988 fields: List[Optional[str]] = [None] * cls.n_fields 989 for name in dir(obj): 990 if name[0] != "_" and isinstance(getattr(obj, name), Marker): 991 fields[getattr(obj, name).index] = name 992 return fields 993 994 995def checkpoint_params(gm): 996 with torch.no_grad(): 997 rng_state = torch.clone(torch.random.get_rng_state()) 998 if torch.cuda.is_available(): 999 cuda_rng_state = torch.clone(torch.cuda.get_rng_state()) 1000 saved_state = [] 1001 for param in itertools.chain(gm.parameters(), gm.buffers()): 1002 saved_state.append((param, param._version, torch.clone(param))) 1003 1004 def restore(): 1005 with torch.no_grad(): 1006 torch.random.set_rng_state(rng_state) 1007 if torch.cuda.is_available(): 1008 torch.cuda.set_rng_state(cuda_rng_state) 1009 for param, version, original_value in saved_state: 1010 if param._version != version: 1011 param.copy_(original_value) 1012 1013 return restore 1014 1015 1016def timed(model, example_inputs, times=1): 1017 if torch.cuda.is_available(): 1018 synchronize = torch.cuda.synchronize 1019 else: 1020 synchronize = nothing 1021 1022 synchronize() 1023 gc.collect() 1024 torch.manual_seed(1337) 1025 t0 = time.perf_counter() 1026 for _ in range(times): 1027 result = model(*example_inputs) 1028 synchronize() 1029 t1 = time.perf_counter() 1030 return result, t1 - t0 # type: ignore[possibly-undefined] 1031 1032 1033def check_is_cuda(gm, example_inputs): 1034 return all(x.is_cuda for x in itertools.chain(example_inputs, gm.parameters(True))) 1035 1036 1037@lru_cache(32) 1038def rot_n_helper(n): 1039 assert n > 1 1040 vars = [f"v{i}" for i in range(n)] 1041 rotated = reversed(vars[-1:] + vars[:-1]) 1042 fn = eval(f"lambda {','.join(vars)}: ({','.join(rotated)})") 1043 fn.__name__ = f"rot_{n}_helper" 1044 return fn 1045 1046 1047common_constant_types = { 1048 int, 1049 float, 1050 complex, 1051 bool, 1052 str, 1053 bytes, 1054 type(None), 1055 Ellipsis.__class__, 1056 types.CodeType, 1057 torch.device, 1058 torch.dtype, 1059 torch.memory_format, 1060 torch.layout, 1061} 1062 1063if has_triton_package(): 1064 import triton 1065 1066 common_constant_types.add(triton.language.dtype) 1067 1068 1069def is_safe_constant(v): 1070 if istype(v, (tuple, frozenset)): 1071 return all(map(is_safe_constant, v)) 1072 return isinstance(v, (enum.Enum, type)) or istype( 1073 v, 1074 common_constant_types | {slice}, 1075 ) 1076 1077 1078def specialize_symnode(arg): 1079 from .variables import ConstantVariable, SymNodeVariable 1080 1081 # Guard and specialize 1082 if isinstance(arg, SymNodeVariable): 1083 return ConstantVariable.create(arg.evaluate_expr()) 1084 1085 return arg 1086 1087 1088def guard_if_dyn(arg): 1089 from .variables import ConstantVariable 1090 1091 arg = specialize_symnode(arg) 1092 1093 if isinstance(arg, ConstantVariable): 1094 return arg.as_python_constant() 1095 1096 return arg 1097 1098 1099def check_constant_args(args, kwargs): 1100 return all(x.is_python_constant() for x in itertools.chain(args, kwargs.values())) 1101 1102 1103def check_unspec_python_args(args, kwargs): 1104 from .variables.constant import ConstantVariable 1105 from .variables.tensor import UnspecializedPythonVariable 1106 1107 unspec_count = 0 1108 for x in itertools.chain(args, kwargs.values()): 1109 if isinstance(x, UnspecializedPythonVariable): 1110 unspec_count += 1 1111 elif not isinstance(x, ConstantVariable): 1112 return False 1113 return unspec_count > 0 1114 1115 1116def check_unspec_or_constant_args(args, kwargs): 1117 # A fused version of: 1118 # return check_constant_args(args, kwargs) or check_unspec_python_args(args, kwargs) 1119 from .variables.tensor import UnspecializedPythonVariable 1120 1121 for x in itertools.chain(args, kwargs.values()): 1122 if not (x.is_python_constant() or isinstance(x, UnspecializedPythonVariable)): 1123 return False 1124 return True 1125 1126 1127def check_numpy_ndarray_args(args, kwargs): 1128 from .variables.tensor import NumpyNdarrayVariable 1129 1130 return any( 1131 isinstance(x, NumpyNdarrayVariable) 1132 for x in itertools.chain(args, kwargs.values()) 1133 ) 1134 1135 1136dict_keys: Type[KeysView[Any]] = type(dict().keys()) 1137dict_values: Type[ValuesView[Any]] = type(dict().values()) 1138odict_values: Type[ValuesView[Any]] = type(collections.OrderedDict().values()) 1139tuple_iterator: Type[Iterator[Any]] = type(iter(tuple())) 1140tuple_iterator_len = tuple_iterator.__length_hint__ # type: ignore[attr-defined] 1141object_new = object.__new__ 1142 1143 1144def nn_module_new(cls): 1145 obj = object_new(cls) 1146 torch.nn.Module.__init__(obj) 1147 return obj 1148 1149 1150def product(it): 1151 return functools.reduce(operator.mul, it, 1) 1152 1153 1154def tuple_iterator_getitem(it, index): 1155 _, (obj,), start = it.__reduce__() 1156 return obj[start + index] 1157 1158 1159iter_next = next 1160 1161 1162def to_subclass(t, cls): 1163 return t.as_subclass(cls) 1164 1165 1166def dict_keys_getitem(d, n): 1167 return next(itertools.islice(iter(d), n, n + 1)) 1168 1169 1170def enum_repr(value, local): 1171 # enum class can override __str__ method. Use __class__ and name attribute 1172 # to extract the class name and key name. 1173 name = value.__class__.__name__ 1174 val = value.name 1175 scope = "L" if local else "G" 1176 local_name = f'{scope}["{name}"].{val}' 1177 return local_name 1178 1179 1180def set_example_value(node, example_value): 1181 # NB: example_value is a bit of a misnomer, because this is always a fake 1182 # tensor of some sort. Furthermore, these example values serve as the 1183 # runtime state of Dynamo tracing, which means if metadata mutation 1184 # occurs, the example_value gets directly updated (so you can't rely on 1185 # this to accurately reflect what the state of the value was at the time 1186 # the program was traced). 1187 node.meta["example_value"] = example_value 1188 shape_env = TracingContext.get().fake_mode.shape_env 1189 if symbol_to_path := torch.fx.experimental.symbolic_shapes.compute_unbacked_bindings( 1190 shape_env, example_value 1191 ): 1192 node.meta["unbacked_bindings"] = symbol_to_path 1193 1194 1195def _get_fake_tensor(vt): 1196 fake_tensor = vt.as_proxy().node.meta.get("example_value") 1197 if not is_fake(fake_tensor): 1198 from .exc import unimplemented 1199 1200 unimplemented("Cannot check Tensor object identity without its fake value") 1201 return fake_tensor 1202 1203 1204def iter_contains(items, search, tx, check_tensor_identity=False): 1205 from .variables import ( 1206 BuiltinVariable, 1207 ConstantVariable, 1208 TensorVariable, 1209 VariableTracker, 1210 ) 1211 1212 if search.is_python_constant(): 1213 found_const = any( 1214 x.is_python_constant() 1215 and x.as_python_constant() == search.as_python_constant() 1216 for x in items 1217 ) 1218 return ConstantVariable.create(found_const) 1219 1220 must_check_tensor_id = False 1221 if check_tensor_identity and isinstance(search, TensorVariable): 1222 must_check_tensor_id = True 1223 # Match of Tensor means match of FakeTensor 1224 search = _get_fake_tensor(search) 1225 1226 found: Optional[VariableTracker] = None 1227 for x in items: 1228 if must_check_tensor_id: 1229 if isinstance(x, TensorVariable): 1230 if search is _get_fake_tensor(x): # Object equivalence 1231 return ConstantVariable.create(True) 1232 else: 1233 check = BuiltinVariable(operator.eq).call_function(tx, [x, search], {}) 1234 if found is None: 1235 found = check 1236 else: 1237 found = BuiltinVariable(operator.or_).call_function( 1238 tx, [check, found], {} 1239 ) 1240 if found is None: 1241 found = ConstantVariable.create(False) 1242 return found 1243 1244 1245def key_is_id(k): 1246 """Returns whether it indexes dictionaries using its id""" 1247 return isinstance(k, (torch.Tensor, torch.nn.Module, MethodWrapperType)) 1248 1249 1250def key_to_id(value): 1251 return [id(k) if key_is_id(k) else k for k in value.keys()] 1252 1253 1254def const_repr(x, *, local) -> str: 1255 from .trace_rules import is_builtin_callable 1256 1257 if isinstance(x, (list, tuple)): 1258 elems_repr = ",".join(const_repr(s, local=local) for s in x) 1259 if isinstance(x, list): 1260 return f"[{elems_repr}]" 1261 else: 1262 assert isinstance(x, tuple) 1263 if len(x) == 1: 1264 return f"({elems_repr},)" 1265 else: 1266 return f"({elems_repr})" 1267 elif isinstance(x, enum.Enum): 1268 # To workaround repr(Enum) returning invalid global reference before python 3.11 1269 # by calling enum_repr and removing quotes to render enum in guard code. 1270 return enum_repr(x, local=local).replace("'", "") 1271 elif is_builtin_callable(x): 1272 return x.__name__ 1273 elif isinstance(x, type): 1274 1275 def fullname(o): 1276 klass = o.__class__ 1277 module = klass.__module__ 1278 if module == "builtins": 1279 return klass.__qualname__ # avoid outputs like 'builtins.str' 1280 return module + "." + klass.__qualname__ 1281 1282 return fullname(x) 1283 else: 1284 return f"{x!r}" 1285 1286 1287def dict_keys_repr(const_keys, *, local) -> str: 1288 keys_str = ",".join(const_repr(s, local=local) for s in const_keys) 1289 return "[" + keys_str + "]" 1290 1291 1292GLOBAL_KEY_PREFIX = "__dict_key" 1293 1294 1295from torch._subclasses import UnsupportedFakeTensorException # noqa: F401 1296 1297 1298def wrap_fake_exception(fn): 1299 try: 1300 return fn() 1301 except UnsupportedFakeTensorException as e: 1302 from .exc import unimplemented 1303 1304 msg = f"Unsupported: {e.reason} with fake tensor propagation." 1305 log.warning(msg) 1306 unimplemented(msg, from_exc=e) 1307 1308 1309def deepcopy_to_fake_tensor(obj, fake_mode): 1310 with torch._subclasses.fake_tensor.FakeCopyMode(fake_mode): 1311 return wrap_fake_exception(lambda: copy.deepcopy(obj)) 1312 1313 1314def rmse(ref, res): 1315 """ 1316 Calculate root mean squared error 1317 """ 1318 return torch.sqrt(torch.mean(torch.square(ref - res))) 1319 1320 1321def same( 1322 ref, 1323 res, 1324 fp64_ref=None, 1325 cos_similarity=False, 1326 tol=1e-4, 1327 equal_nan=False, 1328 exact_dtype=True, 1329 relax_numpy_equality=False, 1330 ignore_non_fp=False, 1331 log_error=log.error, 1332): 1333 """Check correctness to see if ref and res match""" 1334 if fp64_ref is None: 1335 fp64_ref = ref 1336 if isinstance(ref, (list, tuple, torch.nn.ParameterList, torch.Size)): 1337 assert isinstance(res, (list, tuple)), f"type mismatch {type(ref)} {type(res)}" 1338 if len(ref) != len(res): 1339 log_error("Length mismatch") 1340 return False 1341 return len(ref) == len(res) and all( 1342 same( 1343 ai, 1344 bi, 1345 fp64_refi, 1346 cos_similarity, 1347 tol, 1348 equal_nan, 1349 exact_dtype, 1350 relax_numpy_equality, 1351 ignore_non_fp, 1352 log_error=log_error, 1353 ) 1354 for ai, bi, fp64_refi in zip(ref, res, fp64_ref) 1355 ) 1356 elif type(ref).__name__ == "QuestionAnsweringModelOutput": 1357 # This skips checking accuracy for start_logits/end_logits. 1358 # Tentatively, start_logits/end_logits appear to be very prone to 1359 # inaccuracies and is somewhat subsumed by checking the loss. 1360 return same( 1361 ref.loss, 1362 res.loss, 1363 fp64_ref.loss, 1364 cos_similarity, 1365 tol, 1366 equal_nan, 1367 exact_dtype, 1368 relax_numpy_equality, 1369 ignore_non_fp, 1370 log_error=log_error, 1371 ) 1372 elif isinstance(ref, dict): 1373 assert isinstance(res, dict) 1374 assert set(ref.keys()) == set( 1375 res.keys() 1376 ), f"keys mismatch {set(ref.keys())} == {set(res.keys())}" 1377 for k in sorted(ref.keys()): 1378 if not ( 1379 same( 1380 ref[k], 1381 res[k], 1382 fp64_ref[k], 1383 cos_similarity=cos_similarity, 1384 tol=tol, 1385 equal_nan=equal_nan, 1386 exact_dtype=exact_dtype, 1387 relax_numpy_equality=relax_numpy_equality, 1388 ignore_non_fp=ignore_non_fp, 1389 log_error=log_error, 1390 ) 1391 ): 1392 log_error("Accuracy failed for key name %s", k) 1393 return False 1394 return True 1395 elif isinstance(ref, (torch.Tensor, float)): 1396 assert not isinstance(ref, torch._subclasses.FakeTensor) 1397 assert not isinstance(res, torch._subclasses.FakeTensor) 1398 1399 def to_tensor(t): 1400 return t if isinstance(t, torch.Tensor) else torch.tensor(t) 1401 1402 ref, res, fp64_ref = (to_tensor(val) for val in (ref, res, fp64_ref)) 1403 1404 if ref.is_sparse: 1405 assert res.is_sparse 1406 ref = ref.to_dense() 1407 res = res.to_dense() 1408 assert isinstance(res, torch.Tensor), f"type mismatch {type(ref)} {type(res)}" 1409 if exact_dtype: 1410 if ref.dtype != res.dtype: 1411 log_error("dtype mismatch %s, %s", ref.dtype, res.dtype) 1412 return False 1413 if ref.dtype == torch.bool: 1414 if ignore_non_fp: 1415 return True 1416 # triton stores bool as int8, so add this for more accurate checking 1417 r = torch.allclose( 1418 ref.to(dtype=torch.uint8), 1419 res.to(dtype=torch.uint8), 1420 atol=tol, 1421 rtol=tol, 1422 equal_nan=equal_nan, 1423 ) 1424 if not r: 1425 log_error("Accuracy failed: uint8 tensor did not match") 1426 return r 1427 1428 if cos_similarity: 1429 ref = ref.flatten().to(torch.float32) 1430 res = res.flatten().to(torch.float32) 1431 if torch.allclose(ref, res, atol=tol, rtol=tol, equal_nan=True): 1432 # early exit that handles zero/nan better 1433 # cosine_similarity(zeros(10), zeros(10), dim=0) is 0 1434 return True 1435 score = torch.nn.functional.cosine_similarity(ref, res, dim=0, eps=1e-6) 1436 if score < 0.99: 1437 log.warning("Similarity score=%s", score.cpu().detach().item()) 1438 return score >= 0.99 1439 else: 1440 if not exact_dtype: 1441 ref = ref.to(res.dtype) 1442 1443 # First try usual allclose 1444 if torch.allclose(ref, res, atol=tol, rtol=tol, equal_nan=equal_nan): 1445 return True 1446 1447 # Check error from fp64 version 1448 if fp64_ref.dtype == torch.float64: 1449 ref_error = rmse(fp64_ref, ref).item() 1450 # ref unable to produce this with stable numerics in this precision, ignore 1451 if math.isnan(ref_error): 1452 log.warning( 1453 "Found nan in reference. Consider running in higher precision." 1454 ) 1455 1456 res_error = rmse(fp64_ref, res).item() 1457 1458 # In the case of using AMP (Automatic Mixed Precision), certain models have 1459 # failed the benchmark's correctness check. However, the end-to-end model's 1460 # accuracy when comparing AMP with FP32 is within a difference of less than 0.1%. 1461 # Thus, it's possible that the correctness check failures for these models are 1462 # false alarms. We use multiplier of 3 instead of 2 to avoid these false alarms. 1463 multiplier = 3.0 if res.dtype == torch.bfloat16 else 2.0 1464 1465 if ( 1466 fp64_ref.numel() < 1000 1467 or (ref.ndim == 4 and ref.shape[-1] == ref.shape[-2] == 1) 1468 # large tol means a benchmark has been specified as REQUIRE_HIGHER_TOLERANCE 1469 or tol >= 2 * 1e-2 1470 ): 1471 # In the presence of noise, noise might dominate our error 1472 # metric for smaller tensors. 1473 # Similary, for 1x1 kernels, there seems to be high noise with amp. 1474 multiplier = 3.0 1475 1476 passes_test = res_error <= (multiplier * ref_error + tol / 10.0) 1477 if not passes_test: 1478 log_error( 1479 "RMSE (res-fp64): %.5f, (ref-fp64): %.5f and shape=%s. res.dtype: %s, multiplier: %f, tol: %f", 1480 res_error, 1481 ref_error, 1482 res.size(), 1483 res.dtype, 1484 multiplier, 1485 tol, 1486 ) 1487 # import pdb; pdb.set_trace() 1488 return passes_test 1489 1490 if ignore_non_fp: 1491 return True 1492 1493 log_error("Accuracy failed: allclose not within tol=%s", tol) 1494 return False 1495 elif isinstance(ref, (str, int, type(None), bool, torch.device)): 1496 if ignore_non_fp: 1497 return True 1498 r = ref == res 1499 if not r: 1500 log_error("Accuracy failed (%s): %s != %s", type(ref), ref, res) 1501 return r 1502 elif is_numpy_int_type(ref) or is_numpy_float_type(ref): 1503 if relax_numpy_equality and not ( 1504 is_numpy_int_type(res) or is_numpy_float_type(res) 1505 ): 1506 ref = ref.item() 1507 r = (type(ref) is type(res)) and (ref == res) 1508 if not r: 1509 log_error("Accuracy failed (numpy): %s != %s", ref, res) 1510 return r 1511 elif is_numpy_ndarray(ref): 1512 return (type(ref) is type(res)) and same( 1513 torch.as_tensor(ref), 1514 torch.as_tensor(res), 1515 fp64_ref, 1516 cos_similarity=cos_similarity, 1517 tol=tol, 1518 equal_nan=equal_nan, 1519 exact_dtype=exact_dtype, 1520 relax_numpy_equality=relax_numpy_equality, 1521 ignore_non_fp=ignore_non_fp, 1522 log_error=log_error, 1523 ) 1524 elif type(ref).__name__ in ( 1525 "MaskedLMOutput", 1526 "Seq2SeqLMOutput", 1527 "CausalLMOutputWithCrossAttentions", 1528 "LongformerMaskedLMOutput", 1529 "Instances", 1530 "SquashedNormal", 1531 "Boxes", 1532 "Normal", 1533 "TanhTransform", 1534 "Foo", 1535 "Variable", 1536 ): 1537 assert type(ref) is type(res) 1538 return all( 1539 same( 1540 getattr(ref, key), 1541 getattr(res, key), 1542 getattr(fp64_ref, key), 1543 cos_similarity=cos_similarity, 1544 tol=tol, 1545 equal_nan=equal_nan, 1546 exact_dtype=exact_dtype, 1547 relax_numpy_equality=relax_numpy_equality, 1548 ignore_non_fp=ignore_non_fp, 1549 log_error=log_error, 1550 ) 1551 for key in ref.__dict__.keys() 1552 ) 1553 else: 1554 raise RuntimeError(f"unsupported type: {type(ref).__name__}") 1555 1556 1557def format_func_info(code): 1558 short_filename = code.co_filename.split("/")[-1] 1559 return f"'{code.co_name}' ({short_filename}:{code.co_firstlineno})" 1560 1561 1562@contextlib.contextmanager 1563def disable_cache_limit(): 1564 prior = config.cache_size_limit 1565 config.cache_size_limit = sys.maxsize 1566 prior_acc_limit = config.accumulated_cache_size_limit 1567 config.accumulated_cache_size_limit = sys.maxsize 1568 1569 try: 1570 yield 1571 finally: 1572 config.cache_size_limit = prior 1573 config.accumulated_cache_size_limit = prior_acc_limit 1574 1575 1576# map from transformed code back to original user code 1577orig_code_map = ExactWeakKeyDictionary() 1578 1579# keep a record of code_obj -> list of guard failure reasons for logging 1580guard_failures: DefaultDict[Any, List[Any]] = collections.defaultdict(list) 1581 1582# Keep a record of graph break reasons for logging 1583graph_break_reasons: List["torch._dynamo.output_graph.GraphCompileReason"] = list() 1584 1585# keep record of compiled code, if we are in "error if recompile" 1586# to track code that dynamo has compiled previously 1587seen_code_map = ExactWeakKeyDictionary() 1588 1589 1590class CompileProfiler: 1591 """Utility for profiling how and what dynamo would compile. 1592 1593 Can be used for 1594 * diagnosing recompilation issues 1595 * determining an appropriate compile cache limit 1596 * (TODO)confirming which functions got compiled/skipped 1597 """ 1598 1599 def __init__(self): 1600 self.frame_count = 0 1601 self.op_count = 0 1602 self.backend_ctx_ctor = disable_cache_limit 1603 1604 def __call__(self, gm: torch.fx.GraphModule, example_inputs): 1605 self.frame_count += 1 1606 for node in gm.graph.nodes: 1607 if "call" in node.op: 1608 self.op_count += 1 1609 return gm.forward 1610 1611 # no-op __enter__ and __exit__ to preserve BC 1612 def __enter__(self): 1613 return self 1614 1615 def __exit__(self, typ, val, traceback): 1616 pass 1617 1618 def get_metrics(self): 1619 return {"guard_failures": guard_failures} 1620 1621 def report(self): 1622 metrics = self.get_metrics() 1623 gf = metrics["guard_failures"] 1624 1625 def num_recompiles(code): 1626 return len(gf[code]) 1627 1628 def recompile_reasons(code): 1629 return "\n".join([str(x) for x in gf[code]]) 1630 1631 summarized_gf = [ 1632 [format_func_info(code), num_recompiles(code), recompile_reasons(code)] 1633 for code in gf 1634 ] 1635 1636 def graph_break_report(): 1637 if "graph_break" in counters: 1638 graph_breaks = counters["graph_break"] 1639 return tabulate( 1640 [[msg, graph_breaks[msg]] for msg in graph_breaks], 1641 headers=["Graph Break Reason", "Count"], 1642 ) 1643 1644 def recompilation_report(): 1645 if len(gf): 1646 max_recompiles = max(num_recompiles(code) for code in gf) 1647 recomp_table = tabulate( 1648 summarized_gf, 1649 headers=["Function", "Recompiles", "Recompile Reasons"], 1650 ) 1651 return recomp_table + textwrap.dedent( 1652 f""" 1653 1654 Set torch._dynamo.config.cache_size_limit to {max_recompiles} to avoid being cache limited. 1655 """ 1656 ) 1657 1658 report = textwrap.dedent( 1659 """ 1660 Torchdynamo Profiler Report 1661 =========================== 1662 1663 Graph Breaks 1664 ------------ 1665 Graph breaks happen when torchdynamo encounters code it can't safely trace. 1666 If you want to find out why breaks are happening, check below for each break reason 1667 You may gain additional insight by passing `fullgraph=True` to torch.compile, 1668 to stop at the first break. 1669 1670 """ 1671 ) 1672 report += graph_break_report() or "No graph breaks detected." 1673 report += textwrap.dedent( 1674 """ 1675 1676 Recompilation 1677 ------------- 1678 These subgraphs were recompiled more than once due to guard failures 1679 Guard failures indicate some condition assumed to be static by the tracer changed, 1680 making it unsafe to reuse the compiled program. 1681 1682 """ 1683 ) 1684 report += recompilation_report() or "No recompilation detected.\n" 1685 return report 1686 1687 1688# return same dir unless user changes config between calls 1689@functools.lru_cache(None) 1690def _get_debug_dir(root_dir): 1691 dir_name = ( 1692 "run_" 1693 + datetime.datetime.now().strftime("%Y_%m_%d_%H_%M_%S_%f") 1694 # use pid to avoid conflicts among ranks 1695 + "-pid_" 1696 + str(os.getpid()) 1697 ) 1698 return os.path.join(root_dir, dir_name) 1699 1700 1701def get_debug_dir(): 1702 debug_root = config.debug_dir_root 1703 return _get_debug_dir(debug_root) 1704 1705 1706def extract_fake_example_value(node, required=True): 1707 if "example_value" in node.meta and is_fake(node.meta["example_value"]): 1708 return node.meta["example_value"] 1709 elif required: 1710 from torch._dynamo.exc import unimplemented 1711 1712 unimplemented("`FakeTensor` example value was required but not available") 1713 else: 1714 return None 1715 1716 1717def ensure_graph_fake(e, tx): 1718 assert maybe_get_fake_mode(e) is tx.fake_mode 1719 return e 1720 1721 1722def get_fake_values_from_nodes(tx, nodes, allow_non_graph_fake): 1723 def visit(n: torch.fx.Node): 1724 if n.op == "call_function" and "example_value" not in n.meta: 1725 # fake tensor validity is checked inside get_fake_value using 1726 # ensure_graph_fake 1727 return get_fake_value(n, tx, allow_non_graph_fake) 1728 1729 out = n.meta["example_value"] 1730 if not allow_non_graph_fake and isinstance(out, torch.Tensor): 1731 return ensure_graph_fake(out, tx) 1732 return out 1733 1734 return torch.fx.node.map_arg(nodes, visit) 1735 1736 1737def get_fake_value(node, tx, allow_non_graph_fake=False): 1738 """ 1739 Run the computation represented by `node` using fake tensors and return the result. 1740 1741 allow_non_graph_fake: whether to allow the return result to be: 1742 1. non-fake or 2. fake that is not created by this instance of Dynamo. 1743 If `True`, you must be prepared to deal with such return values, ideally 1744 by further wrapping them as this graph's fakes. 1745 """ 1746 from torch.utils._sympy.value_ranges import ValueRangeError 1747 from .exc import ( 1748 TorchRuntimeError, 1749 unimplemented, 1750 Unsupported, 1751 UserError, 1752 UserErrorType, 1753 ) 1754 1755 op = node.op 1756 1757 # FX Node should always return the same fake value 1758 if "example_value" in node.meta and is_fake(node.meta["example_value"]): 1759 return node.meta["example_value"] 1760 1761 args, kwargs = get_fake_values_from_nodes( 1762 tx, (node.args, node.kwargs), allow_non_graph_fake 1763 ) 1764 1765 nnmodule = None 1766 if op == "call_method" and len(args) > 0 and isinstance(args[0], torch.nn.Module): 1767 # If the first argument is nn.Module, should copy to fake mode. 1768 args = (deepcopy_to_fake_tensor(args[0], tx.fake_mode),) + tuple(args[1:]) 1769 1770 if op == "call_module": 1771 nnmodule = tx.output.nn_modules[node.target] 1772 1773 if is_lazy_module(nnmodule) and hasattr(nnmodule, "_initialize_hook"): 1774 # In the case of a lazy module, we want to run 1775 # the pre-hooks which initialize it. 1776 # Afterwards, lazy module deletes its pre-hooks 1777 # to avoid treating it as lazy on subsequent recompile. 1778 nnmodule._infer_parameters(nnmodule, args) 1779 1780 # no matter it's lazy module or not, we should copy to fake mode. 1781 nnmodule = deepcopy_to_fake_tensor(nnmodule, tx.fake_mode) 1782 1783 try: 1784 with tx.fake_mode, enable_python_dispatcher(): 1785 ret_val = wrap_fake_exception( 1786 lambda: run_node(tx.output, node, args, kwargs, nnmodule) 1787 ) 1788 except Unsupported: 1789 raise 1790 except RuntimeError as e: 1791 cause: BaseException = e 1792 if e.__cause__ is not None: 1793 cause = e.__cause__ 1794 1795 if isinstance( 1796 cause, torch._subclasses.fake_tensor.DataDependentOutputException 1797 ): 1798 unimplemented( 1799 f"data dependent operator: {cause.func}; " 1800 "to enable, set torch._dynamo.config.capture_scalar_outputs = True" 1801 ) 1802 elif isinstance( 1803 cause, torch._subclasses.fake_tensor.DynamicOutputShapeException 1804 ): 1805 if not torch._dynamo.config.capture_dynamic_output_shape_ops: 1806 unimplemented( 1807 f"dynamic shape operator: {cause.func}; " 1808 "to enable, set torch._dynamo.config.capture_dynamic_output_shape_ops = True" 1809 ) 1810 else: 1811 unimplemented( 1812 f"dynamic shape operator: {cause.func}; " 1813 "Operator does not have a meta kernel that supports dynamic output shapes, " 1814 "please report an issue to PyTorch" 1815 ) 1816 elif isinstance( 1817 cause, torch._subclasses.fake_tensor.UnsupportedOperatorException 1818 ): 1819 op = cause.func 1820 import_suggestion = "" 1821 if isinstance(op, torch._ops.OpOverload): 1822 maybe_pystub = torch._C._dispatch_pystub( 1823 op._schema.name, op._schema.overload_name 1824 ) 1825 if maybe_pystub is not None: 1826 module, ctx = maybe_pystub 1827 import_suggestion = ( 1828 f"It's possible that the support was implemented in " 1829 f"module `{module}` and you may need to `import {module}`" 1830 f"({ctx}), otherwise " 1831 ) 1832 unimplemented( 1833 f"unsupported operator: {cause.func} ({import_suggestion}see " 1834 "https://docs.google.com/document/d/1GgvOe7C8_NVOMLOCwDaYV1mXXyHMXY7ExoewHqooxrs/edit#heading=h.64r4npvq0w0" 1835 " for how to fix)" 1836 ) 1837 elif isinstance( 1838 cause, torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode 1839 ): 1840 raise UserError( # noqa: B904 1841 UserErrorType.CONSTRAINT_VIOLATION, 1842 "Tried to use data-dependent value in the subsequent computation. " 1843 "This can happen when we encounter unbounded dynamic value that is unknown during tracing time. " 1844 "You will need to explicitly give hint to the compiler. Please take a look at " 1845 f"torch._check OR torch._check_is_size APIs. {cause}", 1846 case_name="constrain_as_size_example", 1847 ) 1848 elif isinstance(cause, ValueRangeError): 1849 raise UserError(UserErrorType.CONSTRAINT_VIOLATION, e.args[0]) from e 1850 elif isinstance(cause, TypeError) and "argument" in str(cause): 1851 unimplemented(f"TypeError {node.target}: {cause}") 1852 1853 raise TorchRuntimeError(str(e)).with_traceback(e.__traceback__) from None 1854 1855 if not allow_non_graph_fake: 1856 _ = pytree.tree_map_only( 1857 torch.Tensor, functools.partial(ensure_graph_fake, tx=tx), ret_val 1858 ) 1859 return ret_val 1860 1861 1862_current_node = threading.local() 1863 1864 1865def get_current_node(): 1866 return getattr(_current_node, "value", None) 1867 1868 1869@contextmanager 1870def set_current_node(node): 1871 old = get_current_node() 1872 _current_node.value = node 1873 try: 1874 yield 1875 finally: 1876 _current_node.value = old 1877 1878 1879def run_node(tracer, node, args, kwargs, nnmodule): 1880 """ 1881 Runs a given node, with the given args and kwargs. 1882 1883 Behavior is dictated by a node's op. 1884 1885 run_node is useful for extracting real values out of nodes. 1886 See get_real_value for more info on common usage. 1887 1888 Note: The tracer arg is only used for 'get_attr' ops 1889 Note: The nnmodule arg is only used for 'call_module' ops 1890 1891 Nodes that are not call_function, call_method, call_module, or get_attr will 1892 raise an AssertionError. 1893 """ 1894 op = node.op 1895 1896 with set_current_node(node): 1897 1898 def make_error_message(e): 1899 return f"Failed running {op} {node.target}(*{args}, **{kwargs}):\n" + str(e) 1900 1901 try: 1902 if op == "call_function": 1903 return node.target(*args, **kwargs) 1904 elif op == "call_method": 1905 return getattr(args[0], node.target)(*args[1:], **kwargs) 1906 elif op == "call_module": 1907 assert nnmodule is not None 1908 return nnmodule(*args, **kwargs) 1909 elif op == "get_attr": 1910 return tracer.output_graph.get_submodule(node.target) 1911 elif op == "placeholder": 1912 assert "example_value" in node.meta 1913 return node.meta["example_value"] 1914 1915 except (NotImplementedError, UnsupportedFakeTensorException) as e: 1916 # NB: mimic how wrap_fake_exception does it 1917 from .exc import unimplemented 1918 1919 unimplemented(make_error_message(e), from_exc=e) 1920 except Exception as e: 1921 raise RuntimeError(make_error_message(e)).with_traceback( 1922 e.__traceback__ 1923 ) from e 1924 1925 raise AssertionError(op) 1926 1927 1928def get_real_value(node, tracer): 1929 """ 1930 Run the actual computation represented by `node` and return the result. 1931 This will execute any dependent nodes in the graph as well. 1932 """ 1933 from .exc import TorchRuntimeError 1934 1935 cache = tracer.real_value_cache 1936 if node in cache: 1937 return cache[node] 1938 1939 op = node.op 1940 args, kwargs = torch.fx.node.map_arg( 1941 (node.args, node.kwargs), 1942 lambda n: get_real_value(n, tracer), 1943 ) 1944 1945 if op == "placeholder" and "grapharg" in node.meta: 1946 return node.meta["grapharg"].example 1947 1948 if op == "call_module": 1949 nn_module = tracer.output_graph.nn_modules[node.target] 1950 if not is_lazy_module(nn_module): 1951 nn_module = copy.deepcopy(nn_module) 1952 else: 1953 # In the case of a lazy module, we want to run 1954 # the pre-hooks which initialize it 1955 nn_module(*args, **kwargs) 1956 else: 1957 nn_module = None 1958 1959 try: 1960 real_value = run_node(tracer, node, args, kwargs, nn_module) 1961 cache[node] = real_value 1962 except RuntimeError as e: 1963 raise TorchRuntimeError(str(e)).with_traceback(e.__traceback__) from None 1964 return real_value 1965 1966 1967def assert_no_fake_params_or_buffers(gm): 1968 from torch._subclasses.fake_tensor import FakeTensorConfig, is_fake 1969 1970 def stack_or_hint(t): 1971 if FakeTensorConfig.debug: 1972 import traceback 1973 1974 return f"FAKE TENSOR CREATION TRACEBACK: \n {traceback.format_list(t._debug_trace)}" 1975 else: 1976 return "Enable TORCH_FAKE_TENSOR_DEBUG=1 to get creation stack traces on fake tensors." 1977 1978 for name, buffer in gm.named_buffers(): 1979 assert not is_fake( 1980 buffer 1981 ), f"Unexpected fake buffer {name} {stack_or_hint(buffer)}" 1982 for name, param in gm.named_parameters(): 1983 assert not is_fake( 1984 param 1985 ), f"Unexpected fake param {name} {stack_or_hint(param)}" 1986 1987 1988def fqn(obj: Any): 1989 """ 1990 Returns the fully qualified name of the object. 1991 """ 1992 return f"{obj.__module__}.{obj.__qualname__}" 1993 1994 1995def ifdynstaticdefault(count1, count2): 1996 if torch._dynamo.config.assume_static_by_default: 1997 return count1 1998 else: 1999 return count2 2000 2001 2002def import_submodule(mod: types.ModuleType): 2003 """ 2004 Ensure all the files in a given submodule are imported 2005 """ 2006 for filename in sorted(os.listdir(os.path.dirname(cast(str, mod.__file__)))): 2007 if filename.endswith(".py") and filename[0] != "_": 2008 importlib.import_module(f"{mod.__name__}.{filename[:-3]}") 2009 2010 2011def object_has_getattribute(value: Any): 2012 try: 2013 if isinstance( 2014 inspect.getattr_static(type(value), "__getattribute__"), 2015 types.FunctionType, 2016 ): 2017 return True 2018 except AttributeError: 2019 pass 2020 return False 2021 2022 2023def get_custom_getattr(value: Any, ignore_nn_module_getattr: bool = False): 2024 try: 2025 getattr_fn = inspect.getattr_static(type(value), "__getattr__") 2026 except AttributeError: 2027 getattr_fn = None 2028 if ignore_nn_module_getattr and getattr_fn is torch.nn.Module.__getattr__: 2029 # ignore this case of getattr 2030 getattr_fn = None 2031 return getattr_fn 2032 2033 2034class TensorStaticReason(enum.Enum): 2035 PARAMETER = 2 2036 NOT_TENSOR = 4 2037 NN_MODULE_PROPERTY = 5 2038 2039 2040def tensor_static_reason_to_message(reason: TensorStaticReason): 2041 if reason == TensorStaticReason.PARAMETER: 2042 return "mark_dynamic on parameter, parameters are always static today." 2043 if reason == TensorStaticReason.NOT_TENSOR: 2044 return "mark_dynamic on a non tensor, how did this happen?" 2045 if reason == TensorStaticReason.NN_MODULE_PROPERTY: 2046 return "tensor is static because it is nn module associated." 2047 raise AssertionError(f"Illegal reason {reason}") 2048 2049 2050def tensor_always_has_static_shape( 2051 tensor: Union[torch.Tensor, Any], 2052 is_tensor: bool, 2053 guard_source: "torch._guards.GuardSource", 2054) -> Tuple[bool, Optional[TensorStaticReason]]: 2055 """ 2056 Given a tensor, source, and is_tensor flag, determine if a shape should be static. 2057 2058 Args: 2059 tensor - the real tensor to evaluate, parameters force a static shape. 2060 is_tensor - internal dynamo check, essentially "is_tensor": target_cls is TensorVariable, 2061 tensors not in a TensorVariable for whatever reason are forced static. 2062 2063 Returns a tuple, where the first element is the bool of whether or not this tensor should have a static shape. 2064 The second element is a TensorStaticReason, useful for passing to tensor_static_reason_to_message if needed. 2065 """ 2066 if guard_source.is_nn_module() and config.force_nn_module_property_static_shapes: 2067 return True, TensorStaticReason.NN_MODULE_PROPERTY 2068 if type(tensor) is torch.nn.Parameter and config.force_parameter_static_shapes: 2069 return True, TensorStaticReason.PARAMETER 2070 if not is_tensor: 2071 return True, TensorStaticReason.NOT_TENSOR 2072 return False, None 2073 2074 2075def lazy_format_graph_tabular(fn_name, gm): 2076 def inner(): 2077 try: 2078 from tabulate import tabulate # TODO: Check that this is installed 2079 except ImportError: 2080 return ( 2081 "Tabulate module missing, please install tabulate to log the graph in tabular format, logging code instead:\n" 2082 + str(lazy_format_graph_code(fn_name, gm)) 2083 ) 2084 2085 node_specs = [ 2086 [n.op, n.name, n.target, n.args, n.kwargs] for n in gm.graph.nodes 2087 ] 2088 graph_str = tabulate( 2089 node_specs, headers=["opcode", "name", "target", "args", "kwargs"] 2090 ) 2091 return _format_graph_code(fn_name, gm.forward.__code__.co_filename, graph_str) 2092 2093 return LazyString(inner) 2094 2095 2096def format_bytecode(prefix, name, filename, line_no, code): 2097 return f"{prefix} {name} {filename} line {line_no} \n{dis.Bytecode(code).dis()}\n" 2098 2099 2100forward_hook_names = ["_forward_pre_hooks", "_forward_hooks"] 2101backward_hook_names = ["_backward_pre_hooks", "_backward_hooks"] 2102state_dict_hook_names = [ 2103 "_state_dict_pre_hooks", 2104 "_state_dict_hooks", 2105 "_load_state_dict_pre_hooks", 2106 "_load_state_dict_post_hooks", 2107] 2108all_hook_names = forward_hook_names + backward_hook_names + state_dict_hook_names 2109 2110 2111def nn_module_has_global_hooks(): 2112 # This is limited to backward hooks for now because NNModuleVariable 2113 # supports fwd hooks underneath. 2114 return len(torch.nn.modules.module._global_backward_hooks) or len( 2115 torch.nn.modules.module._global_backward_pre_hooks 2116 ) 2117 2118 2119def nn_module_get_all_hooks( 2120 mod, 2121 check_forward_hooks=False, 2122 check_backward_hooks=False, 2123 check_state_dict_hooks=False, 2124): 2125 reset_code = torch._C._dynamo.eval_frame.reset_code 2126 """ 2127 Sometimes its useful to differentiate between types of hooks such as forward/backward/pre 2128 hooks executed during module.__call__, and state_dict hooks which are executed separately. 2129 """ 2130 hook_dicts_to_check = [] 2131 check_all_hooks = ( 2132 not check_forward_hooks 2133 and not check_backward_hooks 2134 and not check_state_dict_hooks 2135 ) 2136 if check_forward_hooks or check_all_hooks: 2137 hook_dicts_to_check.extend(forward_hook_names) 2138 if check_backward_hooks or check_all_hooks: 2139 hook_dicts_to_check.extend(backward_hook_names) 2140 if check_state_dict_hooks: 2141 hook_dicts_to_check.extend(state_dict_hook_names) 2142 2143 all_hooks = [] 2144 for hook_dict_name in hook_dicts_to_check: 2145 hooks = getattr(mod, hook_dict_name, []) 2146 for hook_name in hooks: 2147 hook = hooks[hook_name] 2148 2149 all_hooks.append(hook) 2150 return all_hooks 2151 2152 2153def nnmodule_has_hooks( 2154 mod, 2155 check_forward_hooks=False, 2156 check_backward_hooks=False, 2157 check_state_dict_hooks=False, 2158): 2159 """ 2160 Helper function to check if a module has any hooks attached to it. 2161 """ 2162 hooks = nn_module_get_all_hooks( 2163 mod, 2164 check_forward_hooks=check_forward_hooks, 2165 check_backward_hooks=check_backward_hooks, 2166 check_state_dict_hooks=check_state_dict_hooks, 2167 ) 2168 return bool(hooks) 2169 2170 2171def to_numpy_helper(value): 2172 """Convert tensor and tnp.ndarray to numpy.ndarray.""" 2173 if is_fake(value): 2174 return value 2175 if isinstance(value, tnp.ndarray): 2176 return to_numpy_helper(value.tensor) 2177 elif isinstance(value, torch.Tensor): 2178 return value.numpy(force=True) 2179 elif isinstance(value, (tuple, list)): 2180 return type(value)(to_numpy_helper(obj) for obj in value) 2181 else: 2182 return value 2183 2184 2185def numpy_to_tensor(value): 2186 """Convert tnp.ndarray to tensor, leave other types intact. If a list/tuple, loop through it to convert.""" 2187 assert np is not None 2188 if isinstance(value, np.ndarray): 2189 return torch.as_tensor(value) 2190 if isinstance(value, tnp.ndarray): 2191 return value.tensor 2192 elif isinstance(value, (tuple, list)): 2193 return type(value)(numpy_to_tensor(obj) for obj in value) 2194 else: 2195 return value 2196 2197 2198class numpy_to_tensor_wrapper: 2199 def __init__(self, f): 2200 self.f = f 2201 self.__name__ = "wrapped_" + self.f.__name__ 2202 2203 def __repr__(self): 2204 return f"<Wrapped function <original {self.f.__name__}>>" 2205 2206 def __call__(self, *args, **kwargs): 2207 out = self.f(*args, **kwargs) 2208 return numpy_to_tensor(out) 2209 2210 2211def numpy_attr_wrapper(obj, name): 2212 if isinstance(obj, tnp.ndarray): 2213 out = getattr(obj, name) 2214 return numpy_to_tensor(out) 2215 elif isinstance(obj, torch.Tensor): 2216 out = getattr(tnp.ndarray(obj), name) 2217 return numpy_to_tensor(out) 2218 2219 2220class numpy_method_wrapper: 2221 """Convert obj from torch.Tensor to tnp.ndarray and call method. Then convert result back to torch.Tensor.""" 2222 2223 def __init__(self, method: str): 2224 self.method = method 2225 self.__name__ = "wrapped_" + self.method 2226 2227 def __repr__(self): 2228 return f"<Wrapped method <original {self.method}>>" 2229 2230 def __call__(self, *args, **kwargs): 2231 obj = args[0] 2232 if isinstance(obj, torch.Tensor): 2233 obj = tnp.ndarray(obj) 2234 method_callable = getattr(obj, self.method) 2235 out = method_callable(*args[1:], **kwargs) 2236 return numpy_to_tensor(out) 2237 2238 2239class numpy_operator_wrapper: 2240 """Implements dunder methods for tnp.ndarray via functions from the operator library""" 2241 2242 def __init__(self, op: Callable[..., Any]): 2243 self.op = op 2244 self.__name__ = f"wrapped_{op.__name__}" 2245 2246 def __repr__(self): 2247 return f"<Wrapped operator <original {self.__name__}>>" 2248 2249 def __call__(self, *args, **kwargs): 2250 assert not kwargs 2251 2252 args = ( 2253 tnp.ndarray(arg) if isinstance(arg, torch.Tensor) else arg for arg in args 2254 ) 2255 out = self.op(*args) 2256 return numpy_to_tensor(out) 2257 2258 2259def defake(x): 2260 if not isinstance(x, FakeTensor): 2261 return x 2262 size: torch._prims_common.ShapeType 2263 stride: torch._prims_common.StrideType 2264 if x._has_symbolic_sizes_strides: 2265 size = [] 2266 for s in x.size(): 2267 if isinstance(s, torch.SymInt): 2268 size.append(s.node.shape_env.size_hint(s.node.expr)) 2269 else: 2270 size.append(s) 2271 stride = [] 2272 for s in x.stride(): 2273 if isinstance(s, torch.SymInt): 2274 stride.append(s.node.shape_env.size_hint(s.node.expr)) 2275 else: 2276 stride.append(s) 2277 else: 2278 size = x.size() 2279 stride = x.stride() 2280 y = torch.empty_strided( 2281 size, 2282 stride, 2283 dtype=x.dtype, 2284 device=x.device, 2285 requires_grad=x.requires_grad, 2286 ) 2287 y.zero_() 2288 return y 2289 2290 2291def is_utils_checkpoint(obj): 2292 # Lazy import to avoid circular dependencies 2293 import torch.utils.checkpoint 2294 2295 return obj is torch.utils.checkpoint.checkpoint 2296 2297 2298def build_checkpoint_variable(**options): 2299 import torch._higher_order_ops.wrap as higher_order_ops 2300 from .variables.higher_order_ops import TorchHigherOrderOperatorVariable 2301 2302 # TODO - This is a temporary situation where we have two versions of 2303 # checkpointing implementation. We will converge on one and remove the other. 2304 activation_checkpoint_op: torch._ops.HigherOrderOperator = ( 2305 higher_order_ops.tag_activation_checkpoint 2306 ) 2307 if torch._functorch.config.functionalize_rng_ops: 2308 activation_checkpoint_op = higher_order_ops.wrap_activation_checkpoint 2309 2310 return TorchHigherOrderOperatorVariable.make( 2311 activation_checkpoint_op, 2312 **options, 2313 ) 2314 2315 2316def is_compile_supported(device_type): 2317 from .eval_frame import is_dynamo_supported 2318 2319 compile_supported = is_dynamo_supported() 2320 if device_type == "cpu": 2321 pass 2322 elif device_type == "cuda" and compile_supported: 2323 compile_supported = has_triton() 2324 else: 2325 compile_supported = False 2326 return compile_supported 2327 2328 2329# The following 3.11 source code functions are adapted from 2330# https://github.com/python/cpython/blob/v3.11.4/Lib/traceback.py 2331# in order to output source code corresponding to bytecode in 3.11+. 2332# We need our own versions since we want to support multiline expressions. 2333def _fix_offset(str: str, offset: int) -> int: 2334 """ 2335 Convert byte offset `offset` of `str` into character offset. 2336 Byte offset is used for 3.11+ instruction column data. 2337 Takes things like unicode characters into consideration. 2338 2339 Unchanged from CPython implementation. 2340 """ 2341 as_utf8 = str.encode("utf-8") 2342 return len(as_utf8[:offset].decode("utf-8", errors="replace")) 2343 2344 2345@dataclasses.dataclass 2346class _Anchors: 2347 # inclusive 2348 left_end_lineno: int 2349 left_end_offset: int 2350 right_start_lineno: int 2351 # exclusive 2352 right_start_offset: int 2353 2354 2355def _extract_anchors_from_expr(segment: str) -> Optional[_Anchors]: 2356 """ 2357 Given source code `segment` corresponding to a bytecode 2358 instruction, determine: 2359 - for binary ops, the location of the binary op 2360 - for indexing, the location of the brackets. 2361 `segment` is expected to be a valid Python expression 2362 """ 2363 assert sys.version_info >= (3, 11) 2364 2365 import ast 2366 2367 try: 2368 # Without brackets, `segment` is parsed as a statement. 2369 # We expect an expression, so wrap `segment` in 2370 # brackets to handle multi-line expressions. 2371 tree = ast.parse("(\n" + segment + "\n)") 2372 except SyntaxError: 2373 return None 2374 2375 if len(tree.body) != 1: 2376 return None 2377 2378 lines = segment.split("\n") 2379 2380 # get character index given byte offset 2381 def normalize(lineno, offset): 2382 return _fix_offset(lines[lineno], offset) 2383 2384 # Gets the next valid character index in `lines`, if 2385 # the current location is not valid. Handles empty lines. 2386 def next_valid_char(lineno, col): 2387 while lineno < len(lines) and col >= len(lines[lineno]): 2388 col = 0 2389 lineno += 1 2390 assert lineno < len(lines) and col < len(lines[lineno]) 2391 return lineno, col 2392 2393 # Get the next valid character index in `lines`. 2394 def increment(lineno, col): 2395 col += 1 2396 lineno, col = next_valid_char(lineno, col) 2397 assert lineno < len(lines) and col < len(lines[lineno]) 2398 return lineno, col 2399 2400 # Get the next valid character at least on the next line 2401 def nextline(lineno, col): 2402 col = 0 2403 lineno += 1 2404 lineno, col = next_valid_char(lineno, col) 2405 assert lineno < len(lines) and col < len(lines[lineno]) 2406 return lineno, col 2407 2408 statement = tree.body[0] 2409 if isinstance(statement, ast.Expr): 2410 expr = statement.value 2411 if isinstance(expr, ast.BinOp): 2412 # ast gives locations for BinOp subexpressions, e.g. 2413 # ( left_expr ) + ( right_expr ) 2414 # left^^^^^ right^^^^^ 2415 # -2 since end_lineno is 1-indexed and because we added an extra 2416 # bracket to `segment` when calling ast.parse 2417 cur_lineno = cast(int, expr.left.end_lineno) - 2 2418 cur_col = normalize(cur_lineno, expr.left.end_col_offset) 2419 cur_lineno, cur_col = next_valid_char(cur_lineno, cur_col) 2420 2421 # Heuristic to find the operator character. 2422 # The original CPython implementation did not look for ), \, or #, 2423 # leading to incorrect anchor location, e.g. 2424 # (x) + (y) 2425 # ~~^~~~~~~ 2426 while (ch := lines[cur_lineno][cur_col]).isspace() or ch in ")\\#": 2427 if ch in "\\#": 2428 cur_lineno, cur_col = nextline(cur_lineno, cur_col) 2429 else: 2430 cur_lineno, cur_col = increment(cur_lineno, cur_col) 2431 2432 # binary op is 1 or 2 characters long, on the same line 2433 right_col = cur_col + 1 2434 if ( 2435 right_col < len(lines[cur_lineno]) 2436 and not (ch := lines[cur_lineno][right_col]).isspace() 2437 and ch not in "\\#" 2438 ): 2439 right_col += 1 2440 # right_col can be invalid since it is exclusive 2441 2442 return _Anchors(cur_lineno, cur_col, cur_lineno, right_col) 2443 elif isinstance(expr, ast.Subscript): 2444 # ast gives locations for value and slice subexpressions, e.g. 2445 # ( value_expr ) [ slice_expr ] 2446 # value^^^^^ slice^^^^^ 2447 # subscript^^^^^^^^^^^^^^^^^^^^ 2448 # find left bracket (first '[' after value) 2449 left_lineno = cast(int, expr.value.end_lineno) - 2 2450 left_col = normalize(left_lineno, expr.value.end_col_offset) 2451 left_lineno, left_col = next_valid_char(left_lineno, left_col) 2452 while lines[left_lineno][left_col] != "[": 2453 left_lineno, left_col = increment(left_lineno, left_col) 2454 # find right bracket (final character of expression) 2455 right_lineno = cast(int, expr.end_lineno) - 2 2456 right_col = normalize(right_lineno, expr.end_col_offset) 2457 return _Anchors(left_lineno, left_col, right_lineno, right_col) 2458 elif isinstance(expr, ast.Call): 2459 # ( func_expr ) (args, kwargs) 2460 # func^^^^^ 2461 # call^^^^^^^^^^^^^^^^^^^^^^^^ 2462 # find left bracket (first '(' after func) 2463 left_lineno = cast(int, expr.func.end_lineno) - 2 2464 left_col = normalize(left_lineno, expr.func.end_col_offset) 2465 left_lineno, left_col = next_valid_char(left_lineno, left_col) 2466 while lines[left_lineno][left_col] != "(": 2467 left_lineno, left_col = increment(left_lineno, left_col) 2468 # find right bracket (final character of expression) 2469 right_lineno = cast(int, expr.end_lineno) - 2 2470 right_col = normalize(right_lineno, expr.end_col_offset) 2471 return _Anchors(left_lineno, left_col, right_lineno, right_col) 2472 2473 return None 2474 2475 2476def get_instruction_source_311(code: types.CodeType, inst: dis.Instruction) -> str: 2477 """ 2478 Python 3.11+ only. Returns lines of source code (from code object `code`) 2479 corresponding to `inst`'s location data, and underlines relevant code to `inst`. 2480 2481 Example: CALL on `g`: 2482 f(g( 2483 ^^ 2484 h(x))) 2485 ^^^^^ 2486 2487 We need our own implementation since `format_frame_summary` in 2488 Python's `traceback` module doesn't handle multi-line expressions 2489 (and their anchor extraction code is not completely correct). 2490 """ 2491 assert inst.positions is not None 2492 if inst.positions.lineno is None: 2493 return "" 2494 # The rstrip + "\n" pattern is used throughout this function to handle 2495 # linecache.getline errors. Error lines are treated as empty strings "", but we want 2496 # to treat them as blank lines "\n". 2497 first_line = linecache.getline(code.co_filename, inst.positions.lineno).rstrip() 2498 if inst.positions.end_lineno is None: 2499 return first_line 2500 if inst.positions.col_offset is None or inst.positions.end_col_offset is None: 2501 return first_line 2502 2503 # character index of the start of the instruction 2504 start_offset = _fix_offset(first_line, inst.positions.col_offset) 2505 # character index of the end of the instruction 2506 # compute later since end may be a different line 2507 end_offset = None 2508 # expression corresponding to the instruction so we can get anchors 2509 segment = "" 2510 # underline markers to be printed - start with `~` marker and replace with `^` later 2511 markers = [] 2512 2513 # Compute segment and initial markers 2514 if inst.positions.end_lineno == inst.positions.lineno: 2515 end_offset = _fix_offset(first_line, inst.positions.end_col_offset) 2516 segment = first_line[start_offset:end_offset] 2517 markers.append(" " * start_offset + "~" * (end_offset - start_offset)) 2518 else: 2519 segment = first_line[start_offset:] + "\n" 2520 markers.append(" " * start_offset + "~" * (len(first_line) - start_offset)) 2521 last_line = linecache.getline( 2522 code.co_filename, inst.positions.end_lineno 2523 ).rstrip() 2524 end_offset = _fix_offset(last_line, inst.positions.end_col_offset) 2525 for lineno in range(inst.positions.lineno + 1, inst.positions.end_lineno): 2526 line = linecache.getline(code.co_filename, lineno).rstrip() 2527 segment += line + "\n" 2528 # don't underline leading spaces 2529 num_spaces = len(line) - len(line.lstrip()) 2530 markers.append(" " * num_spaces + "~" * (len(line) - num_spaces)) 2531 segment += last_line[:end_offset] 2532 num_spaces = len(last_line) - len(last_line.lstrip()) 2533 markers.append(" " * num_spaces + "~" * (end_offset - num_spaces)) 2534 2535 anchors: Optional[_Anchors] = None 2536 try: 2537 anchors = _extract_anchors_from_expr(segment) 2538 except AssertionError: 2539 pass 2540 2541 # replace `~` markers with `^` where necessary 2542 if anchors is None: 2543 markers = [marker.replace("~", "^") for marker in markers] 2544 else: 2545 # make markers mutable 2546 mutable_markers: List[List[str]] = [list(marker) for marker in markers] 2547 2548 # anchor positions do not take start_offset into account 2549 if anchors.left_end_lineno == 0: 2550 anchors.left_end_offset += start_offset 2551 if anchors.right_start_lineno == 0: 2552 anchors.right_start_offset += start_offset 2553 2554 # Turn `~`` markers between anchors to `^` 2555 for lineno in range(len(markers)): 2556 for col in range(len(mutable_markers[lineno])): 2557 if lineno < anchors.left_end_lineno: 2558 continue 2559 if lineno == anchors.left_end_lineno and col < anchors.left_end_offset: 2560 continue 2561 if ( 2562 lineno == anchors.right_start_lineno 2563 and col >= anchors.right_start_offset 2564 ): 2565 continue 2566 if lineno > anchors.right_start_lineno: 2567 continue 2568 if mutable_markers[lineno][col] == "~": 2569 mutable_markers[lineno][col] = "^" 2570 2571 # make markers into strings again 2572 markers = ["".join(marker) for marker in mutable_markers] 2573 2574 result = "" 2575 for i in range(len(markers)): 2576 result += ( 2577 linecache.getline(code.co_filename, inst.positions.lineno + i).rstrip() 2578 + "\n" 2579 ) 2580 result += markers[i] + "\n" 2581 return result 2582 2583 2584def get_static_address_type(t): 2585 if isinstance(t, torch.Tensor): 2586 return getattr(t, "_dynamo_static_input_type", None) 2587 2588 return None 2589 2590 2591def is_rng_state_getter_or_setter(value): 2592 getters = ( 2593 # The following two functions are not identical, so don't remove anyone! 2594 torch._C.Generator.get_state, 2595 torch.default_generator.get_state, 2596 torch.get_rng_state, 2597 torch.cuda.get_rng_state, 2598 ) 2599 setters = ( 2600 torch._C.Generator.set_state, 2601 torch.default_generator.set_state, 2602 torch.set_rng_state, 2603 torch.cuda.set_rng_state, 2604 ) 2605 return value in (*setters, *getters) 2606 2607 2608def is_tensor_base_attr_getter(value): 2609 return ( 2610 isinstance(value, types.MethodWrapperType) 2611 and value.__name__ == "__get__" 2612 and value.__self__.__objclass__ is torch._C._TensorBase # type: ignore[attr-defined] 2613 ) 2614 2615 2616def is_torch_function_object(value): 2617 return hasattr(value, "__torch_function__") 2618 2619 2620def has_torch_function(vt: "torch._dynamo.variables.base.VariableTracker") -> bool: 2621 from torch._dynamo.variables import LazyVariableTracker, UserDefinedObjectVariable 2622 from torch._dynamo.variables.torch_function import TensorWithTFOverrideVariable 2623 2624 if isinstance(vt, TensorWithTFOverrideVariable): 2625 return True 2626 2627 if isinstance(vt, LazyVariableTracker): 2628 LazyVariableTracker.realize(vt) 2629 2630 return isinstance(vt, UserDefinedObjectVariable) and hasattr( 2631 vt.value, "__torch_function__" 2632 ) 2633 2634 2635# see note [Tensor Fakification and Symbol Caching] 2636def to_fake_tensor(t, fake_mode): 2637 symbolic_context = None 2638 source = None 2639 if tracing_context := torch._guards.TracingContext.try_get(): 2640 if t in tracing_context.tensor_to_context: 2641 symbolic_context = tracing_context.tensor_to_context[t] 2642 source = symbolic_context.tensor_source 2643 2644 return fake_mode.from_tensor( 2645 t, static_shapes=False, symbolic_context=symbolic_context, source=source 2646 ) 2647 2648 2649def get_first_attr(obj, *attrs): 2650 """ 2651 Return the first available attribute or throw an exception if none is present. 2652 """ 2653 for attr in attrs: 2654 if hasattr(obj, attr): 2655 return getattr(obj, attr) 2656 2657 raise AssertionError(f"{obj} does not has any of the attributes: {attrs}") 2658 2659 2660@contextlib.contextmanager 2661def maybe_enable_compiled_autograd(should_enable): 2662 def compiler_fn(gm): 2663 def inner_compiler(gm_, example_inputs_): 2664 torch._dynamo.utils.counters["compiled_autograd"]["compiles"] += 1 2665 return torch._inductor.compile(gm_, example_inputs_) 2666 2667 return torch.compile(gm, backend=inner_compiler, fullgraph=True, dynamic=True) 2668 2669 if should_enable: 2670 with torch._dynamo.compiled_autograd.enable(compiler_fn) as ctx: 2671 yield ctx 2672 else: 2673 yield 2674 2675 2676def invalid_removeable_handle(): 2677 # need a subclass so weakref works 2678 class Invalid(dict): # type: ignore[type-arg] 2679 pass 2680 2681 return RemovableHandle(Invalid()) 2682 2683 2684# Returns a "proxy" (new object with the same class and dict) for (non-GraphModule) nn.Module's. 2685# Attribute changes to the original object/proxy will be reflected in the other. 2686# This is useful for cases where we want a keep-alive reference to a module without increasing 2687# its reference count. 2688def nn_module_proxy(mod): 2689 if not isinstance(mod, torch.nn.Module): 2690 return mod 2691 if isinstance(mod, torch.fx.GraphModule): 2692 # Dynamo-generated GM's shouldn't contain user-created GM's 2693 return mod 2694 proxy = mod.__class__.__new__(mod.__class__) 2695 proxy.__dict__ = mod.__dict__ 2696 return proxy 2697 2698 2699class GmWrapper(torch.nn.Module): 2700 def __init__(self, gm, spec): 2701 super().__init__() 2702 self.gm = gm 2703 self.spec = spec 2704 2705 def forward(self, *args): 2706 args: List[Any] = list(args) 2707 return self.gm(*pytree.tree_unflatten(args, self.spec)) 2708 2709 2710def flatten_graph_inputs(gm: torch.fx.GraphModule, inputs, compile_gm): 2711 """ 2712 Mutate inputs so that they are flat and wrap gm such that it 2713 accepts those inputs. This is needed for graphs that take 2714 bumpy inputs. 2715 """ 2716 inputs, spec = pytree.tree_flatten(inputs) 2717 compiled_fn = compile_gm(GmWrapper(gm, spec), inputs) 2718 2719 idx_to_steal = [ 2720 i 2721 for i, node in enumerate(gm.graph.nodes) 2722 if node.op == "placeholder" and node.meta.get("steal_arg", False) 2723 ] 2724 2725 def wrapper(*args): 2726 # note this doesn't check the spec, assuming it is the same 2727 flat_args = pytree.arg_tree_leaves(*args) 2728 2729 # flat_args is a new list, so we need to clear references from the old list 2730 for i in idx_to_steal: 2731 args[i].clear() 2732 2733 # this call is boxed to avoid increasing refcount until we reach aot_module_simplified forward 2734 return compiled_fn(flat_args) 2735 2736 return wrapper 2737 2738 2739def get_locals_to_steal(maybe_gm): 2740 if not isinstance(maybe_gm, torch.fx.GraphModule) or not hasattr(maybe_gm, "meta"): 2741 return [] 2742 return maybe_gm.meta.get("locals_to_steal", []) 2743 2744 2745def set_locals_to_steal(gm, locals_to_steal): 2746 gm.meta["locals_to_steal"] = locals_to_steal 2747 2748 2749class Lit: 2750 def __init__(self, s): 2751 self.s = s 2752 2753 def __repr__(self): 2754 return self.s 2755 2756 2757warn_once_cache: Set[str] = set() 2758 2759 2760def warn_once(msg, stacklevel=1): 2761 # Dynamo causes all warnings.warn (in user code and in Dynamo code) to print all the time. 2762 # https://github.com/pytorch/pytorch/issues/128427. 2763 # warn_once is a workaround: if the msg has been warned on before, then we will not 2764 # warn again. 2765 # NB: it's totally ok to store a cache of all the strings: this is what warnings.warn does as well. 2766 if msg in warn_once_cache: 2767 return 2768 warn_once_cache.add(msg) 2769 warnings.warn(msg, stacklevel=stacklevel + 1) 2770