1# mypy: allow-untyped-defs 2import logging 3import operator 4from functools import partial 5from typing import Any, Callable, Dict 6 7from sympy import Expr 8 9import torch 10from torch.utils._sympy.value_ranges import bound_sympy, ValueRangeAnalysis, ValueRanges 11 12from .loop_body import InterpreterShim, LoopBody, LoopBodyBlock 13from .utils import cache_on_self, dominated_nodes 14from .virtualized import V 15 16 17log = logging.getLogger(__name__) 18 19 20class BoundVars: 21 """ 22 Performs Value Range Analysis on LoopBody's fx graph by calling BoundVars.run() 23 It exposes the ranges of the nodes in the `bounds` variable 24 25 Note. A current limitation of this analysis is that it just works on a per-loop basis. 26 We should be able to propagate the bounds between across the whole graph. This may benefit 27 the case a bounded variable is returned by a kernel and fed into another. 28 """ 29 30 def __init__(self, loop_body: LoopBody) -> None: 31 def upper_bound(v): 32 return bound_sympy(v).upper if isinstance(v, Expr) else v 33 34 self.loop_body = loop_body 35 self.replacement_vals = { 36 k: ValueRanges[Expr](0, upper_bound(v) - 1) 37 for k, v in loop_body.var_ranges.items() 38 } 39 # avoid computing these values, pessimistically assume that they are unbounded 40 self.unbounded_vars = dominated_nodes( 41 node 42 for node in self.loop_body.get_nodes() 43 if node.target in ["load", "reduction", operator.getitem] 44 or "masked_subblock" in node.target 45 ) 46 # To access this variable call `get_bounds()` 47 self._bounds: Dict[torch.fx.Node, ValueRanges[Expr]] = {} 48 49 def __repr__(self) -> str: 50 return ( 51 f"{self.__class__.__name__}(" 52 f"loop_body={self.loop_body},\n " 53 f"replacement_vals={self.replacement_vals}, \n" 54 f"unbounded_vars={self.unbounded_vars}, \n" 55 f"_bounds={self._bounds})" 56 ) 57 58 @cache_on_self 59 def get_bounds(self) -> Dict[torch.fx.Node, ValueRanges[Expr]]: 60 submodules = self.swap_submodules(self.loop_body.submodules) 61 62 # Initialize the environment with the unbounded variables 63 for node in self.unbounded_vars: 64 # we need to evaluate masked_subblock to recurse, and we need to set indirect values 65 if not isinstance(node.target, str) or ( 66 "masked_subblock" not in node.target 67 and "set_indirect" not in node.target 68 ): 69 self._bounds[node] = ValueRanges[Expr].unknown() 70 71 with V.set_ops_handler(ValueRangeAnalysis()): 72 interpreter = InterpreterShim(self.loop_body.root_block.graph, submodules) 73 log.debug("get_bounds:\n%s", self.loop_body.root_block.graph) 74 interpreter.run(V.get_ops_handler(), initial_env=self._bounds) 75 return self._bounds 76 77 def swap_submodules( 78 self, submodules: Dict[str, Callable[..., Any]] 79 ) -> Dict[str, Callable[..., ValueRanges[Expr]]]: 80 result: Dict[str, Callable[..., ValueRanges[Expr]]] = {} 81 for key in submodules.keys(): 82 if key == "get_index": 83 result[key] = self.get_index 84 elif "masked_subblock" in key: 85 subblock = self.loop_body.subblocks[key] 86 # The result within the lambda will reference to the final 87 # set of modules at the end of the for-loop as it stores a reference to it 88 89 # bind subblock in a function because python lambdas close over by reference 90 # moving the lambda out of make_fn would close over the reference to subblock, 91 # so all lambdas would have the same subblock reference that is the final 92 # subblock in the loop 93 def make_fn(subblock): 94 return lambda mask, value: self.masked_subblock( 95 subblock, self._bounds, mask, value, result 96 ) 97 98 result[key] = make_fn(subblock) 99 elif "set_indirect" in key: 100 idx = int(key[len("set_indirect") :]) 101 var = self.loop_body.indirect_vars[idx] 102 indirect = partial(self.set_indirect, var) 103 result[key] = indirect 104 else: 105 assert "scan" in key 106 result[key] = submodules[key] 107 108 return result 109 110 def masked_subblock( 111 self, 112 subblock: LoopBodyBlock, 113 env: Dict[torch.fx.Node, ValueRanges[Expr]], 114 mask: Any, 115 value: Any, 116 submodules: Dict[str, Callable[..., Any]], 117 ) -> ValueRanges[Expr]: 118 interp = InterpreterShim(subblock.graph, submodules) 119 interp.run(V.get_ops_handler(), initial_env=env) 120 output = [node for node in subblock.graph.nodes if node.target == "output"] 121 assert len(output) == 1 122 # dont bother unioning with value since the load from buffer will be 123 # pessimistically assumed to be inf anyway 124 return interp.env[output[0]] 125 126 def set_indirect(self, old: Expr, new: ValueRanges[Expr]) -> ValueRanges[Expr]: 127 assert isinstance(new, ValueRanges) 128 self.replacement_vals[old] = new 129 return new 130 131 def get_index(self, name: Expr) -> ValueRanges[Expr]: 132 expr = self.loop_body.indexing_exprs[name] 133 bound = self.replacement_vals.get(expr) 134 if bound is None: 135 bound = bound_sympy(expr, self.replacement_vals) 136 # The following assertion is true at the time of this writing 137 # We don't assert is as to not execute bound_sympy when bound is not None 138 # assert bound is None or bound == bound_sympy(expr, self.replacement_vals) 139 self.replacement_vals[name] = bound 140 return bound 141