xref: /aosp_15_r20/external/pytorch/torch/distributed/fsdp/_trace_utils.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import functools
3from contextlib import contextmanager
4from dataclasses import dataclass, field
5from typing import Any, Callable, Dict, List, NamedTuple, Optional, Set, Tuple
6
7import torch
8import torch.nn as nn
9
10
11@dataclass
12class TracingConfig:
13    """
14    This represents a symbolic tracing configuration.
15
16    Args:
17        tracer (torch.fx.Tracer): An instance of :class:`torch.fx.Tracer` to
18            use for symbolic tracing. The default value is the native
19            :class:`torch.fx.Tracer` constructed with default arguments.
20            However, the user may want to pass a different value such as the
21            ``HFTracer`` for models in the HuggingFace Transformers_ library.
22            .. _Transformers: https://huggingface.co/docs/transformers/index
23        concrete_args (Optional[Dict[str, Any]]): Concrete arguments that
24            should not be treated as ``torch.fx.Proxy`` when tracing the
25            module ``forward()``. Passing ``concrete_args`` allows partially
26            specializing the forward, e.g. to remove control flow or data
27            structures. This ``concrete_args`` here is the same argument used
28            in :meth:`~torch.fx.Tracer.trace`.
29    """
30
31    tracer: torch.fx.Tracer = field(default_factory=torch.fx.Tracer)
32    concrete_args: Optional[Dict[str, Any]] = None
33
34
35class _ParamUsageInfo(NamedTuple):
36    """
37    This is used for ``_ExecutionInfo.module_to_param_usage_infos`` to record
38    execution information. The ``dict`` maps modules to a list of these
39    ``_ParamUsageInfo`` instances, where each instance represents a group of
40    parameters used together.
41
42    Specifically, for each module key in the ``dict``, each instance of this
43    class represents either:
44    (1) the module and some sublist of its ``named_parameters()`` used
45    together in execution (see ``_patched_create_proxy()``), or
46    (2) a submodule and all of ``submodule.named_parameters()`` (see
47    ``_patched_call_module()``).
48
49    Type (1) corresponds to directly using parameters in ops without calling
50    ``forward()``, and type (2) corresponds to calling ``forward()``. The
51    mapped-to lists in the ``dict`` follow the execution order.
52    """
53
54    module: nn.Module
55    named_params: List[Tuple[str, nn.Parameter]]
56
57
58class _ExecutionInfo:
59    """
60    This represents the execution order information from the forward pass.
61
62    Attributes:
63        curr_module (nn.Module): Current module being traced.
64        module_forward_order (List[nn.Module]): The modules in (pre-)forward
65            order, i.e. the order in which their ``forward()`` methods are
66            called. Each call to a module's ``forward()`` corresponds to one
67            element in the list.
68        module_to_param_usage_infos (Dict[nn.Module, List[_ParamUsageInfo]]):
69            Maps a module to a list of module execution infos. See
70            :class:`_ParamUsageInfo` for details.
71        param_forward_order (List[nn.Parameter]): The parameters in forward
72            execution order, where only a parameter's first participation is
73            included.
74        visited_params (Set[nn.Parameter]): The parameters visited so far
75            during the trace. This is only used during tracing for fast
76            membership check. Invariant: The parameters in
77            ``param_forward_order`` are exactly those in ``visited_params``.
78    """
79
80    def __init__(self, root_module: nn.Module) -> None:
81        self.curr_module: nn.Module = root_module
82        self.module_forward_order: List[nn.Module] = [root_module]
83        self.module_to_param_usage_infos: Dict[nn.Module, List[_ParamUsageInfo]] = {
84            root_module: []
85        }
86        self.param_forward_order: List[nn.Parameter] = []
87        self.visited_params: Set[nn.Parameter] = set()
88
89
90class _ExecOrderTracer:
91    def __init__(self) -> None:
92        self.exec_info: Optional[_ExecutionInfo] = None
93
94    @contextmanager
95    def patch_tracer(self, tracer: torch.fx.Tracer, root_module: nn.Module):
96        self.exec_info = _ExecutionInfo(root_module)
97        orig_call_module = tracer.call_module
98        orig_create_proxy = tracer.create_proxy
99        tracer.call_module = functools.partial(  # type: ignore[method-assign]
100            self._patched_call_module, orig_call_module, self.exec_info
101        )
102        fqn_to_param = dict(root_module.named_parameters())
103        tracer.create_proxy = functools.partial(  # type: ignore[method-assign]
104            self._patched_create_proxy,
105            orig_create_proxy,
106            self.exec_info,
107            fqn_to_param,
108        )
109        try:
110            yield
111        finally:
112            tracer.call_module = orig_call_module  # type: ignore[method-assign]
113            tracer.create_proxy = orig_create_proxy  # type: ignore[method-assign]
114
115    def _patched_call_module(
116        self,
117        call_module: Callable,
118        exec_info: _ExecutionInfo,
119        # Below are the expected arguments to `call_module()`
120        module: nn.Module,
121        forward: Callable,
122        args: Tuple[Any, ...],
123        kwargs: Dict[str, Any],
124    ) -> Any:
125        """
126        Overrides ``call_module`` to save execution information to
127        ``exec_info``. Note that ``call_module`` is called during symbolic
128        tracing for each non-root module.
129
130        Args:
131            call_module (Callable): Original ``call_module`` to override.
132            exec_info (_ExecutionInfo): Used to record execution information.
133            module (nn.Module): Module corresponding to this ``call_module``.
134            forward (Callable): ``forward()`` method of ``module`` to be called
135                for this ``call_module``.
136            args (Tuple[Any, ...]): Positional arguments for ``forward``.
137            kwargs (Dict[str, Any]): Keyword arguments for ``forward``.
138
139        Returns:
140            Same return value as ``call_module``.
141        """
142        exec_info.module_forward_order.append(module)
143        named_params = list(module.named_parameters())
144        curr_module = exec_info.curr_module
145        if named_params:
146            assert (
147                curr_module in exec_info.module_to_param_usage_infos
148            ), "The current module should have already been processed by a patched `call_module`"
149            exec_info.module_to_param_usage_infos[exec_info.curr_module].append(
150                _ParamUsageInfo(module, named_params)
151            )
152        prev_curr_module = curr_module
153        exec_info.curr_module = module
154        exec_info.module_to_param_usage_infos[module] = []
155        output = call_module(module, forward, args, kwargs)
156        exec_info.curr_module = prev_curr_module
157        return output
158
159    def _patched_create_proxy(
160        self,
161        create_proxy: Callable,
162        exec_info: _ExecutionInfo,
163        fqn_to_param: Dict[str, nn.Parameter],
164        # Below are the expected arguments to `create_proxy()`
165        kind: str,
166        target: torch.fx.node.Target,
167        args: Tuple[Any, ...],
168        kwargs: Dict[str, Any],
169        name: Optional[str] = None,
170        type_expr: Optional[Any] = None,
171        proxy_factory_fn: Optional[Callable[[torch.fx.Node], torch.fx.Proxy]] = None,
172    ) -> torch.fx.Proxy:
173        """
174        Overrides ``create_proxy`` to save execution information to
175        ``exec_info``. Note that ``create_proxy`` is called during symbolic
176        tracing for each leaf function/method/module.
177
178        Args:
179            create_proxy (Callable): Original ``create_proxy`` to override.
180            exec_info (_ExecutionInfo): Used to record execution information.
181            fqn_to_param (Dict[str, nn.Parameter]): ``dict`` version of the
182                root module's ``named_parameters()`` with FQN as key and
183                parameter as value.
184            kind (str): Kind of the target method ('call_function',
185                'call_method', 'get_attr', 'call_module', 'placeholder', or
186                'output'). See :class:`torch.fx.Graph` for details. This is
187                passed to ``create_proxy``.
188            target (torch.fx.node.Target): Contains the string name of the
189                function/method/module. This is passed to ``create_proxy``.
190            args (Tuple[Any, ...]): Positional arguments for the function/
191                method/module. This is passed to ``create_proxy``.
192            kwargs (Dict[str, Any]): Keyword arguments for the function/method/
193                module. This is passed to ``create_proxy``
194            name (Optional[str]): An optional string name for the ``Node``
195                created in ``create_proxy``. This is passed to
196                ``create_proxy``.
197            type_expr (Optional[Any]): An optional type annotation representing
198                the Python type that the output of the node has. This is passed
199                to ``create_proxy``.
200            proxy_factory_fn (Callable[[torch.fx.Node], torch.fx.Proxy]):
201                An alternative proxy constructor used in ``create_proxy``. This
202                is passed to ``create_proxy``.
203
204        Returns:
205            torch.fx.Proxy: Created ``Node`` wrapped in a ``Proxy`` object.
206        """
207        proxy = create_proxy(
208            kind, target, args, kwargs, name, type_expr, proxy_factory_fn
209        )
210        curr_module = exec_info.curr_module
211        if kind in ("call_function", "call_method"):
212            if args is not None:
213                named_params: List[Tuple[str, nn.Parameter]] = []
214                for arg in args:
215                    if (
216                        isinstance(arg, torch.fx.Proxy)
217                        and arg.node.target in fqn_to_param
218                    ):
219                        param = fqn_to_param[arg.node.target]  # type: ignore[index]
220                        named_params.append((arg.node.target, param))  # type: ignore[arg-type]
221                        if param not in exec_info.visited_params:
222                            exec_info.visited_params.add(param)
223                            exec_info.param_forward_order.append(param)
224                if named_params:
225                    exec_info.module_to_param_usage_infos[curr_module].append(
226                        _ParamUsageInfo(curr_module, named_params)
227                    )
228        elif kind == "call_module":
229            named_params = list(curr_module.named_parameters())
230            if named_params:
231                exec_info.module_to_param_usage_infos[curr_module].append(
232                    _ParamUsageInfo(curr_module, named_params)
233                )
234            for _, param in named_params:
235                if param not in exec_info.visited_params:
236                    exec_info.visited_params.add(param)
237                    exec_info.param_forward_order.append(param)
238        return proxy
239