xref: /aosp_15_r20/external/pytorch/torch/optim/sgd.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2r"""Implementation for Stochastic Gradient Descent optimizer."""
3from typing import cast, List, Optional, Union
4
5import torch
6from torch import Tensor
7
8from .optimizer import (
9    _default_to_fused_or_foreach,
10    _device_dtype_check_for_fused,
11    _differentiable_doc,
12    _foreach_doc,
13    _fused_doc,
14    _maximize_doc,
15    _use_grad_for_differentiable,
16    DeviceDict,
17    Optimizer,
18)
19
20
21__all__ = ["SGD", "sgd"]
22
23
24class SGD(Optimizer):  # noqa: D101
25    def __init__(
26        self,
27        params,
28        lr: Union[float, Tensor] = 1e-3,
29        momentum: float = 0,
30        dampening: float = 0,
31        weight_decay: float = 0,
32        nesterov=False,
33        *,
34        maximize: bool = False,
35        foreach: Optional[bool] = None,
36        differentiable: bool = False,
37        fused: Optional[bool] = None,
38    ):  # noqa: D107
39        if isinstance(lr, Tensor) and lr.numel() != 1:
40            raise ValueError("Tensor lr must be 1-element")
41        if lr < 0.0:
42            raise ValueError(f"Invalid learning rate: {lr}")
43        if momentum < 0.0:
44            raise ValueError(f"Invalid momentum value: {momentum}")
45        if weight_decay < 0.0:
46            raise ValueError(f"Invalid weight_decay value: {weight_decay}")
47
48        defaults = dict(
49            lr=lr,
50            momentum=momentum,
51            dampening=dampening,
52            weight_decay=weight_decay,
53            nesterov=nesterov,
54            maximize=maximize,
55            foreach=foreach,
56            differentiable=differentiable,
57            fused=fused,
58        )
59        if nesterov and (momentum <= 0 or dampening != 0):
60            raise ValueError("Nesterov momentum requires a momentum and zero dampening")
61        super().__init__(params, defaults)
62
63        if fused:
64            self._step_supports_amp_scaling = True
65            self._need_device_dtype_check_for_fused = True
66            if differentiable:
67                raise RuntimeError("`fused` does not support `differentiable`")
68            if foreach:
69                raise RuntimeError("`fused` and `foreach` cannot be `True` together.")
70
71    def __setstate__(self, state):  # noqa: D105
72        super().__setstate__(state)
73        for group in self.param_groups:
74            group.setdefault("nesterov", False)
75            group.setdefault("maximize", False)
76            group.setdefault("foreach", None)
77            group.setdefault("differentiable", False)
78            group.setdefault("fused", False)
79
80    def _init_group(self, group, params, grads, momentum_buffer_list):
81        has_sparse_grad = False
82
83        for p in group["params"]:
84            if p.grad is not None:
85                if group["fused"] and getattr(
86                    self, "_need_device_dtype_check_for_fused", True
87                ):
88                    _device_dtype_check_for_fused(p)
89                    self._need_device_dtype_check_for_fused = False
90                params.append(p)
91                grads.append(p.grad)
92                if p.grad.is_sparse:
93                    has_sparse_grad = True
94
95                if group["momentum"] != 0:
96                    state = self.state[p]
97                    momentum_buffer_list.append(state.get("momentum_buffer"))
98
99        return has_sparse_grad
100
101    @_use_grad_for_differentiable
102    def step(self, closure=None):
103        """Perform a single optimization step.
104
105        Args:
106            closure (Callable, optional): A closure that reevaluates the model
107                and returns the loss.
108        """
109        loss = None
110        if closure is not None:
111            with torch.enable_grad():
112                loss = closure()
113
114        for group in self.param_groups:
115            params: List[Tensor] = []
116            grads: List[Tensor] = []
117            momentum_buffer_list: List[Optional[Tensor]] = []
118
119            has_sparse_grad = self._init_group(
120                group, params, grads, momentum_buffer_list
121            )
122
123            sgd(
124                params,
125                grads,
126                momentum_buffer_list,
127                weight_decay=group["weight_decay"],
128                momentum=group["momentum"],
129                lr=group["lr"],
130                dampening=group["dampening"],
131                nesterov=group["nesterov"],
132                maximize=group["maximize"],
133                has_sparse_grad=has_sparse_grad,
134                foreach=group["foreach"],
135                fused=group["fused"],
136                grad_scale=getattr(self, "grad_scale", None),
137                found_inf=getattr(self, "found_inf", None),
138            )
139
140            if group["momentum"] != 0:
141                # update momentum_buffers in state
142                for p, momentum_buffer in zip(params, momentum_buffer_list):
143                    state = self.state[p]
144                    state["momentum_buffer"] = momentum_buffer
145
146        return loss
147
148
149SGD.__doc__ = (
150    r"""Implements stochastic gradient descent (optionally with momentum).
151
152    .. math::
153       \begin{aligned}
154            &\rule{110mm}{0.4pt}                                                                 \\
155            &\textbf{input}      : \gamma \text{ (lr)}, \: \theta_0 \text{ (params)}, \: f(\theta)
156                \text{ (objective)}, \: \lambda \text{ (weight decay)},                          \\
157            &\hspace{13mm} \:\mu \text{ (momentum)}, \:\tau \text{ (dampening)},
158            \:\textit{ nesterov,}\:\textit{ maximize}                                     \\[-1.ex]
159            &\rule{110mm}{0.4pt}                                                                 \\
160            &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do}                         \\
161            &\hspace{5mm}g_t           \leftarrow   \nabla_{\theta} f_t (\theta_{t-1})           \\
162            &\hspace{5mm}\textbf{if} \: \lambda \neq 0                                           \\
163            &\hspace{10mm} g_t \leftarrow g_t + \lambda  \theta_{t-1}                            \\
164            &\hspace{5mm}\textbf{if} \: \mu \neq 0                                               \\
165            &\hspace{10mm}\textbf{if} \: t > 1                                                   \\
166            &\hspace{15mm} \textbf{b}_t \leftarrow \mu \textbf{b}_{t-1} + (1-\tau) g_t           \\
167            &\hspace{10mm}\textbf{else}                                                          \\
168            &\hspace{15mm} \textbf{b}_t \leftarrow g_t                                           \\
169            &\hspace{10mm}\textbf{if} \: \textit{nesterov}                                       \\
170            &\hspace{15mm} g_t \leftarrow g_{t} + \mu \textbf{b}_t                             \\
171            &\hspace{10mm}\textbf{else}                                                   \\[-1.ex]
172            &\hspace{15mm} g_t  \leftarrow  \textbf{b}_t                                         \\
173            &\hspace{5mm}\textbf{if} \: \textit{maximize}                                          \\
174            &\hspace{10mm}\theta_t \leftarrow \theta_{t-1} + \gamma g_t                   \\[-1.ex]
175            &\hspace{5mm}\textbf{else}                                                    \\[-1.ex]
176            &\hspace{10mm}\theta_t \leftarrow \theta_{t-1} - \gamma g_t                   \\[-1.ex]
177            &\rule{110mm}{0.4pt}                                                          \\[-1.ex]
178            &\bf{return} \:  \theta_t                                                     \\[-1.ex]
179            &\rule{110mm}{0.4pt}                                                          \\[-1.ex]
180       \end{aligned}
181
182    Nesterov momentum is based on the formula from
183    `On the importance of initialization and momentum in deep learning`__.
184    """
185    + rf"""
186    Args:
187        params (iterable): iterable of parameters to optimize or dicts defining
188            parameter groups
189        lr (float, Tensor, optional): learning rate (default: 1e-3)
190        momentum (float, optional): momentum factor (default: 0)
191        weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
192        dampening (float, optional): dampening for momentum (default: 0)
193        nesterov (bool, optional): enables Nesterov momentum (default: False)
194        {_maximize_doc}
195        {_foreach_doc}
196        {_differentiable_doc}
197        {_fused_doc}
198    """
199    + r"""
200
201    Example:
202        >>> # xdoctest: +SKIP
203        >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
204        >>> optimizer.zero_grad()
205        >>> loss_fn(model(input), target).backward()
206        >>> optimizer.step()
207
208    __ http://www.cs.toronto.edu/%7Ehinton/absps/momentum.pdf
209
210    .. note::
211        The implementation of SGD with Momentum/Nesterov subtly differs from
212        Sutskever et al. and implementations in some other frameworks.
213
214        Considering the specific case of Momentum, the update can be written as
215
216        .. math::
217            \begin{aligned}
218                v_{t+1} & = \mu * v_{t} + g_{t+1}, \\
219                p_{t+1} & = p_{t} - \text{lr} * v_{t+1},
220            \end{aligned}
221
222        where :math:`p`, :math:`g`, :math:`v` and :math:`\mu` denote the
223        parameters, gradient, velocity, and momentum respectively.
224
225        This is in contrast to Sutskever et al. and
226        other frameworks which employ an update of the form
227
228        .. math::
229            \begin{aligned}
230                v_{t+1} & = \mu * v_{t} + \text{lr} * g_{t+1}, \\
231                p_{t+1} & = p_{t} - v_{t+1}.
232            \end{aligned}
233
234        The Nesterov version is analogously modified.
235
236        Moreover, the initial value of the momentum buffer is set to the
237        gradient value at the first step. This is in contrast to some other
238        frameworks that initialize it to all zeros.
239
240    """
241)
242
243
244def sgd(
245    params: List[Tensor],
246    d_p_list: List[Tensor],
247    momentum_buffer_list: List[Optional[Tensor]],
248    # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
249    # setting this as kwarg for now as functional API is compiled by torch/distributed/optim
250    has_sparse_grad: bool = False,
251    foreach: Optional[bool] = None,
252    fused: Optional[bool] = None,
253    grad_scale: Optional[Tensor] = None,
254    found_inf: Optional[Tensor] = None,
255    *,
256    weight_decay: float,
257    momentum: float,
258    lr: float,
259    dampening: float,
260    nesterov: bool,
261    maximize: bool,
262):
263    r"""Functional API that performs SGD algorithm computation.
264
265    See :class:`~torch.optim.SGD` for details.
266    """
267    # Respect when the user inputs False/True for foreach or fused. We only want to change
268    # the default when neither have been user-specified. Note that we default to foreach
269    # and pass False to use_fused. This is not a mistake--we want to give the fused impl
270    # bake-in time before making it the default, even if it is typically faster.
271    if foreach is None and fused is None:
272        # why must we be explicit about an if statement for torch.jit.is_scripting here?
273        # because JIT can't handle Optionals nor fancy conditionals when scripting
274        if not torch.jit.is_scripting():
275            fused, foreach = _default_to_fused_or_foreach(
276                params, differentiable=False, use_fused=False
277            )
278        else:
279            foreach = False
280            fused = False
281    if foreach is None:
282        foreach = False
283    if fused is None:
284        fused = False
285
286    if foreach and torch.jit.is_scripting():
287        raise RuntimeError("torch.jit.script not supported with foreach optimizers")
288    if fused and torch.jit.is_scripting():
289        raise RuntimeError("torch.jit.script not supported with fused optimizers")
290
291    if foreach and not torch.jit.is_scripting():
292        func = _multi_tensor_sgd
293    elif fused and not torch.jit.is_scripting():
294        func = _fused_sgd
295    else:
296        func = _single_tensor_sgd
297
298    func(
299        params,
300        d_p_list,
301        momentum_buffer_list,
302        weight_decay=weight_decay,
303        momentum=momentum,
304        lr=lr,
305        dampening=dampening,
306        nesterov=nesterov,
307        has_sparse_grad=has_sparse_grad,
308        maximize=maximize,
309        grad_scale=grad_scale,
310        found_inf=found_inf,
311    )
312
313
314def _single_tensor_sgd(
315    params: List[Tensor],
316    grads: List[Tensor],
317    momentum_buffer_list: List[Optional[Tensor]],
318    grad_scale: Optional[Tensor],
319    found_inf: Optional[Tensor],
320    *,
321    weight_decay: float,
322    momentum: float,
323    lr: float,
324    dampening: float,
325    nesterov: bool,
326    maximize: bool,
327    has_sparse_grad: bool,
328):
329    assert grad_scale is None and found_inf is None
330
331    for i, param in enumerate(params):
332        grad = grads[i] if not maximize else -grads[i]
333
334        if weight_decay != 0:
335            grad = grad.add(param, alpha=weight_decay)
336
337        if momentum != 0:
338            buf = momentum_buffer_list[i]
339
340            if buf is None:
341                buf = torch.clone(grad).detach()
342                momentum_buffer_list[i] = buf
343            else:
344                buf.mul_(momentum).add_(grad, alpha=1 - dampening)
345
346            if nesterov:
347                grad = grad.add(buf, alpha=momentum)
348            else:
349                grad = buf
350
351        param.add_(grad, alpha=-lr)
352
353
354def _multi_tensor_sgd(
355    params: List[Tensor],
356    grads: List[Tensor],
357    momentum_buffer_list: List[Optional[Tensor]],
358    grad_scale: Optional[Tensor],
359    found_inf: Optional[Tensor],
360    *,
361    weight_decay: float,
362    momentum: float,
363    lr: float,
364    dampening: float,
365    nesterov: bool,
366    maximize: bool,
367    has_sparse_grad: bool,
368):
369    assert grad_scale is None and found_inf is None
370
371    if len(params) == 0:
372        return
373
374    grouped_tensors = Optimizer._group_tensors_by_device_and_dtype(
375        [params, grads, momentum_buffer_list], with_indices=True  # type: ignore[list-item]
376    )
377
378    for (
379        device_params_,
380        device_grads_,
381        device_momentum_buffer_list,
382    ), indices in grouped_tensors.values():
383        device_params: List[Tensor] = cast(List[Tensor], device_params_)
384        device_grads: List[Tensor] = cast(List[Tensor], device_grads_)
385
386        device_has_sparse_grad = has_sparse_grad and any(
387            grad.is_sparse for grad in device_grads
388        )
389
390        if maximize:
391            device_grads = torch._foreach_neg(device_grads)  # type: ignore[assignment]
392
393        if weight_decay != 0:
394            # Re-use the intermediate memory (device_grads) already allocated for maximize
395            if maximize:
396                torch._foreach_add_(device_grads, device_params, alpha=weight_decay)
397            else:
398                device_grads = torch._foreach_add(  # type: ignore[assignment]
399                    device_grads, device_params, alpha=weight_decay
400                )
401
402        if momentum != 0:
403            bufs: List[Tensor] = []
404
405            all_states_with_momentum_buffer = True
406            for i in range(len(device_momentum_buffer_list)):
407                if device_momentum_buffer_list[i] is None:
408                    all_states_with_momentum_buffer = False
409                    break
410                else:
411                    bufs.append(cast(Tensor, device_momentum_buffer_list[i]))
412
413            if all_states_with_momentum_buffer:
414                torch._foreach_mul_(bufs, momentum)
415                torch._foreach_add_(bufs, device_grads, alpha=1 - dampening)
416            else:
417                bufs = []
418                for i in range(len(device_momentum_buffer_list)):
419                    if device_momentum_buffer_list[i] is None:
420                        buf = device_momentum_buffer_list[i] = momentum_buffer_list[
421                            indices[i]
422                        ] = torch.clone(device_grads[i]).detach()
423                    else:
424                        buf = cast(Tensor, device_momentum_buffer_list[i])
425                        buf.mul_(momentum).add_(device_grads[i], alpha=1 - dampening)
426
427                    bufs.append(buf)
428
429            if nesterov:
430                torch._foreach_add_(device_grads, bufs, alpha=momentum)
431            else:
432                device_grads = bufs
433
434        if not device_has_sparse_grad:
435            # handle internal item() call if lr is a tensor
436            if isinstance(lr, torch.Tensor) and torch._utils.is_compiling():
437                grads_x_lr = torch._foreach_mul(device_grads, -lr)
438                torch._foreach_add_(device_params, grads_x_lr)
439            else:
440                torch._foreach_add_(device_params, device_grads, alpha=-lr)
441        else:
442            # foreach APIs don't support sparse
443            for i in range(len(device_params)):
444                device_params[i].add_(device_grads[i], alpha=-lr)
445
446
447def _fused_sgd(
448    params: List[Tensor],
449    grads: List[Tensor],
450    momentum_buffer_list: List[Optional[Tensor]],
451    grad_scale: Optional[Tensor],
452    found_inf: Optional[Tensor],
453    *,
454    weight_decay: float,
455    momentum: float,
456    lr: float,
457    dampening: float,
458    nesterov: bool,
459    maximize: bool,
460    has_sparse_grad: bool,
461) -> None:
462    if not params:
463        return
464    if has_sparse_grad:
465        raise RuntimeError("`_fused_sgd` does not support sparse gradients")
466    grad_scale_dict: DeviceDict = (
467        {grad_scale.device: grad_scale} if grad_scale is not None else {}
468    )
469    found_inf_dict: DeviceDict = (
470        {found_inf.device: found_inf} if found_inf is not None else {}
471    )
472
473    no_momentum_buffer = momentum == 0
474    is_first_step = (
475        all(t is None for t in momentum_buffer_list) and not no_momentum_buffer
476    )
477    if is_first_step:
478        for i, g in enumerate(grads):
479            momentum_buffer_list[i] = torch.empty_like(g)
480    grouped_tensors = Optimizer._group_tensors_by_device_and_dtype(
481        [params, grads, momentum_buffer_list], with_indices=False  # type: ignore[list-item]
482    )
483    for (device, _), (
484        (device_params_, device_grads_, device_momentum_buffer_list),
485        _,
486    ) in grouped_tensors.items():
487        device_params: List[Tensor] = cast(List[Tensor], device_params_)
488        device_grads: List[Tensor] = cast(List[Tensor], device_grads_)
489        device_grad_scale, device_found_inf = None, None
490        if grad_scale is not None:
491            device_grad_scale = grad_scale_dict.setdefault(
492                device, grad_scale.to(device)
493            )
494        if found_inf_dict is not None and found_inf is not None:
495            device_found_inf = found_inf_dict.setdefault(device, found_inf.to(device))
496        torch._fused_sgd_(
497            device_params,
498            device_grads,
499            []
500            if no_momentum_buffer
501            else cast(List[Tensor], device_momentum_buffer_list),
502            weight_decay=weight_decay,
503            momentum=momentum,
504            lr=lr,
505            dampening=dampening,
506            nesterov=nesterov,
507            maximize=maximize,
508            is_first_step=is_first_step,
509            grad_scale=device_grad_scale,
510            found_inf=device_found_inf,
511        )
512