xref: /aosp_15_r20/external/executorch/exir/passes/dynamic_shape_prop_pass.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7from contextlib import contextmanager
8from dataclasses import dataclass
9from typing import Any, Dict, List, Optional, Tuple, Union
10
11import sympy
12
13import torch
14import torch.utils._pytree as pytree
15from executorch.exir.delegate import LoweredBackendModule
16from executorch.exir.dynamic_shape import (
17    calculate_dynamic_shape_spec,
18    DynamicMemoryPlanningMode,
19)
20from executorch.exir.pass_base import Argument, ExportPass
21from executorch.exir.pass_infra.node_metadata import NodeMetadata
22from executorch.exir.pass_infra.proxy_value import ProxyValue
23from executorch.exir.passes.executorch_prim_ops_registry import _EXECUTORCH_SYM_OPS
24from executorch.exir.schema import TensorShapeDynamism
25from executorch.exir.sym_util import collect_free_symbols, eval_expr
26from executorch.exir.tensor import TensorSpec
27from torch._subclasses import FakeTensor
28from torch.fx import GraphModule
29
30
31@dataclass
32class DSInfo:
33    """
34    Dynamic shape information we are tracking for each dynamic shape symbol.
35    """
36
37    # the output of format_node() for the node introducing the symbol
38    node_debug_str: str
39    # upper bound value or None for fully dynamic memory planning
40    ubval: Optional[int]
41
42
43class DynamicShapePropPass(ExportPass):
44    """
45    In general, for each op, this pass propagate dynamic shape information from
46    op inputs to op outputs.
47
48    For cond/map nodes, we need pass dynamic shape information to submodules'
49    placeholder nodes, propagate the dynamic shape information thru the graphs
50    of the submodules, and finally set the node's dynamic shape info based on
51    submodules' output nodes' dynamic shape info.
52    """
53
54    def __init__(
55        self, mode: DynamicMemoryPlanningMode = DynamicMemoryPlanningMode.UPPER_BOUND
56    ):
57        """
58        mode controls how we do memory planning for dynamic shape tensors.
59        In UPPER_BOUND mode, we plan dynamic shape tensors' memory based on
60        its upper bound shape;
61        In FULL_DYNAMIC mdoe, the compiler does not allocata memory for
62        dynamic shape tensors, the runtime will do the allocation.
63        """
64        super().__init__()
65        self.mode = mode
66        self.sym_to_dsinfo = {}
67        self.shape_env = None
68
69    @contextmanager
70    def apply_upper_bounds(self):
71        """
72        Context manager to use upper bound value to evaluate expressions.
73        """
74        try:
75            if self.shape_env:
76                old_var_to_val = dict(self.shape_env.var_to_val)
77                for sym, dsinfo in self.sym_to_dsinfo.items():
78                    assert dsinfo.ubval is not None
79                    self.shape_env.var_to_val[sym] = sympy.Integer(dsinfo.ubval)
80            yield
81        finally:
82            if self.shape_env:
83                self.shape_env.var_to_val = old_var_to_val
84
85    def copy_dsinfo_btw_specs(self, src_spec: TensorSpec, dst_spec: TensorSpec):
86        dst_spec.shape_dynamism = src_spec.shape_dynamism
87        dst_spec._upper_bound_shape = src_spec._upper_bound_shape
88
89    def inject_dsinfo_to_graph(
90        self,
91        subgm: GraphModule,
92        inputs: Union[List[ProxyValue], Tuple[ProxyValue, ...]],
93        ignore_first_ph: bool = False,
94    ):
95        """
96        ignore_first_ph: This argument is added for map node. For map node,
97            the first placeholder is special and we need ignore it here
98            and handle it specially.
99        """
100        phs = [n for n in subgm.graph.nodes if n.op == "placeholder"]
101        if ignore_first_ph:
102            phs = phs[1:]
103        assert len(phs) == len(inputs)
104        for ph, inp in zip(phs, inputs):
105            dst_spec: TensorSpec = ph.meta["spec"]
106            src_spec: TensorSpec = inp.node.meta["spec"]
107            self.copy_dsinfo_btw_specs(src_spec, dst_spec)
108
109    def inject_xs_dsinfo_to_graph(self, subgm: GraphModule, xs: ProxyValue):
110        """
111        xs means the first argument for the map node.
112
113        Even if xs is a upper bound tensor, it's possible that the first placeholder
114        of subgm is still a static shape tensor if only the first dimension of xs
115        is dynamic. But we don't have this optimization yet. If xs is dynamic,
116        we treat the first placeholder of subgm as dynamic.
117        """
118        ph = next(n for n in subgm.graph.nodes if n.op == "placeholder")
119        src_spec: TensorSpec = xs.node.meta["spec"]
120        dst_spec = ph.meta["spec"]
121
122        self.copy_dsinfo_btw_specs(src_spec, dst_spec)
123        # update dst_spec to remove the highest dimesion
124        if dst_spec._upper_bound_shape:
125            dst_spec._upper_bound_shape = dst_spec._upper_bound_shape[1:]
126
127    def verify_dsinfo_from_both_branches(
128        self, true_gm: GraphModule, false_gm: GraphModule
129    ):
130        """
131        For cond node, true and false branch should return outputs with the
132        same shape.
133        """
134        *_, true_out = true_gm.graph.nodes
135        *_, false_out = false_gm.graph.nodes
136        true_out = pytree.tree_flatten(true_out)[0]
137        false_out = pytree.tree_flatten(false_out)[0]
138        assert len(true_out) == len(false_out)
139        for true_out_item, false_out_item in zip(true_out, false_out):
140            true_spec = true_out_item.meta["spec"]
141            false_spec = false_out_item.meta["spec"]
142            assert true_spec.shape_dynamism == false_spec.shape_dynamism
143            assert true_spec._upper_bound_shape == false_spec._upper_bound_shape
144
145    def extract_dsinfo_from_graph(self, subgm: GraphModule, meta: NodeMetadata):
146        *_, out_node = subgm.graph.nodes
147        dst_spec_list = pytree.tree_flatten(meta["spec"])[0]
148        src_spec_list = pytree.tree_flatten(out_node.meta["spec"])[0]
149        for src_spec, dst_spec in zip(src_spec_list, dst_spec_list):
150            self.copy_dsinfo_btw_specs(src_spec, dst_spec)
151
152    def call_cond(
153        self,
154        pred: ProxyValue,
155        true_fn: torch.fx.GraphModule,
156        false_fn: torch.fx.GraphModule,
157        inputs: List[Any],
158        meta: NodeMetadata,
159    ) -> ProxyValue:
160        self.inject_dsinfo_to_graph(true_fn, inputs)
161        self.inject_dsinfo_to_graph(false_fn, inputs)
162        retval = super().call_cond(pred, true_fn, false_fn, inputs, meta)
163
164        self.verify_dsinfo_from_both_branches(true_fn, false_fn)
165
166        # Note: 'meta' will override the metadata in retval.
167        # so we update 'meta' rather than 'retval' here.
168        self.extract_dsinfo_from_graph(true_fn, meta)
169        return retval
170
171    def call_map(
172        self,
173        f: torch.fx.GraphModule,
174        xs: ProxyValue,
175        args: Tuple[ProxyValue, ...],
176        meta: NodeMetadata,
177    ) -> ProxyValue:
178        self.inject_dsinfo_to_graph(f, args, True)
179        self.inject_xs_dsinfo_to_graph(f, xs)
180        retval = super().call_map(f, xs, args, meta)
181
182        # We are being a bit conservative that if xs of f's output are dynamic
183        # shape, we decide the output of map node as dynamic shape.
184        xs_spec = xs.node.meta["spec"]
185        *_, subgm_out = f.graph.nodes
186        subgm_out_spec = subgm_out.meta["spec"]
187
188        # Take advantage that the static TensorShapeDynamsim is miminal
189        result_spec = meta["spec"]
190        result_spec.shape_dynamism = max(
191            spec.shape_dynamism
192            for spec in pytree.tree_flatten((xs_spec, subgm_out_spec))[0]
193        )
194        if result_spec.shape_dynamism == TensorShapeDynamism.DYNAMIC_BOUND:
195            # on the right hand side of the assignment we use 'upper_bound_shape'
196            # rather than '_upper_bound_shape'. The former return the static shape
197            # for static tensor which is what we want.
198            result_spec._upper_bound_shape = (
199                xs_spec.upper_bound_shape[:1] + subgm_out_spec.upper_bound_shape
200            )
201
202        return retval
203
204    def add_symint_upperbound(
205        self, node_debug_str: str, symint: torch.SymInt, ubval: int
206    ):
207        if not isinstance(symint, torch.SymInt):
208            return
209        expr = symint.node.expr
210        if isinstance(expr, sympy.Symbol):
211            self.sym_to_dsinfo[expr] = DSInfo(node_debug_str, ubval)
212            if self.shape_env is None:
213                self.shape_env = symint.node.shape_env
214            else:
215                assert symint.node.shape_env is self.shape_env
216
217    def placeholder(self, name: str, arg: Argument, meta: NodeMetadata) -> ProxyValue:
218        output = super().placeholder(name, arg, meta)
219        # TODO: handle full dynamic
220        if (
221            self.mode == DynamicMemoryPlanningMode.UPPER_BOUND
222            and meta.data.get("spec", None) is not None
223            and meta.data.get("val", None) is not None
224        ):
225            spec = meta.data["spec"]
226            val = meta.data["val"]
227            if not isinstance(val, FakeTensor):
228                return output
229
230            if spec.shape_dynamism != TensorShapeDynamism.DYNAMIC_BOUND:
231                return output
232
233            for sym, ubval in zip(val.shape, spec._upper_bound_shape):
234                assert self.node_debug_str is not None
235                self.add_symint_upperbound(self.node_debug_str, sym, ubval)
236        return output
237
238    def eval_symint_to_ubval(self, symint: torch.SymInt) -> int:
239        return eval_expr(symint)
240
241    def decide_upper_bound_from_symbols(self, meta):
242        with self.apply_upper_bounds():
243            meta = meta.data
244            if meta.get("val", None) is None or meta.get("spec", None) is None:
245                return
246            vallist, _ = pytree.tree_flatten(meta["val"])
247            speclist, _ = pytree.tree_flatten(meta["spec"])
248            for val, spec in zip(vallist, speclist):
249                if not isinstance(val, FakeTensor) or not isinstance(spec, TensorSpec):
250                    continue
251                free_symbols = collect_free_symbols(val.shape)
252                if len(free_symbols & set(self.sym_to_dsinfo.keys())) == 0:
253                    spec.shape_dynamism = TensorShapeDynamism.STATIC
254                    spec._upper_bound_shape = None
255                    continue
256                spec.shape_dynamism = TensorShapeDynamism.DYNAMIC_BOUND
257                # evaluate the upper bound shape
258                spec._upper_bound_shape = [
259                    self.eval_symint_to_ubval(s) for s in val.shape
260                ]
261
262    def call_delegate(
263        self,
264        lowered_module: LoweredBackendModule,
265        args: Tuple[Argument, ...],
266        kwargs: Dict[str, Argument],
267        meta: NodeMetadata,
268    ) -> ProxyValue:
269        """
270        Override this method so we can properly calculate the dynamic shape
271        information for the output of delegate.
272        """
273        if self.mode == DynamicMemoryPlanningMode.UPPER_BOUND:
274            self.decide_upper_bound_from_symbols(meta)
275        else:
276            raise RuntimeError("NYI: delegatoin supporting in full dynamic mode")
277        return super().call_delegate(lowered_module, args, kwargs, meta)
278
279    def call_operator(self, op, args, kwargs, meta):
280        """
281        If any of the arguments has dynamic shape, mark the output as dynamic shape.
282        """
283
284        # no need to do dynamic shape propagation for these ops
285        if op.target in _EXECUTORCH_SYM_OPS:
286            return super().call_operator(op, args, kwargs, meta)
287
288        if self.mode == DynamicMemoryPlanningMode.UPPER_BOUND:
289            self.decide_upper_bound_from_symbols(meta)
290            return super().call_operator(op, args, kwargs, meta)
291
292        ds_spec = calculate_dynamic_shape_spec(self.mode, op.target, args, kwargs)
293
294        out_tensor_spec = meta["spec"]
295
296        for ds_spec_item, tensor_spec_item in zip(
297            pytree.tree_flatten(ds_spec)[0], pytree.tree_flatten(out_tensor_spec)[0]
298        ):
299            tensor_spec_item.shape_dynamism = ds_spec_item.shape_dynamism
300            tensor_spec_item._upper_bound_shape = ds_spec_item.upper_bound_shape
301        return super().call_operator(op, args, kwargs, meta)
302