1import copy 2from typing import Dict, Optional, Tuple, List 3 4import torch 5from torch._export.pass_base import _ExportPassBaseDeprecatedDoNotUse, PassResult, Argument 6from torch._export.pass_infra.node_metadata import NodeMetadata 7from torch._export.pass_infra.proxy_value import ProxyValue 8from torch._ops import OpOverload 9 10aten = torch.ops.aten 11 12_NON_FUNCTIONAL_TO_FUNCTIONAL_SIDE_EFFECTFUL_FUNCS: Dict[OpOverload, OpOverload] = { 13 aten.sym_constrain_range.default: aten._functional_sym_constrain_range, 14 aten._assert_async.msg: aten._functional_assert_async.msg, 15} 16 17 18class _FunctionalizeSideEffectfulOpsPass(_ExportPassBaseDeprecatedDoNotUse): 19 """ 20 Functionalize ops with side effect in graph module by replacing the op with 21 functional version of it. A new dependency token (`dep_token`) will be 22 created and propagated through functional ops to output. 23 For example: 24 ``` 25 def f(x): 26 sym_constrain_range(x.shape[0], min=1, max=3) 27 return x.add(3) 28 ``` 29 Will be transformed to: 30 ``` 31 def f(x): 32 dep_token0 = _make_dep_token() 33 dep_token1 = _functional_sym_constrain_range( 34 x.shape[0], min=1, max=3, dep_token=dep_token0 35 ) 36 37 return x.add(3), dep_token1 38 ``` 39 """ 40 41 def __init__(self) -> None: 42 super().__init__() 43 self._dep_token: Optional[ProxyValue] = None 44 self._next_dep_token_index: Optional[int] = None 45 46 def call(self, graph_module: torch.fx.GraphModule) -> PassResult: 47 # Early return if no non-functional assertions. 48 if not any( 49 n.target in _NON_FUNCTIONAL_TO_FUNCTIONAL_SIDE_EFFECTFUL_FUNCS 50 for n in graph_module.graph.nodes 51 ): 52 return PassResult(graph_module=graph_module, modified=False) 53 54 gm = copy.deepcopy(graph_module) 55 self._dep_token = None 56 self._next_dep_token_index = None 57 return super().call(gm) 58 59 def call_operator( 60 self, 61 op: OpOverload, 62 args: Tuple[Argument, ...], 63 kwargs: Dict[str, Argument], 64 meta: NodeMetadata, 65 ) -> ProxyValue: 66 if op not in _NON_FUNCTIONAL_TO_FUNCTIONAL_SIDE_EFFECTFUL_FUNCS: 67 return super().call_operator(op, args, kwargs, meta) 68 69 if self._dep_token is None: 70 self._dep_token = super().call_operator( 71 aten._make_dep_token, 72 args=(), 73 kwargs={}, 74 meta=self._create_dummy_node_metadata(), 75 ) 76 self._dep_token.node.name = "dep_token0" 77 self._next_dep_token_index = 1 78 79 self._dep_token = super().call_operator( 80 _NON_FUNCTIONAL_TO_FUNCTIONAL_SIDE_EFFECTFUL_FUNCS[op], 81 args=args, 82 kwargs={**kwargs, "dep_token": self._dep_token}, 83 meta=meta, 84 ) 85 assert self._next_dep_token_index is not None 86 self._dep_token.node.name = f"dep_token{self._next_dep_token_index}" 87 self._next_dep_token_index += 1 88 89 return self._dep_token 90 91 def output(self, results: List[Argument], meta: NodeMetadata) -> ProxyValue: 92 assert self._dep_token is not None 93 94 return super().output(results=(*results, self._dep_token), meta=meta) # type: ignore[arg-type] 95