xref: /aosp_15_r20/external/pytorch/torch/optim/rmsprop.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-decorators
2# mypy: allow-untyped-defs
3r"""Implementation for the RMSprop algorithm."""
4from typing import cast, List, Optional, Union
5
6import torch
7from torch import Tensor
8
9from .optimizer import (
10    _capturable_doc,
11    _default_to_fused_or_foreach,
12    _differentiable_doc,
13    _disable_dynamo_if_unsupported,
14    _foreach_doc,
15    _get_capturable_supported_devices,
16    _get_scalar_dtype,
17    _maximize_doc,
18    _use_grad_for_differentiable,
19    _view_as_real,
20    Optimizer,
21    ParamsT,
22)
23
24
25__all__ = ["RMSprop", "rmsprop"]
26
27
28class RMSprop(Optimizer):  # noqa: D101
29    def __init__(
30        self,
31        params: ParamsT,
32        lr: Union[float, Tensor] = 1e-2,
33        alpha: float = 0.99,
34        eps: float = 1e-8,
35        weight_decay: float = 0,
36        momentum: float = 0,
37        centered=False,
38        capturable=False,
39        foreach: Optional[bool] = None,
40        maximize: bool = False,
41        differentiable: bool = False,
42    ):  # noqa: D107
43        if isinstance(lr, Tensor) and lr.numel() != 1:
44            raise ValueError("Tensor lr must be 1-element")
45        if not 0.0 <= lr:
46            raise ValueError(f"Invalid learning rate: {lr}")
47        if not 0.0 <= eps:
48            raise ValueError(f"Invalid epsilon value: {eps}")
49        if not 0.0 <= momentum:
50            raise ValueError(f"Invalid momentum value: {momentum}")
51        if not 0.0 <= weight_decay:
52            raise ValueError(f"Invalid weight_decay value: {weight_decay}")
53        if not 0.0 <= alpha:
54            raise ValueError(f"Invalid alpha value: {alpha}")
55
56        defaults = dict(
57            lr=lr,
58            momentum=momentum,
59            alpha=alpha,
60            eps=eps,
61            centered=centered,
62            weight_decay=weight_decay,
63            capturable=capturable,
64            foreach=foreach,
65            maximize=maximize,
66            differentiable=differentiable,
67        )
68        super().__init__(params, defaults)
69
70    def __setstate__(self, state):  # noqa: D105
71        super().__setstate__(state)
72        for group in self.param_groups:
73            group.setdefault("momentum", 0)
74            group.setdefault("centered", False)
75            group.setdefault("foreach", None)
76            group.setdefault("maximize", False)
77            group.setdefault("differentiable", False)
78            group.setdefault("capturable", False)
79            for p in group["params"]:
80                p_state = self.state.get(p, [])
81                if len(p_state) != 0 and not torch.is_tensor(p_state["step"]):
82                    step_val = float(p_state["step"])
83                    p_state["step"] = (
84                        torch.tensor(
85                            step_val, dtype=_get_scalar_dtype(), device=p.device
86                        )
87                        if group["capturable"]
88                        else torch.tensor(step_val, dtype=_get_scalar_dtype())
89                    )
90
91    def _init_group(
92        self,
93        group,
94        params_with_grad,
95        grads,
96        square_avgs,
97        momentum_buffer_list,
98        grad_avgs,
99        state_steps,
100    ):
101        has_complex = False
102        for p in group["params"]:
103            if p.grad is None:
104                continue
105            has_complex |= torch.is_complex(p)
106            params_with_grad.append(p)
107
108            if p.grad.is_sparse:
109                raise RuntimeError("RMSprop does not support sparse gradients")
110            grads.append(p.grad)
111
112            state = self.state[p]
113
114            # State initialization
115            if len(state) == 0:
116                state["step"] = (
117                    torch.zeros((), dtype=_get_scalar_dtype(), device=p.device)
118                    if group["capturable"]
119                    else torch.zeros((), dtype=_get_scalar_dtype())
120                )
121                state["square_avg"] = torch.zeros_like(
122                    p, memory_format=torch.preserve_format
123                )
124                if group["momentum"] > 0:
125                    state["momentum_buffer"] = torch.zeros_like(
126                        p, memory_format=torch.preserve_format
127                    )
128                if group["centered"]:
129                    state["grad_avg"] = torch.zeros_like(
130                        p, memory_format=torch.preserve_format
131                    )
132            square_avgs.append(state["square_avg"])
133            state_steps.append(state["step"])
134
135            if group["momentum"] > 0:
136                momentum_buffer_list.append(state["momentum_buffer"])
137            if group["centered"]:
138                grad_avgs.append(state["grad_avg"])
139
140        return has_complex
141
142    @_use_grad_for_differentiable
143    def step(self, closure=None):
144        """Perform a single optimization step.
145
146        Args:
147            closure (Callable, optional): A closure that reevaluates the model
148                and returns the loss.
149        """
150        self._cuda_graph_capture_health_check()
151
152        loss = None
153        if closure is not None:
154            with torch.enable_grad():
155                loss = closure()
156
157        for group in self.param_groups:
158            params_with_grad: List[Tensor] = []
159            grads: List[Tensor] = []
160            square_avgs: List[Tensor] = []
161            grad_avgs: List[Tensor] = []
162            momentum_buffer_list: List[Tensor] = []
163            state_steps: List[Tensor] = []
164
165            has_complex = self._init_group(
166                group,
167                params_with_grad,
168                grads,
169                square_avgs,
170                momentum_buffer_list,
171                grad_avgs,
172                state_steps,
173            )
174
175            rmsprop(
176                params_with_grad,
177                grads,
178                square_avgs,
179                grad_avgs,
180                momentum_buffer_list,
181                state_steps,
182                lr=group["lr"],
183                alpha=group["alpha"],
184                eps=group["eps"],
185                weight_decay=group["weight_decay"],
186                momentum=group["momentum"],
187                centered=group["centered"],
188                foreach=group["foreach"],
189                maximize=group["maximize"],
190                differentiable=group["differentiable"],
191                capturable=group["capturable"],
192                has_complex=has_complex,
193            )
194
195        return loss
196
197
198RMSprop.__doc__ = (
199    r"""Implements RMSprop algorithm.
200
201    .. math::
202       \begin{aligned}
203            &\rule{110mm}{0.4pt}                                                                 \\
204            &\textbf{input}      : \alpha \text{ (alpha)},\: \gamma \text{ (lr)},
205                \: \theta_0 \text{ (params)}, \: f(\theta) \text{ (objective)}                   \\
206            &\hspace{13mm}   \lambda \text{ (weight decay)},\: \mu \text{ (momentum)},\: centered\\
207            &\textbf{initialize} : v_0 \leftarrow 0 \text{ (square average)}, \:
208                \textbf{b}_0 \leftarrow 0 \text{ (buffer)}, \: g^{ave}_0 \leftarrow 0     \\[-1.ex]
209            &\rule{110mm}{0.4pt}                                                                 \\
210            &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do}                         \\
211            &\hspace{5mm}g_t           \leftarrow   \nabla_{\theta} f_t (\theta_{t-1})           \\
212            &\hspace{5mm}if \: \lambda \neq 0                                                    \\
213            &\hspace{10mm} g_t \leftarrow g_t + \lambda  \theta_{t-1}                            \\
214            &\hspace{5mm}v_t           \leftarrow   \alpha v_{t-1} + (1 - \alpha) g^2_t
215                \hspace{8mm}                                                                     \\
216            &\hspace{5mm} \tilde{v_t} \leftarrow v_t                                             \\
217            &\hspace{5mm}if \: centered                                                          \\
218            &\hspace{10mm} g^{ave}_t \leftarrow g^{ave}_{t-1} \alpha + (1-\alpha) g_t            \\
219            &\hspace{10mm} \tilde{v_t} \leftarrow \tilde{v_t} -  \big(g^{ave}_{t} \big)^2        \\
220            &\hspace{5mm}if \: \mu > 0                                                           \\
221            &\hspace{10mm} \textbf{b}_t\leftarrow \mu \textbf{b}_{t-1} +
222                g_t/ \big(\sqrt{\tilde{v_t}} +  \epsilon \big)                                   \\
223            &\hspace{10mm} \theta_t \leftarrow \theta_{t-1} - \gamma \textbf{b}_t                \\
224            &\hspace{5mm} else                                                                   \\
225            &\hspace{10mm}\theta_t      \leftarrow   \theta_{t-1} -
226                \gamma  g_t/ \big(\sqrt{\tilde{v_t}} + \epsilon \big)  \hspace{3mm}              \\
227            &\rule{110mm}{0.4pt}                                                          \\[-1.ex]
228            &\bf{return} \:  \theta_t                                                     \\[-1.ex]
229            &\rule{110mm}{0.4pt}                                                          \\[-1.ex]
230       \end{aligned}
231
232    For further details regarding the algorithm we refer to
233    `lecture notes <https://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf>`_ by G. Hinton.
234    and centered version `Generating Sequences
235    With Recurrent Neural Networks <https://arxiv.org/pdf/1308.0850v5.pdf>`_.
236    The implementation here takes the square root of the gradient average before
237    adding epsilon (note that TensorFlow interchanges these two operations). The effective
238    learning rate is thus :math:`\gamma/(\sqrt{v} + \epsilon)` where :math:`\gamma`
239    is the scheduled learning rate and :math:`v` is the weighted moving average
240    of the squared gradient.
241    """
242    + rf"""
243    Args:
244        params (iterable): iterable of parameters to optimize or dicts defining
245            parameter groups
246        lr (float, Tensor, optional): learning rate (default: 1e-2)
247        momentum (float, optional): momentum factor (default: 0)
248        alpha (float, optional): smoothing constant (default: 0.99)
249        eps (float, optional): term added to the denominator to improve
250            numerical stability (default: 1e-8)
251        centered (bool, optional) : if ``True``, compute the centered RMSProp,
252            the gradient is normalized by an estimation of its variance
253        weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
254        {_foreach_doc}
255        {_maximize_doc}
256        {_capturable_doc}
257        {_differentiable_doc}
258
259    """
260)
261
262
263def _single_tensor_rmsprop(
264    params: List[Tensor],
265    grads: List[Tensor],
266    square_avgs: List[Tensor],
267    grad_avgs: List[Tensor],
268    momentum_buffer_list: List[Tensor],
269    state_steps: List[Tensor],
270    *,
271    lr: float,
272    alpha: float,
273    eps: float,
274    weight_decay: float,
275    momentum: float,
276    centered: bool,
277    maximize: bool,
278    differentiable: bool,
279    capturable: bool,
280    has_complex: bool,
281):
282    for i, param in enumerate(params):
283        step = state_steps[i]
284
285        # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
286        if not torch._utils.is_compiling() and capturable:
287            capturable_supported_devices = _get_capturable_supported_devices()
288            assert (
289                param.device.type == step.device.type
290                and param.device.type in capturable_supported_devices
291            ), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
292
293        grad = grads[i]
294        grad = grad if not maximize else -grad
295        square_avg = square_avgs[i]
296
297        step += 1
298
299        if weight_decay != 0:
300            grad = grad.add(param, alpha=weight_decay)
301
302        is_complex_param = torch.is_complex(param)
303        if is_complex_param:
304            param = torch.view_as_real(param)
305            grad = torch.view_as_real(grad)
306            square_avg = torch.view_as_real(square_avg)
307
308        square_avg.mul_(alpha).addcmul_(grad, grad, value=1 - alpha)
309
310        if centered:
311            grad_avg = grad_avgs[i]
312            if is_complex_param:
313                grad_avg = torch.view_as_real(grad_avg)
314            grad_avg.lerp_(grad, 1 - alpha)
315            avg = square_avg.addcmul(grad_avg, grad_avg, value=-1).sqrt_()
316        else:
317            avg = square_avg.sqrt()
318
319        if differentiable:
320            avg = avg.add(eps)
321        else:
322            avg = avg.add_(eps)
323
324        if momentum > 0:
325            buf = momentum_buffer_list[i]
326            if is_complex_param:
327                buf = torch.view_as_real(buf)
328            buf.mul_(momentum).addcdiv_(grad, avg)
329            param.add_(buf, alpha=-lr)
330        else:
331            param.addcdiv_(grad, avg, value=-lr)
332
333
334def _multi_tensor_rmsprop(
335    params: List[Tensor],
336    grads: List[Tensor],
337    square_avgs: List[Tensor],
338    grad_avgs: List[Tensor],
339    momentum_buffer_list: List[Tensor],
340    state_steps: List[Tensor],
341    *,
342    lr: float,
343    alpha: float,
344    eps: float,
345    weight_decay: float,
346    momentum: float,
347    centered: bool,
348    maximize: bool,
349    differentiable: bool,
350    capturable: bool,
351    has_complex: bool,
352):
353    if len(params) == 0:
354        return
355
356    assert not differentiable, "_foreach ops don't support autograd"
357
358    # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
359    if not torch._utils.is_compiling() and capturable:
360        capturable_supported_devices = _get_capturable_supported_devices()
361        assert all(
362            p.device.type == step.device.type
363            and p.device.type in capturable_supported_devices
364            for p, step in zip(params, state_steps)
365        ), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
366
367    grouped_tensors = Optimizer._group_tensors_by_device_and_dtype(
368        [params, grads, square_avgs, grad_avgs, momentum_buffer_list, state_steps]  # type: ignore[list-item]
369    )
370    for (
371        (
372            grouped_params_,
373            grouped_grads_,
374            grouped_square_avgs_,
375            grouped_grad_avgs_,
376            grouped_momentum_buffer_list_,
377            grouped_state_steps_,
378        )
379    ), _ in grouped_tensors.values():
380        grouped_params = cast(List[Tensor], grouped_params_)
381        grouped_grads = cast(List[Tensor], grouped_grads_)
382        grouped_square_avgs = cast(List[Tensor], grouped_square_avgs_)
383        grouped_state_steps = cast(List[Tensor], grouped_state_steps_)
384
385        if has_complex:
386            state_and_grads = [grouped_grads, grouped_square_avgs]
387            if momentum > 0:
388                grouped_momentum_buffer_list = cast(
389                    List[Tensor], grouped_momentum_buffer_list_
390                )
391                state_and_grads.append(grouped_momentum_buffer_list)
392            if centered:
393                grouped_grad_avgs = cast(List[Tensor], grouped_grad_avgs_)
394                state_and_grads.append(grouped_grad_avgs)
395            _view_as_real(grouped_params, *state_and_grads)
396
397        if maximize:
398            grouped_grads = torch._foreach_neg(grouped_grads)  # type: ignore[assignment]
399
400        # Update steps
401        # If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over
402        # and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just
403        # wrapped it once now. The alpha is required to assure we go to the right overload.
404        if not torch._utils.is_compiling() and grouped_state_steps[0].is_cpu:
405            torch._foreach_add_(
406                grouped_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0
407            )
408        else:
409            torch._foreach_add_(grouped_state_steps, 1)
410
411        if weight_decay != 0:
412            # Re-use the intermediate memory (grouped_grads) already allocated for maximize
413            if maximize:
414                torch._foreach_add_(grouped_grads, grouped_params, alpha=weight_decay)
415            else:
416                grouped_grads = torch._foreach_add(  # type: ignore[assignment]
417                    grouped_grads, grouped_params, alpha=weight_decay
418                )
419
420        torch._foreach_mul_(grouped_square_avgs, alpha)
421        torch._foreach_addcmul_(
422            grouped_square_avgs, grouped_grads, grouped_grads, value=1 - alpha
423        )
424
425        if centered:
426            grouped_grad_avgs = cast(List[Tensor], grouped_grad_avgs_)
427            torch._foreach_lerp_(grouped_grad_avgs, grouped_grads, 1 - alpha)
428            avg = torch._foreach_addcmul(
429                grouped_square_avgs, grouped_grad_avgs, grouped_grad_avgs, value=-1
430            )
431            torch._foreach_sqrt_(avg)
432            torch._foreach_add_(avg, eps)
433        else:
434            avg = torch._foreach_sqrt(grouped_square_avgs)
435            torch._foreach_add_(avg, eps)
436
437        if momentum > 0:
438            grouped_momentum_buffer_list = cast(
439                List[Tensor], grouped_momentum_buffer_list_
440            )
441            torch._foreach_mul_(grouped_momentum_buffer_list, momentum)
442            torch._foreach_addcdiv_(grouped_momentum_buffer_list, grouped_grads, avg)
443            # If LR is a tensor, the else branch will internally call item()
444            # which will cause silent incorrectness if we are capturing
445            if capturable and isinstance(lr, torch.Tensor):
446                momentum_lr = torch._foreach_mul(grouped_momentum_buffer_list, -lr)
447                torch._foreach_add_(grouped_params, momentum_lr)
448            else:
449                torch._foreach_add_(
450                    grouped_params, grouped_momentum_buffer_list, alpha=-lr
451                )
452        else:
453            # If LR is a tensor, the else branch will internally call item()
454            # which will cause silent incorrectness if we are capturing
455            if capturable and isinstance(lr, torch.Tensor):
456                torch._foreach_div_(avg, -lr)
457                torch._foreach_addcdiv_(grouped_params, grouped_grads, avg)
458            else:
459                torch._foreach_addcdiv_(grouped_params, grouped_grads, avg, value=-lr)
460
461
462@_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_rmsprop)
463def rmsprop(
464    params: List[Tensor],
465    grads: List[Tensor],
466    square_avgs: List[Tensor],
467    grad_avgs: List[Tensor],
468    momentum_buffer_list: List[Tensor],
469    state_steps: List[Tensor],
470    # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
471    # setting this as kwarg for now as functional API is compiled by torch/distributed/optim
472    foreach: Optional[bool] = None,
473    maximize: bool = False,
474    differentiable: bool = False,
475    capturable: bool = False,
476    has_complex: bool = False,
477    *,
478    lr: float,
479    alpha: float,
480    eps: float,
481    weight_decay: float,
482    momentum: float,
483    centered: bool,
484):
485    r"""Functional API that performs rmsprop algorithm computation.
486
487    See :class:`~torch.optim.RMSProp` for details.
488    """
489    # this check is slow during compilation, so we skip it
490    # if it's strictly needed we can add this check back in dynamo
491    if not torch._utils.is_compiling() and not all(
492        isinstance(t, torch.Tensor) for t in state_steps
493    ):
494        raise RuntimeError(
495            "API has changed, `state_steps` argument must contain a list of singleton tensors"
496        )
497
498    if foreach is None:
499        _, foreach = _default_to_fused_or_foreach(
500            params, differentiable, use_fused=False
501        )
502
503    if foreach and torch.jit.is_scripting():
504        raise RuntimeError("torch.jit.script not supported with foreach optimizers")
505
506    if foreach and not torch.jit.is_scripting():
507        func = _multi_tensor_rmsprop
508    else:
509        func = _single_tensor_rmsprop
510
511    func(
512        params,
513        grads,
514        square_avgs,
515        grad_avgs,
516        momentum_buffer_list,
517        state_steps,
518        lr=lr,
519        alpha=alpha,
520        eps=eps,
521        weight_decay=weight_decay,
522        momentum=momentum,
523        centered=centered,
524        maximize=maximize,
525        capturable=capturable,
526        differentiable=differentiable,
527        has_complex=has_complex,
528    )
529