xref: /aosp_15_r20/external/pytorch/torch/_dynamo/variables/distributed.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: ignore-errors
2import functools
3import inspect
4from typing import Dict, List, TYPE_CHECKING
5
6import torch
7from torch.fx.experimental._backward_state import BackwardState
8
9from .. import compiled_autograd, variables
10from .._trace_wrapped_higher_order_op import trace_wrapped
11from ..exc import unimplemented
12from ..external_utils import call_module_hooks_from_backward_state
13from ..guards import GuardBuilder, install_guard
14from ..source import AttrSource
15from ..utils import istype
16from .base import VariableTracker
17from .constant import ConstantVariable
18
19
20if TYPE_CHECKING:
21    from torch._dynamo.symbolic_convert import InstructionTranslator
22
23
24class DistributedVariable(VariableTracker):
25    """
26    The base distributed variable that encapsulates common methods
27    for the distributed objects (i.e. ProcessGroup, DeviceMesh, etc.).
28    Concrete distributed objects could inherit this class and add object
29    specific logic.
30
31    i.e. It provides the check on the distributed package existance
32    and hold the tracking value for the corresponding distributed object.
33    """
34
35    def __init__(self, value, **kwargs) -> None:
36        super().__init__(**kwargs)
37        if not DistributedVariable.is_available():
38            unimplemented("torch.distributed package is not available!")
39        self.value = value
40
41    def python_type(self):
42        return type(self.value)
43
44    @staticmethod
45    def is_available():
46        # check if the distributed package is available or not
47        return torch.distributed.is_available()
48
49
50def is_from_local(value):
51    if not DistributedVariable.is_available():
52        return False
53    from torch.distributed.tensor import DTensor
54
55    return inspect.isfunction(value) and value is DTensor.from_local
56
57
58def is_constant_pg_functions(value):
59    if not DistributedVariable.is_available():
60        return False
61
62    from torch.distributed.distributed_c10d import (
63        _get_group_size_by_name,
64        _get_group_tag,
65        _rank_not_in_group,
66        _resolve_group_name_by_ranks_and_tag,
67        get_process_group_ranks,
68    )
69
70    constant_processgroup_functions = [
71        _get_group_size_by_name,
72        _get_group_tag,
73        _rank_not_in_group,
74        get_process_group_ranks,
75        _resolve_group_name_by_ranks_and_tag,
76    ]
77
78    return inspect.isfunction(value) and value in constant_processgroup_functions
79
80
81class WorldMetaClassVariable(DistributedVariable):
82    """
83    Tracks torch.distributed.GroupMember and torch.distributed.group, which are
84    instances of the metaclass _WorldMeta.
85    """
86
87    @classmethod
88    def is_group_member_type(cls, value):
89        if not cls.is_available():
90            return False
91
92        from torch.distributed.distributed_c10d import _WorldMeta
93
94        return type(value) is _WorldMeta
95
96    def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker:
97        if name == "WORLD":
98            source = AttrSource(base=self.source, member="WORLD")
99            install_guard(source.make_guard(GuardBuilder.ID_MATCH))
100            return ProcessGroupVariable(self.value.WORLD)
101        return super().var_getattr(tx, name)
102
103
104class PlacementClassVariable(DistributedVariable):
105    @staticmethod
106    def is_placement_type(value):
107        # we can't rely on importing/accessing torch distributed, it is not always built.
108        if not DistributedVariable.is_available():
109            return False
110
111        from torch.distributed.tensor.placement_types import Placement
112
113        return type(value) is type and issubclass(value, Placement)
114
115    def as_python_constant(self):
116        return self.value
117
118    def call_function(
119        self,
120        tx: "InstructionTranslator",
121        args: "List[VariableTracker]",
122        kwargs: "Dict[str, VariableTracker]",
123    ) -> "VariableTracker":
124        if (
125            inspect.getattr_static(self.value, "__new__", None) in (object.__new__,)
126            and self.source
127        ):
128            # NOTE: we don't need to track mutations to the placement class as they
129            # suppose to be immutable.
130            new_obj = object.__new__(self.value)
131            var = PlacementVariable(new_obj)
132            if inspect.getattr_static(self.value, "__init__", None):
133                var.call_method(tx, "__init__", args, kwargs)
134                return var
135
136        return super().call_function(tx, args, kwargs)
137
138
139class PlacementVariable(DistributedVariable):
140    @staticmethod
141    def is_placement(value):
142        # we can't rely on importing/accessing torch distributed, it is not always built.
143        if not DistributedVariable.is_available():
144            return False
145
146        from torch.distributed.tensor.placement_types import Placement
147
148        return isinstance(value, Placement)
149
150    def as_python_constant(self):
151        return self.value
152
153    def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker:
154        if name == "dim":
155            return ConstantVariable.create(self.value.dim)
156        return super().var_getattr(tx, name)
157
158    def call_method(
159        self,
160        tx,
161        name,
162        args: "List[VariableTracker]",
163        kwargs: "Dict[str, VariableTracker]",
164    ) -> "VariableTracker":
165        from . import ConstantVariable
166
167        # Placement types dynamo tracking only allows following methods
168        # and __setattr__  is for case like `Shard(dim)` and methods.
169        # Methods in the list must satisfy:
170        #    1. Input arguments are constants and do not need to be guarded on;
171        #    2. Output is constant with respect to their inputs
172        constant_fold_functions = [
173            "__init__",
174            "__setattr__",
175            "is_shard",
176            "is_partial",
177            "is_replicate",
178        ]
179
180        if name in constant_fold_functions:
181            try:
182                value_type = type(self.value)
183                assert (
184                    inspect.getattr_static(value_type, "__getattr__", None) is None
185                ), "no custom getattr allowed!"
186                method = inspect.getattr_static(value_type, name)
187            except AttributeError:
188                method = None
189            if method is object.__init__:
190                return ConstantVariable.create(None)
191
192            args = [x.as_python_constant() for x in args]
193            kwargs = {k: v.as_python_constant() for k, v in kwargs.items()}
194            if name == "__setattr__":
195                method(self.value, *args, **kwargs)
196                return self
197            constant_val = method(self.value, *args, **kwargs)
198            return ConstantVariable.create(constant_val)
199
200        return super().call_method(tx, name, args, kwargs)
201
202
203class DeviceMeshVariable(DistributedVariable):
204    @staticmethod
205    def is_device_mesh(value):
206        # we can't rely on importing/accessing torch distributed, it is not always built.
207        if not DistributedVariable.is_available():
208            return False
209
210        from torch.distributed.device_mesh import DeviceMesh
211
212        return istype(value, DeviceMesh)
213
214    def as_python_constant(self):
215        return self.value
216
217    def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker:
218        if name == "ndim":
219            return ConstantVariable.create(self.value.ndim)
220        if name == "device_type":
221            return ConstantVariable.create(self.value.device_type)
222        return super().var_getattr(tx, name)
223
224    def call_method(
225        self,
226        tx,
227        name,
228        args: "List[VariableTracker]",
229        kwargs: "Dict[str, VariableTracker]",
230    ) -> "VariableTracker":
231        if name == "size":
232            const_args = [x.as_python_constant() for x in args]
233            const_kwargs = {k: v.as_python_constant() for k, v in kwargs.items()}
234            return ConstantVariable.create(self.value.size(*const_args, **const_kwargs))
235        if name == "get_coordinate":
236            return ConstantVariable.create(self.value.get_coordinate())
237        if name == "get_group":
238            return ConstantVariable.create(self.value.get_group())
239        if name == "_get_or_create_default_group":
240            return ProcessGroupVariable(self.value._get_or_create_default_group())
241        return super().call_method(tx, name, args, kwargs)
242
243
244class ProcessGroupVariable(DistributedVariable):
245    """
246    We don't want a ProcessGroup object to end up in our output graph.
247
248    But it's common for dynamo to intercept a PG that is then used to get info like
249    rank() or world_size(), as well as passed to utility functions in distributed_c10d
250    which desugar it into plain types like a ranklist and tag.
251
252    For convenience and proper guarding, we construct a variable type.
253
254    TODO: make it possible to use ProcessGroupVariable as input to simple functions
255          like _expand_group without dynamo complaining about making a proxy for it.
256          It is not a tensor-like type, and we don't want a proxy- but dynamo assumes
257          torch library functions are dealing with tensor-like types and would have proxies
258          for their args.
259    TODO: should we make this inherit VT instead of UDOV? Do we want any of the default behaviors
260          or just graph-break whenever one of our special cases is not hit?
261    """
262
263    def as_python_constant(self):
264        return self.value
265
266    def call_method(
267        self,
268        tx,
269        name,
270        args: "List[VariableTracker]",
271        kwargs: "Dict[str, VariableTracker]",
272    ) -> "VariableTracker":
273        if name == "rank":
274            return variables.ConstantVariable.create(self.value.rank())
275        if name == "size":
276            return variables.ConstantVariable.create(self.value.size())
277        if name == "_get_backend_name":
278            return variables.ConstantVariable.create(self.value._get_backend_name())
279
280        return super().call_method(tx, name, args, kwargs)
281
282    def var_getattr(self, tx: "InstructionTranslator", name):
283        if name == "group_name":
284            return variables.ConstantVariable.create(self.value.group_name)
285        if name in ["rank", "size"]:
286            return variables.LambdaVariable(
287                lambda *args, **kwargs: self.call_method(tx, name, args, kwargs)
288            )
289        # TODO should this just raise unimplemented?
290        return super().var_getattr(tx, name)
291
292    @staticmethod
293    def is_process_group(value):
294        # we can't rely on importing/accessing torch distributed, it is not always built.
295        if not DistributedVariable.is_available():
296            return False
297        from torch._C._distributed_c10d import ProcessGroup
298        from torch.testing._internal.distributed.fake_pg import FakeProcessGroup
299
300        return istype(value, (ProcessGroup, FakeProcessGroup))
301
302
303class BackwardHookVariable(VariableTracker):
304    """
305    Handles torch.utils.hooks.BackwardHook for module-level backward
306    hooks.
307    """
308
309    @staticmethod
310    def create(
311        tx,
312        module: VariableTracker,
313        user_hooks: VariableTracker,
314        user_pre_hooks: VariableTracker,
315    ):
316        if not compiled_autograd.compiled_autograd_enabled:
317            unimplemented("module-level backwards hooks require compiled autograd")
318
319        def _in_graph_bw_hooks(bw_state: BackwardState):
320            """
321            Rather than installing the user hooks in the graph (which
322            don't survive AotAutograd), we install hooks that will call
323            trace_wrapped in the backward pass that CompiledAutograd
324            can turn into actual hook calls.
325            """
326            return torch.utils.hooks.BackwardHook(
327                None,
328                (
329                    functools.partial(
330                        trace_wrapped,
331                        fn=call_module_hooks_from_backward_state,
332                        bw_state=bw_state,
333                        hooks_name=user_hooks_name,
334                        module_name=module_name,
335                    ),
336                ),
337                (
338                    functools.partial(
339                        trace_wrapped,
340                        fn=call_module_hooks_from_backward_state,
341                        bw_state=bw_state,
342                        hooks_name=user_pre_hooks_name,
343                        module_name=module_name,
344                    ),
345                ),
346            )
347
348        module_name, bw_state_proxy = tx.output.add_backward_state_hook(module, "mod")
349        user_pre_hooks_name, _ = tx.output.add_backward_state_hook(user_pre_hooks)
350        user_hooks_name, _ = tx.output.add_backward_state_hook(user_hooks)
351        proxy = tx.output.create_proxy(
352            "call_function",
353            _in_graph_bw_hooks,
354            (bw_state_proxy,),
355            {},
356        )
357        proxy.node.meta["example_value"] = torch.utils.hooks.BackwardHook(None, (), ())
358        return BackwardHookVariable(proxy, module, user_hooks, user_pre_hooks)
359
360    def __init__(
361        self,
362        proxy: torch.fx.Proxy,
363        module: VariableTracker,
364        user_hooks: VariableTracker,
365        user_pre_hooks: VariableTracker,
366        **options,
367    ) -> None:
368        super().__init__(**options)
369        self.proxy = proxy
370        self.module = module
371        self.user_hooks = user_hooks
372        self.user_pre_hooks = user_pre_hooks
373
374    def as_proxy(self):
375        return self.proxy
376
377    def call_method(
378        self,
379        tx,
380        name,
381        args: List[VariableTracker],
382        kwargs: Dict[str, VariableTracker],
383    ) -> VariableTracker:
384        if name in ("setup_input_hook", "setup_output_hook"):
385            return self._setup_hook(tx, name, *args, **kwargs)
386        return super().call_method(tx, name, args, kwargs)
387
388    def _setup_hook(self, tx: "InstructionTranslator", hook_method_name, args):
389        from .builder import wrap_fx_proxy
390
391        return wrap_fx_proxy(
392            tx,
393            tx.output.create_proxy(
394                "call_method",
395                hook_method_name,
396                (self.as_proxy(), args.as_proxy()),
397                {},
398            ),
399        )
400