xref: /aosp_15_r20/external/pytorch/torch/distributed/optim/named_optimizer.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import logging
3import warnings
4from copy import deepcopy
5from typing import (
6    Any,
7    Callable,
8    Collection,
9    Dict,
10    List,
11    Mapping,
12    Optional,
13    overload,
14    Union,
15)
16
17import torch
18import torch.nn as nn
19from torch import optim
20from torch.distributed._shard.sharded_tensor import ShardedTensor
21from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
22
23
24__all__: List[str] = []
25
26logger = logging.getLogger(__name__)
27
28
29class _NamedOptimizer(optim.Optimizer):
30    """
31    ``_NamedOptimizer`` takes a dict of parameters and exposes ``state_dict`` by parameter key.
32
33    We replace the original key (number) in an optim to the
34    fully qualified name (FQN) string. User can initialize the optim as they
35    initialize a PyTorch optim, the only difference is that they also need to
36    pass in the FQN of each parameters.
37
38    Args:
39        named_parameters (Mapping[str, Union[torch.Tensor, ShardedTensor]]):
40            Mapping from FQN to parameter.
41        optimizer_class (optim.Optimizer):
42            The class of optimizer to instantiate.
43        param_groups (Collection[Mapping[str, Any]]):
44            `param_groups` to pass to optimizer if specified.
45            The key of the inner map needs to be FQNs.
46            Default: None
47        module (nn.Module): the module whose parameters to updated
48            by the optimizer.
49        args: arguments to pass to the optimizer constructor.
50        kwargs: arguments to pass to the optimizer constructor.
51
52    Example::
53        >>> # xdoctest: +SKIP("distributed")
54        >>> from torch import optim
55        >>> from torch.distributed.optim import _NamedOptimizer
56        >>>
57        >>> # Define the named optimizer.
58        >>> m = Model(...)
59        >>> named_optim = _NamedOptimizer(m.named_parameters(), optim.SGD)
60        >>> # Forward pass + backward pass.
61        >>> named_optim.step()
62        >>> ...
63        >>> # Call state_dict for the named optimizer returns a FQN state_dict.
64        >>> named_optim.state_dict()
65
66    Warning: This API is still in development and subject to change.
67
68    TODO: Add tutorial for _NamedOptimizer.
69    TODO: Add documentation in the docstring for the public attributes
70          like self.param_groups and self.named_parameters.
71    """
72
73    def __init__(
74        self,
75        named_parameters: Mapping[str, Union[torch.Tensor, ShardedTensor]],
76        optimizer_class: optim.Optimizer,
77        param_groups: Optional[Collection[Mapping[str, Any]]] = None,
78        module: Optional[nn.Module] = None,
79        *args,
80        **kwargs,
81    ) -> None:
82        torch._C._log_api_usage_once("torch.distributed.optim._NamedOptimizer")
83        self.param_groups: Collection[Mapping[str, Any]] = param_groups  # type: ignore[assignment]
84        self._param_groups_check()
85        self.named_parameters = dict(named_parameters)
86        params_for_optimizer = (
87            self.named_parameters.values() if param_groups is None else param_groups
88        )
89        self._optimizer = optimizer_class(  # type: ignore[operator]
90            params_for_optimizer,
91            *args,
92            **kwargs,
93        )
94        self.module = module
95        if param_groups is None:
96            self.ordered_param_keys = list(self.named_parameters.keys())
97        else:
98            warnings.warn(
99                "Since we pass in param_groups, we will use param_groups to "
100                "initialize the optimizer, not all parameters of the module."
101            )
102            param_to_key = {param: key for key, param in self.named_parameters.items()}  # type: ignore[misc, has-type]
103            ordered_param_keys = []
104            for group in param_groups:
105                for param in group["params"]:
106                    if param not in param_to_key:
107                        raise ValueError(
108                            f"Expect param name {param} found in param group but is missing."
109                        )
110                    ordered_param_keys.append(param_to_key[param])
111            self.ordered_param_keys = ordered_param_keys
112        # Update param_groups from optimizer.
113        self.param_groups = self._optimizer.param_groups
114
115    def _param_groups_check(self):
116        if self.param_groups is not None:
117            for param_group in self.param_groups:
118                assert isinstance(param_group, dict), "param group must be a dict"
119                assert "params" in param_group, "param group must contain key params"
120                params = param_group["params"]
121                if isinstance(params, torch.Tensor):
122                    params = [params]
123                params = list(params)
124                for param in params:
125                    if not isinstance(param, torch.Tensor):
126                        raise TypeError(
127                            "optimizer can only optimize Tensors, "
128                            "but one of the params is " + torch.typename(param)
129                        )
130                param_group["params"] = params
131
132    def state_dict(self) -> Dict[str, Any]:
133        """
134        Return the ``state_dict`` of the optimizer.
135
136        Instead of using number to index
137        parameters, we will use module fully qualified name (FQN) as the key.
138        """
139        state_dict = self._optimizer.state_dict()
140        param_groups = state_dict["param_groups"]
141
142        ret_state = {
143            self.ordered_param_keys[st_key]: state_val
144            for st_key, state_val in state_dict["state"].items()
145        }
146
147        ret_groups = []
148        for group in param_groups:
149            param_keys = []
150            for param in group["params"]:
151                param_keys.append(self.ordered_param_keys[param])
152            ret_group = {"params": sorted(param_keys)}
153            for k, v in group.items():
154                if k != "params":
155                    ret_group[k] = deepcopy(v)
156            ret_groups.append(ret_group)
157
158        return self._post_state_dict({"state": ret_state, "param_groups": ret_groups})
159
160    @overload
161    def step(self, closure: None = ...) -> None:
162        ...
163
164    @overload
165    def step(self, closure: Callable[[], float]) -> float:
166        ...
167
168    def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]:
169        """
170        Perform a single optimization step.
171
172        This will call :meth:`torch.optim.Optimizer.step` on the wrapped
173        optimizer.
174        """
175        return self._optimizer.step(closure=closure)
176
177    @property
178    def state(self) -> Mapping[torch.Tensor, Any]:  # type: ignore[override]
179        return self._optimizer.state
180
181    def load_state_dict(self, state_dict: Mapping[str, Any]) -> None:
182        """
183        Define the default behavior to load a state_dict for ``_NamedOptimizer``.
184
185        Sample Code
186        ```
187            my_model = MyModule()
188            optimizer = _NamedOptimizer(my_model.named_parameters(), Adagrad)
189            ...
190
191            optim_state_dict = optimizer.state_dict()
192            ...
193            ...
194
195            optimizer.load_state_dict(optim_state_dict)
196            ...
197        ```
198        Args:
199            state_dict (Dict[str, Any]) : A ``state_dict`` to load into the optimizer.
200                Note that this state dict update is performed in place.
201
202        .. note:: PyTorch is using lazy init to initialize the optim states.
203            So it is possible that there is no optim state when user call
204            ``load_state_dict`` and for ``_NamedOptimizer`` we make it stricter
205            that users can only call ``load_state_dict`` after the state is initialized.
206            By doing this, we can validate the optim ``state_dict`` to be loaded.
207        """
208        new_state_dict = self._optimizer.state_dict()
209        state_dict = self._pre_load_state_dict(state_dict)
210        state = state_dict["state"]
211        new_state = new_state_dict["state"]
212        if len(new_state) == 0:
213            raise ValueError(
214                "Expects the optim to be initialized before load but found not initialized."
215            )
216
217        for idx, param_key in enumerate(self.ordered_param_keys):
218            # When the conditional training is performed, not all parameters are updated in the optim.
219            if param_key not in state.keys():
220                continue
221            if len(state[param_key]) != len(new_state[idx]):
222                raise ValueError(
223                    f"Expects equal length as {len(new_state[idx])} for parameter {param_key} but found: {len(state[param_key])}"
224                )
225            # Iterate through all optimizer states.
226            for state_key, state_val in new_state[idx].items():
227                if state_key not in state[param_key]:
228                    raise ValueError(
229                        f"Expects state {state_key} for parameter {param_key} but not found."
230                    )
231
232                src_state_val = state[param_key][state_key]
233                if isinstance(state_val, ShardedTensor):
234                    assert isinstance(src_state_val, ShardedTensor)
235                    num_shards = len(state_val.local_shards())
236                    num_new_shards = len(src_state_val.local_shards())
237                    if num_shards != num_new_shards:
238                        raise ValueError(
239                            f"Expects equal number of shards as {num_new_shards} but found {num_shards} for {param_key}/{state_key}"
240                        )
241                    for shard, src_shard in zip(
242                        state_val.local_shards(), src_state_val.local_shards()
243                    ):
244                        shard.tensor.detach().copy_(src_shard.tensor)
245                elif isinstance(state_val, torch.Tensor):
246                    assert isinstance(src_state_val, torch.Tensor)
247                    state_val.detach().copy_(src_state_val)
248                else:
249                    new_state[idx][state_key] = deepcopy(src_state_val)
250
251        # Load param_groups of state_dict
252        src_param_groups = state_dict["param_groups"]
253        new_param_groups = new_state_dict["param_groups"]
254
255        src_group_map = {}
256        for group in src_param_groups:
257            param_keys = list(group["params"])
258            src_group_map[_gen_param_group_key(param_keys)] = group
259        new_group_map = {}
260        for new_group in new_param_groups:
261            param_keys = []
262            for param_key in new_group["params"]:
263                param_keys.append(self.ordered_param_keys[param_key])  # type: ignore[call-overload]
264            new_group_map[_gen_param_group_key(param_keys)] = new_group
265        for group_key, new_group in new_group_map.items():
266            # When not all parameters are used in training or receive gradient, aka., not all parameters
267            # would be in the param_group. Thus we skip the group_key here.
268            if group_key not in src_group_map:
269                continue
270            src_group = src_group_map[group_key]
271            if len(src_group) != len(new_group):
272                raise ValueError(
273                    f"Expects equal param_group size as {len(new_group)} for group {group_key} but found {len(src_group)}."
274                )
275            for k in src_group:
276                if k not in new_group:
277                    raise ValueError(
278                        f"Expects group key {k} to be in group {group_key} in `state_dict` but is missing."
279                    )
280                if k != "params":
281                    new_group[k] = deepcopy(src_group[k])
282
283        self._optimizer.load_state_dict(new_state_dict)
284
285    def add_param_group(self, param_group: Mapping[str, Any]) -> None:
286        """
287        Add a param group to the :class:`_NamedOptimizer` s `param_groups`.
288
289        Warning: This API is still in development and subject to change.
290        """
291        assert isinstance(param_group, dict), "param group must be a dict"
292
293        params = param_group["params"]
294        if isinstance(params, torch.Tensor):
295            param_group["params"] = [params]
296        else:
297            param_group["params"] = list(params)
298
299        param_to_key = {param: key for key, param in self.named_parameters.items()}  # type: ignore[misc, has-type]
300        for param in param_group["params"]:
301            if param not in param_to_key:
302                raise ValueError("some parameters are not in the module")
303            self.ordered_param_keys.append(param_to_key[param])
304
305        self._optimizer.add_param_group(param_group)
306        # Update param_groups from optimizer.
307        self.param_groups = self._optimizer.param_groups
308
309    def init_state(self) -> None:
310        """
311        Run a dummy optimizer step, which allows to initialize optimizer state because we do lazy init for most optimizers.
312
313        This allows doing in-place loading of optimizer state from a checkpoint.
314        """
315        for param in self.named_parameters.values():
316            if param.requires_grad:
317                t = torch.zeros_like(param)
318                param.grad = torch.autograd.Variable(t)
319        # Calling ``step`` will load the initial state for optimizer states.
320        self.step(closure=None)
321
322    def _pre_load_state_dict(self, state_dict) -> Dict[str, Any]:
323        # TODO(chienchin): This API should be FSDP agnostic and should support
324        # general user hooks.
325        if isinstance(self.module, FSDP):
326            return FSDP.optim_state_dict_to_load(
327                self.module, self._optimizer, state_dict, is_named_optimizer=True
328            )
329        return state_dict
330
331    def _post_state_dict(self, state_dict) -> Dict[str, Any]:
332        # TODO(chienchin): This API should be FSDP agnostic and should support
333        # general user hooks.
334        if isinstance(self.module, FSDP):
335            FSDP.optim_state_dict(self.module, self._optimizer, state_dict)
336        return state_dict
337
338
339def _gen_param_group_key(param_keys: List[str]) -> str:
340    """Concatenate all param keys as a unique indentifier for one param group."""
341    return "/".join(sorted(param_keys))
342