xref: /aosp_15_r20/external/pytorch/torch/optim/adamax.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-decorators
2# mypy: allow-untyped-defs
3from typing import cast, List, Optional, Tuple, Union
4
5import torch
6from torch import Tensor
7
8from .optimizer import (
9    _capturable_doc,
10    _default_to_fused_or_foreach,
11    _differentiable_doc,
12    _disable_dynamo_if_unsupported,
13    _foreach_doc,
14    _get_capturable_supported_devices,
15    _get_scalar_dtype,
16    _get_value,
17    _maximize_doc,
18    _use_grad_for_differentiable,
19    _view_as_real,
20    Optimizer,
21    ParamsT,
22)
23
24
25__all__ = ["Adamax", "adamax"]
26
27
28class Adamax(Optimizer):
29    def __init__(
30        self,
31        params: ParamsT,
32        lr: Union[float, Tensor] = 2e-3,
33        betas: Tuple[float, float] = (0.9, 0.999),
34        eps: float = 1e-8,
35        weight_decay: float = 0,
36        foreach: Optional[bool] = None,
37        *,
38        maximize: bool = False,
39        differentiable: bool = False,
40        capturable: bool = False,
41    ):
42        if isinstance(lr, Tensor) and lr.numel() != 1:
43            raise ValueError("Tensor lr must be 1-element")
44        if not 0.0 <= lr:
45            raise ValueError(f"Invalid learning rate: {lr}")
46        if not 0.0 <= eps:
47            raise ValueError(f"Invalid epsilon value: {eps}")
48        if not 0.0 <= betas[0] < 1.0:
49            raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
50        if not 0.0 <= betas[1] < 1.0:
51            raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
52        if not 0.0 <= weight_decay:
53            raise ValueError(f"Invalid weight_decay value: {weight_decay}")
54
55        defaults = dict(
56            lr=lr,
57            betas=betas,
58            eps=eps,
59            weight_decay=weight_decay,
60            foreach=foreach,
61            maximize=maximize,
62            differentiable=differentiable,
63            capturable=capturable,
64        )
65        super().__init__(params, defaults)
66
67    def __setstate__(self, state):
68        super().__setstate__(state)
69        for group in self.param_groups:
70            group.setdefault("foreach", None)
71            group.setdefault("maximize", False)
72            group.setdefault("differentiable", False)
73            group.setdefault("capturable", False)
74            for p in group["params"]:
75                p_state = self.state.get(p, [])
76                if len(p_state) != 0 and not torch.is_tensor(p_state["step"]):
77                    step_val = float(p_state["step"])
78                    p_state["step"] = (
79                        torch.tensor(
80                            step_val, dtype=_get_scalar_dtype(), device=p.device
81                        )
82                        if group["capturable"]
83                        else torch.tensor(step_val, dtype=_get_scalar_dtype())
84                    )
85
86    def _init_group(
87        self, group, params_with_grad, grads, exp_avgs, exp_infs, state_steps
88    ):
89        has_complex = False
90        for p in group["params"]:
91            if p.grad is None:
92                continue
93            has_complex |= torch.is_complex(p)
94            params_with_grad.append(p)
95            if p.grad.is_sparse:
96                raise RuntimeError("Adamax does not support sparse gradients")
97            grads.append(p.grad)
98
99            state = self.state[p]
100
101            # State initialization
102            if len(state) == 0:
103                state["step"] = (
104                    torch.zeros((), dtype=_get_scalar_dtype(), device=p.device)
105                    if group["capturable"]
106                    else torch.tensor(0.0, dtype=_get_scalar_dtype())
107                )
108                state["exp_avg"] = torch.zeros_like(
109                    p, memory_format=torch.preserve_format
110                )
111                state["exp_inf"] = torch.zeros_like(
112                    p, memory_format=torch.preserve_format
113                )
114
115            exp_avgs.append(state["exp_avg"])
116            exp_infs.append(state["exp_inf"])
117            state_steps.append(state["step"])
118
119        return has_complex
120
121    @_use_grad_for_differentiable
122    def step(self, closure=None):
123        """Performs a single optimization step.
124
125        Args:
126            closure (Callable, optional): A closure that reevaluates the model
127                and returns the loss.
128        """
129        self._cuda_graph_capture_health_check()
130
131        loss = None
132        if closure is not None:
133            with torch.enable_grad():
134                loss = closure()
135
136        for group in self.param_groups:
137            params_with_grad: List[Tensor] = []
138            grads: List[Tensor] = []
139            exp_avgs: List[Tensor] = []
140            exp_infs: List[Tensor] = []
141            state_steps: List[Tensor] = []
142
143            beta1, beta2 = group["betas"]
144            eps = group["eps"]
145            lr = group["lr"]
146            weight_decay = group["weight_decay"]
147            foreach = group["foreach"]
148            maximize = group["maximize"]
149            differentiable = group["differentiable"]
150            capturable = group["capturable"]
151
152            has_complex = self._init_group(
153                group, params_with_grad, grads, exp_avgs, exp_infs, state_steps
154            )
155
156            adamax(
157                params_with_grad,
158                grads,
159                exp_avgs,
160                exp_infs,
161                state_steps,
162                eps=eps,
163                beta1=beta1,
164                beta2=beta2,
165                lr=lr,
166                weight_decay=weight_decay,
167                foreach=foreach,
168                maximize=maximize,
169                differentiable=differentiable,
170                capturable=capturable,
171                has_complex=has_complex,
172            )
173
174        return loss
175
176
177Adamax.__doc__ = (
178    r"""Implements Adamax algorithm (a variant of Adam based on infinity norm).
179
180    .. math::
181       \begin{aligned}
182            &\rule{110mm}{0.4pt}                                                                 \\
183            &\textbf{input}      : \gamma \text{ (lr)}, \beta_1, \beta_2
184                \text{ (betas)},\theta_0 \text{ (params)},f(\theta) \text{ (objective)},
185                \: \lambda \text{ (weight decay)},                                                \\
186            &\hspace{13mm}    \epsilon \text{ (epsilon)}                                          \\
187            &\textbf{initialize} :  m_0 \leftarrow 0 \text{ ( first moment)},
188                u_0 \leftarrow 0 \text{ ( infinity norm)}                                 \\[-1.ex]
189            &\rule{110mm}{0.4pt}                                                                 \\
190            &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do}                         \\
191            &\hspace{5mm}g_t           \leftarrow   \nabla_{\theta} f_t (\theta_{t-1})           \\
192            &\hspace{5mm}if \: \lambda \neq 0                                                    \\
193            &\hspace{10mm} g_t \leftarrow g_t + \lambda  \theta_{t-1}                            \\
194            &\hspace{5mm}m_t      \leftarrow   \beta_1 m_{t-1} + (1 - \beta_1) g_t               \\
195            &\hspace{5mm}u_t      \leftarrow   \mathrm{max}(\beta_2 u_{t-1}, |g_{t}|+\epsilon)   \\
196            &\hspace{5mm}\theta_t \leftarrow \theta_{t-1} - \frac{\gamma m_t}{(1-\beta^t_1) u_t} \\
197            &\rule{110mm}{0.4pt}                                                          \\[-1.ex]
198            &\bf{return} \:  \theta_t                                                     \\[-1.ex]
199            &\rule{110mm}{0.4pt}                                                          \\[-1.ex]
200       \end{aligned}
201
202    For further details regarding the algorithm we refer to `Adam: A Method for Stochastic Optimization`_.
203    """
204    + rf"""
205    Args:
206        params (iterable): iterable of parameters to optimize or dicts defining
207            parameter groups
208        lr (float, Tensor, optional): learning rate (default: 2e-3)
209        betas (Tuple[float, float], optional): coefficients used for computing
210            running averages of gradient and its square
211        eps (float, optional): term added to the denominator to improve
212            numerical stability (default: 1e-8)
213        weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
214        {_foreach_doc}
215        {_maximize_doc}
216        {_differentiable_doc}
217        {_capturable_doc}
218
219    .. _Adam\: A Method for Stochastic Optimization:
220        https://arxiv.org/abs/1412.6980
221
222    """
223)
224
225
226def _single_tensor_adamax(
227    params: List[Tensor],
228    grads: List[Tensor],
229    exp_avgs: List[Tensor],
230    exp_infs: List[Tensor],
231    state_steps: List[Tensor],
232    *,
233    eps: float,
234    beta1: float,
235    beta2: float,
236    lr: float,
237    weight_decay: float,
238    maximize: bool,
239    differentiable: bool,
240    capturable: bool,
241    has_complex: bool,
242):
243    for i, param in enumerate(params):
244        grad = grads[i]
245        grad = grad if not maximize else -grad
246        exp_avg = exp_avgs[i]
247        exp_inf = exp_infs[i]
248        step_t = state_steps[i]
249
250        # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
251        if not torch._utils.is_compiling() and capturable:
252            capturable_supported_devices = _get_capturable_supported_devices()
253            assert (
254                param.device.type == step_t.device.type
255                and param.device.type in capturable_supported_devices
256            ), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
257
258        # update step
259        step_t += 1
260
261        if weight_decay != 0:
262            grad = grad.add(param, alpha=weight_decay)
263
264        if torch.is_complex(param):
265            param = torch.view_as_real(param)
266            grad = torch.view_as_real(grad)
267            exp_avg = torch.view_as_real(exp_avg)
268            exp_inf = torch.view_as_real(exp_inf)
269
270        # Update biased first moment estimate.
271        exp_avg.lerp_(grad, 1 - beta1)
272        # Update the exponentially weighted infinity norm.
273        if not differentiable:
274            torch.maximum(
275                exp_inf.mul_(beta2),
276                grad.abs().add_(eps),
277                out=exp_inf,
278            )
279        else:
280            norm_buf = torch.cat(
281                [exp_inf.mul_(beta2).unsqueeze(0), grad.abs().add_(eps).unsqueeze_(0)],
282                0,
283            )
284            exp_inf.copy_(torch.amax(norm_buf, 0, keepdim=False))
285
286        if capturable:
287            # why jump through extra hoops and negate bias_correction? check out #121238
288            # once fixed, we should use bias_correction with addcdiv value=-1 for readability
289            neg_bias_correction = beta1**step_t - 1
290            neg_bias_correction.div_(lr)
291            denom = exp_inf * neg_bias_correction
292            param.addcdiv_(exp_avg, denom)
293        else:
294            bias_correction = 1 - beta1 ** _get_value(step_t)
295            clr = lr / bias_correction
296
297            param.addcdiv_(exp_avg, exp_inf, value=-clr)
298
299
300def _multi_tensor_adamax(
301    params: List[Tensor],
302    grads: List[Tensor],
303    exp_avgs: List[Tensor],
304    exp_infs: List[Tensor],
305    state_steps: List[Tensor],
306    *,
307    eps: float,
308    beta1: float,
309    beta2: float,
310    lr: float,
311    weight_decay: float,
312    maximize: bool,
313    differentiable: bool,
314    capturable: bool,
315    has_complex: bool,
316):
317    assert not differentiable, "_foreach ops don't support autograd"
318
319    if len(params) == 0:
320        return
321
322    # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
323    if not torch._utils.is_compiling() and capturable:
324        capturable_supported_devices = _get_capturable_supported_devices(
325            supports_xla=False
326        )
327        assert all(
328            p.device.type == step.device.type
329            and p.device.type in capturable_supported_devices
330            for p, step in zip(params, state_steps)
331        ), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
332
333    grouped_tensors = Optimizer._group_tensors_by_device_and_dtype(
334        [params, grads, exp_avgs, exp_infs, state_steps]  # type: ignore[list-item]
335    )
336    for (
337        grouped_params_,
338        grouped_grads_,
339        grouped_exp_avgs_,
340        grouped_exp_infs_,
341        grouped_state_steps_,
342    ), _ in grouped_tensors.values():
343        grouped_params = cast(List[Tensor], grouped_params_)
344        grouped_grads = cast(List[Tensor], grouped_grads_)
345        grouped_exp_avgs = cast(List[Tensor], grouped_exp_avgs_)
346        grouped_exp_infs = cast(List[Tensor], grouped_exp_infs_)
347        grouped_state_steps = cast(List[Tensor], grouped_state_steps_)
348
349        if has_complex:
350            _view_as_real(
351                grouped_params, grouped_grads, grouped_exp_avgs, grouped_exp_infs
352            )
353
354        if maximize:
355            grouped_grads = torch._foreach_neg(grouped_grads)  # type: ignore[assignment]
356
357        # Update steps
358        # If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over
359        # and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just
360        # wrapped it once now. The alpha is required to assure we go to the right overload.
361        if not torch._utils.is_compiling() and grouped_state_steps[0].is_cpu:
362            torch._foreach_add_(
363                grouped_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0
364            )
365        else:
366            torch._foreach_add_(grouped_state_steps, 1)
367
368        if weight_decay != 0:
369            if maximize:
370                # Re-use the intermediate memory (grouped_grads) already allocated for maximize
371                torch._foreach_add_(grouped_grads, grouped_params, alpha=weight_decay)
372            else:
373                grouped_grads = torch._foreach_add(  # type: ignore[assignment]
374                    grouped_grads, grouped_params, alpha=weight_decay
375                )
376
377        # Update biased first moment estimate.
378        torch._foreach_lerp_(grouped_exp_avgs, grouped_grads, 1 - beta1)
379
380        # Update the exponentially weighted infinity norm.
381        torch._foreach_mul_(grouped_exp_infs, beta2)
382
383        # in this case, we need to introduce a copy of the grads
384        # since one has not been introduced previously
385        if not maximize and weight_decay == 0:
386            grouped_grads = torch._foreach_abs(grouped_grads)  # type: ignore[assignment]
387        else:
388            torch._foreach_abs_(grouped_grads)
389
390        torch._foreach_add_(grouped_grads, eps)
391        torch._foreach_maximum_(grouped_exp_infs, grouped_grads)
392
393        bias_corrections: Union[Tuple[Tensor, ...], List[Tensor]]
394        if capturable:
395            bias_corrections = torch._foreach_pow(beta1, grouped_state_steps)
396            # foreach_sub doesn't allow a scalar as the first arg
397            torch._foreach_sub_(bias_corrections, 1)
398            torch._foreach_div_(bias_corrections, lr)
399
400            denom = torch._foreach_mul(grouped_exp_infs, bias_corrections)
401            torch._foreach_addcdiv_(grouped_params, grouped_exp_avgs, denom)
402        else:
403            bias_corrections = [
404                1 - beta1 ** _get_value(step) for step in grouped_state_steps
405            ]
406            step_size = [(_get_value(lr) / bc) * -1 for bc in bias_corrections]
407            torch._foreach_addcdiv_(
408                grouped_params, grouped_exp_avgs, grouped_exp_infs, step_size
409            )
410
411
412@_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_adamax)
413def adamax(
414    params: List[Tensor],
415    grads: List[Tensor],
416    exp_avgs: List[Tensor],
417    exp_infs: List[Tensor],
418    state_steps: List[Tensor],
419    # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
420    # setting this as kwarg for now as functional API is compiled by torch/distributed/optim
421    foreach: Optional[bool] = None,
422    maximize: bool = False,
423    differentiable: bool = False,
424    capturable: bool = False,
425    has_complex: bool = False,
426    *,
427    eps: float,
428    beta1: float,
429    beta2: float,
430    lr: float,
431    weight_decay: float,
432):
433    r"""Functional API that performs adamax algorithm computation.
434
435    See :class:`~torch.optim.Adamax` for details.
436    """
437
438    if not torch._utils.is_compiling() and not all(
439        isinstance(t, torch.Tensor) for t in state_steps
440    ):
441        raise RuntimeError(
442            "API has changed, `state_steps` argument must contain a list of singleton tensors"
443        )
444
445    if foreach is None:
446        _, foreach = _default_to_fused_or_foreach(
447            params, differentiable, use_fused=False
448        )
449
450    if foreach and torch.jit.is_scripting():
451        raise RuntimeError("torch.jit.script not supported with foreach optimizers")
452
453    if foreach and not torch.jit.is_scripting():
454        func = _multi_tensor_adamax
455    else:
456        func = _single_tensor_adamax
457
458    func(
459        params,
460        grads,
461        exp_avgs,
462        exp_infs,
463        state_steps,
464        eps=eps,
465        beta1=beta1,
466        beta2=beta2,
467        lr=lr,
468        weight_decay=weight_decay,
469        maximize=maximize,
470        differentiable=differentiable,
471        has_complex=has_complex,
472        capturable=capturable,
473    )
474