1# mypy: allow-untyped-defs 2r"""Implementation for Stochastic Gradient Descent optimizer.""" 3from typing import cast, List, Optional, Union 4 5import torch 6from torch import Tensor 7 8from .optimizer import ( 9 _default_to_fused_or_foreach, 10 _device_dtype_check_for_fused, 11 _differentiable_doc, 12 _foreach_doc, 13 _fused_doc, 14 _maximize_doc, 15 _use_grad_for_differentiable, 16 DeviceDict, 17 Optimizer, 18) 19 20 21__all__ = ["SGD", "sgd"] 22 23 24class SGD(Optimizer): # noqa: D101 25 def __init__( 26 self, 27 params, 28 lr: Union[float, Tensor] = 1e-3, 29 momentum: float = 0, 30 dampening: float = 0, 31 weight_decay: float = 0, 32 nesterov=False, 33 *, 34 maximize: bool = False, 35 foreach: Optional[bool] = None, 36 differentiable: bool = False, 37 fused: Optional[bool] = None, 38 ): # noqa: D107 39 if isinstance(lr, Tensor) and lr.numel() != 1: 40 raise ValueError("Tensor lr must be 1-element") 41 if lr < 0.0: 42 raise ValueError(f"Invalid learning rate: {lr}") 43 if momentum < 0.0: 44 raise ValueError(f"Invalid momentum value: {momentum}") 45 if weight_decay < 0.0: 46 raise ValueError(f"Invalid weight_decay value: {weight_decay}") 47 48 defaults = dict( 49 lr=lr, 50 momentum=momentum, 51 dampening=dampening, 52 weight_decay=weight_decay, 53 nesterov=nesterov, 54 maximize=maximize, 55 foreach=foreach, 56 differentiable=differentiable, 57 fused=fused, 58 ) 59 if nesterov and (momentum <= 0 or dampening != 0): 60 raise ValueError("Nesterov momentum requires a momentum and zero dampening") 61 super().__init__(params, defaults) 62 63 if fused: 64 self._step_supports_amp_scaling = True 65 self._need_device_dtype_check_for_fused = True 66 if differentiable: 67 raise RuntimeError("`fused` does not support `differentiable`") 68 if foreach: 69 raise RuntimeError("`fused` and `foreach` cannot be `True` together.") 70 71 def __setstate__(self, state): # noqa: D105 72 super().__setstate__(state) 73 for group in self.param_groups: 74 group.setdefault("nesterov", False) 75 group.setdefault("maximize", False) 76 group.setdefault("foreach", None) 77 group.setdefault("differentiable", False) 78 group.setdefault("fused", False) 79 80 def _init_group(self, group, params, grads, momentum_buffer_list): 81 has_sparse_grad = False 82 83 for p in group["params"]: 84 if p.grad is not None: 85 if group["fused"] and getattr( 86 self, "_need_device_dtype_check_for_fused", True 87 ): 88 _device_dtype_check_for_fused(p) 89 self._need_device_dtype_check_for_fused = False 90 params.append(p) 91 grads.append(p.grad) 92 if p.grad.is_sparse: 93 has_sparse_grad = True 94 95 if group["momentum"] != 0: 96 state = self.state[p] 97 momentum_buffer_list.append(state.get("momentum_buffer")) 98 99 return has_sparse_grad 100 101 @_use_grad_for_differentiable 102 def step(self, closure=None): 103 """Perform a single optimization step. 104 105 Args: 106 closure (Callable, optional): A closure that reevaluates the model 107 and returns the loss. 108 """ 109 loss = None 110 if closure is not None: 111 with torch.enable_grad(): 112 loss = closure() 113 114 for group in self.param_groups: 115 params: List[Tensor] = [] 116 grads: List[Tensor] = [] 117 momentum_buffer_list: List[Optional[Tensor]] = [] 118 119 has_sparse_grad = self._init_group( 120 group, params, grads, momentum_buffer_list 121 ) 122 123 sgd( 124 params, 125 grads, 126 momentum_buffer_list, 127 weight_decay=group["weight_decay"], 128 momentum=group["momentum"], 129 lr=group["lr"], 130 dampening=group["dampening"], 131 nesterov=group["nesterov"], 132 maximize=group["maximize"], 133 has_sparse_grad=has_sparse_grad, 134 foreach=group["foreach"], 135 fused=group["fused"], 136 grad_scale=getattr(self, "grad_scale", None), 137 found_inf=getattr(self, "found_inf", None), 138 ) 139 140 if group["momentum"] != 0: 141 # update momentum_buffers in state 142 for p, momentum_buffer in zip(params, momentum_buffer_list): 143 state = self.state[p] 144 state["momentum_buffer"] = momentum_buffer 145 146 return loss 147 148 149SGD.__doc__ = ( 150 r"""Implements stochastic gradient descent (optionally with momentum). 151 152 .. math:: 153 \begin{aligned} 154 &\rule{110mm}{0.4pt} \\ 155 &\textbf{input} : \gamma \text{ (lr)}, \: \theta_0 \text{ (params)}, \: f(\theta) 156 \text{ (objective)}, \: \lambda \text{ (weight decay)}, \\ 157 &\hspace{13mm} \:\mu \text{ (momentum)}, \:\tau \text{ (dampening)}, 158 \:\textit{ nesterov,}\:\textit{ maximize} \\[-1.ex] 159 &\rule{110mm}{0.4pt} \\ 160 &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\ 161 &\hspace{5mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\ 162 &\hspace{5mm}\textbf{if} \: \lambda \neq 0 \\ 163 &\hspace{10mm} g_t \leftarrow g_t + \lambda \theta_{t-1} \\ 164 &\hspace{5mm}\textbf{if} \: \mu \neq 0 \\ 165 &\hspace{10mm}\textbf{if} \: t > 1 \\ 166 &\hspace{15mm} \textbf{b}_t \leftarrow \mu \textbf{b}_{t-1} + (1-\tau) g_t \\ 167 &\hspace{10mm}\textbf{else} \\ 168 &\hspace{15mm} \textbf{b}_t \leftarrow g_t \\ 169 &\hspace{10mm}\textbf{if} \: \textit{nesterov} \\ 170 &\hspace{15mm} g_t \leftarrow g_{t} + \mu \textbf{b}_t \\ 171 &\hspace{10mm}\textbf{else} \\[-1.ex] 172 &\hspace{15mm} g_t \leftarrow \textbf{b}_t \\ 173 &\hspace{5mm}\textbf{if} \: \textit{maximize} \\ 174 &\hspace{10mm}\theta_t \leftarrow \theta_{t-1} + \gamma g_t \\[-1.ex] 175 &\hspace{5mm}\textbf{else} \\[-1.ex] 176 &\hspace{10mm}\theta_t \leftarrow \theta_{t-1} - \gamma g_t \\[-1.ex] 177 &\rule{110mm}{0.4pt} \\[-1.ex] 178 &\bf{return} \: \theta_t \\[-1.ex] 179 &\rule{110mm}{0.4pt} \\[-1.ex] 180 \end{aligned} 181 182 Nesterov momentum is based on the formula from 183 `On the importance of initialization and momentum in deep learning`__. 184 """ 185 + rf""" 186 Args: 187 params (iterable): iterable of parameters to optimize or dicts defining 188 parameter groups 189 lr (float, Tensor, optional): learning rate (default: 1e-3) 190 momentum (float, optional): momentum factor (default: 0) 191 weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 192 dampening (float, optional): dampening for momentum (default: 0) 193 nesterov (bool, optional): enables Nesterov momentum (default: False) 194 {_maximize_doc} 195 {_foreach_doc} 196 {_differentiable_doc} 197 {_fused_doc} 198 """ 199 + r""" 200 201 Example: 202 >>> # xdoctest: +SKIP 203 >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9) 204 >>> optimizer.zero_grad() 205 >>> loss_fn(model(input), target).backward() 206 >>> optimizer.step() 207 208 __ http://www.cs.toronto.edu/%7Ehinton/absps/momentum.pdf 209 210 .. note:: 211 The implementation of SGD with Momentum/Nesterov subtly differs from 212 Sutskever et al. and implementations in some other frameworks. 213 214 Considering the specific case of Momentum, the update can be written as 215 216 .. math:: 217 \begin{aligned} 218 v_{t+1} & = \mu * v_{t} + g_{t+1}, \\ 219 p_{t+1} & = p_{t} - \text{lr} * v_{t+1}, 220 \end{aligned} 221 222 where :math:`p`, :math:`g`, :math:`v` and :math:`\mu` denote the 223 parameters, gradient, velocity, and momentum respectively. 224 225 This is in contrast to Sutskever et al. and 226 other frameworks which employ an update of the form 227 228 .. math:: 229 \begin{aligned} 230 v_{t+1} & = \mu * v_{t} + \text{lr} * g_{t+1}, \\ 231 p_{t+1} & = p_{t} - v_{t+1}. 232 \end{aligned} 233 234 The Nesterov version is analogously modified. 235 236 Moreover, the initial value of the momentum buffer is set to the 237 gradient value at the first step. This is in contrast to some other 238 frameworks that initialize it to all zeros. 239 240 """ 241) 242 243 244def sgd( 245 params: List[Tensor], 246 d_p_list: List[Tensor], 247 momentum_buffer_list: List[Optional[Tensor]], 248 # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627 249 # setting this as kwarg for now as functional API is compiled by torch/distributed/optim 250 has_sparse_grad: bool = False, 251 foreach: Optional[bool] = None, 252 fused: Optional[bool] = None, 253 grad_scale: Optional[Tensor] = None, 254 found_inf: Optional[Tensor] = None, 255 *, 256 weight_decay: float, 257 momentum: float, 258 lr: float, 259 dampening: float, 260 nesterov: bool, 261 maximize: bool, 262): 263 r"""Functional API that performs SGD algorithm computation. 264 265 See :class:`~torch.optim.SGD` for details. 266 """ 267 # Respect when the user inputs False/True for foreach or fused. We only want to change 268 # the default when neither have been user-specified. Note that we default to foreach 269 # and pass False to use_fused. This is not a mistake--we want to give the fused impl 270 # bake-in time before making it the default, even if it is typically faster. 271 if foreach is None and fused is None: 272 # why must we be explicit about an if statement for torch.jit.is_scripting here? 273 # because JIT can't handle Optionals nor fancy conditionals when scripting 274 if not torch.jit.is_scripting(): 275 fused, foreach = _default_to_fused_or_foreach( 276 params, differentiable=False, use_fused=False 277 ) 278 else: 279 foreach = False 280 fused = False 281 if foreach is None: 282 foreach = False 283 if fused is None: 284 fused = False 285 286 if foreach and torch.jit.is_scripting(): 287 raise RuntimeError("torch.jit.script not supported with foreach optimizers") 288 if fused and torch.jit.is_scripting(): 289 raise RuntimeError("torch.jit.script not supported with fused optimizers") 290 291 if foreach and not torch.jit.is_scripting(): 292 func = _multi_tensor_sgd 293 elif fused and not torch.jit.is_scripting(): 294 func = _fused_sgd 295 else: 296 func = _single_tensor_sgd 297 298 func( 299 params, 300 d_p_list, 301 momentum_buffer_list, 302 weight_decay=weight_decay, 303 momentum=momentum, 304 lr=lr, 305 dampening=dampening, 306 nesterov=nesterov, 307 has_sparse_grad=has_sparse_grad, 308 maximize=maximize, 309 grad_scale=grad_scale, 310 found_inf=found_inf, 311 ) 312 313 314def _single_tensor_sgd( 315 params: List[Tensor], 316 grads: List[Tensor], 317 momentum_buffer_list: List[Optional[Tensor]], 318 grad_scale: Optional[Tensor], 319 found_inf: Optional[Tensor], 320 *, 321 weight_decay: float, 322 momentum: float, 323 lr: float, 324 dampening: float, 325 nesterov: bool, 326 maximize: bool, 327 has_sparse_grad: bool, 328): 329 assert grad_scale is None and found_inf is None 330 331 for i, param in enumerate(params): 332 grad = grads[i] if not maximize else -grads[i] 333 334 if weight_decay != 0: 335 grad = grad.add(param, alpha=weight_decay) 336 337 if momentum != 0: 338 buf = momentum_buffer_list[i] 339 340 if buf is None: 341 buf = torch.clone(grad).detach() 342 momentum_buffer_list[i] = buf 343 else: 344 buf.mul_(momentum).add_(grad, alpha=1 - dampening) 345 346 if nesterov: 347 grad = grad.add(buf, alpha=momentum) 348 else: 349 grad = buf 350 351 param.add_(grad, alpha=-lr) 352 353 354def _multi_tensor_sgd( 355 params: List[Tensor], 356 grads: List[Tensor], 357 momentum_buffer_list: List[Optional[Tensor]], 358 grad_scale: Optional[Tensor], 359 found_inf: Optional[Tensor], 360 *, 361 weight_decay: float, 362 momentum: float, 363 lr: float, 364 dampening: float, 365 nesterov: bool, 366 maximize: bool, 367 has_sparse_grad: bool, 368): 369 assert grad_scale is None and found_inf is None 370 371 if len(params) == 0: 372 return 373 374 grouped_tensors = Optimizer._group_tensors_by_device_and_dtype( 375 [params, grads, momentum_buffer_list], with_indices=True # type: ignore[list-item] 376 ) 377 378 for ( 379 device_params_, 380 device_grads_, 381 device_momentum_buffer_list, 382 ), indices in grouped_tensors.values(): 383 device_params: List[Tensor] = cast(List[Tensor], device_params_) 384 device_grads: List[Tensor] = cast(List[Tensor], device_grads_) 385 386 device_has_sparse_grad = has_sparse_grad and any( 387 grad.is_sparse for grad in device_grads 388 ) 389 390 if maximize: 391 device_grads = torch._foreach_neg(device_grads) # type: ignore[assignment] 392 393 if weight_decay != 0: 394 # Re-use the intermediate memory (device_grads) already allocated for maximize 395 if maximize: 396 torch._foreach_add_(device_grads, device_params, alpha=weight_decay) 397 else: 398 device_grads = torch._foreach_add( # type: ignore[assignment] 399 device_grads, device_params, alpha=weight_decay 400 ) 401 402 if momentum != 0: 403 bufs: List[Tensor] = [] 404 405 all_states_with_momentum_buffer = True 406 for i in range(len(device_momentum_buffer_list)): 407 if device_momentum_buffer_list[i] is None: 408 all_states_with_momentum_buffer = False 409 break 410 else: 411 bufs.append(cast(Tensor, device_momentum_buffer_list[i])) 412 413 if all_states_with_momentum_buffer: 414 torch._foreach_mul_(bufs, momentum) 415 torch._foreach_add_(bufs, device_grads, alpha=1 - dampening) 416 else: 417 bufs = [] 418 for i in range(len(device_momentum_buffer_list)): 419 if device_momentum_buffer_list[i] is None: 420 buf = device_momentum_buffer_list[i] = momentum_buffer_list[ 421 indices[i] 422 ] = torch.clone(device_grads[i]).detach() 423 else: 424 buf = cast(Tensor, device_momentum_buffer_list[i]) 425 buf.mul_(momentum).add_(device_grads[i], alpha=1 - dampening) 426 427 bufs.append(buf) 428 429 if nesterov: 430 torch._foreach_add_(device_grads, bufs, alpha=momentum) 431 else: 432 device_grads = bufs 433 434 if not device_has_sparse_grad: 435 # handle internal item() call if lr is a tensor 436 if isinstance(lr, torch.Tensor) and torch._utils.is_compiling(): 437 grads_x_lr = torch._foreach_mul(device_grads, -lr) 438 torch._foreach_add_(device_params, grads_x_lr) 439 else: 440 torch._foreach_add_(device_params, device_grads, alpha=-lr) 441 else: 442 # foreach APIs don't support sparse 443 for i in range(len(device_params)): 444 device_params[i].add_(device_grads[i], alpha=-lr) 445 446 447def _fused_sgd( 448 params: List[Tensor], 449 grads: List[Tensor], 450 momentum_buffer_list: List[Optional[Tensor]], 451 grad_scale: Optional[Tensor], 452 found_inf: Optional[Tensor], 453 *, 454 weight_decay: float, 455 momentum: float, 456 lr: float, 457 dampening: float, 458 nesterov: bool, 459 maximize: bool, 460 has_sparse_grad: bool, 461) -> None: 462 if not params: 463 return 464 if has_sparse_grad: 465 raise RuntimeError("`_fused_sgd` does not support sparse gradients") 466 grad_scale_dict: DeviceDict = ( 467 {grad_scale.device: grad_scale} if grad_scale is not None else {} 468 ) 469 found_inf_dict: DeviceDict = ( 470 {found_inf.device: found_inf} if found_inf is not None else {} 471 ) 472 473 no_momentum_buffer = momentum == 0 474 is_first_step = ( 475 all(t is None for t in momentum_buffer_list) and not no_momentum_buffer 476 ) 477 if is_first_step: 478 for i, g in enumerate(grads): 479 momentum_buffer_list[i] = torch.empty_like(g) 480 grouped_tensors = Optimizer._group_tensors_by_device_and_dtype( 481 [params, grads, momentum_buffer_list], with_indices=False # type: ignore[list-item] 482 ) 483 for (device, _), ( 484 (device_params_, device_grads_, device_momentum_buffer_list), 485 _, 486 ) in grouped_tensors.items(): 487 device_params: List[Tensor] = cast(List[Tensor], device_params_) 488 device_grads: List[Tensor] = cast(List[Tensor], device_grads_) 489 device_grad_scale, device_found_inf = None, None 490 if grad_scale is not None: 491 device_grad_scale = grad_scale_dict.setdefault( 492 device, grad_scale.to(device) 493 ) 494 if found_inf_dict is not None and found_inf is not None: 495 device_found_inf = found_inf_dict.setdefault(device, found_inf.to(device)) 496 torch._fused_sgd_( 497 device_params, 498 device_grads, 499 [] 500 if no_momentum_buffer 501 else cast(List[Tensor], device_momentum_buffer_list), 502 weight_decay=weight_decay, 503 momentum=momentum, 504 lr=lr, 505 dampening=dampening, 506 nesterov=nesterov, 507 maximize=maximize, 508 is_first_step=is_first_step, 509 grad_scale=device_grad_scale, 510 found_inf=device_found_inf, 511 ) 512