1# mypy: ignore-errors 2import functools 3import inspect 4from typing import Dict, List, TYPE_CHECKING 5 6import torch 7from torch.fx.experimental._backward_state import BackwardState 8 9from .. import compiled_autograd, variables 10from .._trace_wrapped_higher_order_op import trace_wrapped 11from ..exc import unimplemented 12from ..external_utils import call_module_hooks_from_backward_state 13from ..guards import GuardBuilder, install_guard 14from ..source import AttrSource 15from ..utils import istype 16from .base import VariableTracker 17from .constant import ConstantVariable 18 19 20if TYPE_CHECKING: 21 from torch._dynamo.symbolic_convert import InstructionTranslator 22 23 24class DistributedVariable(VariableTracker): 25 """ 26 The base distributed variable that encapsulates common methods 27 for the distributed objects (i.e. ProcessGroup, DeviceMesh, etc.). 28 Concrete distributed objects could inherit this class and add object 29 specific logic. 30 31 i.e. It provides the check on the distributed package existance 32 and hold the tracking value for the corresponding distributed object. 33 """ 34 35 def __init__(self, value, **kwargs) -> None: 36 super().__init__(**kwargs) 37 if not DistributedVariable.is_available(): 38 unimplemented("torch.distributed package is not available!") 39 self.value = value 40 41 def python_type(self): 42 return type(self.value) 43 44 @staticmethod 45 def is_available(): 46 # check if the distributed package is available or not 47 return torch.distributed.is_available() 48 49 50def is_from_local(value): 51 if not DistributedVariable.is_available(): 52 return False 53 from torch.distributed.tensor import DTensor 54 55 return inspect.isfunction(value) and value is DTensor.from_local 56 57 58def is_constant_pg_functions(value): 59 if not DistributedVariable.is_available(): 60 return False 61 62 from torch.distributed.distributed_c10d import ( 63 _get_group_size_by_name, 64 _get_group_tag, 65 _rank_not_in_group, 66 _resolve_group_name_by_ranks_and_tag, 67 get_process_group_ranks, 68 ) 69 70 constant_processgroup_functions = [ 71 _get_group_size_by_name, 72 _get_group_tag, 73 _rank_not_in_group, 74 get_process_group_ranks, 75 _resolve_group_name_by_ranks_and_tag, 76 ] 77 78 return inspect.isfunction(value) and value in constant_processgroup_functions 79 80 81class WorldMetaClassVariable(DistributedVariable): 82 """ 83 Tracks torch.distributed.GroupMember and torch.distributed.group, which are 84 instances of the metaclass _WorldMeta. 85 """ 86 87 @classmethod 88 def is_group_member_type(cls, value): 89 if not cls.is_available(): 90 return False 91 92 from torch.distributed.distributed_c10d import _WorldMeta 93 94 return type(value) is _WorldMeta 95 96 def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker: 97 if name == "WORLD": 98 source = AttrSource(base=self.source, member="WORLD") 99 install_guard(source.make_guard(GuardBuilder.ID_MATCH)) 100 return ProcessGroupVariable(self.value.WORLD) 101 return super().var_getattr(tx, name) 102 103 104class PlacementClassVariable(DistributedVariable): 105 @staticmethod 106 def is_placement_type(value): 107 # we can't rely on importing/accessing torch distributed, it is not always built. 108 if not DistributedVariable.is_available(): 109 return False 110 111 from torch.distributed.tensor.placement_types import Placement 112 113 return type(value) is type and issubclass(value, Placement) 114 115 def as_python_constant(self): 116 return self.value 117 118 def call_function( 119 self, 120 tx: "InstructionTranslator", 121 args: "List[VariableTracker]", 122 kwargs: "Dict[str, VariableTracker]", 123 ) -> "VariableTracker": 124 if ( 125 inspect.getattr_static(self.value, "__new__", None) in (object.__new__,) 126 and self.source 127 ): 128 # NOTE: we don't need to track mutations to the placement class as they 129 # suppose to be immutable. 130 new_obj = object.__new__(self.value) 131 var = PlacementVariable(new_obj) 132 if inspect.getattr_static(self.value, "__init__", None): 133 var.call_method(tx, "__init__", args, kwargs) 134 return var 135 136 return super().call_function(tx, args, kwargs) 137 138 139class PlacementVariable(DistributedVariable): 140 @staticmethod 141 def is_placement(value): 142 # we can't rely on importing/accessing torch distributed, it is not always built. 143 if not DistributedVariable.is_available(): 144 return False 145 146 from torch.distributed.tensor.placement_types import Placement 147 148 return isinstance(value, Placement) 149 150 def as_python_constant(self): 151 return self.value 152 153 def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker: 154 if name == "dim": 155 return ConstantVariable.create(self.value.dim) 156 return super().var_getattr(tx, name) 157 158 def call_method( 159 self, 160 tx, 161 name, 162 args: "List[VariableTracker]", 163 kwargs: "Dict[str, VariableTracker]", 164 ) -> "VariableTracker": 165 from . import ConstantVariable 166 167 # Placement types dynamo tracking only allows following methods 168 # and __setattr__ is for case like `Shard(dim)` and methods. 169 # Methods in the list must satisfy: 170 # 1. Input arguments are constants and do not need to be guarded on; 171 # 2. Output is constant with respect to their inputs 172 constant_fold_functions = [ 173 "__init__", 174 "__setattr__", 175 "is_shard", 176 "is_partial", 177 "is_replicate", 178 ] 179 180 if name in constant_fold_functions: 181 try: 182 value_type = type(self.value) 183 assert ( 184 inspect.getattr_static(value_type, "__getattr__", None) is None 185 ), "no custom getattr allowed!" 186 method = inspect.getattr_static(value_type, name) 187 except AttributeError: 188 method = None 189 if method is object.__init__: 190 return ConstantVariable.create(None) 191 192 args = [x.as_python_constant() for x in args] 193 kwargs = {k: v.as_python_constant() for k, v in kwargs.items()} 194 if name == "__setattr__": 195 method(self.value, *args, **kwargs) 196 return self 197 constant_val = method(self.value, *args, **kwargs) 198 return ConstantVariable.create(constant_val) 199 200 return super().call_method(tx, name, args, kwargs) 201 202 203class DeviceMeshVariable(DistributedVariable): 204 @staticmethod 205 def is_device_mesh(value): 206 # we can't rely on importing/accessing torch distributed, it is not always built. 207 if not DistributedVariable.is_available(): 208 return False 209 210 from torch.distributed.device_mesh import DeviceMesh 211 212 return istype(value, DeviceMesh) 213 214 def as_python_constant(self): 215 return self.value 216 217 def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker: 218 if name == "ndim": 219 return ConstantVariable.create(self.value.ndim) 220 if name == "device_type": 221 return ConstantVariable.create(self.value.device_type) 222 return super().var_getattr(tx, name) 223 224 def call_method( 225 self, 226 tx, 227 name, 228 args: "List[VariableTracker]", 229 kwargs: "Dict[str, VariableTracker]", 230 ) -> "VariableTracker": 231 if name == "size": 232 const_args = [x.as_python_constant() for x in args] 233 const_kwargs = {k: v.as_python_constant() for k, v in kwargs.items()} 234 return ConstantVariable.create(self.value.size(*const_args, **const_kwargs)) 235 if name == "get_coordinate": 236 return ConstantVariable.create(self.value.get_coordinate()) 237 if name == "get_group": 238 return ConstantVariable.create(self.value.get_group()) 239 if name == "_get_or_create_default_group": 240 return ProcessGroupVariable(self.value._get_or_create_default_group()) 241 return super().call_method(tx, name, args, kwargs) 242 243 244class ProcessGroupVariable(DistributedVariable): 245 """ 246 We don't want a ProcessGroup object to end up in our output graph. 247 248 But it's common for dynamo to intercept a PG that is then used to get info like 249 rank() or world_size(), as well as passed to utility functions in distributed_c10d 250 which desugar it into plain types like a ranklist and tag. 251 252 For convenience and proper guarding, we construct a variable type. 253 254 TODO: make it possible to use ProcessGroupVariable as input to simple functions 255 like _expand_group without dynamo complaining about making a proxy for it. 256 It is not a tensor-like type, and we don't want a proxy- but dynamo assumes 257 torch library functions are dealing with tensor-like types and would have proxies 258 for their args. 259 TODO: should we make this inherit VT instead of UDOV? Do we want any of the default behaviors 260 or just graph-break whenever one of our special cases is not hit? 261 """ 262 263 def as_python_constant(self): 264 return self.value 265 266 def call_method( 267 self, 268 tx, 269 name, 270 args: "List[VariableTracker]", 271 kwargs: "Dict[str, VariableTracker]", 272 ) -> "VariableTracker": 273 if name == "rank": 274 return variables.ConstantVariable.create(self.value.rank()) 275 if name == "size": 276 return variables.ConstantVariable.create(self.value.size()) 277 if name == "_get_backend_name": 278 return variables.ConstantVariable.create(self.value._get_backend_name()) 279 280 return super().call_method(tx, name, args, kwargs) 281 282 def var_getattr(self, tx: "InstructionTranslator", name): 283 if name == "group_name": 284 return variables.ConstantVariable.create(self.value.group_name) 285 if name in ["rank", "size"]: 286 return variables.LambdaVariable( 287 lambda *args, **kwargs: self.call_method(tx, name, args, kwargs) 288 ) 289 # TODO should this just raise unimplemented? 290 return super().var_getattr(tx, name) 291 292 @staticmethod 293 def is_process_group(value): 294 # we can't rely on importing/accessing torch distributed, it is not always built. 295 if not DistributedVariable.is_available(): 296 return False 297 from torch._C._distributed_c10d import ProcessGroup 298 from torch.testing._internal.distributed.fake_pg import FakeProcessGroup 299 300 return istype(value, (ProcessGroup, FakeProcessGroup)) 301 302 303class BackwardHookVariable(VariableTracker): 304 """ 305 Handles torch.utils.hooks.BackwardHook for module-level backward 306 hooks. 307 """ 308 309 @staticmethod 310 def create( 311 tx, 312 module: VariableTracker, 313 user_hooks: VariableTracker, 314 user_pre_hooks: VariableTracker, 315 ): 316 if not compiled_autograd.compiled_autograd_enabled: 317 unimplemented("module-level backwards hooks require compiled autograd") 318 319 def _in_graph_bw_hooks(bw_state: BackwardState): 320 """ 321 Rather than installing the user hooks in the graph (which 322 don't survive AotAutograd), we install hooks that will call 323 trace_wrapped in the backward pass that CompiledAutograd 324 can turn into actual hook calls. 325 """ 326 return torch.utils.hooks.BackwardHook( 327 None, 328 ( 329 functools.partial( 330 trace_wrapped, 331 fn=call_module_hooks_from_backward_state, 332 bw_state=bw_state, 333 hooks_name=user_hooks_name, 334 module_name=module_name, 335 ), 336 ), 337 ( 338 functools.partial( 339 trace_wrapped, 340 fn=call_module_hooks_from_backward_state, 341 bw_state=bw_state, 342 hooks_name=user_pre_hooks_name, 343 module_name=module_name, 344 ), 345 ), 346 ) 347 348 module_name, bw_state_proxy = tx.output.add_backward_state_hook(module, "mod") 349 user_pre_hooks_name, _ = tx.output.add_backward_state_hook(user_pre_hooks) 350 user_hooks_name, _ = tx.output.add_backward_state_hook(user_hooks) 351 proxy = tx.output.create_proxy( 352 "call_function", 353 _in_graph_bw_hooks, 354 (bw_state_proxy,), 355 {}, 356 ) 357 proxy.node.meta["example_value"] = torch.utils.hooks.BackwardHook(None, (), ()) 358 return BackwardHookVariable(proxy, module, user_hooks, user_pre_hooks) 359 360 def __init__( 361 self, 362 proxy: torch.fx.Proxy, 363 module: VariableTracker, 364 user_hooks: VariableTracker, 365 user_pre_hooks: VariableTracker, 366 **options, 367 ) -> None: 368 super().__init__(**options) 369 self.proxy = proxy 370 self.module = module 371 self.user_hooks = user_hooks 372 self.user_pre_hooks = user_pre_hooks 373 374 def as_proxy(self): 375 return self.proxy 376 377 def call_method( 378 self, 379 tx, 380 name, 381 args: List[VariableTracker], 382 kwargs: Dict[str, VariableTracker], 383 ) -> VariableTracker: 384 if name in ("setup_input_hook", "setup_output_hook"): 385 return self._setup_hook(tx, name, *args, **kwargs) 386 return super().call_method(tx, name, args, kwargs) 387 388 def _setup_hook(self, tx: "InstructionTranslator", hook_method_name, args): 389 from .builder import wrap_fx_proxy 390 391 return wrap_fx_proxy( 392 tx, 393 tx.output.create_proxy( 394 "call_method", 395 hook_method_name, 396 (self.as_proxy(), args.as_proxy()), 397 {}, 398 ), 399 ) 400