xref: /aosp_15_r20/external/pytorch/torch/distributed/fsdp/_exec_order_utils.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import itertools
3import warnings
4from enum import auto, Enum
5from typing import Dict, List, Optional, Tuple, Union
6
7import torch
8import torch.distributed as dist
9import torch.distributed.fsdp._traversal_utils as traversal_utils
10import torch.nn as nn
11from torch.distributed.fsdp._common_utils import _FSDPState, _get_param_to_fqns
12from torch.distributed.fsdp._flat_param import FlatParamHandle
13
14
15class _ExecOrderWarnStatus(Enum):
16    """Used internally for execution order validation."""
17
18    NONE = auto()  # no deviation yet
19    WARNING = auto()  # deviated this iteration; currently issuing warnings
20    WARNED = auto()  # deviated in a previous iteration
21
22
23class _ExecOrderData:
24    """
25    This contains the data structures to track the execution order. We track
26    the pre-forward order on the *first* iteration for forward prefetching
27    (which thus assumes static graph) and the post-forward order on *every*
28    iteration for backward prefetching (which thus does not assume static
29    graph but may be provide an incorrect order).
30    """
31
32    def __init__(
33        self,
34        debug_level: dist.DebugLevel,
35        backward_prefetch_limit: int,
36        forward_prefetch_limit: int,
37    ) -> None:
38        # Tracks the (static) pre-forward order for execution order validation
39        # and forward prefetching
40        self.handles_pre_forward_order: List[FlatParamHandle] = []
41        # Tracks the post-forward order for pre-backward prefetching
42        self.handles_post_forward_order: List[Optional[FlatParamHandle]] = []
43        self._iter = 0
44
45        # Gives the max number of backward/forward prefetched all-gathers by a
46        # single module
47        self._backward_prefetch_limit = backward_prefetch_limit
48        self._forward_prefetch_limit = forward_prefetch_limit
49
50        # Data structures for execution order validation
51        self._checking_order: bool = debug_level == dist.DebugLevel.DETAIL
52        self.process_group: Optional[dist.ProcessGroup] = None
53        self.world_size: Optional[int] = None
54        self.all_handles: List[FlatParamHandle] = []
55        # Names are prefixed from the root module
56        self.param_to_fqn: Dict[nn.Parameter, List[str]] = {}
57        # Current index in the pre-forward execution order
58        self.current_order_index = 0
59        self.warn_status = _ExecOrderWarnStatus.NONE
60
61    def init(
62        self,
63        state: _FSDPState,
64        root_module: nn.Module,
65        process_group: dist.ProcessGroup,
66    ) -> None:
67        """
68        Initializes the data structures needed for checking the forward order.
69        This should be called after a root FSDP instance has been set during
70        lazy initialization.
71        """
72        self.process_group = process_group
73        self.rank = process_group.rank()
74        self.world_size = process_group.size()
75        # Fix an order over the handles, which should be the same across ranks
76        for handle in traversal_utils._get_fsdp_handles(root_module):
77            index = len(self.all_handles)
78            self.all_handles.append(handle)
79            handle._handle_index = index
80        self.param_to_fqn = _get_param_to_fqns(root_module)
81        # TODO (awgu): We can broadcast the metadata of rank 0's `all_handles`
82        # to check that all ranks have the same handles in the same order.
83        # https://github.com/pytorch/pytorch/issues/79620
84
85    @property
86    def is_first_iter(self) -> bool:
87        return self._iter == 0
88
89    def get_handle_to_backward_prefetch(
90        self,
91        current_handle: FlatParamHandle,
92    ) -> Optional[FlatParamHandle]:
93        """
94        Returns a :class:`list` of the handles keys of the handles to backward
95        prefetch given the current handles key. If there are no valid handles
96        keys to prefetch, then this returns an empty :class:`list`.
97        """
98        current_index = current_handle._post_forward_index
99        if current_index is None:
100            return None
101        target_index = current_index - 1
102        target_handle: Optional[FlatParamHandle] = None
103        for _ in range(self._backward_prefetch_limit):
104            if target_index < 0:
105                break
106            target_handle = self.handles_post_forward_order[target_index]
107            target_index -= 1
108        return target_handle
109
110    def get_handle_to_forward_prefetch(
111        self,
112        current_handle: FlatParamHandle,
113    ) -> Optional[FlatParamHandle]:
114        """
115        Returns a :class:`list` of the handles keys of the handles to forward
116        prefetch given the current handles key. If there are no valid handles
117        keys to prefetch, then this returns an empty :class:`list`.
118        """
119        current_index = current_handle._pre_forward_order_index
120        if current_index is None:
121            return None
122        target_index = current_index + 1
123        target_handle: Optional[FlatParamHandle] = None
124        for _ in range(self._forward_prefetch_limit):
125            if target_index >= len(self.handles_pre_forward_order):
126                break
127            target_handle = self.handles_pre_forward_order[target_index]
128            target_index += 1
129        return target_handle
130
131    def record_post_forward(self, handle: Optional[FlatParamHandle]) -> None:
132        """
133        Records ``handles`` in the post-forward order, where ``handles`` should
134        be a group of handles used in the same module's forward. If ``handles``
135        is empty, then it is omitted.
136
137        Unlike :meth:`record_pre_forward`, this records the order *every*
138        iteration with the expectation that the recorded order is reset in
139        :meth:`next_iter`.
140        """
141        if not handle:
142            return
143        # Only record the first usage of a handles key
144        if handle._post_forward_index:
145            self.handles_post_forward_order.append(handle)
146            return
147        index = len(self.handles_post_forward_order)
148        handle._post_forward_index = index
149        self.handles_post_forward_order.append(handle)
150
151    def record_pre_forward(
152        self, handle: Optional[FlatParamHandle], is_training: bool
153    ) -> None:
154        """
155        Records ``handles`` in the pre-forward order, where ``handles`` should
156        be a group of handles used in the same module's forward. If ``handles``
157        is empty, then it is omitted.
158
159        On the first iteration, this checks the execution order across ranks.
160        See :meth:`_check_order` for details.
161        """
162        if not handle:
163            return
164        self._check_order(handle, is_training)
165        # Fix the order after the first iteration and only record the first
166        # usage of a handles key
167        if not self.is_first_iter or handle._pre_forward_order_index is not None:
168            return
169        index = len(self.handles_pre_forward_order)
170        handle._pre_forward_order_index = index
171        self.handles_pre_forward_order.append(handle)
172
173    def _check_order(self, handle: FlatParamHandle, is_training: bool) -> None:
174        """
175        Checks the forward execution order as long as ``is_training`` is
176        ``True`` since checking in eval mode is not supported. This only checks
177        if the distributed debug level is DETAIL.
178
179        - On the first iteration, this uses all-gathers to check that all ranks
180        are all-gathering the same handles and hence ``FlatParameter`` s,
181        raising an error if not.
182        - On subsequent iterations, this checks that each rank is locally
183        consistent with its own forward order from the first iteration, issuing
184        a warning if not. This issues a warning on the first deviating
185        iteration and stops warning thereafter.
186        """
187        # Do not check order in eval mode since the post-backward callback does
188        # not run so it cannot be used to mark the end of an iteration
189        if not is_training or not self._checking_order:
190            return
191        if self.is_first_iter:
192            msg_prefix = "Forward order differs across ranks:"
193            optional_local_indices: Tuple[
194                Optional[int], ...
195            ] = self._get_handle_indices(handle)
196            device = handle.device  # guaranteed to be non-CPU
197            num_valid_indices = sum(
198                (index is not None) for index in optional_local_indices
199            )
200            tensor_kwargs: Dict[str, Union[torch.dtype, torch.device]] = {
201                "dtype": torch.int32,
202                "device": device,
203            }
204            world_num_valid_indices = torch.zeros(self.world_size, **tensor_kwargs)  # type: ignore[arg-type, call-overload]
205            local_num_valid_indices = torch.tensor([num_valid_indices], **tensor_kwargs)  # type: ignore[arg-type, call-overload]
206            dist.all_gather_into_tensor(
207                world_num_valid_indices,
208                local_num_valid_indices,
209                group=self.process_group,
210            )
211            # Copy entire tensor from D2H once to avoid per element D2H copies
212            world_num_valid_indices = world_num_valid_indices.cpu()
213            # Check that all ranks plan to all-gather the same number of
214            # parameters
215            # TODO (awgu): Since every module has at most one handle in the
216            # current implementation, this should never raise the error.
217            assert self.world_size is not None  # mypy
218            if not torch.distributed._functional_collectives.is_torchdynamo_compiling():
219                # TODO(voz): Don't graph break on this - dynamo hates the n1 != n2
220                # tensor comparison control flow.
221                # https://github.com/pytorch/pytorch/issues/107055
222                for (r1, n1), (r2, n2) in itertools.combinations(
223                    (
224                        (rank, world_num_valid_indices[rank])
225                        for rank in range(self.world_size)
226                    ),
227                    2,
228                ):
229                    if n1 != n2:
230                        raise RuntimeError(
231                            f"{msg_prefix} rank {r1} is all-gathering {n1} parameters "
232                            f"while rank {r2} is all-gathering {n2} parameters"
233                        )
234            world_indices = torch.zeros(  # type: ignore[call-overload]
235                self.world_size * num_valid_indices, **tensor_kwargs
236            )
237            local_indices = torch.tensor(optional_local_indices, **tensor_kwargs)  # type: ignore[arg-type]
238            dist.all_gather_into_tensor(
239                world_indices, local_indices, group=self.process_group
240            )
241            # Copy entire tensor from D2H once to avoid per element D2H copies
242            world_indices = world_indices.cpu()
243            # Check that all ranks plan to all-gather the same index parameters
244            if not torch.distributed._functional_collectives.is_torchdynamo_compiling():
245                # TODO(voz): Don't graph break on this - dynamo hates the i1 != i2
246                # tensor comparison control flow.
247                # https://github.com/pytorch/pytorch/issues/107055
248                for (r1, i1), (r2, i2) in itertools.combinations(
249                    (
250                        (
251                            rank,
252                            world_indices[
253                                rank
254                                * num_valid_indices : (rank + 1)
255                                * num_valid_indices
256                            ],
257                        )
258                        for rank in range(self.world_size)
259                    ),
260                    2,
261                ):
262                    if i1 != i2:
263                        r1_param_names = self._get_names_from_handle_indices(i1)
264                        r2_param_names = self._get_names_from_handle_indices(i2)
265                        raise RuntimeError(
266                            f"{msg_prefix} rank {r1} is all-gathering parameters "
267                            f"for {r1_param_names} while rank {r2} is all-gathering "
268                            f"parameters for {r2_param_names}"
269                        )
270        else:
271            # Only issue warnings on the first deviating iteration and stop
272            # checking thereafter to avoid flooding the console
273            if self.warn_status == _ExecOrderWarnStatus.WARNED:
274                return
275            msg_prefix = None  # non-`None` means we should warn
276            if self.current_order_index >= len(self.handles_pre_forward_order):
277                # This iteration sees extra all-gather(s) compared to the first
278                msg_prefix = (
279                    "Expected to not all-gather any more parameters in the "
280                    "forward but trying to all-gather parameters for "
281                )
282            else:
283                expected_handle = self.handles_pre_forward_order[
284                    self.current_order_index
285                ]
286                if expected_handle != handle:
287                    expected_param_names = self._get_names_from_handles(expected_handle)
288                    msg_prefix = (
289                        f"Expected to all-gather for {expected_param_names} "
290                        "but trying to all-gather parameters for "
291                    )
292            if msg_prefix is not None:
293                param_names = self._get_names_from_handles(handle)
294                msg_suffix = (
295                    f"{param_names}"
296                    if param_names
297                    else "a newly-added parameter since construction time"
298                )
299                warnings.warn(
300                    "Forward order differs from that of the first iteration "
301                    f"on rank {self.rank}. Collectives are unchecked and may "
302                    f"give incorrect results or hang.\n{msg_prefix}{msg_suffix}"
303                )
304                self.warn_status = _ExecOrderWarnStatus.WARNING
305            self.current_order_index += 1
306
307    def _get_handle_indices(
308        self,
309        handle: FlatParamHandle,
310    ) -> Tuple[Optional[int], ...]:
311        """
312        Returns the handle indices (i.e. indices into ``self.all_handles``)
313        corresponding to the handles in ``handle``. An entry in the
314        returned tuple is ``None`` if the handle is invalid.
315        """
316        indices: List[Optional[int]] = []
317        if handle:
318            indices.append(handle._handle_index)
319        return tuple(indices)
320
321    def _get_names_from_handle_indices(
322        self,
323        handle_indices: Tuple[int, ...],
324    ) -> List[List[str]]:
325        """
326        Returns a list of FQNs for each handle in ``handle_indices``. If a
327        handle index is invalid, then its FQNs are omitted from the returned
328        list.
329        """
330        fqns: List[List[str]] = []
331        for index in handle_indices:
332            if index is None or index < 0 or index >= len(self.all_handles):
333                continue
334            handle = self.all_handles[index]
335            flat_param = handle.flat_param
336            fqns.append(self.param_to_fqn[flat_param])
337        return fqns
338
339    def _get_names_from_handles(
340        self,
341        handle: FlatParamHandle,
342    ) -> List[List[str]]:
343        """
344        Returns a list of FQNs for each handle in ``handles_key``. If a handle
345        is invalid, then its FQNs are omitted from the returned list.
346        """
347        fqns: List[List[str]] = []
348        if handle:
349            flat_param = handle.flat_param
350            if flat_param in self.param_to_fqn:
351                fqns.append(self.param_to_fqn[flat_param])
352        return fqns
353
354    def next_iter(self):
355        """
356        Advances the internal data structures per iteration. This should be
357        called in the post-backward callback since that marks the true end of
358        an iteration.
359        """
360        self._iter += 1
361        self.handles_post_forward_order.clear()
362        if self._checking_order:
363            self.current_order_index = 0
364            if self.warn_status == _ExecOrderWarnStatus.WARNING:
365                self.warn_status = _ExecOrderWarnStatus.WARNED
366