1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7from contextlib import contextmanager 8from dataclasses import dataclass 9from typing import Any, Dict, List, Optional, Tuple, Union 10 11import sympy 12 13import torch 14import torch.utils._pytree as pytree 15from executorch.exir.delegate import LoweredBackendModule 16from executorch.exir.dynamic_shape import ( 17 calculate_dynamic_shape_spec, 18 DynamicMemoryPlanningMode, 19) 20from executorch.exir.pass_base import Argument, ExportPass 21from executorch.exir.pass_infra.node_metadata import NodeMetadata 22from executorch.exir.pass_infra.proxy_value import ProxyValue 23from executorch.exir.passes.executorch_prim_ops_registry import _EXECUTORCH_SYM_OPS 24from executorch.exir.schema import TensorShapeDynamism 25from executorch.exir.sym_util import collect_free_symbols, eval_expr 26from executorch.exir.tensor import TensorSpec 27from torch._subclasses import FakeTensor 28from torch.fx import GraphModule 29 30 31@dataclass 32class DSInfo: 33 """ 34 Dynamic shape information we are tracking for each dynamic shape symbol. 35 """ 36 37 # the output of format_node() for the node introducing the symbol 38 node_debug_str: str 39 # upper bound value or None for fully dynamic memory planning 40 ubval: Optional[int] 41 42 43class DynamicShapePropPass(ExportPass): 44 """ 45 In general, for each op, this pass propagate dynamic shape information from 46 op inputs to op outputs. 47 48 For cond/map nodes, we need pass dynamic shape information to submodules' 49 placeholder nodes, propagate the dynamic shape information thru the graphs 50 of the submodules, and finally set the node's dynamic shape info based on 51 submodules' output nodes' dynamic shape info. 52 """ 53 54 def __init__( 55 self, mode: DynamicMemoryPlanningMode = DynamicMemoryPlanningMode.UPPER_BOUND 56 ): 57 """ 58 mode controls how we do memory planning for dynamic shape tensors. 59 In UPPER_BOUND mode, we plan dynamic shape tensors' memory based on 60 its upper bound shape; 61 In FULL_DYNAMIC mdoe, the compiler does not allocata memory for 62 dynamic shape tensors, the runtime will do the allocation. 63 """ 64 super().__init__() 65 self.mode = mode 66 self.sym_to_dsinfo = {} 67 self.shape_env = None 68 69 @contextmanager 70 def apply_upper_bounds(self): 71 """ 72 Context manager to use upper bound value to evaluate expressions. 73 """ 74 try: 75 if self.shape_env: 76 old_var_to_val = dict(self.shape_env.var_to_val) 77 for sym, dsinfo in self.sym_to_dsinfo.items(): 78 assert dsinfo.ubval is not None 79 self.shape_env.var_to_val[sym] = sympy.Integer(dsinfo.ubval) 80 yield 81 finally: 82 if self.shape_env: 83 self.shape_env.var_to_val = old_var_to_val 84 85 def copy_dsinfo_btw_specs(self, src_spec: TensorSpec, dst_spec: TensorSpec): 86 dst_spec.shape_dynamism = src_spec.shape_dynamism 87 dst_spec._upper_bound_shape = src_spec._upper_bound_shape 88 89 def inject_dsinfo_to_graph( 90 self, 91 subgm: GraphModule, 92 inputs: Union[List[ProxyValue], Tuple[ProxyValue, ...]], 93 ignore_first_ph: bool = False, 94 ): 95 """ 96 ignore_first_ph: This argument is added for map node. For map node, 97 the first placeholder is special and we need ignore it here 98 and handle it specially. 99 """ 100 phs = [n for n in subgm.graph.nodes if n.op == "placeholder"] 101 if ignore_first_ph: 102 phs = phs[1:] 103 assert len(phs) == len(inputs) 104 for ph, inp in zip(phs, inputs): 105 dst_spec: TensorSpec = ph.meta["spec"] 106 src_spec: TensorSpec = inp.node.meta["spec"] 107 self.copy_dsinfo_btw_specs(src_spec, dst_spec) 108 109 def inject_xs_dsinfo_to_graph(self, subgm: GraphModule, xs: ProxyValue): 110 """ 111 xs means the first argument for the map node. 112 113 Even if xs is a upper bound tensor, it's possible that the first placeholder 114 of subgm is still a static shape tensor if only the first dimension of xs 115 is dynamic. But we don't have this optimization yet. If xs is dynamic, 116 we treat the first placeholder of subgm as dynamic. 117 """ 118 ph = next(n for n in subgm.graph.nodes if n.op == "placeholder") 119 src_spec: TensorSpec = xs.node.meta["spec"] 120 dst_spec = ph.meta["spec"] 121 122 self.copy_dsinfo_btw_specs(src_spec, dst_spec) 123 # update dst_spec to remove the highest dimesion 124 if dst_spec._upper_bound_shape: 125 dst_spec._upper_bound_shape = dst_spec._upper_bound_shape[1:] 126 127 def verify_dsinfo_from_both_branches( 128 self, true_gm: GraphModule, false_gm: GraphModule 129 ): 130 """ 131 For cond node, true and false branch should return outputs with the 132 same shape. 133 """ 134 *_, true_out = true_gm.graph.nodes 135 *_, false_out = false_gm.graph.nodes 136 true_out = pytree.tree_flatten(true_out)[0] 137 false_out = pytree.tree_flatten(false_out)[0] 138 assert len(true_out) == len(false_out) 139 for true_out_item, false_out_item in zip(true_out, false_out): 140 true_spec = true_out_item.meta["spec"] 141 false_spec = false_out_item.meta["spec"] 142 assert true_spec.shape_dynamism == false_spec.shape_dynamism 143 assert true_spec._upper_bound_shape == false_spec._upper_bound_shape 144 145 def extract_dsinfo_from_graph(self, subgm: GraphModule, meta: NodeMetadata): 146 *_, out_node = subgm.graph.nodes 147 dst_spec_list = pytree.tree_flatten(meta["spec"])[0] 148 src_spec_list = pytree.tree_flatten(out_node.meta["spec"])[0] 149 for src_spec, dst_spec in zip(src_spec_list, dst_spec_list): 150 self.copy_dsinfo_btw_specs(src_spec, dst_spec) 151 152 def call_cond( 153 self, 154 pred: ProxyValue, 155 true_fn: torch.fx.GraphModule, 156 false_fn: torch.fx.GraphModule, 157 inputs: List[Any], 158 meta: NodeMetadata, 159 ) -> ProxyValue: 160 self.inject_dsinfo_to_graph(true_fn, inputs) 161 self.inject_dsinfo_to_graph(false_fn, inputs) 162 retval = super().call_cond(pred, true_fn, false_fn, inputs, meta) 163 164 self.verify_dsinfo_from_both_branches(true_fn, false_fn) 165 166 # Note: 'meta' will override the metadata in retval. 167 # so we update 'meta' rather than 'retval' here. 168 self.extract_dsinfo_from_graph(true_fn, meta) 169 return retval 170 171 def call_map( 172 self, 173 f: torch.fx.GraphModule, 174 xs: ProxyValue, 175 args: Tuple[ProxyValue, ...], 176 meta: NodeMetadata, 177 ) -> ProxyValue: 178 self.inject_dsinfo_to_graph(f, args, True) 179 self.inject_xs_dsinfo_to_graph(f, xs) 180 retval = super().call_map(f, xs, args, meta) 181 182 # We are being a bit conservative that if xs of f's output are dynamic 183 # shape, we decide the output of map node as dynamic shape. 184 xs_spec = xs.node.meta["spec"] 185 *_, subgm_out = f.graph.nodes 186 subgm_out_spec = subgm_out.meta["spec"] 187 188 # Take advantage that the static TensorShapeDynamsim is miminal 189 result_spec = meta["spec"] 190 result_spec.shape_dynamism = max( 191 spec.shape_dynamism 192 for spec in pytree.tree_flatten((xs_spec, subgm_out_spec))[0] 193 ) 194 if result_spec.shape_dynamism == TensorShapeDynamism.DYNAMIC_BOUND: 195 # on the right hand side of the assignment we use 'upper_bound_shape' 196 # rather than '_upper_bound_shape'. The former return the static shape 197 # for static tensor which is what we want. 198 result_spec._upper_bound_shape = ( 199 xs_spec.upper_bound_shape[:1] + subgm_out_spec.upper_bound_shape 200 ) 201 202 return retval 203 204 def add_symint_upperbound( 205 self, node_debug_str: str, symint: torch.SymInt, ubval: int 206 ): 207 if not isinstance(symint, torch.SymInt): 208 return 209 expr = symint.node.expr 210 if isinstance(expr, sympy.Symbol): 211 self.sym_to_dsinfo[expr] = DSInfo(node_debug_str, ubval) 212 if self.shape_env is None: 213 self.shape_env = symint.node.shape_env 214 else: 215 assert symint.node.shape_env is self.shape_env 216 217 def placeholder(self, name: str, arg: Argument, meta: NodeMetadata) -> ProxyValue: 218 output = super().placeholder(name, arg, meta) 219 # TODO: handle full dynamic 220 if ( 221 self.mode == DynamicMemoryPlanningMode.UPPER_BOUND 222 and meta.data.get("spec", None) is not None 223 and meta.data.get("val", None) is not None 224 ): 225 spec = meta.data["spec"] 226 val = meta.data["val"] 227 if not isinstance(val, FakeTensor): 228 return output 229 230 if spec.shape_dynamism != TensorShapeDynamism.DYNAMIC_BOUND: 231 return output 232 233 for sym, ubval in zip(val.shape, spec._upper_bound_shape): 234 assert self.node_debug_str is not None 235 self.add_symint_upperbound(self.node_debug_str, sym, ubval) 236 return output 237 238 def eval_symint_to_ubval(self, symint: torch.SymInt) -> int: 239 return eval_expr(symint) 240 241 def decide_upper_bound_from_symbols(self, meta): 242 with self.apply_upper_bounds(): 243 meta = meta.data 244 if meta.get("val", None) is None or meta.get("spec", None) is None: 245 return 246 vallist, _ = pytree.tree_flatten(meta["val"]) 247 speclist, _ = pytree.tree_flatten(meta["spec"]) 248 for val, spec in zip(vallist, speclist): 249 if not isinstance(val, FakeTensor) or not isinstance(spec, TensorSpec): 250 continue 251 free_symbols = collect_free_symbols(val.shape) 252 if len(free_symbols & set(self.sym_to_dsinfo.keys())) == 0: 253 spec.shape_dynamism = TensorShapeDynamism.STATIC 254 spec._upper_bound_shape = None 255 continue 256 spec.shape_dynamism = TensorShapeDynamism.DYNAMIC_BOUND 257 # evaluate the upper bound shape 258 spec._upper_bound_shape = [ 259 self.eval_symint_to_ubval(s) for s in val.shape 260 ] 261 262 def call_delegate( 263 self, 264 lowered_module: LoweredBackendModule, 265 args: Tuple[Argument, ...], 266 kwargs: Dict[str, Argument], 267 meta: NodeMetadata, 268 ) -> ProxyValue: 269 """ 270 Override this method so we can properly calculate the dynamic shape 271 information for the output of delegate. 272 """ 273 if self.mode == DynamicMemoryPlanningMode.UPPER_BOUND: 274 self.decide_upper_bound_from_symbols(meta) 275 else: 276 raise RuntimeError("NYI: delegatoin supporting in full dynamic mode") 277 return super().call_delegate(lowered_module, args, kwargs, meta) 278 279 def call_operator(self, op, args, kwargs, meta): 280 """ 281 If any of the arguments has dynamic shape, mark the output as dynamic shape. 282 """ 283 284 # no need to do dynamic shape propagation for these ops 285 if op.target in _EXECUTORCH_SYM_OPS: 286 return super().call_operator(op, args, kwargs, meta) 287 288 if self.mode == DynamicMemoryPlanningMode.UPPER_BOUND: 289 self.decide_upper_bound_from_symbols(meta) 290 return super().call_operator(op, args, kwargs, meta) 291 292 ds_spec = calculate_dynamic_shape_spec(self.mode, op.target, args, kwargs) 293 294 out_tensor_spec = meta["spec"] 295 296 for ds_spec_item, tensor_spec_item in zip( 297 pytree.tree_flatten(ds_spec)[0], pytree.tree_flatten(out_tensor_spec)[0] 298 ): 299 tensor_spec_item.shape_dynamism = ds_spec_item.shape_dynamism 300 tensor_spec_item._upper_bound_shape = ds_spec_item.upper_bound_shape 301 return super().call_operator(op, args, kwargs, meta) 302