xref: /aosp_15_r20/external/executorch/exir/control_flow.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# pyre-strict
8
9"""
10This module contains tools for rewriting a dynamic PyTorch program such
11that the dynamic part (e.g. control flow) can be properly captured by
12DispatchTracer.
13The core idea is annotating all branches in the graph with unique keys,
14and using a dictionary of supplemental inputs as arguments to these
15local branches so that every path gets a canonical input during tracing.
16
17For example, consider the following usage of Python if statement:
18
19.. code-block:: python
20
21    if pred:
22        ...
23        ret = a
24    else:
25        ...
26        ret = b
27
28To rewrite the code to be tracable, users may use tracing_key decorator
29and cond operator:
30
31.. code-block:: python
32
33    @control_flow.tracing_context(inputs)
34    def branch_true(args):
35        ...
36        return a
37
38    @control_flow.tracing_context(inputs)
39    def branch_false(args):
40        ...
41        return b
42
43    ret = control_flow.cond(pred, branch_true, branch_false, args)
44
45and we can use the usual exir.capture() function.
46
47.. code-block:: python
48
49    exir.capture(module, args)
50
51"""
52
53from typing import Callable, List, Optional, Tuple, Union
54
55import torch
56import torch.utils._pytree as pytree
57from executorch.exir.error import ExportError, ExportErrorType, internal_assert
58from executorch.exir.tracer import (
59    DispatchTracer,
60    flattened_dispatch_trace,
61    PythonTensor,
62    tree_return,
63    unwrap_functional,
64    unwrap_proxy,
65    using_tracer,
66    Value,
67)
68from executorch.exir.wrap import update_with_proxy
69
70
71def shape(x: torch.Tensor) -> Union[torch._C.Size, torch.Tensor]:
72    """
73    A helper function for capturing the shape as a tensor from a tensor
74    value.
75    """
76    tracer = DispatchTracer.get()
77    if tracer is None:
78        return x.shape
79    x = unwrap_functional(x)
80    if not isinstance(x, PythonTensor):
81        raise ExportError(
82            ExportErrorType.INVALID_INPUT_TYPE,
83            f"exir custom shape function only takes EXIR dispatch tensor, but got: {type(x)}",
84        )
85    # TODO _shape_as_tensor should work with functional tensor but currently not.
86    # TODO torch.tensor() should succeed under functionalization but currently not.
87    #      see: https://github.com/pytorch/pytorch/pull/76319
88    tmp = torch.empty(len(x.shape), dtype=torch.int64)
89    for i, s in enumerate(x.shape):
90        tmp[i] = s
91    proxy = torch.ops.aten._shape_as_tensor.default(x.proxy)
92    return PythonTensor(unwrap_functional(tmp), proxy)
93
94
95def _make_submodule(
96    fn: Callable[..., Union[torch.Tensor, Tuple[torch.Tensor]]],
97    example_returns: Optional[List[torch.Tensor]] = None,
98    single_return: bool = False,
99) -> torch.fx.GraphModule:
100    if not hasattr(fn, "__tracing_inputs__"):
101        raise ExportError(
102            ExportErrorType.MISSING_PROPERTY,
103            f"Expect function '{fn.__name__}' to be decorated with tracing_context.",
104        )
105    # pyre-ignore
106    args = fn.__tracing_inputs__
107    # TODO(yidi): we don't want to enable here because we are not gonna use this code path in the future anyways
108    gm, _ = flattened_dispatch_trace(fn, args, set(), enable_functionalization=False)
109    output = next(iter(reversed(gm.graph.nodes)))
110    if example_returns:
111        internal_assert(
112            len(example_returns) == len(output.args[0]),
113            f"Eager mode of this {gm} returns {len(example_returns)} elements, but this graph returns {len(output.args[0])} elements",
114        )
115
116    if single_return:
117        # Force number of returned value to be 1.
118        internal_assert(
119            len(output.args[0]) == 1,
120            f"Graph {gm} should return just one element, but got {len(output.args[0])}",
121        )
122        output.args = tuple(output.args[0])
123        gm.recompile()
124    # pyre-fixme[16]: `GraphModule` has no attribute `__tracing_inputs__`.
125    gm.__tracing_inputs__ = args
126    return gm
127
128
129def while_loop(
130    cond_fn: Callable[..., torch.Tensor],
131    body_fn: Callable[..., Tuple[torch.Tensor]],
132    init_val: pytree.PyTree,
133) -> Union[Tuple[torch.Tensor], Value]:
134    """
135    A higher order function returning the result based on executing body_fn
136    until cond_fn returns False.
137    """
138    flattened_inputs, _ = pytree.tree_flatten(init_val)
139    if not all(isinstance(i, torch.Tensor) for i in flattened_inputs):
140        raise ExportError(
141            ExportErrorType.INVALID_INPUT_TYPE,
142            f"control_flow.while_loop() expects all inputs values to be tensors, actual inputs: {init_val}",
143        )
144
145    with using_tracer(None):
146        val = init_val
147        while cond_fn(*val):
148            val = body_fn(*val)
149
150    flattened_outputs, _ = pytree.tree_flatten(val)
151    if not all(isinstance(o, torch.Tensor) for o in flattened_outputs):
152        raise ExportError(
153            ExportErrorType.INVALID_OUTPUT_TYPE,
154            f"control_flow.while_loop() expects all returned values to be tensors, actual outputs: {val}",
155        )
156
157    tracer = DispatchTracer.get()
158
159    if tracer is None:
160        return val
161
162    gm_cond = _make_submodule(cond_fn, single_return=True)
163    gm_body = _make_submodule(body_fn)
164
165    proxies = tuple([unwrap_proxy(v) for v in flattened_inputs])
166
167    proxy = tracer.create_proxy(
168        "call_function",
169        while_loop,
170        (gm_cond, gm_body, proxies),
171        {},
172    )
173
174    return tree_return(val, proxy, update_with_proxy)
175
176
177def tracing_context(
178    inputs: Tuple[torch.Tensor, ...],
179) -> Callable[..., Callable[..., Union[torch.Tensor, Tuple[torch.Tensor]]]]:
180    """
181    A decorator function to annotate code path that we conditionally
182    run during tracing. We need to annotate these paths for now because
183    during exir.capture(), the tracer does not know what's the proper
184    local inputs to be passed to the untaken path.
185    """
186
187    def decorator(
188        f: Callable[..., Tuple[torch.Tensor]]
189    ) -> Callable[..., Union[torch.Tensor, Tuple[torch.Tensor]]]:
190        def wrapper(
191            *args: torch.Tensor, **kwargs: Tuple[torch.Tensor]
192        ) -> Tuple[torch.Tensor]:
193            if kwargs:
194                raise ExportError(
195                    ExportErrorType.NOT_SUPPORTED,
196                    "kwargs are not supported for @tracing_context decorated functions.",
197                )
198
199            return f(*args)
200
201        wrapper.__tracing_inputs__ = inputs  # pyre-ignore
202        return wrapper
203
204    return decorator
205