1# mypy: allow-untyped-defs 2from __future__ import annotations 3 4import collections 5import contextlib 6import dataclasses 7import enum 8import functools 9import inspect 10import io 11import itertools 12import logging 13import math 14import operator 15import os 16import platform 17import shutil 18import sys 19import tempfile 20import textwrap 21import time 22import unittest 23from datetime import datetime 24from io import StringIO 25from typing import ( 26 Any, 27 Callable, 28 Dict, 29 Generic, 30 Iterable, 31 List, 32 NamedTuple, 33 Optional, 34 Protocol, 35 Sequence, 36 Set, 37 TypeVar, 38 Union, 39 ValuesView, 40) 41from typing_extensions import Concatenate, ParamSpec 42from unittest import mock 43 44import sympy 45 46import torch 47 48 49GPU_TYPES = ["cuda", "xpu"] 50 51 52# defines here before import torch._dynamo is for avoiding circular import 53# when get_gpu_type is imported from dynamo 54@functools.lru_cache(None) 55def get_gpu_type(): 56 avail_gpus = [x for x in GPU_TYPES if getattr(torch, x).is_available()] 57 assert len(avail_gpus) <= 1 58 gpu_type = "cuda" if len(avail_gpus) == 0 else avail_gpus.pop() 59 return gpu_type 60 61 62from torch._dynamo.device_interface import get_interface_for_device 63from torch._dynamo.utils import detect_fake_mode 64from torch.autograd import DeviceType 65from torch.autograd.profiler_util import EventList 66from torch.fx.passes.graph_transform_observer import GraphTransformObserver 67from torch.fx.passes.shape_prop import ShapeProp 68from torch.utils._sympy.functions import ( 69 CeilDiv, 70 CleanDiv, 71 FloorDiv, 72 Identity, 73 ModularIndexing, 74) 75from torch.utils._sympy.symbol import make_symbol, SymT 76from torch.utils._sympy.value_ranges import bound_sympy, ValueRanges 77 78from . import config 79from .runtime.runtime_utils import ceildiv as runtime_ceildiv 80 81 82_IS_WINDOWS = sys.platform == "win32" 83 84log = logging.getLogger(__name__) 85 86_T = TypeVar("_T") 87VarRanges = Dict[sympy.Expr, sympy.Expr] 88InputType = Union[torch.Tensor, int] 89 90 91GPU_ALIGN_BYTES = 16 92ALIGNMENT = 16 93 94ALIGN_BYTES = 64 95assert (ALIGN_BYTES & (ALIGN_BYTES - 1)) == 0 and ALIGN_BYTES >= 8, "must be power of 2" 96 97 98def _align(nbytes): 99 """Round up to the nearest multiple of ALIGN_BYTES""" 100 return (nbytes + ALIGN_BYTES - 1) & -ALIGN_BYTES 101 102 103def _is_aligned(v: sympy.Expr): 104 """v can be statically proven to be a multiple of ALIGN_BYTES""" 105 if isinstance(v, (sympy.Add, sympy.Max)): 106 return all(map(_is_aligned, v.args)) 107 return isinstance(v, align) or sympy.gcd(v, ALIGN_BYTES) == ALIGN_BYTES 108 109 110class align(sympy.Function): 111 """Symbolically round up to the nearest multiple of ALIGN_BYTES""" 112 113 nargs = (1,) 114 is_integer = True 115 116 @classmethod 117 def eval(cls, value): 118 if isinstance(value, (int, sympy.Integer)): 119 return _align(int(value)) 120 if _is_aligned(value): 121 return value 122 123 124def do_bench_using_profiling(fn: Callable[[], Any], warmup=25, rep=100) -> float: 125 """ 126 Returns benchmark results by examining torch profiler events. 127 This could be more accurate as it doesn't count CPU side overhead. 128 However, this also requires manually excluding irrelevant event, e.g. 129 vectorized_elementwise_kernel which is used to fill L2 cache, 130 various CUDA events, etc, so could also be fragile. 131 """ 132 133 fn() 134 torch.cuda.synchronize() 135 cache = torch.empty(int(256e6 // 4), dtype=torch.int, device="cuda") 136 137 # Estimate the runtime of the function 138 start_event = torch.cuda.Event(enable_timing=True) 139 end_event = torch.cuda.Event(enable_timing=True) 140 start_event.record() 141 for _ in range(5): 142 cache.zero_() 143 fn() 144 end_event.record() 145 torch.cuda.synchronize() 146 estimate_ms = start_event.elapsed_time(end_event) / 5 147 148 # compute number of warmup and repeat 149 n_warmup = max(1, int(warmup / estimate_ms)) 150 n_repeat = max(1, int(rep / estimate_ms)) 151 152 # Warm-up 153 for _ in range(n_warmup): 154 fn() 155 156 with torch.profiler.profile( 157 activities=[ 158 torch.profiler.ProfilerActivity.CUDA, 159 ] 160 ) as p: 161 # Benchmark 162 for i in range(n_repeat): 163 # we clear the L2 cache before each run 164 cache.zero_() 165 # record time of `fn` 166 fn() 167 # Record clocks 168 torch.cuda.synchronize() 169 170 log.debug("raw events") 171 log.debug(p.key_averages().table(sort_by="self_device_time_total", row_limit=-1)) 172 173 filtered_events = EventList( 174 [ 175 event 176 for event in p.events() 177 if event.device_type == DeviceType.CUDA and event.name != "Context Sync" 178 ] 179 ) 180 if len(filtered_events) % n_repeat != 0: 181 raise RuntimeError( 182 "Failed to divide all profiling events into #repeat groups. " 183 "#CUDA events: %d, #repeats: %s", 184 len(filtered_events), 185 n_repeat, 186 ) 187 num_event_per_group = len(filtered_events) / n_repeat 188 actual_events = EventList( 189 [ 190 event 191 for i, event in enumerate(filtered_events) 192 if i % num_event_per_group != 0 193 ] 194 ) 195 actual_events._build_tree() 196 actual_events = actual_events.key_averages() 197 198 log.debug("profiling time breakdown") 199 log.debug(actual_events.table(row_limit=-1)) 200 201 res = sum(event.device_time_total for event in actual_events) / 1000.0 / n_repeat 202 log.debug("profiling results: %s ms", res) 203 return res 204 205 206@functools.lru_cache(None) 207def has_torchvision_roi_align() -> bool: 208 try: 209 from torchvision.ops import roi_align # noqa: F401 210 211 torch._C._dispatch_has_kernel_for_dispatch_key("torchvision::nms", "Meta") 212 return roi_align is not None and hasattr( 213 getattr(torch.ops, "torchvision", None), "roi_align" 214 ) 215 except ImportError: 216 return False 217 except RuntimeError as e: 218 assert "torchvision::nms does not exist" in str(e) 219 return False 220 221 222def decode_device(device: Union[Optional[torch.device], str]) -> torch.device: 223 if device is None: 224 return torch.tensor(0.0).device # default device 225 if isinstance(device, str): 226 device = torch.device(device) 227 if device.type not in ("cpu", "meta") and device.index is None: 228 device_interface = get_interface_for_device(device.type) 229 return torch.device(device.type, index=device_interface.Worker.current_device()) 230 return device 231 232 233def sympy_product(it): 234 return functools.reduce(operator.mul, it, sympy.Integer(1)) 235 236 237def sympy_dot(seq1, seq2): 238 assert len(seq1) == len(seq2) 239 return sympy.expand(sum(a * b for a, b in zip(seq1, seq2))) 240 241 242def unique(it: Iterable[_T]) -> ValuesView[_T]: 243 return {id(x): x for x in it}.values() 244 245 246def ceildiv( 247 numer: Union[int, sympy.Expr], denom: Union[int, sympy.Expr] 248) -> Union[int, sympy.Expr]: 249 if isinstance(numer, sympy.Expr) or isinstance(denom, sympy.Expr): 250 return CeilDiv(sympy.sympify(numer), sympy.sympify(denom)) 251 # TODO: There is a bug in a call to this function, to repro: 252 # python benchmarks/dynamo/huggingface.py --inductor -d cuda --accuracy 253 # --amp --only YituTechConvBert --dynamic-shapes 254 assert isinstance(numer, int) and isinstance( 255 denom, int 256 ), f"{numer}: {type(numer)}, {denom}: {type(denom)}" 257 return runtime_ceildiv(numer, denom) 258 259 260def _type_of(key): 261 # Use the function here to get rid of dependencies on the Triton during the codegen. 262 # Refer to Triton implementation here: 263 # https://github.com/openai/triton/blob/98b5945d2aef679e00ebca8e07c35c3658ec76de/python/triton/runtime/jit.py#L238 264 # `None` is nullptr. Implicitly convert to *i8. 265 if key is None: 266 return "*i8" 267 dtype_str = str(key).split(".")[-1] 268 tys = { 269 "bool": "i1", 270 "float8e4nv": "fp8e4nv", 271 "float8e5": "fp8e5", 272 "float8e4b15": "fp8e4b15", 273 "float8e4b15x4": "fp8e4b15x4", 274 "float8_e4m3fn": "fp8e4nv", 275 "float8_e5m2": "fp8e5", 276 "float16": "fp16", 277 "bfloat16": "bf16", 278 "float32": "fp32", 279 "float64": "fp64", 280 "int8": "i8", 281 "int16": "i16", 282 "int32": "i32", 283 "int64": "i64", 284 "uint8": "u8", 285 "uint16": "u16", 286 "uint32": "u32", 287 "uint64": "u64", 288 } 289 # reinterpret can create triton type 290 for v in list(tys.values()): 291 tys[v] = v 292 return key if isinstance(key, str) else f"*{tys[dtype_str]}" 293 294 295def convert_shape_to_inductor( 296 lst: Iterable[Union[int, torch.SymInt]] 297) -> List[sympy.Expr]: 298 """ 299 Gets the shape and stride of a tensor. For non-symbolic tensors, this is 300 trivial. But for symbolic tensors, we need to map from SymIntNode into 301 sympy.Expr. 302 """ 303 return [sympy.sympify(i) for i in lst] 304 305 306def convert_shape_to_symint( 307 lst: Iterable[Union[int, sympy.Expr]] 308) -> List[Union[int, torch.SymInt]]: 309 """ 310 Takes a list of shapes from Inductor and converts them into symints (or just 311 ints if all shapes are static). 312 """ 313 from .virtualized import V 314 315 return [ 316 i 317 if isinstance(i, int) 318 else int(i) 319 if isinstance(i, sympy.Integer) 320 else V.graph.sizevars.shape_env.create_symintnode(i, hint=None) 321 for i in lst 322 ] 323 324 325def is_view(op: torch._ops.OpOverload): 326 """ 327 Does this op overload have aliasing 328 """ 329 assert isinstance(op, torch._ops.OpOverload) 330 return any(a.alias_info is not None for a in op._schema.arguments) 331 332 333def is_pointwise_use( 334 use, is_pointwise_fn: Optional[Callable[[torch._ops.OpOverload], bool]] = None 335): 336 """ 337 Do all uses of this op have torch.Tag.pointwise or return True for optional `is_pointwise_fn` 338 339 Uses in views ops will follow the views uses 340 """ 341 342 if not use.op == "call_function": 343 return False 344 345 if not ( 346 isinstance(use.target, torch._ops.OpOverload) or use.target is operator.getitem 347 ): 348 return False 349 350 if use.target is operator.getitem or is_view(use.target): 351 return all(is_pointwise_use(u, is_pointwise_fn) for u in use.users) 352 353 return torch.Tag.pointwise in use.target.tags or ( 354 is_pointwise_fn is not None and is_pointwise_fn(use.target) 355 ) 356 357 358def gen_gm_and_inputs(target, args, kwargs): 359 g = torch.fx.Graph() 360 g_args = [] 361 a_args = [] 362 for n, arg in enumerate(args): 363 if isinstance(arg, torch.Tensor): 364 g_args.append(g.placeholder(f"arg{n}")) 365 a_args.append(arg) 366 else: 367 g_args.append(arg) 368 assert all(not isinstance(x, torch.Tensor) for x in kwargs.values()) 369 node = g.call_function(target, tuple(g_args), kwargs) 370 if ( 371 len(target._schema.returns) == 1 372 and str(target._schema.returns[0].type) == "Tensor" 373 ): 374 node = (node,) # type: ignore[assignment] 375 g.output(node) 376 377 gm = torch.fx.GraphModule({}, g) 378 return gm, a_args 379 380 381def synchronize(device: str = "cuda"): 382 if device == "cpu": 383 return 384 device_interface = get_interface_for_device(device) 385 if device_interface.is_available(): 386 device_interface.synchronize() 387 388 389def timed( 390 model: Callable[..., Any], example_inputs, times: int = 1, device: str = "cuda" 391) -> float: 392 synchronize(device) 393 torch.manual_seed(1337) 394 t0 = time.perf_counter() 395 for _ in range(times): 396 result = model(*example_inputs) 397 synchronize(device) 398 t1 = time.perf_counter() 399 # GC the result after timing 400 assert result is not None # type: ignore[possibly-undefined] 401 return t1 - t0 402 403 404def print_performance( 405 fn, args=(), times=10, repeat=10, baseline=1.0, device: str = "cuda" 406): 407 timings = torch.tensor([timed(fn, args, times, device) for _ in range(repeat)]) 408 took = torch.median(timings) / times 409 print(f"{took / baseline:.6f}") 410 return took 411 412 413def precompute_method(obj: Any, method: str): 414 """Replace obj.method() with a new method that returns a precomputed constant.""" 415 result = getattr(obj, method)() 416 setattr(obj, method, lambda: result) 417 418 419def precompute_methods(obj: Any, methods: List[str]): 420 """Replace methods with new methods that returns a precomputed constants.""" 421 for method in methods: 422 precompute_method(obj, method) 423 424 425def cmp(a, b) -> int: 426 return int(a > b) - int(a < b) 427 428 429def pad_listlike(x, size): 430 if len(x) == 1: 431 return type(x)([x[0]]) * size 432 else: 433 return x 434 435 436# Used to ensure that iterating over a set is deterministic 437def tuple_sorted(x): 438 if len(x) == 0: 439 return [] 440 441 def sort_func(elem): 442 if isinstance(elem, str): 443 return elem 444 else: 445 # We expect `elem` to be `scheduler.BaseSchedulerNode` type here, 446 # but we are not able to do isinstance assert because of circular dependency 447 return elem.get_name() 448 449 return sorted(x, key=sort_func) 450 451 452P = ParamSpec("P") 453RV = TypeVar("RV", covariant=True) 454 455 456class CachedMethod(Protocol, Generic[P, RV]): 457 @staticmethod 458 def clear_cache(self) -> None: 459 ... 460 461 def __call__(self, *args: P.args, **kwargs: P.kwargs) -> RV: 462 ... 463 464 465# See https://github.com/python/mypy/issues/13222#issuecomment-1193073470 to understand the type signature 466def cache_on_self(fn: Callable[Concatenate[Any, P], RV]) -> CachedMethod[P, RV]: 467 key = f"__{fn.__name__}_cache" 468 469 @functools.wraps(fn) 470 def wrapper(self): 471 if not hasattr(self, key): 472 setattr(self, key, fn(self)) 473 return getattr(self, key) 474 475 def clear_cache(self): 476 if hasattr(self, key): 477 delattr(self, key) 478 479 wrapper.clear_cache = clear_cache # type: ignore[attr-defined] 480 return wrapper # type: ignore[return-value] 481 482 483def aggregate_origins(node_schedule): 484 from . import ir 485 486 if isinstance(node_schedule, list): 487 return functools.reduce( 488 operator.or_, 489 [ 490 node.node.origins 491 for node in node_schedule 492 if hasattr(node, "node") and node.node 493 ], 494 set(), 495 ) 496 elif isinstance(node_schedule, ir.ExternKernel): 497 return node_schedule.origins 498 else: 499 return set() 500 501 502def get_fused_kernel_name(node_schedule, descriptive_names): 503 all_origins = aggregate_origins(node_schedule) 504 if descriptive_names == "original_aten": 505 # Bases the kernel name off of the top-level aten operator (i.e. pre-decompositions) 506 sources = [ 507 origin.meta["original_aten"]._overloadpacket.__name__ 508 for origin in all_origins 509 if origin.op == "call_function" 510 and "original_aten" in origin.meta 511 and origin.meta["original_aten"] is not None 512 ] 513 sources = sorted(set(sources)) 514 elif descriptive_names == "torch": 515 # Bases the kernel name off of the top-level "torch" operator (i.e. post-dynamo graph) 516 sources = [] 517 for origin in all_origins: 518 if origin.op == "call_function" and "source_fn_stack" in origin.meta: 519 source_fn = origin.meta["source_fn_stack"][-1] 520 if isinstance(source_fn[1], str): 521 sources.append(source_fn[1]) 522 else: 523 sources.append(source_fn[1].__name__) 524 sources = sorted(set(sources)) 525 elif descriptive_names == "inductor_node": 526 sources = [ 527 origin.name for origin in all_origins if origin.op == "call_function" 528 ] 529 else: 530 raise NotImplementedError 531 sources = sources 532 return "_".join(["fused"] + sources) 533 534 535def get_kernel_metadata(node_schedule, wrapper): 536 all_origins = aggregate_origins(node_schedule) 537 inductor_nodes = [origin for origin in all_origins if origin.op == "call_function"] 538 539 from_node_dict = collections.defaultdict(list) 540 original_aten_dict = collections.defaultdict(list) 541 542 # Attempt to sort `inductor_nodes` topologically. Note that the case 543 # where `inductor_nodes` contains nodes from multiple graph instances 544 # is not supported. An example of this is conditional statements. 545 single_graph = None 546 if len(inductor_nodes): 547 unique_graphs = {n.graph for n in inductor_nodes} 548 if len(unique_graphs) == 1: 549 single_graph = inductor_nodes[0].graph 550 # create a map of idx -> node and cache it 551 if not hasattr(single_graph, "_inductor_kernel_metadata_node_to_idx_map"): 552 node_to_idx_map = {} 553 for idx, n in enumerate(single_graph.nodes): 554 node_to_idx_map[n] = idx 555 single_graph._inductor_kernel_metadata_node_to_idx_map = node_to_idx_map 556 inductor_nodes.sort( 557 key=lambda n: single_graph._inductor_kernel_metadata_node_to_idx_map[n] 558 ) 559 560 for node in inductor_nodes: 561 if "original_aten" in node.meta and node.meta["original_aten"] is not None: 562 key = str(node.meta["original_aten"]._overloadpacket) 563 original_aten_dict[key].append(node.name) 564 if "from_node" in node.meta: 565 key = node.meta["from_node"][0][0] 566 from_node_dict[key].append(node.name) 567 sort_str = "Topologically Sorted" if single_graph is not None else "Unsorted" 568 metadata = ( 569 f"{wrapper.comment} {sort_str} Source Nodes: [{', '.join(from_node_dict.keys())}], " 570 f"Original ATen: [{', '.join(original_aten_dict.keys())}]" 571 ) 572 573 # trace back to original node here 574 detailed_metadata = [f"{wrapper.comment} Source node to ATen node mapping:"] 575 for original_node, nodes in sorted(from_node_dict.items()): 576 detailed_metadata.append( 577 f"{wrapper.comment} {original_node} => {', '.join(sorted(nodes))}" 578 ) 579 580 # print the aot_autograd graph fragment 581 if single_graph is not None: 582 detailed_metadata.append(f"{wrapper.comment} Graph fragment:") 583 for n in inductor_nodes: 584 # TODO(future): maybe refactor torch/fx/graph.py to make it easy to 585 # generate python code for graph fragments 586 detailed_metadata.append(f"{wrapper.comment} {n.format_node()}") 587 588 return metadata, "\n".join(detailed_metadata) 589 590 591def dominated_nodes( 592 initial_queue: Iterable[torch.fx.Node], skip_filter=None 593) -> Set[torch.fx.Node]: 594 """Returns the set of nodes whose values depend on those within initial_queue""" 595 initial_queue = list(initial_queue) 596 dominated_set = set(initial_queue) 597 598 while initial_queue: 599 node = initial_queue.pop() 600 for user in node.users: 601 if skip_filter and skip_filter(user): 602 continue 603 if user not in dominated_set: 604 dominated_set.add(user) 605 initial_queue.append(user) 606 607 return dominated_set 608 609 610def gather_origins(args, kwargs): 611 import itertools 612 613 from . import ir 614 615 def is_unrealized_node(n): 616 if isinstance(n, ir.TensorBox): 617 return is_unrealized_node(n.data) 618 if isinstance(n, ir.StorageBox): 619 return is_unrealized_node(n.data) 620 return isinstance(n, ir.IRNode) and isinstance(n, ir.Pointwise) 621 622 kwarg_origins = [val.origins for val in kwargs.values() if is_unrealized_node(val)] 623 arg_origins = [arg.origins for arg in args if is_unrealized_node(arg)] 624 return set(itertools.chain(*arg_origins, *kwarg_origins)) 625 626 627def sympy_str(expr: sympy.Expr) -> str: 628 """ 629 Normal sympy str is very slow, this is a lot faster. The result are 630 somewhat worse, as it doesn't do as much simplification. So don't 631 use this for final codegen. 632 """ 633 if isinstance(expr, sympy.Symbol): 634 return expr.name 635 if isinstance(expr, sympy.Add): 636 return " + ".join(map(sympy_str, expr.args)) 637 if isinstance(expr, sympy.Mul): 638 return " * ".join(map(sympy_str, expr.args)) 639 640 if isinstance(expr, (ModularIndexing, CleanDiv, FloorDiv, Identity)): 641 return f"{expr.func.__name__}({', '.join(map(sympy_str, expr.args))})" 642 return str(expr) 643 644 645def get_bounds_index_expr(index): 646 from .virtualized import V 647 648 # If this expression does not come from an FX node, we compute its bounds 649 if ( 650 config.compute_all_bounds 651 and (fx_node := getattr(V.interpreter, "current_node", None)) 652 and fx_node.target != "index_expr" 653 ): 654 return bound_sympy(index) 655 else: 656 return ValueRanges.unknown() 657 658 659def sympy_index_symbol_with_prefix(prefix: SymT, idx: int) -> sympy.Symbol: 660 """ 661 Used to generate an integer-nonnegative symbol. 662 """ 663 # This should never be used for creating shape/stride symbols, as those 664 # should all be allocated before Inductor. 665 assert prefix != SymT.SIZE 666 # NOTE: shape symbols are positive (> 0), but index variables are only 667 # non-negative (>= 0). 668 return make_symbol(prefix, idx, integer=True, nonnegative=True) 669 670 671def generate_assert(check): 672 return (check or config.debug_index_asserts) and config.assert_indirect_indexing 673 674 675def sympy_index_symbol(name: str) -> sympy.Symbol: 676 """ 677 Used to generate an integer-nonnegative symbol. 678 """ 679 # This should never be used for creating shape/stride symbols, as those 680 # should all be allocated before Inductor. 681 assert name[0] != "s" 682 # NOTE: shape symbols are positive (> 0), but index variables are only 683 # non-negative (>= 0). 684 return sympy.Symbol(name, integer=True, nonnegative=True) 685 686 687def sympy_subs(expr: sympy.Expr, replacements: Dict[sympy.Expr, Any]) -> sympy.Expr: 688 """ 689 When the passed replacement symbol v is a string, it is converted to a symbol with name v that 690 have the same replaced expression integer and nonnegative properties. 691 """ 692 693 def to_symbol(replaced, replacement): 694 assert isinstance(replaced, sympy.Expr) 695 if isinstance(replacement, str): 696 return sympy.Symbol( 697 replacement, 698 integer=replaced.is_integer, # type: ignore[attr-defined] 699 nonnegative=replaced.is_nonnegative, # type: ignore[attr-defined] 700 ) 701 else: 702 return replacement 703 704 # xreplace is faster than subs, but is way more picky 705 return sympy.sympify(expr).xreplace( 706 {k: to_symbol(k, v) for k, v in replacements.items()} 707 ) 708 709 710def is_symbolic(a: Any) -> bool: 711 return isinstance(a, torch.SymInt) or ( 712 isinstance(a, torch.Tensor) 713 and any(is_symbolic(x) for x in itertools.chain(a.size(), a.stride())) 714 ) 715 716 717def any_is_symbolic(*args: Any) -> bool: 718 return any(is_symbolic(a) for a in args) 719 720 721def get_first_incompatible_cudagraph_node(gm): 722 from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols 723 724 forbidden_set = { 725 "aten._fused_moving_avg_obs_fq_helper.default", 726 "aten._fused_moving_avg_obs_fq_helper_functional.default", 727 "aten.multinomial.default", 728 "fbgemm.dense_to_jagged.default", 729 "fbgemm.jagged_to_padded_dense.default", 730 "run_and_save_rng_state", 731 "run_with_rng_state", 732 "aten._local_scalar_dense", 733 # Technically, it's not necessary to ban this, because an 734 # assert_scalar with constant arguments can be validly run 735 # with CUDA graphs, but the operator is also pointless with 736 # constant arguments, so might as well ban 737 "aten._assert_scalar", 738 } 739 if torch.are_deterministic_algorithms_enabled(): 740 forbidden_set.update( 741 { 742 "aten._unsafe_index_put.default", 743 "aten._unsafe_masked_index_put_accumulate.default", 744 "aten.index_put.default", 745 "aten.index_put_.default", 746 "aten.scatter.src", 747 "aten.scatter.reduce", 748 "aten.scatter.value_reduce", 749 "aten.scatter_add_", 750 "aten.scatter_add.default", 751 "aten.scatter_reduce.two", 752 "aten.scatter_reduce_.two", 753 "aten.scatter_reduce.two_out", 754 } 755 ) 756 for node in gm.graph.nodes: 757 if str(node.target) in forbidden_set: 758 return node 759 if (val := node.meta.get("val")) is not None and free_unbacked_symbols(val): 760 return node 761 return None 762 763 764def has_incompatible_cudagraph_ops(gm): 765 return get_first_incompatible_cudagraph_node(gm) is not None 766 767 768def output_node(gm: torch.fx.GraphModule): 769 """Get the output node from an FX graph""" 770 last_node = next(iter(reversed(gm.graph.nodes))) 771 assert last_node.op == "output" 772 return last_node 773 774 775_registered_caches: List[Any] = [] 776 777 778def clear_on_fresh_inductor_cache(obj: Any): 779 """ 780 Use this decorator to register any caches that should be cache_clear'd 781 with fresh_inductor_cache(). 782 """ 783 if not hasattr(obj, "cache_clear") or not callable(obj.cache_clear): 784 raise AttributeError(f"{obj} does not have a cache_clear method") 785 786 _registered_caches.append(obj) 787 return obj 788 789 790def clear_inductor_caches(): 791 """ 792 Clear all registered caches. 793 """ 794 for obj in _registered_caches: 795 obj.cache_clear() 796 797 798@contextlib.contextmanager 799def fresh_inductor_cache(cache_entries=None, dir=None, delete=True): 800 """ 801 Contextmanager that provides a clean tmp cachedir for inductor. 802 803 Optionally, pass a dict as 'cache_entries' to get a list of filenames and sizes 804 generated with this cache instance. 805 """ 806 clear_inductor_caches() 807 808 inductor_cache_dir = tempfile.mkdtemp(dir=dir) 809 try: 810 with mock.patch.dict( 811 os.environ, {"TORCHINDUCTOR_CACHE_DIR": inductor_cache_dir} 812 ): 813 log.debug("Using inductor cache dir %s", inductor_cache_dir) 814 triton_cache_dir = os.path.join(inductor_cache_dir, "triton") 815 with mock.patch.dict(os.environ, {"TRITON_CACHE_DIR": triton_cache_dir}): 816 yield 817 if isinstance(cache_entries, dict): 818 assert len(cache_entries) == 0, "expected empty cache_entries dict" 819 if os.path.exists(triton_cache_dir): 820 files = os.listdir(triton_cache_dir) 821 cache_entries.update( 822 { 823 f: os.path.getsize(os.path.join(triton_cache_dir, f)) 824 for f in files 825 if ".lock" not in f 826 } 827 ) 828 if delete: 829 shutil.rmtree(inductor_cache_dir) 830 except Exception: 831 if not _IS_WINDOWS: 832 """ 833 Windows can't delete the loaded modules, because the modules binaries are opened. 834 TODO: discuss if have better solution to handle this issue. 835 """ 836 log.warning("on error, temporary cache dir kept at %s", inductor_cache_dir) 837 raise 838 finally: 839 clear_inductor_caches() 840 841 842def argsort(seq) -> List[int]: 843 # preserve original order for equal strides 844 getter = seq.__getitem__ 845 a_r = range(len(seq)) 846 return list(reversed(sorted(a_r, key=getter, reverse=True))) # noqa: C413 847 848 849@functools.lru_cache(8) 850def get_dtype_size(dtype): 851 return torch.empty((), dtype=dtype).element_size() 852 853 854class LineContext(NamedTuple): 855 context: Any 856 857 858class IndentedBuffer: 859 tabwidth = 4 860 861 def __init__(self, initial_indent=0): 862 self._lines = [] 863 self._indent = initial_indent 864 865 def getvaluewithlinemap(self) -> tuple[str, list[tuple[int, LineContext]]]: 866 buf = StringIO() 867 p = 1 868 linemap = [] 869 for line in self._lines: 870 if isinstance(line, DeferredLineBase): 871 line = line() 872 if line is None: 873 continue 874 elif isinstance(line, LineContext): 875 linemap.append((p, line.context)) 876 continue 877 assert isinstance(line, str) 878 buf.write(line) 879 buf.write("\n") 880 p += 1 + line.count("\n") 881 return buf.getvalue(), linemap 882 883 def getvalue(self) -> str: 884 v, _ = self.getvaluewithlinemap() 885 return v 886 887 def getrawvalue(self) -> str: 888 buf = StringIO() 889 for line in self._lines: 890 if isinstance(line, DeferredLineBase): 891 line = line() 892 if line is None: 893 continue 894 elif isinstance(line, LineContext): 895 continue 896 assert isinstance(line, str) 897 # backslash implies line continuation 898 if line.endswith("\\"): 899 buf.write(line[:-1]) 900 else: 901 buf.write(line) 902 buf.write("\n") 903 return buf.getvalue() 904 905 def clear(self): 906 self._lines.clear() 907 908 def __bool__(self): 909 return bool(self._lines) 910 911 def prefix(self): 912 return " " * (self._indent * self.tabwidth) 913 914 def newline(self): 915 self.writeline("\n") 916 917 def writeline(self, line): 918 if isinstance(line, LineContext): 919 self._lines.append(line) 920 elif isinstance(line, DeferredLineBase): 921 self._lines.append(line.with_prefix(self.prefix())) 922 elif line.strip(): 923 self._lines.append(f"{self.prefix()}{line}") 924 else: 925 self._lines.append("") 926 927 def writelines(self, lines): 928 for line in lines: 929 self.writeline(line) 930 931 def indent(self, offset=1): 932 @contextlib.contextmanager 933 def ctx(): 934 self._indent += offset 935 try: 936 yield 937 finally: 938 self._indent -= offset 939 940 return ctx() 941 942 def do_indent(self, offset=1): 943 self._indent += offset 944 945 def do_unindent(self, offset=1): 946 self._indent -= offset 947 948 def splice(self, other_code, strip=False): 949 if isinstance(other_code, IndentedBuffer): 950 dedent = float("inf") 951 for line in other_code._lines: 952 if not isinstance(line, LineContext) and line: 953 dedent = min(dedent, len(line) - len(line.lstrip())) 954 if math.isinf(dedent): 955 dedent = 0 956 for line in other_code._lines: 957 if isinstance(line, LineContext): 958 self._lines.append(line) 959 else: 960 IndentedBuffer.writeline(self, line[int(dedent) :]) 961 else: 962 other_code = textwrap.dedent(other_code) 963 if strip: 964 other_code = other_code.lstrip() 965 if not other_code: 966 return 967 other_code = other_code.rstrip() 968 for line in other_code.split("\n"): 969 self.writeline(line) 970 971 def map(self, func: Callable[[Any], Any]) -> IndentedBuffer: 972 res = IndentedBuffer(initial_indent=self._indent) 973 res._lines = [func(line) for line in self._lines] 974 return res 975 976 def __repr__(self): 977 return f"{type(self)}({self.getvalue()})" 978 979 def __add__(self, other): 980 assert self._indent == other._indent 981 res = IndentedBuffer(initial_indent=self._indent) 982 res.writelines(self._lines) 983 res.writelines(other._lines) 984 return res 985 986 987class FakeIndentedBuffer(IndentedBuffer): 988 def __init__(self) -> None: 989 super().__init__() 990 991 def __getattribute__(self, name): 992 if name == "__class__": # Allow access to the class attribute 993 return object.__getattribute__(self, name) 994 raise RuntimeError( 995 f"Tried to call self.{name} on FakeIndentedBuffer. This buffer" 996 "is currently used on TritonTemplateKernel to prevent actual" 997 "writes to the body without explicitly specifying the body with" 998 "`TritonTemplateKernel.set_subgraph_body(name)`" 999 ) 1000 1001 1002@contextlib.contextmanager 1003def restore_stdout_stderr(initial_stdout, initial_stderr): 1004 try: 1005 yield 1006 finally: 1007 sys.stdout = initial_stdout 1008 sys.stderr = initial_stderr 1009 1010 1011class DeferredLineBase: 1012 """A line that can be 'unwritten' at a later time""" 1013 1014 def __init__(self, line): 1015 if not line.strip(): 1016 line = "" 1017 self.line = line 1018 1019 def __call__(self) -> Optional[str]: 1020 """Returns either self.line or None to indicate the line has been 'unwritten'""" 1021 raise NotImplementedError 1022 1023 def _new_line(self, line: str) -> DeferredLineBase: 1024 """Returns a new deferred line with the same condition""" 1025 raise NotImplementedError 1026 1027 def with_prefix(self, prefix): 1028 return self._new_line(f"{prefix}{self.line}") 1029 1030 def lstrip(self): 1031 return self._new_line(self.line.lstrip()) 1032 1033 def __getitem__(self, index): 1034 return self._new_line(self.line[index]) 1035 1036 def __bool__(self): 1037 return bool(self.line) 1038 1039 def __len__(self): 1040 return len(self.line) 1041 1042 1043@functools.lru_cache(None) 1044def is_big_gpu(index) -> bool: 1045 min_sms = 68 # 3080 1046 avail_sms = torch.cuda.get_device_properties(index).multi_processor_count 1047 if avail_sms < min_sms: 1048 log.warning( 1049 "Not enough SMs to use max_autotune_gemm mode", 1050 extra={"min_sms": min_sms, "avail_sms": avail_sms}, 1051 ) 1052 return False 1053 return True 1054 1055 1056def use_max_autotune() -> bool: 1057 return config.max_autotune or config.max_autotune_gemm 1058 1059 1060def _use_template_for_cuda(layout, allowed_layout_dtypes: List[torch.dtype]) -> bool: 1061 return ( 1062 use_max_autotune() 1063 and layout.device.type == "cuda" 1064 and layout.dtype in allowed_layout_dtypes 1065 and is_big_gpu(layout.device.index or 0) 1066 ) 1067 1068 1069def _use_autotune_backend(backend: str) -> bool: 1070 return backend.upper() in [ 1071 x.strip() for x in config.max_autotune_gemm_backends.upper().split(",") 1072 ] 1073 1074 1075def _use_conv_autotune_backend(backend: str) -> bool: 1076 return backend.upper() in [ 1077 x.strip() for x in config.max_autotune_conv_backends.upper().split(",") 1078 ] 1079 1080 1081def use_triton_template(layout, *, enable_int32=False, enable_float8=False): 1082 from .codegen.common import BackendFeature, has_backend_feature 1083 1084 layout_dtypes = [torch.float16, torch.bfloat16, torch.float32] 1085 if enable_int32: 1086 layout_dtypes = [torch.float16, torch.bfloat16, torch.float32, torch.int32] 1087 if enable_float8: 1088 layout_dtypes.extend([torch.float8_e4m3fn, torch.float8_e5m2]) 1089 return ( 1090 _use_template_for_cuda(layout, layout_dtypes) 1091 and _use_autotune_backend("TRITON") 1092 and has_backend_feature(layout.device, BackendFeature.TRITON_TEMPLATES) 1093 ) 1094 1095 1096def use_cutlass_template(layout, m, n, k): 1097 from .virtualized import V 1098 1099 gemm_size = V.graph.sizevars.size_hint(m * n * k, fallback=-1) 1100 if gemm_size <= 0 or gemm_size < config.cuda.cutlass_backend_min_gemm_size: 1101 return False 1102 from .codegen.cuda.cutlass_utils import try_import_cutlass 1103 1104 # Do not use cutlass template on ROCm 1105 if torch.version.hip: 1106 return False 1107 1108 layout_dtypes = [torch.float16, torch.bfloat16, torch.float32, torch.int32] 1109 res = _use_template_for_cuda(layout, layout_dtypes) and _use_autotune_backend( 1110 "CUTLASS" 1111 ) 1112 1113 if res: 1114 if not try_import_cutlass(): 1115 log.warning( 1116 "Failed to import CUTLASS lib. Please check whether " 1117 "_inductor.config.cuda.cutlass_dir is set correctly. " 1118 "Skipping CUTLASS backend for now." 1119 ) 1120 return False 1121 return res 1122 1123 1124@functools.lru_cache(None) 1125def _rocm_native_device_arch_name(device): 1126 return torch.cuda.get_device_properties(device).gcnArchName 1127 1128 1129@functools.lru_cache(None) 1130def try_import_ck_lib(): 1131 try: 1132 import ck4inductor # type: ignore[import] 1133 from ck4inductor.universal_gemm.gen_instances import ( # type: ignore[import] 1134 gen_ops_library, 1135 gen_ops_preselected, 1136 ) 1137 from ck4inductor.universal_gemm.op import ( # type: ignore[import] 1138 CKGemmOperation, 1139 ) 1140 1141 package_dirname = os.path.dirname(ck4inductor.__file__) 1142 except ImportError: 1143 1144 def gen_ops_library(): 1145 return [] 1146 1147 def gen_ops_preselected(): 1148 return [] 1149 1150 class CKGemmOperation: # type: ignore[no-redef] 1151 pass 1152 1153 package_dirname = None 1154 return package_dirname, gen_ops_library, gen_ops_preselected, CKGemmOperation 1155 1156 1157def use_ck_template(layout, m, n, k): 1158 # config knobs check 1 1159 if not use_max_autotune(): 1160 return False 1161 # config knobs check 2 1162 if not _use_autotune_backend("CK"): 1163 return False 1164 # platform check 1165 if not torch.version.hip: 1166 return False 1167 # tensors must be on GPU 1168 if not layout.device.type == "cuda": 1169 return False 1170 # hardware check 1171 # if config arch list is not specified, get the native arch from the device properties 1172 native_arch = _rocm_native_device_arch_name(layout.device) 1173 requested_archs = {k.split(":")[0]: k for k in config.rocm.arch} or { 1174 native_arch.split(":")[0]: native_arch 1175 } 1176 requested_supported_archs = [ 1177 requested_archs[k] 1178 for k in requested_archs.keys() & config.rocm.ck_supported_arch 1179 ] 1180 if not requested_supported_archs: 1181 return False 1182 # supported input dtypes 1183 if layout.dtype not in [torch.float16, torch.bfloat16]: 1184 return False 1185 # TBD: investigate if we need to disable backend based on number of available CUs similar to `is_big_gpu` 1186 # check if shape is static and gemm size is not 0 1187 from .virtualized import V 1188 1189 gemm_size = V.graph.sizevars.size_hint(m * n * k, fallback=-1) 1190 if gemm_size <= 0: 1191 return False 1192 # TBD: investigate if backend needs to be disabled for small gemms similar to CUTLASS 1193 1194 ck_package_dirname, _, _, _ = try_import_ck_lib() 1195 1196 if not ck_package_dirname: 1197 log.warning("Please pip install Composable Kernel package") 1198 return False 1199 1200 if not config.rocm.ck_dir: 1201 log.warning("Please set TORCHINDUCTOR_CK_DIR env variable") 1202 return False 1203 1204 if ck_package_dirname != config.rocm.ck_dir: 1205 log.warning("Invalid path to CK library") 1206 return False 1207 1208 return True 1209 1210 1211def _use_template_for_cpu(layout): 1212 return use_max_autotune() and layout.device.type == "cpu" 1213 1214 1215def use_cpp_packed_gemm_template(layout, mat1, mat2, mat2_transposed=False): 1216 from . import ir 1217 from .codegen.cpp_micro_gemm import create_micro_gemm 1218 from .codegen.cpp_utils import get_gemm_template_output_and_compute_dtype 1219 from .kernel.mm_common import mm_args 1220 1221 if not _use_template_for_cpu(layout) or not _use_autotune_backend("CPP"): 1222 return False 1223 1224 if not config.cpp.weight_prepack: 1225 return False 1226 1227 int8_gemm = mat1.get_dtype() == torch.uint8 1228 layout_dtypes = [torch.float32, torch.bfloat16, torch.half, torch.uint8] 1229 m, n, k, layout, mat1, mat2 = mm_args( 1230 mat1, 1231 mat2, 1232 out_dtype=layout.dtype if int8_gemm else None, 1233 mat2_transposed=mat2_transposed, 1234 ) 1235 1236 # TODO(jgong5): support dynamic shapes for n or k 1237 if has_free_symbols((n, k)): 1238 return False 1239 if isinstance(mat2, ir.BaseView): 1240 mat2 = mat2.unwrap_view() 1241 1242 output_dtype, _ = get_gemm_template_output_and_compute_dtype(mat1.get_dtype()) 1243 micro_gemm = create_micro_gemm( 1244 "micro_gemm", 1245 m, 1246 n, 1247 k, 1248 input_dtype=mat1.get_dtype(), 1249 input2_dtype=mat2.get_dtype(), 1250 output_dtype=output_dtype, 1251 num_threads=parallel_num_threads(), 1252 ) 1253 1254 def is_last_dim_stride1(x): 1255 x.freeze_layout() 1256 return x.get_stride()[-1] == 1 1257 1258 return ( 1259 layout.dtype in layout_dtypes 1260 and micro_gemm is not None 1261 and is_last_dim_stride1(mat1) # TODO(jgong5): support transposed input 1262 and isinstance(mat2, ir.StorageBox) 1263 and mat2.is_module_buffer() 1264 ) 1265 1266 1267def use_aten_gemm_kernels(): 1268 return not use_max_autotune() or _use_autotune_backend("ATEN") 1269 1270 1271class DebugDirManager: 1272 counter = itertools.count(0) 1273 prev_debug_name: str 1274 1275 def __init__(self) -> None: 1276 self.id = next(DebugDirManager.counter) 1277 1278 def __enter__(self): 1279 self.prev_debug_name = torch._dynamo.config.debug_dir_root 1280 self.new_name = f"{self.prev_debug_name}_tmp_{self.id}" 1281 torch._dynamo.config.debug_dir_root = self.new_name 1282 1283 def __exit__(self, *args): 1284 shutil.rmtree(self.new_name) 1285 torch._dynamo.config.debug_dir_root = self.prev_debug_name 1286 1287 1288def run_and_get_code(fn, *args, **kwargs): 1289 from .graph import GraphLowering 1290 1291 source_codes: List[str] = [] 1292 1293 def save_output_code(code: str): 1294 source_codes.append(code) 1295 1296 with mock.patch.object(GraphLowering, "save_output_code", save_output_code): 1297 torch._dynamo.reset() 1298 result = fn(*args, **kwargs) 1299 return result, source_codes 1300 1301 1302def run_fw_bw_and_get_code(fn): 1303 def run_with_backward(): 1304 result = fn() 1305 result.sum().backward() 1306 return result 1307 1308 return run_and_get_code(run_with_backward) 1309 1310 1311def get_code(fn, *args, **kwargs): 1312 """Get the inductor-generated code, but skip any actual compilation or running.""" 1313 from .graph import GraphLowering 1314 1315 source_codes: List[str] = [] 1316 1317 def save_output_code(code: str): 1318 source_codes.append(code) 1319 1320 def patched_compile_to_module(self: GraphLowering): 1321 class DummyModule: 1322 """This is empty to replace the generated triton module""" 1323 1324 def __init__(self) -> None: 1325 pass 1326 1327 def call(self, *args, **kwargs): 1328 # Don't do anything when called 1329 pass 1330 1331 code, _ = ( 1332 self.codegen_with_cpp_wrapper() if self.cpp_wrapper else self.codegen() 1333 ) 1334 # Skip all the actual compiling. 1335 nonlocal save_output_code 1336 save_output_code(code) 1337 1338 return DummyModule() 1339 1340 with mock.patch.object( 1341 GraphLowering, "compile_to_module", patched_compile_to_module 1342 ), mock.patch.object(GraphLowering, "save_output_code", save_output_code): 1343 torch._dynamo.reset() 1344 # Note the return here is None 1345 _ = fn(*args, **kwargs) 1346 1347 return source_codes 1348 1349 1350def get_triton_code(fn, *args, **kwargs): 1351 source_codes = get_code(fn, *args, **kwargs) 1352 # Can have two outputs if backwards was eagerly compiled 1353 assert ( 1354 1 <= len(source_codes) <= 2 1355 ), f"expected one or two code outputs got {len(source_codes)}" 1356 return source_codes[0] 1357 1358 1359def run_and_get_triton_code(fn, *args, **kwargs): 1360 _, source_codes = run_and_get_code(fn, *args, **kwargs) 1361 # Can have two outputs if backwards was eagerly compiled 1362 assert ( 1363 1 <= len(source_codes) <= 2 1364 ), f"expected one or two code outputs got {len(source_codes)}" 1365 return source_codes[0] 1366 1367 1368def run_and_get_graph_lowering(fn, *args, **kwargs): 1369 from torch._inductor.codecache import CompiledFxGraph 1370 from torch._inductor.graph import GraphLowering 1371 1372 real_init = CompiledFxGraph.__init__ 1373 graph_lowerings = [] 1374 1375 def fake_init(*args, **kwargs): 1376 real_init(*args, **kwargs) 1377 graph = args[2] 1378 assert isinstance(graph, GraphLowering) 1379 graph_lowerings.append(graph) 1380 1381 with mock.patch.object(CompiledFxGraph, "__init__", fake_init): 1382 result = fn(*args, **kwargs) 1383 1384 return result, graph_lowerings 1385 1386 1387@contextlib.contextmanager 1388def override_lowering(aten_op, override_fn): 1389 """ 1390 Override the lowering of aten_op with override_fn. 1391 The first argument of override_fn is the original lowering fn. 1392 """ 1393 from torch._inductor import lowering 1394 1395 orig_fn = lowering.lowerings[aten_op] 1396 try: 1397 lowering.lowerings[aten_op] = functools.partial(override_fn, orig_fn) 1398 yield 1399 finally: 1400 lowering.lowerings[aten_op] = orig_fn 1401 1402 1403def add_scheduler_init_hook(pre_fn, post_fn=None): 1404 """ 1405 Add hook functions to be called at the beginning and end of Scheduler.__init__. 1406 Used for unit tests. 1407 """ 1408 from torch._inductor.scheduler import Scheduler 1409 1410 orig_fn = Scheduler.__init__ 1411 1412 def wrapper(scheduler, nodes): 1413 pre_fn(scheduler, nodes) 1414 out = orig_fn(scheduler, nodes) 1415 if post_fn: 1416 post_fn(scheduler, nodes) 1417 return out 1418 1419 return unittest.mock.patch.object(Scheduler, "__init__", wrapper) 1420 1421 1422def developer_warning(msg): 1423 """ 1424 Warnings that will be actionable for PyTorch developers, but not 1425 end users. Allows us to easily disable them in stable releases but 1426 keep them on for nightly builds. 1427 """ 1428 if config.developer_warnings: 1429 log.warning(msg) 1430 else: 1431 log.info(msg) 1432 1433 1434def get_benchmark_name(): 1435 """ 1436 An experimental API used only when config.benchmark_kernel is true. 1437 1438 The benchmark name is only available at codegen time. So we can not 1439 directly call it in benchmark_all_kernels which is run after codegen. 1440 1441 The function assumes the argument after --only is the benchmark name. 1442 It works for torchbench.py/hugginface.py/timm_models.py. But for ad-hoc 1443 scripts, this function may return None. 1444 1445 There are 2 flavors of --only argument we need handle: 1446 1. --only model_name 1447 2. --only=model_name 1448 """ 1449 try: 1450 idx = sys.argv.index("--only") 1451 if ( 1452 idx + 1 < len(sys.argv) 1453 and len(sys.argv[idx + 1]) > 0 1454 and sys.argv[idx + 1][0] != "-" 1455 ): 1456 return sys.argv[idx + 1] 1457 except ValueError: 1458 pass 1459 1460 for arg in sys.argv: 1461 if arg.startswith("--only="): 1462 return arg[len("--only=") :] 1463 1464 1465def is_ones(items): 1466 return all(x == 1 for x in items) 1467 1468 1469def is_zeros(items): 1470 return all(x == 0 for x in items) 1471 1472 1473def is_cpu_device(inputs): 1474 return all( 1475 item.device == torch.device("cpu") 1476 for item in inputs 1477 if isinstance(item, torch.Tensor) 1478 ) 1479 1480 1481def get_sympy_Expr_dtype(val: sympy.Expr) -> torch.dtype: 1482 assert isinstance( 1483 val, sympy.Expr 1484 ), "only support sympy.Expr as input to get_sympy_Expr_dtype" 1485 if val.is_integer: # type: ignore[attr-defined] 1486 return torch.int64 1487 else: 1488 return torch.float64 1489 1490 1491@contextlib.contextmanager 1492def maybe_profile(should_profile, *args, **kwargs): 1493 if should_profile: 1494 with torch.profiler.profile(*args, **kwargs) as p: 1495 yield p 1496 else: 1497 yield 1498 1499 1500def parallel_num_threads(): 1501 threads = config.cpp.threads 1502 if threads < 1: 1503 threads = torch.get_num_threads() 1504 return threads 1505 1506 1507@functools.lru_cache(None) 1508def get_device_tflops(dtype): 1509 from triton.testing import get_max_simd_tflops, get_max_tensorcore_tflops 1510 1511 assert dtype in (torch.float16, torch.bfloat16, torch.float32) 1512 1513 if inspect.signature(get_max_simd_tflops).parameters.get("clock_rate"): 1514 # Triton API change in https://github.com/openai/triton/pull/2293 1515 from torch._utils_internal import max_clock_rate 1516 1517 sm_clock = max_clock_rate() 1518 if dtype in (torch.float16, torch.bfloat16): 1519 return get_max_tensorcore_tflops(dtype, sm_clock) 1520 1521 if torch.backends.cuda.matmul.allow_tf32: 1522 return get_max_tensorcore_tflops(torch.float32, sm_clock) 1523 else: 1524 return get_max_simd_tflops(torch.float32, sm_clock) 1525 else: 1526 if dtype in (torch.float16, torch.bfloat16): 1527 return get_max_tensorcore_tflops(dtype) 1528 1529 if torch.backends.cuda.matmul.allow_tf32: 1530 return get_max_tensorcore_tflops(torch.float32) 1531 else: 1532 return get_max_simd_tflops(torch.float32) 1533 1534 1535@functools.lru_cache(None) 1536def get_gpu_dram_gbps(): 1537 from triton.testing import get_dram_gbps 1538 1539 return get_dram_gbps() 1540 1541 1542def get_gpu_shared_memory(): 1543 from triton.runtime import driver 1544 1545 return driver.active.utils.get_device_properties(0).get("max_shared_mem", 0) 1546 1547 1548def is_welford_reduction(reduction_type): 1549 return reduction_type.startswith("welford") 1550 1551 1552def reduction_num_outputs(reduction_type): 1553 return 3 if is_welford_reduction(reduction_type) else 1 1554 1555 1556def is_linux() -> bool: 1557 return platform.system() == "Linux" 1558 1559 1560def is_windows(): 1561 return sys.platform == "win32" 1562 1563 1564def has_free_symbols(itr: Iterable[Any]): 1565 return any(isinstance(x, sympy.Expr) and not x.is_number for x in itr) 1566 1567 1568def is_dynamic(*args): 1569 from . import ir 1570 1571 for t in args: 1572 if isinstance(t, ir.TensorBox): 1573 if has_free_symbols(t.data.get_size()) or ( 1574 hasattr(t.data, "get_stride") and has_free_symbols(t.data.get_stride()) 1575 ): 1576 return True 1577 elif isinstance(t, (ir.StorageBox, ir.BaseView, ir.ComputedBuffer)): 1578 assert hasattr(t, "get_size") and hasattr(t, "get_stride") 1579 if has_free_symbols(t.get_size()) or has_free_symbols(t.get_stride()): 1580 return True 1581 elif not isinstance(t, ir.IRNode): 1582 continue 1583 else: 1584 raise TypeError(f"unexpected type for is_dynamic {type(t)}") 1585 1586 return False 1587 1588 1589# Placeholder strings used in triton codegen. 1590class Placeholder(enum.Enum): 1591 # The placeholder for the actual name of a triton kernel. 1592 # e.g. for "def triton_" it would be "triton_" 1593 KERNEL_NAME = "KERNEL_NAME" 1594 1595 # The descriptive name of the triton kernel; when unique_kernel_names = False, this 1596 # placeholder will be replaced with a string with more information. 1597 DESCRIPTIVE_NAME = "DESCRIPTIVE_NAME" 1598 1599 1600def pass_execution_and_save(func, gm, inp, msg): 1601 from .pattern_matcher import stable_topological_sort 1602 1603 with tempfile.NamedTemporaryFile( 1604 mode="w", 1605 encoding="utf-8", 1606 delete=False, 1607 ) as f: 1608 before_io = io.StringIO() 1609 after_io = io.StringIO() 1610 ShapeProp(gm=gm, fake_mode=detect_fake_mode(inp)).propagate(*inp) 1611 print(f"Before:\n{gm.graph}", file=f) 1612 print(gm.graph, file=before_io) 1613 start_time = datetime.now() 1614 with GraphTransformObserver(gm, msg, config.trace.log_url_for_graph_xform): 1615 func(gm.graph) 1616 time_elapsed = datetime.now() - start_time 1617 # recompile graph 1618 stable_topological_sort(gm.graph) 1619 gm.graph.lint() 1620 gm.recompile() 1621 1622 print(f"After:\n{gm.graph}", file=f) 1623 print(gm.graph, file=after_io) 1624 t = before_io.getvalue() == after_io.getvalue() 1625 log.info( 1626 "%s, save before/after graph to %s, graph before/after are the same = %s, time elapsed = %s", 1627 msg, 1628 f.name, 1629 t, 1630 time_elapsed, 1631 ) 1632 1633 1634def is_collective(node, op=None): 1635 from . import ir 1636 1637 return type(node) == ir._CollectiveKernel and (op is None or node.op_overload is op) 1638 1639 1640def is_wait(node): 1641 from . import ir 1642 1643 return type(node) == ir._WaitKernel 1644 1645 1646def contains_collective(snode): 1647 from torch._inductor.scheduler import BaseSchedulerNode, GroupedSchedulerNode 1648 1649 assert isinstance(snode, BaseSchedulerNode) 1650 if isinstance(snode, GroupedSchedulerNode): 1651 return any(contains_collective(x) for x in snode.snodes) 1652 else: 1653 return is_collective(snode.node) 1654 1655 1656def contains_wait(snode): 1657 from torch._inductor.scheduler import BaseSchedulerNode, GroupedSchedulerNode 1658 1659 assert isinstance(snode, BaseSchedulerNode) 1660 if isinstance(snode, GroupedSchedulerNode): 1661 return any(contains_wait(x) for x in snode.snodes) 1662 else: 1663 return is_wait(snode.node) 1664 1665 1666def is_fallback_op(node, op): 1667 from . import ir 1668 1669 if isinstance(op, torch._ops.OpOverload): 1670 op = {op} 1671 return isinstance(node, ir.FallbackKernel) and node.op_overload in op 1672 1673 1674def buf_name_to_fused_snode(buf_name, name_to_buf, name_to_fused_node): 1675 return name_to_fused_node[name_to_buf[buf_name].defining_op.get_name()] 1676 1677 1678def find_recursive_deps_of_node( 1679 snode, collected_node_set, name_to_buf, name_to_fused_node, criteria_cb=None 1680): 1681 if criteria_cb and criteria_cb(snode): 1682 return 1683 collected_node_set.add(snode) 1684 for dep in snode.unmet_dependencies: 1685 defining_op_for_dep = buf_name_to_fused_snode( 1686 dep.name, name_to_buf, name_to_fused_node 1687 ) 1688 if defining_op_for_dep in collected_node_set: 1689 continue 1690 find_recursive_deps_of_node( 1691 defining_op_for_dep, 1692 collected_node_set, 1693 name_to_buf, 1694 name_to_fused_node, 1695 criteria_cb=criteria_cb, 1696 ) 1697 1698 1699def find_recursive_users_of_node( 1700 snode, collected_node_set, name_to_buf, name_to_fused_node, criteria_cb=None 1701): 1702 if criteria_cb and criteria_cb(snode): 1703 return 1704 collected_node_set.add(snode) 1705 for o in snode.get_outputs(): 1706 for user in o.users: 1707 assert user.node is not None 1708 if user.node.get_name() == "OUTPUT": 1709 continue 1710 if user.node.get_name() not in name_to_fused_node: 1711 continue 1712 user_op = name_to_fused_node[user.node.get_name()] 1713 if user_op in collected_node_set: 1714 continue 1715 find_recursive_users_of_node( 1716 user_op, 1717 collected_node_set, 1718 name_to_buf, 1719 name_to_fused_node, 1720 criteria_cb=criteria_cb, 1721 ) 1722 1723 1724def num_fw_fixed_arguments(dynamo_gm_num_inputs: int, aot_fw_gm_num_inputs: int): 1725 "Computes the number of inputs to the aot fw graph which have fixed addresses (params and buffers)" 1726 num_rng_seed_offset_inputs = ( 1727 2 if torch._functorch.config.functionalize_rng_ops else 0 1728 ) 1729 # AOT won't lift any parameters if we're inlining NN Modules 1730 # however desugaring subclasses will still add arguments 1731 # resulted in extra fixed inputs https://github.com/pytorch/pytorch/issues/130502 1732 if ( 1733 torch._dynamo.config.inline_inbuilt_nn_modules 1734 and not torch._dynamo.utils.is_parameter_freezing() 1735 ): 1736 return 0 1737 1738 return aot_fw_gm_num_inputs - dynamo_gm_num_inputs - num_rng_seed_offset_inputs 1739 1740 1741def count_tangents(fx_g: torch.fx.GraphModule): 1742 """ 1743 Infers which inputs are static for a backwards graph 1744 """ 1745 1746 def is_saved_tensor(x): 1747 return ( 1748 "tangents" not in x.name 1749 and "bwd_seed" not in x.name 1750 and "bwd_base_offset" not in x.name 1751 ) 1752 1753 arg_count = 0 1754 static_arg_idxs = [] 1755 for n in fx_g.graph.nodes: 1756 if n.op == "placeholder": 1757 if is_saved_tensor(n): 1758 static_arg_idxs.append(arg_count) 1759 arg_count += 1 1760 1761 assert static_arg_idxs == list(range(len(static_arg_idxs))) 1762 return len(static_arg_idxs) 1763 1764 1765@dataclasses.dataclass 1766class BoxedBool: 1767 value: bool 1768 1769 def __bool__(self): 1770 return self.value 1771 1772 @staticmethod 1773 def disable(obj): 1774 if isinstance(obj, BoxedBool): 1775 obj.value = False 1776 return obj 1777 return False 1778 1779 1780@contextlib.contextmanager 1781def collect_defined_kernels(kernel_list): 1782 from .codegen.wrapper import WrapperCodeGen 1783 1784 orig_define_kernel = WrapperCodeGen.define_kernel 1785 1786 def new_define_kernel(wrapper, name, kernel_code, metadata, *args, **kwargs): 1787 nonlocal kernel_list 1788 kernel_list.append(kernel_code) 1789 return orig_define_kernel(wrapper, name, kernel_code, metadata, *args, **kwargs) 1790 1791 with unittest.mock.patch.object(WrapperCodeGen, "define_kernel", new_define_kernel): 1792 yield 1793 1794 1795def get_cloned_parameter_buffer_name(name: str): 1796 return name + "__original__" 1797 1798 1799def is_gpu(device: str): 1800 assert isinstance(device, str) or device is None, device 1801 return device in ["cuda", "xpu"] 1802 1803 1804def device_need_guard(device: str): 1805 assert isinstance(device, str) 1806 return is_gpu(device) 1807 1808 1809def needs_fallback_due_to_atomic_add_limitations(dtype): 1810 # tl.atomic_add does NOT support the following types 1811 return dtype in {torch.int64, torch.bool, torch.bfloat16} 1812 1813 1814def use_scatter_fallback( 1815 op_overload: torch._ops.OpOverload, 1816 reduction_type, 1817 self_dtype, 1818 src_dtype, 1819 src_device_type, 1820 src_is_tensor, 1821): 1822 if ( 1823 op_overload.overloadpacket 1824 in (torch.ops.aten.scatter_reduce_, torch.ops.aten.scatter_reduce) 1825 and reduction_type is None 1826 ): 1827 return False 1828 1829 reduce_ty = ( 1830 "add" if op_overload.overloadpacket == torch.ops.aten.scatter_ else "sum" 1831 ) 1832 1833 return ( 1834 reduction_type not in {None, reduce_ty} 1835 or ( 1836 src_is_tensor 1837 and is_gpu(src_device_type) 1838 and needs_fallback_due_to_atomic_add_limitations(src_dtype) 1839 ) 1840 or ( 1841 op_overload.overloadpacket == torch.ops.aten.scatter_reduce_ 1842 and reduction_type == "sum" 1843 and src_is_tensor 1844 and src_device_type == "cpu" 1845 and config.cpp.fallback_scatter_reduce_sum 1846 and (config.cpp.dynamic_threads or parallel_num_threads() != 1) 1847 ) 1848 or (reduction_type == reduce_ty and self_dtype in {torch.bool, torch.int64}) 1849 or torch.are_deterministic_algorithms_enabled() 1850 ) 1851 1852 1853def dump_node_schedule(node_schedule): 1854 """ 1855 An API that can be used in pdb to dump a node_schedule. 1856 Right mainly dump the read/write dependencies but can add more as needed. 1857 """ 1858 from torch._inductor.codegen.simd import DisableReduction, EnableReduction 1859 from torch._inductor.scheduler import SchedulerNode 1860 1861 print(f"Node schedule with {len(node_schedule)} nodes") 1862 for idx, node in enumerate(node_schedule): 1863 print(f" {idx:3}:") 1864 if node is EnableReduction: 1865 print("enable reduction") 1866 elif node is DisableReduction: 1867 print("disable reduction") 1868 elif isinstance(node, SchedulerNode): 1869 is_red = node.is_reduction() 1870 print(f"{'red' if is_red else 'pw'} scheduler node") 1871 if is_red: 1872 assert node.node is not None 1873 print(f"original reduction hint {node.node.data.reduction_hint}") # type: ignore[attr-defined] 1874 print("ReadDep:") 1875 for dep in node.read_writes.reads: 1876 print(dep) 1877 print("WriteDep:") 1878 for dep in node.read_writes.writes: 1879 print(dep) 1880 else: 1881 raise RuntimeError(f"Unrecognized node type: {type(node)}") 1882 1883 1884def tensor_is_aligned(tensor: torch.Tensor): 1885 # See Note: [Input Alignment handling in Inductor] 1886 # Right now, we don't try to guard on the alignment of the storage offset. 1887 # When this comment was written, non-symbolic storage_offsets are not guarded on 1888 # but symbolic storage_offsets are. For consistency, we suppress guard creation 1889 # upon performing this check: that ensures that we don't add recompiles when we 1890 # add this logic. 1891 from torch.fx.experimental.symbolic_shapes import statically_known_true 1892 1893 return statically_known_true( 1894 (tensor.storage_offset() * get_dtype_size(tensor.dtype)) % GPU_ALIGN_BYTES == 0 1895 ) 1896 1897 1898def should_assume_input_aligned(example_input: torch.Tensor): 1899 # See Note: [Input Alignment handling in Inductor] 1900 1901 # right now, we only care about alignment for cuda tensors. 1902 if not is_gpu(example_input.device.type): 1903 return False 1904 return config.assume_aligned_inputs or tensor_is_aligned(example_input) 1905 1906 1907def maybe_get_suppress_shape_guards_ctx(): 1908 # Try to get TracingContext.try_get().fake_mode.shape_env.suppress_guards() 1909 # If it's not available, return a nullcontext. 1910 1911 # If we're dealing with cudagraphs, we might not have a tracing_context 1912 tracing_context = torch._guards.TracingContext.try_get() 1913 if not tracing_context: 1914 return contextlib.nullcontext() 1915 1916 # In standalone inductor compile mode, we might not have a shape_env attached to the fake mode 1917 shape_env = tracing_context.fake_mode.shape_env 1918 if not shape_env: 1919 return contextlib.nullcontext() 1920 1921 return shape_env.suppress_guards() 1922 1923 1924def run_and_get_cpp_code(fn, *args, **kwargs): 1925 # We use the patch context manager instead of using it as a decorator. 1926 # In this way, we can ensure that the attribute is patched and unpatched correctly 1927 # even if this run_and_get_cpp_code function is called multiple times. 1928 with unittest.mock.patch.object(config, "debug", True): 1929 torch._dynamo.reset() 1930 import io 1931 import logging 1932 1933 log_capture_string = io.StringIO() 1934 ch = logging.StreamHandler(log_capture_string) 1935 from torch._inductor.codecache import output_code_log 1936 1937 output_code_log.addHandler(ch) 1938 prev_level = output_code_log.level 1939 output_code_log.setLevel(logging.DEBUG) 1940 result = fn(*args, **kwargs) 1941 s = log_capture_string.getvalue() 1942 output_code_log.setLevel(prev_level) 1943 output_code_log.removeHandler(ch) 1944 return result, s 1945 1946 1947def shape_env_from_inputs(inputs: List[torch.Tensor]): 1948 shape_env = None 1949 fake_mode = detect_fake_mode(inputs) 1950 1951 # TODO(voz): It would be nice to enable this assert, but there are lots of tests that 1952 # pass in real inputs for now. 1953 # if len(inputs) > 0: 1954 # assert fake_mode is not None, breakpoint() 1955 1956 if fake_mode is not None: 1957 return fake_mode.shape_env 1958 1959 # When there are no tensor inputs, get shape_env from the first SymInt. 1960 for input in inputs: 1961 if isinstance(input, torch.SymInt): 1962 return input.node.shape_env 1963 1964 # TODO(voz): Should we always have one anyway? 1965 return None 1966 1967 1968def align_inputs_from_check_idxs( 1969 model: Callable[[List[InputType]], Any], 1970 inputs_to_check: Sequence[int], 1971) -> Callable[[List[InputType]], Any]: 1972 if len(inputs_to_check) == 0: 1973 return model 1974 1975 def run(new_inputs: List[InputType]): 1976 copy_misaligned_inputs(new_inputs, inputs_to_check) 1977 return model(new_inputs) 1978 1979 return run 1980 1981 1982def clone_preserve_strides(x: torch.Tensor): 1983 needed_size = ( 1984 sum((shape - 1) * stride for shape, stride in zip(x.size(), x.stride())) + 1 1985 ) 1986 buffer = torch.as_strided(x, (needed_size,), (1,)).clone() 1987 return torch.as_strided(buffer, x.size(), x.stride()) 1988 1989 1990def copy_misaligned_inputs( 1991 new_inputs: List[InputType], check_inputs_idxs: Sequence[int] 1992) -> None: 1993 for i in check_inputs_idxs: 1994 _inp = new_inputs[i] 1995 assert isinstance(_inp, torch.Tensor) 1996 if _inp.data_ptr() % ALIGNMENT: 1997 new_inputs[i] = clone_preserve_strides(_inp) 1998 1999 2000def remove_unaligned_input_idxs( 2001 inputs: List[InputType], 2002 static_input_idxs: Sequence[int], 2003): 2004 """ 2005 We require all inputs to be aligned, so introduce a copy for any 2006 that aren't. 2007 """ 2008 aligned_static_input_idxs = [] 2009 for idx in static_input_idxs: 2010 input = inputs[idx] 2011 if isinstance(input, torch.Tensor) and (input.data_ptr() % ALIGNMENT) == 0: 2012 aligned_static_input_idxs.append(idx) 2013 if len(aligned_static_input_idxs) != len(static_input_idxs): 2014 return aligned_static_input_idxs 2015 return static_input_idxs 2016 2017 2018def set_tracing_context_output_strides(example_inputs, compiled_graph): 2019 # Return the output strides to the caller via TracingContext 2020 context = torch._guards.TracingContext.try_get() 2021 if context is not None and context.output_strides is not None: 2022 assert len(context.output_strides) == 0 2023 shape_env = shape_env_from_inputs(example_inputs) 2024 for exprs in compiled_graph.output_strides: 2025 if exprs is None: 2026 context.output_strides.append(None) 2027 else: 2028 context.output_strides.append( 2029 tuple( 2030 ( 2031 shape_env.evaluate_symexpr(e) 2032 if shape_env is not None 2033 else int(e) 2034 ) 2035 for e in exprs 2036 ) 2037 ) 2038