xref: /aosp_15_r20/external/pytorch/torch/_export/passes/functionalize_side_effectful_ops_pass.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import copy
2from typing import Dict, Optional, Tuple, List
3
4import torch
5from torch._export.pass_base import _ExportPassBaseDeprecatedDoNotUse, PassResult, Argument
6from torch._export.pass_infra.node_metadata import NodeMetadata
7from torch._export.pass_infra.proxy_value import ProxyValue
8from torch._ops import OpOverload
9
10aten = torch.ops.aten
11
12_NON_FUNCTIONAL_TO_FUNCTIONAL_SIDE_EFFECTFUL_FUNCS: Dict[OpOverload, OpOverload] = {
13    aten.sym_constrain_range.default: aten._functional_sym_constrain_range,
14    aten._assert_async.msg: aten._functional_assert_async.msg,
15}
16
17
18class _FunctionalizeSideEffectfulOpsPass(_ExportPassBaseDeprecatedDoNotUse):
19    """
20    Functionalize ops with side effect in graph module by replacing the op with
21    functional version of it. A new dependency token (`dep_token`) will be
22    created and propagated through functional ops to output.
23    For example:
24    ```
25    def f(x):
26        sym_constrain_range(x.shape[0], min=1, max=3)
27        return x.add(3)
28    ```
29    Will be transformed to:
30    ```
31    def f(x):
32        dep_token0 = _make_dep_token()
33        dep_token1 = _functional_sym_constrain_range(
34            x.shape[0], min=1, max=3, dep_token=dep_token0
35        )
36
37        return x.add(3), dep_token1
38    ```
39    """
40
41    def __init__(self) -> None:
42        super().__init__()
43        self._dep_token: Optional[ProxyValue] = None
44        self._next_dep_token_index: Optional[int] = None
45
46    def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
47        # Early return if no non-functional assertions.
48        if not any(
49            n.target in _NON_FUNCTIONAL_TO_FUNCTIONAL_SIDE_EFFECTFUL_FUNCS
50            for n in graph_module.graph.nodes
51        ):
52            return PassResult(graph_module=graph_module, modified=False)
53
54        gm = copy.deepcopy(graph_module)
55        self._dep_token = None
56        self._next_dep_token_index = None
57        return super().call(gm)
58
59    def call_operator(
60        self,
61        op: OpOverload,
62        args: Tuple[Argument, ...],
63        kwargs: Dict[str, Argument],
64        meta: NodeMetadata,
65    ) -> ProxyValue:
66        if op not in _NON_FUNCTIONAL_TO_FUNCTIONAL_SIDE_EFFECTFUL_FUNCS:
67            return super().call_operator(op, args, kwargs, meta)
68
69        if self._dep_token is None:
70            self._dep_token = super().call_operator(
71                aten._make_dep_token,
72                args=(),
73                kwargs={},
74                meta=self._create_dummy_node_metadata(),
75            )
76            self._dep_token.node.name = "dep_token0"
77            self._next_dep_token_index = 1
78
79        self._dep_token = super().call_operator(
80            _NON_FUNCTIONAL_TO_FUNCTIONAL_SIDE_EFFECTFUL_FUNCS[op],
81            args=args,
82            kwargs={**kwargs, "dep_token": self._dep_token},
83            meta=meta,
84        )
85        assert self._next_dep_token_index is not None
86        self._dep_token.node.name = f"dep_token{self._next_dep_token_index}"
87        self._next_dep_token_index += 1
88
89        return self._dep_token
90
91    def output(self, results: List[Argument], meta: NodeMetadata) -> ProxyValue:
92        assert self._dep_token is not None
93
94        return super().output(results=(*results, self._dep_token), meta=meta)  # type: ignore[arg-type]
95