xref: /aosp_15_r20/external/pytorch/torch/distributed/_composable/fsdp/_fsdp_state.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-decorators
2# mypy: allow-untyped-defs
3import functools
4import logging
5from typing import (
6    Any,
7    Callable,
8    Dict,
9    List,
10    Optional,
11    Sequence,
12    Set,
13    Tuple,
14    TYPE_CHECKING,
15)
16
17import torch
18import torch._dynamo.compiled_autograd as ca
19import torch.nn as nn
20from torch._logging import warning_once
21from torch.autograd import Variable
22from torch.autograd.graph import _MultiHandle
23from torch.distributed._composable_state import (
24    _get_module_state,
25    _insert_module_state,
26    _State,
27)
28from torch.distributed.utils import _to_kwargs
29from torch.utils._pytree import tree_flatten, tree_map
30
31from ._fsdp_api import MixedPrecisionPolicy
32from ._fsdp_common import _cast_fp_tensor, TrainingState
33from ._fsdp_param_group import FSDPCommContext, FSDPParamGroup
34
35
36if TYPE_CHECKING:
37    from ._fsdp_param import FSDPParam
38
39
40logger = logging.getLogger("torch.distributed._composable.fsdp")
41
42
43class FSDPStateContext:
44    """This has state shared across FSDP states."""
45
46    def __init__(self) -> None:
47        # All FSDP states in the root state's module tree
48        self.all_states: List[FSDPState] = []
49        # Iteration's forward root runs the once-per-forward logic; this root
50        # may not be the overall root set by lazy initialization in cases where
51        # only a submodule runs forward (e.g. encoder-only for eval)
52        self.iter_forward_root: Optional[FSDPState] = None
53        # Final callback should only be queued once per backward
54        self.post_backward_final_callback_queued: bool = False
55        # Whether to finalize backward in this backward's final callback
56        self.is_last_backward: bool = True
57        # Optional user-provided event recorded after optimizer for the
58        # all-gather streams to wait on in the root pre-forward
59        self.post_optim_event: Optional[torch.cuda.Event] = None
60
61
62def disable_if_config_true(func):
63    @functools.wraps(func)
64    def fsdp_hook_wrapper(*args, **kwargs):
65        if torch._dynamo.config.skip_fsdp_hooks:
66            return torch._dynamo.disable(func, recursive=True)(*args, **kwargs)
67        else:
68            return func(*args, **kwargs)
69
70    return fsdp_hook_wrapper
71
72
73class FSDPState(_State):
74    def __init__(self) -> None:
75        super().__init__()
76        self._fsdp_param_group: Optional[FSDPParamGroup] = None
77        self._is_root: Optional[bool] = None  # root set during lazy init
78        self._state_ctx = FSDPStateContext()
79        self._comm_ctx = FSDPCommContext()
80        self._training_state: TrainingState = TrainingState.IDLE
81        self._states_to_forward_prefetch: List[FSDPState] = []
82        self._states_to_backward_prefetch: List[FSDPState] = []
83        self._modules_to_run_forward: Set[nn.Module] = set()
84
85    # Define a separate init since `__init__` is called in the contract
86    def init(
87        self,
88        modules: Tuple[nn.Module, ...],
89        device: torch.device,
90        mp_policy: MixedPrecisionPolicy,
91    ) -> None:
92        for module in modules:
93            _insert_module_state(module, self)
94        self._modules = modules
95        self._device = device
96        self._mp_policy = mp_policy
97        if len(modules) == 1:
98            self._pre_forward_hook_handle = modules[0].register_forward_pre_hook(
99                self._pre_forward, prepend=True, with_kwargs=True
100            )
101            self._post_forward_hook_handle = modules[0].register_forward_hook(
102                self._post_forward, prepend=False
103            )
104        else:
105            hook_handle = _register_group_forward_hooks(
106                modules,
107                self._pre_forward,
108                self._post_forward,
109                self._modules_to_run_forward,
110            )
111            self._pre_forward_hook_handle = hook_handle
112            self._post_forward_hook_handle = hook_handle
113
114    def _root_pre_forward(
115        self, module: nn.Module, args: Tuple[Any, ...], kwargs: Dict[str, Any]
116    ) -> Tuple[Tuple[Any, ...], Dict[str, Any]]:
117        self._lazy_init()
118        if self._state_ctx.iter_forward_root is not None:
119            return args, kwargs
120        if not ca.compiled_autograd_enabled:
121            logger.debug("FSDP::root_pre_forward")
122        self._state_ctx.iter_forward_root = self
123        with torch.profiler.record_function("FSDP::root_pre_forward"):
124            # Wait for optimizer before implicitly prefetched all-gathers
125            if (event := self._state_ctx.post_optim_event) is not None:
126                self._comm_ctx.all_gather_copy_in_stream.wait_event(event)
127                self._comm_ctx.all_gather_stream.wait_event(event)
128                self._state_ctx.post_optim_event = None
129            else:
130                current_stream = torch.cuda.current_stream()
131                self._comm_ctx.all_gather_copy_in_stream.wait_stream(current_stream)
132                self._comm_ctx.all_gather_stream.wait_stream(current_stream)
133            if self._device.type == "cuda":
134                with torch.profiler.record_function("FSDP::inputs_to_device"):
135                    args_tuple, kwargs_tuple = _to_kwargs(
136                        args, kwargs, self._device, False
137                    )  # same as DDP
138                args, kwargs = args_tuple[0], kwargs_tuple[0]
139        return args, kwargs
140
141    def _lazy_init(self) -> None:
142        """
143        Lazy initialization represents when all modules' parallelisms have
144        finalized (e.g. FSDP has been applied to all desired modules). This
145        means that we can determine which state is the root, and we do so by
146        the 1st state to run forward.
147        """
148        if self._is_root is not None:
149            return  # no-op: already initialized
150        self._is_root = True
151        if len(self._modules) > 1:
152            raise RuntimeError(
153                f"FSDP requires a single root module but got {self._modules}"
154            )
155        root_module = self._modules[0]
156        visited_states: Set[FSDPState] = set()
157        for module_name, module in root_module.named_modules():
158            if (state := _get_module_fsdp_state(module)) is None:
159                continue
160            if module is not root_module:
161                if state not in visited_states and state._is_root is not None:
162                    raise RuntimeError(
163                        "FSDP state has already been lazily initialized for "
164                        f"{module_name}\nFSDP requires running forward through "
165                        "the root module first"
166                    )
167                state._is_root = False
168            self._state_ctx.all_states.append(state)
169            visited_states.add(state)
170        if self._fsdp_param_group:
171            # For the root, do not reshard after forward since for training,
172            # the parameters would be freed and all-gathered immediately
173            self._fsdp_param_group.post_forward_mesh_info = None
174        self._init_fqns()
175        self._init_shared_state()
176        # Run parameter group lazy inits after initializing FQNs for improved
177        # error messages
178        for state in self._state_ctx.all_states:
179            if state._fsdp_param_group:
180                state._fsdp_param_group.lazy_init()
181
182    def _init_shared_state(self) -> None:
183        self._comm_ctx.lazy_init()
184        for state in self._state_ctx.all_states:
185            state._state_ctx = self._state_ctx
186            state._comm_ctx = self._comm_ctx
187            if fsdp_param_group := state._fsdp_param_group:
188                fsdp_param_group.comm_ctx = self._comm_ctx
189
190    def _init_fqns(self) -> None:
191        """Sets module and parameter FQN attributes for debugging."""
192        assert self._is_root
193        root_module = self._modules[0]
194        param_to_fsdp_param: Dict[nn.Parameter, FSDPParam] = {}
195        module_to_fsdp_param_group: Dict[nn.Module, FSDPParamGroup] = {}
196        for state in self._state_ctx.all_states:
197            if fsdp_param_group := state._fsdp_param_group:
198                for fsdp_param in fsdp_param_group.fsdp_params:
199                    param_to_fsdp_param[fsdp_param.sharded_param] = fsdp_param
200                for module in fsdp_param_group.modules:
201                    module_to_fsdp_param_group[module] = fsdp_param_group
202        for param_name, param in root_module.named_parameters():
203            if param in param_to_fsdp_param:
204                param_to_fsdp_param[param]._param_fqn = param_name
205        for module_name, module in root_module.named_modules():
206            if module in module_to_fsdp_param_group:
207                module_fqn = module_to_fsdp_param_group[module]._module_fqn
208                if module_fqn is None:
209                    module_to_fsdp_param_group[module]._module_fqn = module_name
210                else:
211                    assert isinstance(module_fqn, str), f"{module_fqn}"
212                    module_fqn += f", {module_name}"
213                    module_to_fsdp_param_group[module]._module_fqn = module_fqn
214
215    @disable_if_config_true
216    def _pre_forward(
217        self, module: nn.Module, args: Tuple[Any, ...], kwargs: Dict[str, Any]
218    ) -> Tuple[Tuple[Any, ...], Dict[str, Any]]:
219        # When composing with module-hook-based activation checkpointing, the
220        # the pre-backward hook is responsible for the unshard
221        if self._training_state == TrainingState.PRE_BACKWARD:
222            return args, kwargs
223        self._training_state = TrainingState.FORWARD
224        args, kwargs = self._root_pre_forward(module, args, kwargs)
225        if self._mp_policy.cast_forward_inputs and self._mp_policy.param_dtype:
226            with torch.profiler.record_function("FSDP::cast_forward_inputs"):
227                cast_fn = functools.partial(
228                    _cast_fp_tensor, self._mp_policy.param_dtype
229                )
230                args, kwargs = tree_map(cast_fn, args), tree_map(cast_fn, kwargs)
231        if self._fsdp_param_group:
232            args, kwargs = self._fsdp_param_group.pre_forward(module, args, kwargs)
233        for fsdp_state in self._states_to_forward_prefetch:
234            if (target_param_group := fsdp_state._fsdp_param_group) is not None:
235                FSDPParamGroup._prefetch_unshard(target_param_group, "forward")
236        return args, kwargs
237
238    @disable_if_config_true
239    def _post_forward(self, module: nn.Module, input: Any, output: Any) -> Any:
240        # When composing with module-hook-based activation checkpointing, the
241        # post-backward hook is responsible for the reshard
242        if self._training_state == TrainingState.PRE_BACKWARD:
243            return output
244        if self._fsdp_param_group:
245            output = self._fsdp_param_group.post_forward(module, input, output)
246        output = self._register_pre_backward_hook(output)
247        self._training_state = TrainingState.IDLE
248        if self._state_ctx.iter_forward_root is self:
249            if all_gather_state := self._comm_ctx.all_gather_state:
250                # Free the last all-gather result if needed; refer to
251                # [Note: Overlapping all-gather copy-in and all-gather]
252                self._comm_ctx.all_gather_copy_in_stream.wait_event(
253                    all_gather_state.event
254                )
255                self._comm_ctx.all_gather_stream.wait_event(all_gather_state.event)
256                self._comm_ctx.all_gather_state = None  # free the all-gather result
257            self._state_ctx.iter_forward_root = None
258        if self._mp_policy.output_dtype is not None:
259            with torch.profiler.record_function("FSDP::cast_forward_outputs"):
260                output = tree_map(
261                    functools.partial(_cast_fp_tensor, self._mp_policy.output_dtype),
262                    output,
263                )
264        return output
265
266    def _pre_backward(self, grad: torch.Tensor) -> torch.Tensor:
267        self._training_state = TrainingState.PRE_BACKWARD
268        self._register_root_post_backward_final_callback()
269        if self._fsdp_param_group:
270            default_prefetch = len(self._states_to_backward_prefetch) == 0
271            self._fsdp_param_group.pre_backward(default_prefetch)
272        for fsdp_state in self._states_to_backward_prefetch:
273            if (target_param_group := fsdp_state._fsdp_param_group) is not None:
274                FSDPParamGroup._prefetch_unshard(target_param_group, "backward")
275        return grad
276
277    def _root_post_backward_final_callback(self) -> None:
278        if not ca.compiled_autograd_enabled:
279            logger.debug("FSDP::root_post_backward")
280        with torch.profiler.record_function("FSDP::root_post_backward_callback"):
281            for state in self._state_ctx.all_states:
282                if state._fsdp_param_group and state._fsdp_param_group.is_unsharded:
283                    # Run post-backward in case forward inputs did not require
284                    # gradient so the autograd backward did not run
285                    state._fsdp_param_group.post_backward()
286                state._training_state = TrainingState.IDLE
287                if state._fsdp_param_group:
288                    state._fsdp_param_group._training_state = TrainingState.IDLE
289                if self._state_ctx.is_last_backward:
290                    state._finalize_backward()
291            if self._state_ctx.is_last_backward:
292                self._comm_ctx.post_forward_order.clear()
293                if self._comm_ctx.reduce_scatter_state is not None:
294                    torch.cuda.current_stream().wait_event(
295                        self._comm_ctx.reduce_scatter_state.event
296                    )
297                    self._comm_ctx.reduce_scatter_state = None
298            self._state_ctx.post_backward_final_callback_queued = False
299
300    def _finalize_backward(self) -> None:
301        if self._modules_to_run_forward:
302            msg = (
303                f"{len(self._modules_to_run_forward)} of the {len(self._modules)} "
304                f"modules passed to fully_shard did not run forward before backward, "
305                "which is error-prone since FSDP post-forward/pre-backward logic "
306                "will not run for these modules. We recommend passing only modules "
307                "that run forward together. Modules that did not run forward: "
308                f"{list(self._modules_to_run_forward)}"
309            )
310            warning_once(logger, msg, stacklevel=2)
311            # Clear since we want the next forward to run
312            self._modules_to_run_forward.clear()
313        if self._fsdp_param_group:
314            self._fsdp_param_group.finalize_backward()
315
316    def _register_pre_backward_hook(self, output: Any) -> Any:
317        if not torch.is_grad_enabled():
318            return output
319        flat_outputs, _ = tree_flatten(output)
320        for t in flat_outputs:
321            if torch.is_tensor(t) and t.requires_grad:
322                t.register_hook(self._pre_backward)
323        return output
324
325    def _register_root_post_backward_final_callback(self):
326        if self._state_ctx.post_backward_final_callback_queued:
327            return
328        self._state_ctx.post_backward_final_callback_queued = True
329        Variable._execution_engine.queue_callback(
330            self._root_post_backward_final_callback
331        )
332
333
334def _get_module_fsdp_state(module: nn.Module) -> Optional[FSDPState]:
335    state = _get_module_state(module)
336    if isinstance(state, FSDPState):
337        return state
338    return None
339
340
341def _register_group_forward_hooks(
342    modules: Sequence[nn.Module],
343    pre_hook: Callable,
344    post_hook: Callable,
345    modules_to_run: Set[nn.Module],
346):
347    """
348    Registers group forward pre and post-hooks. The pre-hook runs upon the
349    first module pre-forward, and the post-hook runs upon the last. If at least
350    one module does not run forward, then the post-hook does not run.
351    """
352    modules_set = set(modules)
353
354    @disable_if_config_true
355    @functools.wraps(pre_hook)
356    def wrapped_pre_hook(*args: Any, **kwargs: Any):
357        if len(modules_to_run) == 0:  # first to run
358            modules_to_run.update(modules_set)
359            return pre_hook(*args, **kwargs)
360
361    @disable_if_config_true
362    def get_wrapped_post_hook(module: nn.Module):
363        @functools.wraps(post_hook)
364        def wrapped_post_hook(*args: Any, **kwargs: Any):
365            modules_to_run.discard(module)
366            if len(modules_to_run) == 0:
367                return post_hook(*args, **kwargs)
368
369        return wrapped_post_hook
370
371    pre_handles = [
372        module.register_forward_pre_hook(
373            wrapped_pre_hook, prepend=True, with_kwargs=True
374        )
375        for module in modules
376    ]
377    post_handles = [
378        module.register_forward_hook(
379            get_wrapped_post_hook(module), prepend=False, always_call=True
380        )
381        for module in modules
382    ]
383    return _MultiHandle(tuple(pre_handles + post_handles))
384