1# mypy: allow-untyped-defs 2# ruff: noqa: TCH004 3from dataclasses import dataclass 4from typing import TYPE_CHECKING 5 6import torch 7from torch.utils._python_dispatch import is_traceable_wrapper_subclass 8from . import trace_rules, variables 9from .comptime import comptime 10from .eval_frame import DisableContext, innermost_fn, RunOnlyContext 11from .exc import IncorrectUsage 12from .external_utils import is_compiling 13 14if TYPE_CHECKING: 15 from torch._C._dynamo.eval_frame import ( # noqa: F401 16 reset_code, 17 set_eval_frame, 18 set_guard_error_hook, 19 skip_code, 20 unsupported, 21 ) 22else: 23 for name in dir(torch._C._dynamo.eval_frame): 24 if name.startswith("__"): 25 continue 26 globals()[name] = getattr(torch._C._dynamo.eval_frame, name) 27 28 29def run(fn=None): 30 """Don't do any dynamic compiles, just use prior optimizations""" 31 if fn is not None: 32 fn = innermost_fn(fn) 33 assert callable(fn) 34 return RunOnlyContext()(fn) 35 return RunOnlyContext() 36 37 38def disable(fn=None, recursive=True): 39 """ 40 Decorator and context manager to disable TorchDynamo 41 42 If recursive=True, Dynamo is completely skipped on the decorated function 43 frame as well as the recursively invoked functions. 44 45 If recursive=False, Dynamo skips frames associated with the function code, 46 but still process recursively invoked frames. 47 """ 48 if recursive: 49 if fn is not None: 50 fn = innermost_fn(fn) 51 assert callable(fn) 52 return DisableContext()(fn) 53 return DisableContext() 54 else: 55 return skip(fn) 56 57 58def skip(fn=None): 59 """ 60 Skip frames associated with the function code, but still process recursively 61 invoked frames 62 """ 63 if fn is None: 64 return skip 65 fn = innermost_fn(fn) 66 assert callable(fn) 67 skip_code(fn.__code__) 68 fn._torchdynamo_disable = True 69 return fn 70 71 72def assume_constant_result(fn): 73 fn._dynamo_marked_constant = True 74 return fn 75 76 77def allow_in_graph(fn): 78 """ 79 Tells the compiler frontend (Dynamo) to skip symbolic introspection of the function 80 and instead directly write it to the graph when encountered. 81 82 See :func:`torch.compiler.allow_in_graph`'s docstring for the full documentation 83 84 WARNING: this API can be a footgun, please read the documentation carefully. 85 """ 86 if isinstance(fn, (list, tuple)): 87 return [allow_in_graph(x) for x in fn] 88 assert callable(fn), "allow_in_graph expects a callable" 89 if trace_rules.lookup_callable(fn) != variables.TorchInGraphFunctionVariable: 90 trace_rules._disallowed_callable_ids.remove(id(fn)) 91 trace_rules._allowed_callable_ids.add(id(fn)) 92 return fn 93 94 95def _disallow_in_graph_helper(throw_if_not_allowed): 96 def inner(fn): 97 if isinstance(fn, (list, tuple)): 98 return [disallow_in_graph(x) for x in fn] 99 assert callable(fn), "disallow_in_graph expects a callable" 100 if ( 101 throw_if_not_allowed 102 and trace_rules.lookup_callable(fn) 103 != variables.TorchInGraphFunctionVariable 104 and trace_rules.lookup(fn) != variables.TorchInGraphFunctionVariable 105 ): 106 raise IncorrectUsage( 107 "disallow_in_graph is expected to be used on an already allowed callable (like torch.* ops). " 108 "Allowed callables means callables that TorchDynamo puts as-is in the extracted graph." 109 ) 110 trace_rules._allowed_callable_ids.remove(id(fn)) 111 trace_rules._disallowed_callable_ids.add(id(fn)) 112 return fn 113 114 return inner 115 116 117def disallow_in_graph(fn): 118 """ 119 Customize which functions TorchDynamo will exclude in the generated 120 graph and force a graph break on. 121 :: 122 123 torch._dynamo.disallow_in_graph(torch.sub) 124 125 @torch._dynamo.optimize(...) 126 def fn(a): 127 x = torch.add(x, 1) 128 x = torch.sub(x, 1) 129 x = torch.add(x, 1) 130 return x 131 132 fn(...) 133 134 Will break the graph on `torch.sub`, and give two graphs each with a 135 single `torch.add()` op. 136 """ 137 return _disallow_in_graph_helper(throw_if_not_allowed=True)(fn) 138 139 140@_disallow_in_graph_helper(throw_if_not_allowed=False) 141def graph_break(): 142 """Force a graph break""" 143 pass 144 145 146def forbid_in_graph(fn): 147 """ 148 Customize which functions TorchDynamo will assert are not present while tracing. 149 150 If you want a graph break on this function instead, use disallow_in_graph. 151 TODO(voz): We now have allow_in_graph, disallow_in_graph, forbid_in_graph - some more robust 152 documentation would not be amiss. 153 """ 154 if isinstance(fn, (list, tuple)): 155 return [forbid_in_graph(x) for x in fn] 156 assert callable(fn), "forbid_in_graph applies only to callables" 157 fn._dynamo_forbidden = True 158 return fn 159 160 161# Helper function to flatten a tensor subclass and apply a function to 162# all inner tensors that match the outer dim. Used to reduce duplication 163# across the various marking APIs. 164def _apply_func_to_inner_tensors_of_same_dim(func, t, *args, **kwargs): 165 assert is_traceable_wrapper_subclass(t) 166 167 attrs, ctx = t.__tensor_flatten__() 168 for attr in attrs: 169 inner = getattr(t, attr) 170 if inner.dim() == t.dim(): 171 func(inner, *args, **kwargs) 172 173 174@dataclass(frozen=True) 175class _DimRange: 176 """ 177 This represents an dimension of a tensor and the corresponding 178 min and max values it can take. Don't create this 179 class directly; instead, use :func:`mark_dynamic`. 180 """ 181 182 dim: int 183 min: int 184 max: int 185 186 187@forbid_in_graph 188def mark_dynamic(t, index, *, min=None, max=None): 189 """ 190 Mark a tensor as having a dynamic dim and set corresponding min and max range for the dim. 191 192 [Note - on the state of mark_dynamic] 193 194 The behavior of having a dynamic dimension on a tensor is governed by a few factors: 195 196 1) torch._dynamo.config dynamic_shapes True or False. 197 a) dynamic_shapes=True - dynamic_shapes must be True for mark_dynamic to work. 198 a) dynamic_shapes=False - This config will raise an exception when used in conjunction with 199 mark_dynamic. We will eventually support this. 200 201 2) If the dimension is fully constrained - as in, it does not allow more than a single value 202 in both eager (torch.compile, torch._dynamo.optimize) mode and export mode (torch._dynamo.export), 203 we will raise an error 204 205 3) If the dimension is partially constrained - allowing at least 2 values but not the full unbounded 206 range of shapes, in eager we will pass it through, but export will raise an error. 207 208 4) Attempts to trace this function will explicitly raise. As such, all calls to mark_dynamic must be made 209 before torch.compile. 210 211 """ 212 if is_traceable_wrapper_subclass(t): 213 # default behavior: mirror mark_dynamic() on all inner tensors with same dim as t 214 # TODO: Make this configurable via a supported public API 215 _apply_func_to_inner_tensors_of_same_dim( 216 mark_dynamic, t, index, min=min, max=max 217 ) 218 219 if isinstance(index, int): 220 if not hasattr(t, "_dynamo_dynamic_indices"): 221 t._dynamo_dynamic_indices = set() 222 t._dynamo_dynamic_range = set() 223 # TODO(voz): Should we bounds check? 224 t._dynamo_dynamic_indices.add(index) 225 t._dynamo_dynamic_range.add(_DimRange(index, min, max)) 226 return 227 228 assert isinstance(index, (list, tuple)) 229 for i in index: 230 mark_dynamic(t, i, min=min, max=max) 231 232 233@forbid_in_graph 234def maybe_mark_dynamic(t, index): 235 """ 236 Mark a tensor as having a dynamic dim, but don't enforce it (i.e., if this 237 dimension ends up getting specialized, don't error). 238 """ 239 if is_traceable_wrapper_subclass(t): 240 # default behavior: mirror maybe_mark_dynamic() on all inner tensors with same dim as t 241 # TODO: Make this configurable via a supported public API 242 _apply_func_to_inner_tensors_of_same_dim(maybe_mark_dynamic, t, index) 243 244 if isinstance(index, int): 245 if not hasattr(t, "_dynamo_weak_dynamic_indices"): 246 t._dynamo_weak_dynamic_indices = set() 247 # TODO(voz): Should we bounds check? 248 t._dynamo_weak_dynamic_indices.add(index) 249 return 250 251 assert isinstance(index, (list, tuple)) 252 for i in index: 253 maybe_mark_dynamic(t, i) 254 255 256def mark_static(t, index=None): 257 """ 258 Mark a tensor as having a static dim. 259 260 This will prevent us from attempting to compile it dynamically 261 when dynamic=True; this can improve trace-time performance. 262 263 This has lower precedence than mark_dynamic. 264 265 Unlike mark_dynamic, this can be done inside a graph, in which case it 266 induces specialization on the tensor. 267 """ 268 if is_compiling(): 269 if index is None: 270 for s in t.size(): 271 comptime.force_static(s) 272 else: 273 comptime.force_static(t.size(index)) 274 return 275 276 if is_traceable_wrapper_subclass(t): 277 # default behavior: mirror mark_static() on all inner tensors with same dim as t 278 # TODO: Make this configurable via a supported public API 279 _apply_func_to_inner_tensors_of_same_dim(mark_static, t, index) 280 281 if isinstance(index, int): 282 if not hasattr(t, "_dynamo_static_indices"): 283 t._dynamo_static_indices = set() 284 # TODO(voz): Should we bounds check? 285 t._dynamo_static_indices.add(index) 286 elif index is None: 287 for i in range(t.dim()): 288 mark_static(t, i) 289 else: 290 assert isinstance(index, (list, tuple)) 291 for i in index: 292 mark_static(t, i) 293 294 295@forbid_in_graph 296def mark_static_address(t, guard=True): 297 """ 298 Marks an input tensor whose data_ptr will not change across multiple calls 299 to a dynamo-compiled function. This indicates to cudagraphs that an extra allocation 300 is not needed for this input. The data_ptr will be guarded if guard=True. Note: 301 Tensors marked in this way will be kept alive until `torch._dynamo.reset()` is called. 302 """ 303 if not isinstance(t, torch.Tensor): 304 raise TypeError(f"mark_static_address expects a tensor but recieved {type(t)}") 305 306 if guard: 307 t._dynamo_static_input_type = "guarded" # type: ignore[attr-defined] 308 else: 309 t._dynamo_static_input_type = "unguarded" # type: ignore[attr-defined] 310 311 312# Note: this carefully avoids eagerly import einops. 313# TODO: we should delete this whole _allow_in_graph_einops logic by approximately 2024 Q2 314def _allow_in_graph_einops(): 315 import einops 316 317 try: 318 # requires einops > 0.6.1, torch >= 2.0 319 from einops._torch_specific import ( # type: ignore[attr-defined] # noqa: F401 320 _ops_were_registered_in_torchdynamo, 321 ) 322 323 # einops > 0.6.1 will call the op registration logic as it is imported. 324 pass 325 except ImportError: 326 # einops <= 0.6.1 327 allow_in_graph(einops.rearrange) 328 allow_in_graph(einops.reduce) 329 if hasattr(einops, "repeat"): 330 allow_in_graph(einops.repeat) # available since einops 0.2.0 331 if hasattr(einops, "einsum"): 332 allow_in_graph(einops.einsum) # available since einops 0.5.0 333 if hasattr(einops, "pack"): 334 allow_in_graph(einops.pack) # available since einops 0.6.0 335 if hasattr(einops, "unpack"): 336 allow_in_graph(einops.unpack) # available since einops 0.6.0 337 338 339trace_rules.add_module_init_func("einops", _allow_in_graph_einops) 340