xref: /aosp_15_r20/external/pytorch/torch/optim/rprop.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-decorators
2# mypy: allow-untyped-defs
3r"""Implementation for the Resilient backpropagation."""
4from typing import cast, List, Optional, Tuple, Union
5
6import torch
7from torch import Tensor
8
9from .optimizer import (
10    _capturable_doc,
11    _default_to_fused_or_foreach,
12    _differentiable_doc,
13    _disable_dynamo_if_unsupported,
14    _foreach_doc,
15    _get_capturable_supported_devices,
16    _get_scalar_dtype,
17    _maximize_doc,
18    _use_grad_for_differentiable,
19    _view_as_real,
20    Optimizer,
21    ParamsT,
22)
23
24
25__all__ = ["Rprop", "rprop"]
26
27
28class Rprop(Optimizer):  # noqa: D101
29    def __init__(
30        self,
31        params: ParamsT,
32        lr: Union[float, Tensor] = 1e-2,
33        etas: Tuple[float, float] = (0.5, 1.2),
34        step_sizes: Tuple[float, float] = (1e-6, 50),
35        *,
36        capturable: bool = False,
37        foreach: Optional[bool] = None,
38        maximize: bool = False,
39        differentiable: bool = False,
40    ):  # noqa: D107
41        if isinstance(lr, Tensor) and lr.numel() != 1:
42            raise ValueError("Tensor lr must be 1-element")
43        if not 0.0 <= lr:
44            raise ValueError(f"Invalid learning rate: {lr}")
45        if not 0.0 < etas[0] < 1.0 < etas[1]:
46            raise ValueError(f"Invalid eta values: {etas[0]}, {etas[1]}")
47
48        defaults = dict(
49            lr=lr,
50            etas=etas,
51            step_sizes=step_sizes,
52            foreach=foreach,
53            maximize=maximize,
54            differentiable=differentiable,
55            capturable=capturable,
56        )
57        super().__init__(params, defaults)
58
59    def __setstate__(self, state):  # noqa: D105
60        super().__setstate__(state)
61        for group in self.param_groups:
62            group.setdefault("foreach", None)
63            group.setdefault("maximize", False)
64            group.setdefault("differentiable", False)
65            group.setdefault("capturable", False)
66            for p in group["params"]:
67                p_state = self.state.get(p, [])
68                if len(p_state) != 0 and not torch.is_tensor(p_state["step"]):
69                    step_val = float(p_state["step"])
70                    p_state["step"] = (
71                        torch.tensor(
72                            step_val, dtype=_get_scalar_dtype(), device=p.device
73                        )
74                        if group["capturable"]
75                        else torch.tensor(step_val, dtype=_get_scalar_dtype())
76                    )
77
78    def _init_group(self, group, params, grads, prevs, step_sizes, state_steps):
79        has_complex = False
80        for p in group["params"]:
81            if p.grad is None:
82                continue
83            has_complex |= torch.is_complex(p)
84            params.append(p)
85            grad = p.grad
86            if grad.is_sparse:
87                raise RuntimeError("Rprop does not support sparse gradients")
88
89            grads.append(grad)
90            state = self.state[p]
91
92            # State initialization
93            if len(state) == 0:
94                state["step"] = (
95                    torch.zeros((), dtype=_get_scalar_dtype(), device=p.device)
96                    if group["capturable"]
97                    else torch.zeros((), dtype=_get_scalar_dtype())
98                )
99
100                state["prev"] = torch.zeros_like(p, memory_format=torch.preserve_format)
101                if p.dtype.is_complex:
102                    # Complex Number should be as if they are two independent real numbers.
103                    # Hence the step_size shouldn't be zero for imaginary part.
104                    state["step_size"] = torch.full_like(
105                        grad, complex(group["lr"], group["lr"])
106                    )
107                else:
108                    state["step_size"] = torch.full_like(grad, group["lr"])
109
110            prevs.append(state["prev"])
111            step_sizes.append(state["step_size"])
112            state_steps.append(state["step"])
113
114        return has_complex
115
116    @_use_grad_for_differentiable
117    def step(self, closure=None):
118        """Perform a single optimization step.
119
120        Args:
121            closure (Callable, optional): A closure that reevaluates the model
122                and returns the loss.
123        """
124        self._cuda_graph_capture_health_check()
125
126        loss = None
127        if closure is not None:
128            with torch.enable_grad():
129                loss = closure()
130
131        for group in self.param_groups:
132            params: List[Tensor] = []
133            grads: List[Tensor] = []
134            prevs: List[Tensor] = []
135            step_sizes: List[Tensor] = []
136            state_steps: List[Tensor] = []
137
138            etaminus, etaplus = group["etas"]
139            step_size_min, step_size_max = group["step_sizes"]
140            foreach = group["foreach"]
141            maximize = group["maximize"]
142
143            has_complex = self._init_group(
144                group, params, grads, prevs, step_sizes, state_steps
145            )
146
147            rprop(
148                params,
149                grads,
150                prevs,
151                step_sizes,
152                state_steps,
153                step_size_min=step_size_min,
154                step_size_max=step_size_max,
155                etaminus=etaminus,
156                etaplus=etaplus,
157                foreach=foreach,
158                maximize=maximize,
159                differentiable=group["differentiable"],
160                capturable=group["capturable"],
161                has_complex=has_complex,
162            )
163
164        return loss
165
166
167Rprop.__doc__ = (
168    r"""Implements the resilient backpropagation algorithm.
169
170    .. math::
171       \begin{aligned}
172            &\rule{110mm}{0.4pt}                                                                 \\
173            &\textbf{input}      : \theta_0 \in \mathbf{R}^d \text{ (params)},f(\theta)
174                \text{ (objective)},                                                             \\
175            &\hspace{13mm}      \eta_{+/-} \text{ (etaplus, etaminus)}, \Gamma_{max/min}
176                \text{ (step sizes)}                                                             \\
177            &\textbf{initialize} :   g^0_{prev} \leftarrow 0,
178                \: \eta_0 \leftarrow \text{lr (learning rate)}                                   \\
179            &\rule{110mm}{0.4pt}                                                                 \\
180            &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do}                         \\
181            &\hspace{5mm}g_t           \leftarrow   \nabla_{\theta} f_t (\theta_{t-1})           \\
182            &\hspace{5mm} \textbf{for} \text{  } i = 0, 1, \ldots, d-1 \: \mathbf{do}            \\
183            &\hspace{10mm}  \textbf{if} \:   g^i_{prev} g^i_t  > 0                               \\
184            &\hspace{15mm}  \eta^i_t \leftarrow \mathrm{min}(\eta^i_{t-1} \eta_{+},
185                \Gamma_{max})                                                                    \\
186            &\hspace{10mm}  \textbf{else if}  \:  g^i_{prev} g^i_t < 0                           \\
187            &\hspace{15mm}  \eta^i_t \leftarrow \mathrm{max}(\eta^i_{t-1} \eta_{-},
188                \Gamma_{min})                                                                    \\
189            &\hspace{15mm}  g^i_t \leftarrow 0                                                   \\
190            &\hspace{10mm}  \textbf{else}  \:                                                    \\
191            &\hspace{15mm}  \eta^i_t \leftarrow \eta^i_{t-1}                                     \\
192            &\hspace{5mm}\theta_t \leftarrow \theta_{t-1}- \eta_t \mathrm{sign}(g_t)             \\
193            &\hspace{5mm}g_{prev} \leftarrow  g_t                                                \\
194            &\rule{110mm}{0.4pt}                                                          \\[-1.ex]
195            &\bf{return} \:  \theta_t                                                     \\[-1.ex]
196            &\rule{110mm}{0.4pt}                                                          \\[-1.ex]
197       \end{aligned}
198
199    For further details regarding the algorithm we refer to the paper
200    `A Direct Adaptive Method for Faster Backpropagation Learning: The RPROP Algorithm
201    <http://citeseerx.ist.psu.edu/viewdoc/summary?doi=10.1.1.21.1417>`_.
202    """
203    + rf"""
204    Args:
205        params (iterable): iterable of parameters to optimize or dicts defining
206            parameter groups
207        lr (float, optional): learning rate (default: 1e-2)
208        etas (Tuple[float, float], optional): pair of (etaminus, etaplus), that
209            are multiplicative increase and decrease factors
210            (default: (0.5, 1.2))
211        step_sizes (Tuple[float, float], optional): a pair of minimal and
212            maximal allowed step sizes (default: (1e-6, 50))
213        {_foreach_doc}
214        {_capturable_doc}
215        {_maximize_doc}
216        {_differentiable_doc}
217
218    """
219)
220
221
222def _single_tensor_rprop(
223    params: List[Tensor],
224    grads: List[Tensor],
225    prevs: List[Tensor],
226    step_sizes: List[Tensor],
227    state_steps: List[Tensor],
228    *,
229    step_size_min: float,
230    step_size_max: float,
231    etaminus: float,
232    etaplus: float,
233    maximize: bool,
234    capturable: bool,
235    differentiable: bool,
236    has_complex: bool,
237):
238    for i, param in enumerate(params):
239        grad = grads[i]
240        grad = grad if not maximize else -grad
241        prev = prevs[i]
242        step_size = step_sizes[i]
243        step = state_steps[i]
244
245        # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
246        if not torch._utils.is_compiling() and capturable:
247            capturable_supported_devices = _get_capturable_supported_devices()
248            assert (
249                param.device.type == step.device.type
250                and param.device.type in capturable_supported_devices
251            ), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
252
253        step += 1
254
255        if torch.is_complex(param):
256            grad = torch.view_as_real(grad)
257            prev = torch.view_as_real(prev)
258            param = torch.view_as_real(param)
259            step_size = torch.view_as_real(step_size)
260        if differentiable:
261            sign = grad.mul(prev.clone()).sign()
262        else:
263            sign = grad.mul(prev).sign()
264
265        if capturable:
266            sign.copy_(torch.where(sign.gt(0), etaplus, sign))
267            sign.copy_(torch.where(sign.lt(0), etaminus, sign))
268            sign.copy_(torch.where(sign.eq(0), 1, sign))
269        else:
270            sign[sign.gt(0)] = etaplus
271            sign[sign.lt(0)] = etaminus
272            sign[sign.eq(0)] = 1
273
274        # update stepsizes with step size updates
275        step_size.mul_(sign).clamp_(step_size_min, step_size_max)
276
277        # for dir<0, dfdx=0
278        # for dir>=0 dfdx=dfdx
279        grad = grad.clone(memory_format=torch.preserve_format)
280        if capturable:
281            grad.copy_(torch.where(sign.eq(etaminus), 0, grad))
282        else:
283            grad[sign.eq(etaminus)] = 0
284
285        # update parameters
286        param.addcmul_(grad.sign(), step_size, value=-1)
287        prev.copy_(grad)
288
289
290def _multi_tensor_rprop(
291    params: List[Tensor],
292    grads: List[Tensor],
293    prevs: List[Tensor],
294    step_sizes: List[Tensor],
295    state_steps: List[Tensor],
296    *,
297    step_size_min: float,
298    step_size_max: float,
299    etaminus: float,
300    etaplus: float,
301    maximize: bool,
302    capturable: bool,
303    differentiable: bool,
304    has_complex: bool,
305):
306    if len(params) == 0:
307        return
308
309    assert not differentiable, "_foreach ops don't support autograd"
310
311    # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
312    if not torch._utils.is_compiling() and capturable:
313        capturable_supported_devices = _get_capturable_supported_devices()
314        assert all(
315            p.device.type == step.device.type
316            and p.device.type in capturable_supported_devices
317            for p, step in zip(params, state_steps)
318        ), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
319
320    grouped_tensors = Optimizer._group_tensors_by_device_and_dtype(
321        [params, grads, prevs, step_sizes, state_steps]  # type: ignore[list-item]
322    )
323    for (
324        grouped_params_,
325        grouped_grads_,
326        grouped_prevs_,
327        grouped_step_sizes_,
328        grouped_state_steps_,
329    ), _ in grouped_tensors.values():
330        grouped_params = cast(List[Tensor], grouped_params_)
331        grouped_grads = cast(List[Tensor], grouped_grads_)
332        grouped_prevs = cast(List[Tensor], grouped_prevs_)
333        grouped_step_sizes = cast(List[Tensor], grouped_step_sizes_)
334        grouped_state_steps = cast(List[Tensor], grouped_state_steps_)
335
336        # Update steps
337        # If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over
338        # and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just
339        # wrapped it once now. The alpha is required to assure we go to the right overload.
340        if not torch._utils.is_compiling() and grouped_state_steps[0].is_cpu:
341            torch._foreach_add_(
342                grouped_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0
343            )
344        else:
345            torch._foreach_add_(grouped_state_steps, 1)
346
347        # Handle complex params
348        if has_complex:
349            _view_as_real(
350                grouped_params, grouped_grads, grouped_prevs, grouped_step_sizes
351            )
352
353        signs = torch._foreach_mul(grouped_grads, grouped_prevs)
354        if maximize:
355            torch._foreach_neg_(signs)
356
357        # At the end of the step, grouped_prevs will contain the current grads, so we reuse
358        # grouped_prevs memory instead of creating a new buffer, but, for clarity, we reassign
359        # to keep referring to the buffer as grouped_grads.
360        torch._foreach_copy_(grouped_prevs, grouped_grads)
361        if maximize:
362            torch._foreach_neg_(grouped_prevs)
363        grouped_grads = grouped_prevs
364
365        torch._foreach_sign_(signs)
366        if capturable:
367            for sign in signs:
368                sign.copy_(torch.where(sign.gt(0), etaplus, sign))
369                sign.copy_(torch.where(sign.lt(0), etaminus, sign))
370                sign.copy_(torch.where(sign.eq(0), 1, sign))
371        else:
372            for sign in signs:
373                sign[sign.gt(0)] = etaplus
374                sign[sign.lt(0)] = etaminus
375                sign[sign.eq(0)] = 1
376
377        # update stepsizes with step size updates
378        torch._foreach_mul_(grouped_step_sizes, signs)
379        for step_size in grouped_step_sizes:
380            step_size.clamp_(step_size_min, step_size_max)
381
382        # for dir<0, dfdx=0
383        # for dir>=0 dfdx=dfdx
384        grouped_grads = list(grouped_grads)
385        for i in range(len(grouped_grads)):
386            grouped_grads[i].copy_(
387                torch.where(signs[i].eq(etaminus), 0, grouped_grads[i])
388            )
389
390        # explicitly del signs as it's not used after here to save memory
391        del signs
392
393        # update parameters
394        grad_signs = [grad.sign() for grad in grouped_grads]
395        torch._foreach_addcmul_(
396            grouped_params, grad_signs, grouped_step_sizes, value=-1
397        )
398
399        # Logically, you may expect grouped_prevs to get updated to grouped_grads, but that's
400        # basically already happened since we've been using grouped_prevs' memory to store
401        # updated grouped_grads!
402
403
404@_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_rprop)
405def rprop(
406    params: List[Tensor],
407    grads: List[Tensor],
408    prevs: List[Tensor],
409    step_sizes: List[Tensor],
410    state_steps: List[Tensor],
411    # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
412    # setting this as kwarg for now as functional API is compiled by torch/distributed/optim
413    foreach: Optional[bool] = None,
414    capturable: bool = False,
415    maximize: bool = False,
416    differentiable: bool = False,
417    has_complex: bool = False,
418    *,
419    step_size_min: float,
420    step_size_max: float,
421    etaminus: float,
422    etaplus: float,
423):
424    r"""Functional API that performs rprop algorithm computation.
425
426    See :class:`~torch.optim.Rprop` for details.
427    """
428    # this check is slow during compilation, so we skip it
429    # if it's strictly needed we can add this check back in dynamo
430    if not torch._utils.is_compiling() and not all(
431        isinstance(t, torch.Tensor) for t in state_steps
432    ):
433        raise RuntimeError(
434            "API has changed, `state_steps` argument must contain a list of singleton tensors"
435        )
436
437    if foreach is None:
438        _, foreach = _default_to_fused_or_foreach(
439            params, differentiable, use_fused=False
440        )
441
442    if foreach and torch.jit.is_scripting():
443        raise RuntimeError("torch.jit.script not supported with foreach optimizers")
444
445    if foreach and not torch.jit.is_scripting():
446        func = _multi_tensor_rprop
447    else:
448        func = _single_tensor_rprop
449
450    func(
451        params,
452        grads,
453        prevs,
454        step_sizes,
455        state_steps,
456        step_size_min=step_size_min,
457        step_size_max=step_size_max,
458        etaminus=etaminus,
459        etaplus=etaplus,
460        capturable=capturable,
461        maximize=maximize,
462        differentiable=differentiable,
463        has_complex=has_complex,
464    )
465