1# mypy: allow-untyped-defs 2import logging 3import warnings 4from copy import deepcopy 5from typing import ( 6 Any, 7 Callable, 8 Collection, 9 Dict, 10 List, 11 Mapping, 12 Optional, 13 overload, 14 Union, 15) 16 17import torch 18import torch.nn as nn 19from torch import optim 20from torch.distributed._shard.sharded_tensor import ShardedTensor 21from torch.distributed.fsdp import FullyShardedDataParallel as FSDP 22 23 24__all__: List[str] = [] 25 26logger = logging.getLogger(__name__) 27 28 29class _NamedOptimizer(optim.Optimizer): 30 """ 31 ``_NamedOptimizer`` takes a dict of parameters and exposes ``state_dict`` by parameter key. 32 33 We replace the original key (number) in an optim to the 34 fully qualified name (FQN) string. User can initialize the optim as they 35 initialize a PyTorch optim, the only difference is that they also need to 36 pass in the FQN of each parameters. 37 38 Args: 39 named_parameters (Mapping[str, Union[torch.Tensor, ShardedTensor]]): 40 Mapping from FQN to parameter. 41 optimizer_class (optim.Optimizer): 42 The class of optimizer to instantiate. 43 param_groups (Collection[Mapping[str, Any]]): 44 `param_groups` to pass to optimizer if specified. 45 The key of the inner map needs to be FQNs. 46 Default: None 47 module (nn.Module): the module whose parameters to updated 48 by the optimizer. 49 args: arguments to pass to the optimizer constructor. 50 kwargs: arguments to pass to the optimizer constructor. 51 52 Example:: 53 >>> # xdoctest: +SKIP("distributed") 54 >>> from torch import optim 55 >>> from torch.distributed.optim import _NamedOptimizer 56 >>> 57 >>> # Define the named optimizer. 58 >>> m = Model(...) 59 >>> named_optim = _NamedOptimizer(m.named_parameters(), optim.SGD) 60 >>> # Forward pass + backward pass. 61 >>> named_optim.step() 62 >>> ... 63 >>> # Call state_dict for the named optimizer returns a FQN state_dict. 64 >>> named_optim.state_dict() 65 66 Warning: This API is still in development and subject to change. 67 68 TODO: Add tutorial for _NamedOptimizer. 69 TODO: Add documentation in the docstring for the public attributes 70 like self.param_groups and self.named_parameters. 71 """ 72 73 def __init__( 74 self, 75 named_parameters: Mapping[str, Union[torch.Tensor, ShardedTensor]], 76 optimizer_class: optim.Optimizer, 77 param_groups: Optional[Collection[Mapping[str, Any]]] = None, 78 module: Optional[nn.Module] = None, 79 *args, 80 **kwargs, 81 ) -> None: 82 torch._C._log_api_usage_once("torch.distributed.optim._NamedOptimizer") 83 self.param_groups: Collection[Mapping[str, Any]] = param_groups # type: ignore[assignment] 84 self._param_groups_check() 85 self.named_parameters = dict(named_parameters) 86 params_for_optimizer = ( 87 self.named_parameters.values() if param_groups is None else param_groups 88 ) 89 self._optimizer = optimizer_class( # type: ignore[operator] 90 params_for_optimizer, 91 *args, 92 **kwargs, 93 ) 94 self.module = module 95 if param_groups is None: 96 self.ordered_param_keys = list(self.named_parameters.keys()) 97 else: 98 warnings.warn( 99 "Since we pass in param_groups, we will use param_groups to " 100 "initialize the optimizer, not all parameters of the module." 101 ) 102 param_to_key = {param: key for key, param in self.named_parameters.items()} # type: ignore[misc, has-type] 103 ordered_param_keys = [] 104 for group in param_groups: 105 for param in group["params"]: 106 if param not in param_to_key: 107 raise ValueError( 108 f"Expect param name {param} found in param group but is missing." 109 ) 110 ordered_param_keys.append(param_to_key[param]) 111 self.ordered_param_keys = ordered_param_keys 112 # Update param_groups from optimizer. 113 self.param_groups = self._optimizer.param_groups 114 115 def _param_groups_check(self): 116 if self.param_groups is not None: 117 for param_group in self.param_groups: 118 assert isinstance(param_group, dict), "param group must be a dict" 119 assert "params" in param_group, "param group must contain key params" 120 params = param_group["params"] 121 if isinstance(params, torch.Tensor): 122 params = [params] 123 params = list(params) 124 for param in params: 125 if not isinstance(param, torch.Tensor): 126 raise TypeError( 127 "optimizer can only optimize Tensors, " 128 "but one of the params is " + torch.typename(param) 129 ) 130 param_group["params"] = params 131 132 def state_dict(self) -> Dict[str, Any]: 133 """ 134 Return the ``state_dict`` of the optimizer. 135 136 Instead of using number to index 137 parameters, we will use module fully qualified name (FQN) as the key. 138 """ 139 state_dict = self._optimizer.state_dict() 140 param_groups = state_dict["param_groups"] 141 142 ret_state = { 143 self.ordered_param_keys[st_key]: state_val 144 for st_key, state_val in state_dict["state"].items() 145 } 146 147 ret_groups = [] 148 for group in param_groups: 149 param_keys = [] 150 for param in group["params"]: 151 param_keys.append(self.ordered_param_keys[param]) 152 ret_group = {"params": sorted(param_keys)} 153 for k, v in group.items(): 154 if k != "params": 155 ret_group[k] = deepcopy(v) 156 ret_groups.append(ret_group) 157 158 return self._post_state_dict({"state": ret_state, "param_groups": ret_groups}) 159 160 @overload 161 def step(self, closure: None = ...) -> None: 162 ... 163 164 @overload 165 def step(self, closure: Callable[[], float]) -> float: 166 ... 167 168 def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]: 169 """ 170 Perform a single optimization step. 171 172 This will call :meth:`torch.optim.Optimizer.step` on the wrapped 173 optimizer. 174 """ 175 return self._optimizer.step(closure=closure) 176 177 @property 178 def state(self) -> Mapping[torch.Tensor, Any]: # type: ignore[override] 179 return self._optimizer.state 180 181 def load_state_dict(self, state_dict: Mapping[str, Any]) -> None: 182 """ 183 Define the default behavior to load a state_dict for ``_NamedOptimizer``. 184 185 Sample Code 186 ``` 187 my_model = MyModule() 188 optimizer = _NamedOptimizer(my_model.named_parameters(), Adagrad) 189 ... 190 191 optim_state_dict = optimizer.state_dict() 192 ... 193 ... 194 195 optimizer.load_state_dict(optim_state_dict) 196 ... 197 ``` 198 Args: 199 state_dict (Dict[str, Any]) : A ``state_dict`` to load into the optimizer. 200 Note that this state dict update is performed in place. 201 202 .. note:: PyTorch is using lazy init to initialize the optim states. 203 So it is possible that there is no optim state when user call 204 ``load_state_dict`` and for ``_NamedOptimizer`` we make it stricter 205 that users can only call ``load_state_dict`` after the state is initialized. 206 By doing this, we can validate the optim ``state_dict`` to be loaded. 207 """ 208 new_state_dict = self._optimizer.state_dict() 209 state_dict = self._pre_load_state_dict(state_dict) 210 state = state_dict["state"] 211 new_state = new_state_dict["state"] 212 if len(new_state) == 0: 213 raise ValueError( 214 "Expects the optim to be initialized before load but found not initialized." 215 ) 216 217 for idx, param_key in enumerate(self.ordered_param_keys): 218 # When the conditional training is performed, not all parameters are updated in the optim. 219 if param_key not in state.keys(): 220 continue 221 if len(state[param_key]) != len(new_state[idx]): 222 raise ValueError( 223 f"Expects equal length as {len(new_state[idx])} for parameter {param_key} but found: {len(state[param_key])}" 224 ) 225 # Iterate through all optimizer states. 226 for state_key, state_val in new_state[idx].items(): 227 if state_key not in state[param_key]: 228 raise ValueError( 229 f"Expects state {state_key} for parameter {param_key} but not found." 230 ) 231 232 src_state_val = state[param_key][state_key] 233 if isinstance(state_val, ShardedTensor): 234 assert isinstance(src_state_val, ShardedTensor) 235 num_shards = len(state_val.local_shards()) 236 num_new_shards = len(src_state_val.local_shards()) 237 if num_shards != num_new_shards: 238 raise ValueError( 239 f"Expects equal number of shards as {num_new_shards} but found {num_shards} for {param_key}/{state_key}" 240 ) 241 for shard, src_shard in zip( 242 state_val.local_shards(), src_state_val.local_shards() 243 ): 244 shard.tensor.detach().copy_(src_shard.tensor) 245 elif isinstance(state_val, torch.Tensor): 246 assert isinstance(src_state_val, torch.Tensor) 247 state_val.detach().copy_(src_state_val) 248 else: 249 new_state[idx][state_key] = deepcopy(src_state_val) 250 251 # Load param_groups of state_dict 252 src_param_groups = state_dict["param_groups"] 253 new_param_groups = new_state_dict["param_groups"] 254 255 src_group_map = {} 256 for group in src_param_groups: 257 param_keys = list(group["params"]) 258 src_group_map[_gen_param_group_key(param_keys)] = group 259 new_group_map = {} 260 for new_group in new_param_groups: 261 param_keys = [] 262 for param_key in new_group["params"]: 263 param_keys.append(self.ordered_param_keys[param_key]) # type: ignore[call-overload] 264 new_group_map[_gen_param_group_key(param_keys)] = new_group 265 for group_key, new_group in new_group_map.items(): 266 # When not all parameters are used in training or receive gradient, aka., not all parameters 267 # would be in the param_group. Thus we skip the group_key here. 268 if group_key not in src_group_map: 269 continue 270 src_group = src_group_map[group_key] 271 if len(src_group) != len(new_group): 272 raise ValueError( 273 f"Expects equal param_group size as {len(new_group)} for group {group_key} but found {len(src_group)}." 274 ) 275 for k in src_group: 276 if k not in new_group: 277 raise ValueError( 278 f"Expects group key {k} to be in group {group_key} in `state_dict` but is missing." 279 ) 280 if k != "params": 281 new_group[k] = deepcopy(src_group[k]) 282 283 self._optimizer.load_state_dict(new_state_dict) 284 285 def add_param_group(self, param_group: Mapping[str, Any]) -> None: 286 """ 287 Add a param group to the :class:`_NamedOptimizer` s `param_groups`. 288 289 Warning: This API is still in development and subject to change. 290 """ 291 assert isinstance(param_group, dict), "param group must be a dict" 292 293 params = param_group["params"] 294 if isinstance(params, torch.Tensor): 295 param_group["params"] = [params] 296 else: 297 param_group["params"] = list(params) 298 299 param_to_key = {param: key for key, param in self.named_parameters.items()} # type: ignore[misc, has-type] 300 for param in param_group["params"]: 301 if param not in param_to_key: 302 raise ValueError("some parameters are not in the module") 303 self.ordered_param_keys.append(param_to_key[param]) 304 305 self._optimizer.add_param_group(param_group) 306 # Update param_groups from optimizer. 307 self.param_groups = self._optimizer.param_groups 308 309 def init_state(self) -> None: 310 """ 311 Run a dummy optimizer step, which allows to initialize optimizer state because we do lazy init for most optimizers. 312 313 This allows doing in-place loading of optimizer state from a checkpoint. 314 """ 315 for param in self.named_parameters.values(): 316 if param.requires_grad: 317 t = torch.zeros_like(param) 318 param.grad = torch.autograd.Variable(t) 319 # Calling ``step`` will load the initial state for optimizer states. 320 self.step(closure=None) 321 322 def _pre_load_state_dict(self, state_dict) -> Dict[str, Any]: 323 # TODO(chienchin): This API should be FSDP agnostic and should support 324 # general user hooks. 325 if isinstance(self.module, FSDP): 326 return FSDP.optim_state_dict_to_load( 327 self.module, self._optimizer, state_dict, is_named_optimizer=True 328 ) 329 return state_dict 330 331 def _post_state_dict(self, state_dict) -> Dict[str, Any]: 332 # TODO(chienchin): This API should be FSDP agnostic and should support 333 # general user hooks. 334 if isinstance(self.module, FSDP): 335 FSDP.optim_state_dict(self.module, self._optimizer, state_dict) 336 return state_dict 337 338 339def _gen_param_group_key(param_keys: List[str]) -> str: 340 """Concatenate all param keys as a unique indentifier for one param group.""" 341 return "/".join(sorted(param_keys)) 342