1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# pyre-strict 8 9""" 10This module contains tools for rewriting a dynamic PyTorch program such 11that the dynamic part (e.g. control flow) can be properly captured by 12DispatchTracer. 13The core idea is annotating all branches in the graph with unique keys, 14and using a dictionary of supplemental inputs as arguments to these 15local branches so that every path gets a canonical input during tracing. 16 17For example, consider the following usage of Python if statement: 18 19.. code-block:: python 20 21 if pred: 22 ... 23 ret = a 24 else: 25 ... 26 ret = b 27 28To rewrite the code to be tracable, users may use tracing_key decorator 29and cond operator: 30 31.. code-block:: python 32 33 @control_flow.tracing_context(inputs) 34 def branch_true(args): 35 ... 36 return a 37 38 @control_flow.tracing_context(inputs) 39 def branch_false(args): 40 ... 41 return b 42 43 ret = control_flow.cond(pred, branch_true, branch_false, args) 44 45and we can use the usual exir.capture() function. 46 47.. code-block:: python 48 49 exir.capture(module, args) 50 51""" 52 53from typing import Callable, List, Optional, Tuple, Union 54 55import torch 56import torch.utils._pytree as pytree 57from executorch.exir.error import ExportError, ExportErrorType, internal_assert 58from executorch.exir.tracer import ( 59 DispatchTracer, 60 flattened_dispatch_trace, 61 PythonTensor, 62 tree_return, 63 unwrap_functional, 64 unwrap_proxy, 65 using_tracer, 66 Value, 67) 68from executorch.exir.wrap import update_with_proxy 69 70 71def shape(x: torch.Tensor) -> Union[torch._C.Size, torch.Tensor]: 72 """ 73 A helper function for capturing the shape as a tensor from a tensor 74 value. 75 """ 76 tracer = DispatchTracer.get() 77 if tracer is None: 78 return x.shape 79 x = unwrap_functional(x) 80 if not isinstance(x, PythonTensor): 81 raise ExportError( 82 ExportErrorType.INVALID_INPUT_TYPE, 83 f"exir custom shape function only takes EXIR dispatch tensor, but got: {type(x)}", 84 ) 85 # TODO _shape_as_tensor should work with functional tensor but currently not. 86 # TODO torch.tensor() should succeed under functionalization but currently not. 87 # see: https://github.com/pytorch/pytorch/pull/76319 88 tmp = torch.empty(len(x.shape), dtype=torch.int64) 89 for i, s in enumerate(x.shape): 90 tmp[i] = s 91 proxy = torch.ops.aten._shape_as_tensor.default(x.proxy) 92 return PythonTensor(unwrap_functional(tmp), proxy) 93 94 95def _make_submodule( 96 fn: Callable[..., Union[torch.Tensor, Tuple[torch.Tensor]]], 97 example_returns: Optional[List[torch.Tensor]] = None, 98 single_return: bool = False, 99) -> torch.fx.GraphModule: 100 if not hasattr(fn, "__tracing_inputs__"): 101 raise ExportError( 102 ExportErrorType.MISSING_PROPERTY, 103 f"Expect function '{fn.__name__}' to be decorated with tracing_context.", 104 ) 105 # pyre-ignore 106 args = fn.__tracing_inputs__ 107 # TODO(yidi): we don't want to enable here because we are not gonna use this code path in the future anyways 108 gm, _ = flattened_dispatch_trace(fn, args, set(), enable_functionalization=False) 109 output = next(iter(reversed(gm.graph.nodes))) 110 if example_returns: 111 internal_assert( 112 len(example_returns) == len(output.args[0]), 113 f"Eager mode of this {gm} returns {len(example_returns)} elements, but this graph returns {len(output.args[0])} elements", 114 ) 115 116 if single_return: 117 # Force number of returned value to be 1. 118 internal_assert( 119 len(output.args[0]) == 1, 120 f"Graph {gm} should return just one element, but got {len(output.args[0])}", 121 ) 122 output.args = tuple(output.args[0]) 123 gm.recompile() 124 # pyre-fixme[16]: `GraphModule` has no attribute `__tracing_inputs__`. 125 gm.__tracing_inputs__ = args 126 return gm 127 128 129def while_loop( 130 cond_fn: Callable[..., torch.Tensor], 131 body_fn: Callable[..., Tuple[torch.Tensor]], 132 init_val: pytree.PyTree, 133) -> Union[Tuple[torch.Tensor], Value]: 134 """ 135 A higher order function returning the result based on executing body_fn 136 until cond_fn returns False. 137 """ 138 flattened_inputs, _ = pytree.tree_flatten(init_val) 139 if not all(isinstance(i, torch.Tensor) for i in flattened_inputs): 140 raise ExportError( 141 ExportErrorType.INVALID_INPUT_TYPE, 142 f"control_flow.while_loop() expects all inputs values to be tensors, actual inputs: {init_val}", 143 ) 144 145 with using_tracer(None): 146 val = init_val 147 while cond_fn(*val): 148 val = body_fn(*val) 149 150 flattened_outputs, _ = pytree.tree_flatten(val) 151 if not all(isinstance(o, torch.Tensor) for o in flattened_outputs): 152 raise ExportError( 153 ExportErrorType.INVALID_OUTPUT_TYPE, 154 f"control_flow.while_loop() expects all returned values to be tensors, actual outputs: {val}", 155 ) 156 157 tracer = DispatchTracer.get() 158 159 if tracer is None: 160 return val 161 162 gm_cond = _make_submodule(cond_fn, single_return=True) 163 gm_body = _make_submodule(body_fn) 164 165 proxies = tuple([unwrap_proxy(v) for v in flattened_inputs]) 166 167 proxy = tracer.create_proxy( 168 "call_function", 169 while_loop, 170 (gm_cond, gm_body, proxies), 171 {}, 172 ) 173 174 return tree_return(val, proxy, update_with_proxy) 175 176 177def tracing_context( 178 inputs: Tuple[torch.Tensor, ...], 179) -> Callable[..., Callable[..., Union[torch.Tensor, Tuple[torch.Tensor]]]]: 180 """ 181 A decorator function to annotate code path that we conditionally 182 run during tracing. We need to annotate these paths for now because 183 during exir.capture(), the tracer does not know what's the proper 184 local inputs to be passed to the untaken path. 185 """ 186 187 def decorator( 188 f: Callable[..., Tuple[torch.Tensor]] 189 ) -> Callable[..., Union[torch.Tensor, Tuple[torch.Tensor]]]: 190 def wrapper( 191 *args: torch.Tensor, **kwargs: Tuple[torch.Tensor] 192 ) -> Tuple[torch.Tensor]: 193 if kwargs: 194 raise ExportError( 195 ExportErrorType.NOT_SUPPORTED, 196 "kwargs are not supported for @tracing_context decorated functions.", 197 ) 198 199 return f(*args) 200 201 wrapper.__tracing_inputs__ = inputs # pyre-ignore 202 return wrapper 203 204 return decorator 205