1# mypy: allow-untyped-defs 2import abc 3import dataclasses 4import itertools 5import logging 6import re 7import typing 8from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union 9from unittest.mock import patch 10 11import sympy 12 13import torch 14from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols 15from torch.utils._ordered_set import OrderedSet 16 17from .codegen.common import index_prevent_reordering 18from .utils import ( 19 get_dtype_size, 20 reduction_num_outputs, 21 sympy_index_symbol, 22 sympy_str, 23 sympy_subs, 24 VarRanges, 25) 26from .virtualized import OpsHandler, ReductionType, V 27 28 29log = logging.getLogger(__name__) 30is_indirect = re.compile(r"indirect|tmp").search 31 32 33class Dep(abc.ABC): 34 name: str 35 index: sympy.Expr 36 37 @abc.abstractmethod 38 def rename(self, renames: Dict[str, str]) -> "Dep": 39 pass 40 41 @abc.abstractmethod 42 def get_numel(self) -> sympy.Expr: 43 pass 44 45 @abc.abstractmethod 46 def numbytes_hint(self): 47 pass 48 49 @abc.abstractmethod 50 def has_unbacked_symbols(self) -> bool: 51 pass 52 53 @abc.abstractmethod 54 def is_contiguous(self) -> bool: 55 pass 56 57 def normalize_with_stride_order(self, prefix="t"): 58 return self 59 60 61@dataclasses.dataclass(frozen=True) 62class MemoryDep(Dep): 63 name: str 64 index: sympy.Expr 65 var_names: Tuple[sympy.Symbol, ...] 66 size: Tuple[sympy.Expr, ...] 67 mode: Optional[str] = None 68 69 def __repr__(self) -> str: 70 return f"MemoryDep({self.name!r}, {self.index}, {self.ranges}, {self.mode})" 71 72 @property 73 def num_vars(self): 74 return len(self.var_names) 75 76 def decide_loop_order_to_match(self, other): 77 """ 78 Can return None if not able to decide loop orders. 79 """ 80 assert self.num_vars == other.num_vars 81 82 # ignore broadcast for now since broadcast causes extra 0 strides 83 # which makes it hard to decide the correct loop orders. 84 if self.num_vars != len(self.index.free_symbols): 85 return None 86 if other.num_vars != len(other.index.free_symbols): 87 return None 88 89 # bail out if any size is 0 or 1 90 # For size == 0, it's an empty tensor, any strides for that dimension 91 # are equivalent. Skip for simplicity and it may not matter that much. 92 # 93 # For size == 1, it cause cause tie for strides of different dimensions. 94 # Also when we first time create LoopBody in ComputedBuffer.simplify_and_reorder 95 # we can dependencies.index_vars_squeeze which should already sqeeuze 96 # the size == 1 dimensions. 97 if any(s == 0 or s == 1 for s in itertools.chain(self.size, other.size)): 98 return None 99 100 # Extract strides for both expression 101 self_strides = V.graph.sizevars.stride_hints(self.index, self.var_names) 102 other_strides = V.graph.sizevars.stride_hints(other.index, other.var_names) 103 104 # Even if the shape contains no 0/1, some complex index expression may 105 # still have duplicate stride values. Here is an example: 106 # https://gist.github.com/shunting314/511a7e1ec88aa2e1a8ec85d8445ab129 107 # We don't reorder the loop for these cases for now, but in theory 108 # we could improve the algorithm to detect the correct loop orders. 109 if len(set(self_strides)) != len(self_strides) or len( 110 set(other_strides) 111 ) != len(other_strides): 112 log.debug( 113 "unable to decide loop order. self_dep=%s v.s. other_dep=%s, self_strides=%s v.s. other_strides=%s", 114 self, 115 other, 116 self_strides, 117 other_strides, 118 ) 119 return None 120 121 # May hanppen if self and other are as follows 122 # MemoryDep('addmm_6', 393216*d0 + 768*d1 + d2, {d0: 16, d1: 512, d2: 768}, None) 123 # MemoryDep('addmm_6', 98304*d0 + d1 + 768*d2, {d0: 64, d1: 768, d2: 128}, None) 124 if set(self_strides) != set(other_strides): 125 return None 126 127 stride_to_index = {s: i for i, s in enumerate(self_strides)} 128 order = [] 129 for s in other_strides: 130 order.append(stride_to_index[s]) 131 132 assert set(order) == set(range(0, self.num_vars)) 133 return order 134 135 def get_offset(self): 136 """ 137 Return the offset by setting every variable to be 0. 138 """ 139 return sympy_subs(self.index, dict.fromkeys(self.var_names, 0)) 140 141 def normalize(self) -> "MemoryDep": 142 """ 143 Normalize by merging loops. The different to normalize_with_stride_order is, 144 this method does not reorder loops while normalize_with_stride_order reorder 145 loops based on stride order. 146 """ 147 return MemoryDep( 148 self.name, 149 *_RecordLoadStoreInner._normalize(self.index, self.ranges), # type: ignore[arg-type] 150 self.mode, 151 ) 152 153 def normalize_with_stride_order(self, prefix="t"): 154 r""" 155 Used to decide if two MemoryDep does not equal due to different loop orders. 156 More specifically, when dep1 and dep2 are not equal, we can normalize 157 both and check if they are equal after that. If yes, then the mismatch is 158 caused by different loop orders. 159 """ 160 # import here to avoid circular import 161 from torch._inductor import ir 162 163 strides = V.graph.sizevars.stride_hints(self.index, self.var_names) 164 165 # pick a loop order with stride ordered decreasingly 166 order = sorted(range(len(strides)), key=strides.__getitem__, reverse=True) 167 stride_reorder = ir.same_reorder(order) 168 sizes = self.size 169 var_names = self.var_names 170 171 new_reordered_sizes = stride_reorder(sizes) 172 new_reordered_var_names = stride_reorder(var_names) 173 174 new_simplified_sizes, reindex, prune = V.graph.sizevars._simplify_loops( 175 new_reordered_var_names, 176 new_reordered_sizes, 177 index_prevent_reordering( 178 [self.index], new_reordered_var_names, new_reordered_sizes 179 ), 180 ) 181 182 # now let's create new symbols with the passed in prefix 183 var_ranges, add_var = var_builder(prefix) 184 replacement = dict( 185 zip( 186 new_reordered_var_names, 187 reindex([add_var(x) for x in new_simplified_sizes]), 188 ) 189 ) 190 new_index = sympy_subs(sympy.expand(self.index), replacement) # type: ignore[arg-type] # next PR 191 192 out = MemoryDep(self.name, new_index, tuple(var_ranges.keys()), tuple(var_ranges.values())) # type: ignore[arg-type] 193 return out 194 195 @property 196 def ranges(self) -> Dict[sympy.Symbol, sympy.Expr]: 197 """{c0: 128, c1: 512, ...}""" 198 return dict(zip(self.var_names, self.size)) 199 200 def get_numel(self) -> sympy.Expr: 201 if self.is_indirect(): 202 numel = V.graph.get_numel(self.name) 203 else: 204 vars: OrderedSet[sympy.Basic] = OrderedSet(self.index.free_symbols) 205 numel = sympy.Integer(1) 206 for var, size in zip(self.var_names, self.size): 207 if var in vars: 208 numel = numel * size 209 return numel # type: ignore[return-value] 210 211 def rename(self, renames: Dict[str, str]) -> "MemoryDep": 212 if self.name in renames: 213 return MemoryDep( 214 renames[self.name], 215 self.index, 216 var_names=self.var_names, 217 size=self.size, 218 mode=self.mode, 219 ) 220 return self 221 222 def numbytes_hint(self): 223 return V.graph.sizevars.size_hint(self.get_numel()) * get_dtype_size( 224 V.graph.get_dtype(self.name) 225 ) 226 227 def has_unbacked_symbols(self): 228 return len(free_unbacked_symbols(self.get_numel())) > 0 229 230 def is_contiguous(self) -> bool: 231 return isinstance(self.index, sympy.Symbol) and self.index in self.var_names 232 233 def stride1_for_last_dim(self, result_for_complex_expression=True) -> bool: 234 """ 235 Whether the stride for the last dimension is 1. 236 """ 237 # python test/inductor/test_torchinductor_opinfo.py -k test_comprehensive_masked_scatter_cuda_float16 238 # will exercise thru this corner case. 239 if len(self.var_names) == 0: 240 return True 241 242 terms = self.index.args if isinstance(self.index, sympy.Add) else [self.index] 243 244 last_sym = self.var_names[-1] 245 for term in terms: 246 if term is last_sym: 247 return True 248 249 # Having a >1 stride for the last dimension is bad for perf 250 # return False. 251 if ( 252 isinstance(term, sympy.Mul) 253 and len(term.args) == 2 254 and term.args[1] is last_sym 255 and isinstance(term.args[0], (int, sympy.Integer)) 256 and term.args[0] > 1 257 ): 258 return False 259 260 return result_for_complex_expression 261 262 def is_scalar(self) -> bool: 263 if isinstance(self.index, sympy.Symbol): 264 return self.index not in self.var_names and not self.is_indirect() 265 return isinstance(self.index, (int, sympy.Integer)) 266 267 def is_indirect(self) -> bool: 268 return any(is_indirect(v.name) for v in self.index.free_symbols) # type: ignore[attr-defined] 269 270 271@dataclasses.dataclass(frozen=True) 272class StarDep(Dep): 273 name: str 274 mode: Optional[str] = None 275 276 # depends on the entire buffer 277 @property 278 def index(self): 279 raise NotImplementedError("StarDep does not have an index") 280 281 def get_numel(self) -> sympy.Expr: 282 return V.graph.get_numel(self.name) # type: ignore[return-value] 283 284 def rename(self, renames: Dict[str, str]) -> "StarDep": 285 if self.name in renames: 286 return StarDep(renames[self.name], self.mode) 287 return self 288 289 def numbytes_hint(self): 290 return V.graph.sizevars.size_hint(self.get_numel()) * get_dtype_size( 291 V.graph.get_dtype(self.name) 292 ) 293 294 def has_unbacked_symbols(self): 295 return len(free_unbacked_symbols(self.get_numel())) > 0 296 297 def is_contiguous(self) -> bool: 298 return False 299 300 def is_scalar(self) -> bool: 301 return False 302 303 def is_indirect(self) -> bool: 304 return False 305 306 307# Used for tracking mutation ordering 308# if A reads a buffer and B mutates it 309# B must be ordered after A 310# 311# This is useful for a variety of reasons. 312# For example, if A's read is never actually used, we can eliminate it. 313# Another case is if A's buffer ends up being fused away, we never need to 314# materialize that buffer 315@dataclasses.dataclass(frozen=True) 316class WeakDep(Dep): 317 # Fake dependency on unused buffer 318 name: str 319 # Buffer that is doing the mutation 320 mutating_buf: str 321 322 @property 323 def index(self): 324 raise NotImplementedError("WeakDep does not have an index") 325 326 def get_numel(self) -> sympy.Expr: 327 return sympy.Integer(1) 328 329 def rename(self, renames: Dict[str, str]) -> "WeakDep": 330 if self.name in renames: 331 return WeakDep(renames[self.name], self.mutating_buf) 332 return self 333 334 def numbytes_hint(self): 335 return 1 # Purely inserted for ordering, not an actual dep 336 337 def has_unbacked_symbols(self): 338 return False 339 340 def is_contiguous(self) -> bool: 341 return False 342 343 344@dataclasses.dataclass(frozen=True) 345class IndexExprDep: 346 index: sympy.Expr # type: ignore[assignment] 347 var_names: Tuple[sympy.Symbol, ...] 348 size: Tuple[sympy.Expr, ...] 349 350 351@dataclasses.dataclass 352class ReadWrites: 353 reads: OrderedSet[Dep] 354 writes: OrderedSet[Dep] 355 index_exprs: OrderedSet[IndexExprDep] 356 range_vars: Optional[List[sympy.Expr]] = None 357 var_ranges: Optional[VarRanges] = None 358 359 def rename(self, renames: typing.Dict[str, str]) -> "ReadWrites": 360 return ReadWrites( 361 OrderedSet(dep.rename(renames) for dep in self.reads), 362 OrderedSet(dep.rename(renames) for dep in self.writes), 363 self.index_exprs, 364 self.range_vars, 365 self.var_ranges, 366 ) 367 368 def with_read(self, dep: Union[Dep, Set[Dep]]) -> "ReadWrites": 369 assert isinstance(dep, (WeakDep, StarDep, set)) 370 if not isinstance(dep, set): 371 dep = {dep} 372 return ReadWrites( 373 OrderedSet.union(self.reads, dep), 374 self.writes, 375 self.index_exprs, 376 self.range_vars, 377 self.var_ranges, 378 ) 379 380 def merge(self, other: "ReadWrites"): 381 reads = OrderedSet.union(self.reads, other.reads) 382 writes = OrderedSet.union(self.writes, other.writes) 383 index_exprs = OrderedSet.union(self.index_exprs, other.index_exprs) 384 return ReadWrites(reads - writes, writes, index_exprs) 385 386 @staticmethod 387 def merge_list(read_writes: List["ReadWrites"]): 388 all_writes = OrderedSet.union(*[rw.writes for rw in read_writes]) 389 all_reads = OrderedSet.union(*[rw.reads for rw in read_writes]) - all_writes 390 all_index_exprs = OrderedSet.union(*[rw.index_exprs for rw in read_writes]) 391 return ReadWrites(all_reads, all_writes, all_index_exprs) 392 393 def remove_reads(self, rem_reads): 394 return ReadWrites( 395 self.reads - rem_reads, 396 self.writes, 397 self.index_exprs, 398 self.range_vars, 399 self.var_ranges, 400 ) 401 402 def reads_and_writes(self): 403 return itertools.chain(self.reads, self.writes) 404 405 def buffer_names(self, ignore_integer_index=True): 406 """ 407 Integer index is used for load_seed. 408 """ 409 names: OrderedSet[str] = OrderedSet() 410 for dep in self.reads_and_writes(): 411 if not isinstance(dep, MemoryDep): 412 continue 413 if not ignore_integer_index or not isinstance( 414 dep.index, (int, sympy.Integer) 415 ): 416 names.add(dep.name) 417 return names 418 419 420class _RecordLoadStoreInner(V.MockHandler): # type: ignore[name-defined] 421 def __init__(self, var_ranges: VarRanges, normalize: bool) -> None: 422 super().__init__() 423 self._reads: OrderedSet[Dep] = OrderedSet() 424 self._writes: OrderedSet[MemoryDep] = OrderedSet() 425 self._index_exprs: OrderedSet[IndexExprDep] = OrderedSet() 426 self._var_ranges: VarRanges = var_ranges 427 self._should_normalize: bool = normalize 428 429 @staticmethod 430 def drop_unused_symbols(index, var_names, sizes): 431 """ 432 Reduction has last (reduced) dim in its sizes, but 433 downstream users won't. Normalize this away. 434 """ 435 if not isinstance(index, sympy.Expr): 436 # index can be an int 437 return 438 free_symbols = index.free_symbols 439 while var_names and var_names[-1] not in free_symbols: 440 var_names.pop() 441 sizes.pop() 442 443 @classmethod 444 def _normalize( 445 cls, index: sympy.Expr, var_ranges: VarRanges 446 ) -> Tuple[sympy.Expr, Tuple[sympy.Symbol, ...], Tuple[sympy.Expr, ...]]: 447 # Try to further simplify the indexes even if simplify_loops didn't 448 # convert it to the simplest form because of the interference from 449 # different indexing formulas. 450 index_vars = [*var_ranges.keys()] 451 sizes = tuple(var_ranges.values()) # type: ignore[assignment] 452 new_sizes, reindex, prune = V.graph.sizevars._simplify_loops( 453 index_vars, 454 sizes, 455 index_prevent_reordering([index], index_vars, sizes), 456 ) 457 458 # assign new variables each dimension to deal with numbering mismatches 459 # d0, d1, d2 could become d0, d2 -- which won't match d0, d1 460 new_vars, add_var = var_builder(canonicalization_prefix()) 461 replacement = dict(zip(index_vars, reindex([add_var(x) for x in new_sizes]))) 462 index = sympy_subs(sympy.expand(index), replacement) 463 464 new_vars = [*new_vars.keys()] 465 new_sizes = [*new_sizes] 466 cls.drop_unused_symbols(index, new_vars, new_sizes) 467 return index, tuple(new_vars), tuple(new_sizes) # type: ignore[arg-type] 468 469 def canonicalize( 470 self, index: sympy.Expr 471 ) -> Tuple[sympy.Expr, Tuple[sympy.Symbol, ...], Tuple[sympy.Expr, ...]]: 472 if not self._should_normalize: 473 sizes = [V.graph.sizevars.simplify(x) for x in self._var_ranges.values()] 474 var_names = [k for k, v in zip(self._var_ranges.keys(), sizes) if v != 1] 475 sizes = [v for v in sizes if v != 1] 476 477 self.drop_unused_symbols(index, var_names, sizes) 478 479 return index, tuple(var_names), tuple(sizes) # type: ignore[return-value, arg-type] 480 var_ranges = { 481 k: V.graph.sizevars.simplify(v) 482 for k, v in self._var_ranges.items() 483 # TODO(jansel): explore this further normalization 484 # if k in free_symbols 485 } 486 return self._normalize(index, var_ranges) 487 488 def load(self, name: str, index: sympy.Expr) -> str: 489 self._reads.add(MemoryDep(name, *self.canonicalize(index))) 490 return f"load({name}, {sympy_str(index)})" 491 492 def load_seed(self, name: str, index: int): 493 assert isinstance(index, int) 494 return self.load(name, sympy.Integer(index)) 495 496 def store(self, name: str, index: sympy.Expr, value: str, mode=None) -> str: 497 self._writes.add(MemoryDep(name, *self.canonicalize(index), mode=mode)) 498 return f"store({name}, {sympy_str(index)}, {value}, {mode})" 499 500 def store_reduction(self, name: str, index, value) -> str: 501 return self.store(name, index, f"store_reduction({value})") 502 503 def index_expr(self, index: sympy.Expr, dtype) -> str: 504 self._index_exprs.add(IndexExprDep(*self.canonicalize(index))) 505 return f"index_expr({sympy_str(index)}, {dtype})" 506 507 def bucketize( 508 self, 509 values, 510 offsets_name: str, 511 offsets_size: sympy.Expr, 512 indexing_dtype: torch.dtype, 513 right: bool, 514 ): 515 self._reads.add(StarDep(offsets_name)) 516 return f"bucketize({values}, {offsets_name}, {sympy_str(offsets_size)}, {indexing_dtype}, {right})" 517 518 519class RecordLoadStore(V.KernelFormatterHandler): # type: ignore[name-defined] 520 def __init__(self, var_ranges: VarRanges, normalize: bool) -> None: 521 parent_handler = _RecordLoadStoreInner( 522 var_ranges=var_ranges, normalize=normalize 523 ) 524 super().__init__(parent_handler=parent_handler) 525 526 527# TODO: check call sites 528def var_builder(prefix: str) -> Tuple[VarRanges, Callable[[sympy.Expr], sympy.Symbol]]: 529 cnt = itertools.count() 530 var_ranges: VarRanges = {} 531 532 def add_var(length: sympy.Expr) -> sympy.Symbol: 533 v = sympy_index_symbol(f"{prefix}{next(cnt)}") 534 var_ranges[v] = length 535 return v 536 537 return var_ranges, add_var 538 539 540def index_vars_no_squeeze(*argsizes: Tuple[sympy.Expr, ...], prefix: str): 541 var_ranges, add_var = var_builder(prefix) 542 args: List[List[sympy.Symbol]] = [] 543 for size in argsizes: 544 args.append(list(map(add_var, size))) 545 return args, var_ranges 546 547 548def index_vars_squeeze(*argsizes: Tuple[sympy.Expr, ...], prefix: str = "d"): 549 from .ir import SqueezeView 550 551 var_ranges, add_var = var_builder(prefix) 552 args: List[List[sympy.Expr]] = [] 553 new_sizes: List[List[sympy.Expr]] = [] 554 for size in argsizes: 555 new_size, reindex = SqueezeView.squeezer(size) 556 new_sizes.append(new_size) 557 args.append(reindex(list(map(add_var, new_size)))) 558 return args, var_ranges 559 560 561def extract_read_writes( 562 fn: Callable[..., Any], 563 *argsizes: Tuple[sympy.Expr, ...], 564 normalize: bool = False, 565 prefix: str = "d", 566 hidden_args=(), 567): 568 args, var_ranges = index_vars_squeeze(*argsizes, prefix=prefix) 569 570 from .loop_body import LoopBody, MemoryUsageType 571 572 if isinstance(fn, LoopBody): 573 # Fast path to avoid tracing when we already have a LoopBody 574 inner = _RecordLoadStoreInner(var_ranges=var_ranges, normalize=normalize) 575 name_to_index = fn.indexing_from_args([*args, *hidden_args]) 576 if fn.indirect_vars: 577 # mimic the `tmpX` naming tracing gives us 578 repl = {v: sympy.Symbol(f"tmp{i}") for i, v in enumerate(fn.indirect_vars)} 579 name_to_index = {k: sympy_subs(v, repl) for k, v in name_to_index.items()} 580 for entry in fn.memory_usage[MemoryUsageType.LOAD]: 581 inner.load(entry.buffer_name, name_to_index[entry.index_name]) 582 for entry in fn.memory_usage[MemoryUsageType.LOAD_SEED]: 583 inner.load_seed(entry.buffer_name, int(name_to_index[entry.index_name])) 584 for entry in fn.memory_usage[MemoryUsageType.STORE]: 585 inner.store( 586 entry.buffer_name, name_to_index[entry.index_name], None, entry.mode 587 ) 588 for entry in fn.memory_usage[MemoryUsageType.STORE_REDUCTION]: 589 inner.store_reduction( 590 entry.buffer_name, name_to_index[entry.index_name], None 591 ) 592 for entry in fn.memory_usage[MemoryUsageType.INDEX_EXPR]: 593 inner.index_expr(name_to_index[entry.index_name], None) 594 for entry in fn.memory_usage[MemoryUsageType.BUCKETIZE]: 595 inner.bucketize( 596 None, entry.buffer_name, name_to_index[entry.index_name], None, None 597 ) 598 # fn.memory_usage[MemoryUsageType.CHECK_BOUNDS] intentionally skipped 599 else: 600 # Slow path tracing the function 601 rw = RecordLoadStore(var_ranges, normalize=normalize) 602 with V.set_ops_handler(rw): 603 fn(*args, *hidden_args) 604 inner = rw.parent_handler 605 606 if normalize: 607 range_vars = [] # Number of vars could differ due to normalization 608 else: 609 range_vars = [*itertools.chain.from_iterable(args)] 610 611 return ReadWrites( 612 OrderedSet(inner._reads), 613 OrderedSet(inner._writes), 614 inner._index_exprs, 615 range_vars, 616 var_ranges, 617 ) 618 619 620def extract_input_node_reduction_ranges( 621 input_node: "torch._inductor.ir.TensorBox", 622) -> Tuple[Optional[List[sympy.Expr]], Optional[List[sympy.Expr]]]: 623 """ 624 Returns the size and reduction size of all inputs, if the sizes and reduction_sizes (if exist) are all the same. 625 It's possible that a node has multiple inputs, some are Reduction nodes and others are Pointwise nodes. 626 In this case, reduction_sizes of the Reduction nodes need to be the same. 627 Otherwise returns (None, None). 628 """ 629 630 from .ir import ComputedBuffer, Loops 631 632 if isinstance(input_node.data, ComputedBuffer): 633 # Input node has already been realized. Return its size and reduction_size. 634 size = input_node.get_size() 635 reduction_size = input_node.get_reduction_size() 636 if len(reduction_size) > 0: 637 return (size, reduction_size) 638 else: 639 return (None, None) 640 641 if not isinstance(input_node.data.data, Loops): # type: ignore[attr-defined] 642 # Other IRNodes do not have reduction_ranges. 643 return (None, None) 644 645 # There is one issue: what if there are views / permutations between the input node and its dependent realized nodes? 646 # The current method still uses reduction ranges from the dependent realized node, which is not ideal. 647 # Is there a way to check whether there are permutations inbetween? 648 reads = input_node.get_reads() 649 reduction_size = None 650 size = None 651 while reduction_size is None and len(reads) > 0: 652 seen: OrderedSet[str] = OrderedSet() 653 new_reads = [] 654 for read in reads: 655 if not isinstance(read, MemoryDep): 656 continue 657 if read.name in seen: 658 continue 659 seen.add(read.name) 660 buffer = V.graph.try_get_buffer(read.name) 661 if buffer is None: 662 continue 663 op = buffer.get_defining_op() 664 if op is None: 665 continue 666 667 if isinstance(op, ComputedBuffer) and len(op.get_reduction_size()) > 0: 668 if reduction_size is None: 669 reduction_size = op.get_reduction_size() 670 size = op.get_size() 671 elif reduction_size != op.get_reduction_size() or size != op.get_size(): 672 return (None, None) 673 else: 674 new_reads.extend(op.get_reads()) 675 if reads == new_reads: 676 return (size, reduction_size) 677 else: 678 reads = new_reads 679 return (size, reduction_size) 680 681 682def canonicalization_prefix(): 683 return "c" 684 685 686# ops handler which computes all the free unbacked symbols for an IR 687class FreeUnbackedSymbolsOpsHandler: 688 symbols: OrderedSet[sympy.Symbol] 689 690 def __init__(self) -> None: 691 self.symbols = OrderedSet() 692 693 def __getattr__(self, name: str) -> Callable[..., Any]: 694 def inner(*args, **kwargs): 695 for a in itertools.chain(args, kwargs.values()): 696 if isinstance(a, (sympy.Expr, sympy.logic.boolalg.Boolean)): 697 self.symbols |= free_unbacked_symbols(a) 698 699 return inner 700 701 def indirect_indexing( 702 self, index_var, size, check=True, wrap_neg=True 703 ) -> sympy.Symbol: 704 assert not isinstance(index_var, (sympy.Expr, sympy.logic.boolalg.Boolean)) 705 self.symbols |= free_unbacked_symbols(size) 706 return sympy_index_symbol(f"({str(index_var)})") 707 708 def frexp(self, x): 709 return (None,) * 2 710 711 def scan(self, dtypes, combine_fn, values): 712 return (None,) * len(values) 713 714 def sort(self, dtypes, values, stable, descending): 715 return (None,) * len(values) 716 717 def reduction( 718 self, 719 dtype: torch.dtype, 720 src_dtype: torch.dtype, 721 reduction_type: ReductionType, 722 value: Union[None, Tuple[None, ...]], 723 ) -> Union[None, Tuple[None, ...]]: 724 num_values = reduction_num_outputs(reduction_type) 725 return (None,) * num_values if num_values > 1 else None 726 727 728def _typecheck_FreeUnbackedSymbolsOpsHandler( 729 h: FreeUnbackedSymbolsOpsHandler, 730) -> OpsHandler[None]: 731 return h 732 733 734def extract_free_unbacked_symbols(fn: Callable[..., Any], index, rindex=None): 735 from .ir import FlexibleLayout 736 737 args = [index, rindex] if rindex is not None else [index] 738 handler = FreeUnbackedSymbolsOpsHandler() 739 # NB: I cargo culted the allow_indexing patch here, I don't understand why 740 # people do this all over 741 with V.set_ops_handler(handler), patch.object( 742 FlexibleLayout, "allow_indexing", True 743 ): 744 fn(*args) 745 return handler.symbols 746