xref: /aosp_15_r20/external/pytorch/torch/optim/adam.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-decorators
2# mypy: allow-untyped-defs
3from typing import cast, List, Optional, Tuple, Union
4
5import torch
6from torch import Tensor
7
8from .optimizer import (
9    _capturable_doc,
10    _default_to_fused_or_foreach,
11    _device_dtype_check_for_fused,
12    _differentiable_doc,
13    _disable_dynamo_if_unsupported,
14    _foreach_doc,
15    _fused_doc,
16    _get_capturable_supported_devices,
17    _get_scalar_dtype,
18    _get_value,
19    _maximize_doc,
20    _stack_if_compiling,
21    _use_grad_for_differentiable,
22    _view_as_real,
23    DeviceDict,
24    Optimizer,
25    ParamsT,
26)
27
28
29__all__ = ["Adam", "adam"]
30
31
32class Adam(Optimizer):
33    def __init__(
34        self,
35        params: ParamsT,
36        lr: Union[float, Tensor] = 1e-3,
37        betas: Tuple[float, float] = (0.9, 0.999),
38        eps: float = 1e-8,
39        weight_decay: float = 0,
40        amsgrad: bool = False,
41        *,
42        foreach: Optional[bool] = None,
43        maximize: bool = False,
44        capturable: bool = False,
45        differentiable: bool = False,
46        fused: Optional[bool] = None,
47    ):
48        if isinstance(lr, Tensor):
49            if foreach and not capturable:
50                raise ValueError(
51                    "lr as a Tensor is not supported for capturable=False and foreach=True"
52                )
53            if lr.numel() != 1:
54                raise ValueError("Tensor lr must be 1-element")
55        if not 0.0 <= lr:
56            raise ValueError(f"Invalid learning rate: {lr}")
57        if not 0.0 <= eps:
58            raise ValueError(f"Invalid epsilon value: {eps}")
59        if not 0.0 <= betas[0] < 1.0:
60            raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
61        if not 0.0 <= betas[1] < 1.0:
62            raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
63        if not 0.0 <= weight_decay:
64            raise ValueError(f"Invalid weight_decay value: {weight_decay}")
65
66        defaults = dict(
67            lr=lr,
68            betas=betas,
69            eps=eps,
70            weight_decay=weight_decay,
71            amsgrad=amsgrad,
72            maximize=maximize,
73            foreach=foreach,
74            capturable=capturable,
75            differentiable=differentiable,
76            fused=fused,
77        )
78        super().__init__(params, defaults)
79
80        if fused:
81            if differentiable:
82                raise RuntimeError("`fused` does not support `differentiable`")
83            self._step_supports_amp_scaling = True
84            # TODO(crcrpar): [low prec params & their higher prec copy]
85            # Support AMP with FP16/BF16 model params which would need
86            # higher prec copy of params to do update math in higher prec to
87            # alleviate the loss of information.
88            if foreach:
89                raise RuntimeError("`fused` and `foreach` cannot be `True` together.")
90
91    def __setstate__(self, state):
92        super().__setstate__(state)
93        for group in self.param_groups:
94            group.setdefault("amsgrad", False)
95            group.setdefault("maximize", False)
96            group.setdefault("foreach", None)
97            group.setdefault("capturable", False)
98            group.setdefault("differentiable", False)
99            fused = group.setdefault("fused", None)
100            for p in group["params"]:
101                p_state = self.state.get(p, [])
102                if len(p_state) != 0 and not torch.is_tensor(p_state["step"]):
103                    step_val = float(p_state["step"])
104                    p_state["step"] = (
105                        torch.tensor(
106                            step_val,
107                            dtype=_get_scalar_dtype(is_fused=fused),
108                            device=p.device,
109                        )
110                        if group["capturable"] or group["fused"]
111                        else torch.tensor(step_val, dtype=_get_scalar_dtype())
112                    )
113
114    def _init_group(
115        self,
116        group,
117        params_with_grad,
118        grads,
119        exp_avgs,
120        exp_avg_sqs,
121        max_exp_avg_sqs,
122        state_steps,
123    ):
124        has_complex = False
125        for p in group["params"]:
126            if p.grad is not None:
127                has_complex |= torch.is_complex(p)
128                params_with_grad.append(p)
129                if p.grad.is_sparse:
130                    raise RuntimeError(
131                        "Adam does not support sparse gradients, please consider SparseAdam instead"
132                    )
133                grads.append(p.grad)
134
135                state = self.state[p]
136                # Lazy state initialization
137                if len(state) == 0:
138                    if group["fused"]:
139                        _device_dtype_check_for_fused(p)
140                    # note(crcrpar): [special device hosting for step]
141                    # Deliberately host `step` on CPU if both capturable and fused are off.
142                    # This is because kernel launches are costly on CUDA and XLA.
143                    state["step"] = (
144                        torch.zeros(
145                            (),
146                            dtype=_get_scalar_dtype(is_fused=group["fused"]),
147                            device=p.device,
148                        )
149                        if group["capturable"] or group["fused"]
150                        else torch.tensor(0.0, dtype=_get_scalar_dtype())
151                    )
152                    # Exponential moving average of gradient values
153                    state["exp_avg"] = torch.zeros_like(
154                        p, memory_format=torch.preserve_format
155                    )
156                    # Exponential moving average of squared gradient values
157                    state["exp_avg_sq"] = torch.zeros_like(
158                        p, memory_format=torch.preserve_format
159                    )
160                    if group["amsgrad"]:
161                        # Maintains max of all exp. moving avg. of sq. grad. values
162                        state["max_exp_avg_sq"] = torch.zeros_like(
163                            p, memory_format=torch.preserve_format
164                        )
165
166                exp_avgs.append(state["exp_avg"])
167                exp_avg_sqs.append(state["exp_avg_sq"])
168
169                if group["amsgrad"]:
170                    max_exp_avg_sqs.append(state["max_exp_avg_sq"])
171                if group["differentiable"] and state["step"].requires_grad:
172                    raise RuntimeError(
173                        "`requires_grad` is not supported for `step` in differentiable mode"
174                    )
175
176                # Foreach without capturable does not support a tensor lr
177                if (
178                    group["foreach"]
179                    and torch.is_tensor(group["lr"])
180                    and not group["capturable"]
181                ):
182                    raise RuntimeError(
183                        "lr as a Tensor is not supported for capturable=False and foreach=True"
184                    )
185
186                state_steps.append(state["step"])
187        return has_complex
188
189    @_use_grad_for_differentiable
190    def step(self, closure=None):
191        """Perform a single optimization step.
192
193        Args:
194            closure (Callable, optional): A closure that reevaluates the model
195                and returns the loss.
196        """
197        self._cuda_graph_capture_health_check()
198
199        loss = None
200        if closure is not None:
201            with torch.enable_grad():
202                loss = closure()
203
204        for group in self.param_groups:
205            params_with_grad: List[Tensor] = []
206            grads: List[Tensor] = []
207            exp_avgs: List[Tensor] = []
208            exp_avg_sqs: List[Tensor] = []
209            max_exp_avg_sqs: List[Tensor] = []
210            state_steps: List[Tensor] = []
211            beta1, beta2 = group["betas"]
212
213            has_complex = self._init_group(
214                group,
215                params_with_grad,
216                grads,
217                exp_avgs,
218                exp_avg_sqs,
219                max_exp_avg_sqs,
220                state_steps,
221            )
222
223            adam(
224                params_with_grad,
225                grads,
226                exp_avgs,
227                exp_avg_sqs,
228                max_exp_avg_sqs,
229                state_steps,
230                amsgrad=group["amsgrad"],
231                has_complex=has_complex,
232                beta1=beta1,
233                beta2=beta2,
234                lr=group["lr"],
235                weight_decay=group["weight_decay"],
236                eps=group["eps"],
237                maximize=group["maximize"],
238                foreach=group["foreach"],
239                capturable=group["capturable"],
240                differentiable=group["differentiable"],
241                fused=group["fused"],
242                grad_scale=getattr(self, "grad_scale", None),
243                found_inf=getattr(self, "found_inf", None),
244            )
245
246        return loss
247
248
249Adam.__doc__ = (
250    r"""Implements Adam algorithm.
251
252    .. math::
253       \begin{aligned}
254            &\rule{110mm}{0.4pt}                                                                 \\
255            &\textbf{input}      : \gamma \text{ (lr)}, \beta_1, \beta_2
256                \text{ (betas)},\theta_0 \text{ (params)},f(\theta) \text{ (objective)}          \\
257            &\hspace{13mm}      \lambda \text{ (weight decay)},  \: \textit{amsgrad},
258                \:\textit{maximize}                                                              \\
259            &\textbf{initialize} :  m_0 \leftarrow 0 \text{ ( first moment)},
260                v_0\leftarrow 0 \text{ (second moment)},\: \widehat{v_0}^{max}\leftarrow 0\\[-1.ex]
261            &\rule{110mm}{0.4pt}                                                                 \\
262            &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do}                         \\
263
264            &\hspace{5mm}\textbf{if} \: \textit{maximize}:                                       \\
265            &\hspace{10mm}g_t           \leftarrow   -\nabla_{\theta} f_t (\theta_{t-1})         \\
266            &\hspace{5mm}\textbf{else}                                                           \\
267            &\hspace{10mm}g_t           \leftarrow   \nabla_{\theta} f_t (\theta_{t-1})          \\
268            &\hspace{5mm}\textbf{if} \: \lambda \neq 0                                           \\
269            &\hspace{10mm} g_t \leftarrow g_t + \lambda  \theta_{t-1}                            \\
270            &\hspace{5mm}m_t           \leftarrow   \beta_1 m_{t-1} + (1 - \beta_1) g_t          \\
271            &\hspace{5mm}v_t           \leftarrow   \beta_2 v_{t-1} + (1-\beta_2) g^2_t          \\
272            &\hspace{5mm}\widehat{m_t} \leftarrow   m_t/\big(1-\beta_1^t \big)                   \\
273            &\hspace{5mm}\widehat{v_t} \leftarrow   v_t/\big(1-\beta_2^t \big)                   \\
274            &\hspace{5mm}\textbf{if} \: amsgrad                                                  \\
275            &\hspace{10mm}\widehat{v_t}^{max} \leftarrow \mathrm{max}(\widehat{v_t}^{max},
276                \widehat{v_t})                                                                   \\
277            &\hspace{10mm}\theta_t \leftarrow \theta_{t-1} - \gamma \widehat{m_t}/
278                \big(\sqrt{\widehat{v_t}^{max}} + \epsilon \big)                                 \\
279            &\hspace{5mm}\textbf{else}                                                           \\
280            &\hspace{10mm}\theta_t \leftarrow \theta_{t-1} - \gamma \widehat{m_t}/
281                \big(\sqrt{\widehat{v_t}} + \epsilon \big)                                       \\
282            &\rule{110mm}{0.4pt}                                                          \\[-1.ex]
283            &\bf{return} \:  \theta_t                                                     \\[-1.ex]
284            &\rule{110mm}{0.4pt}                                                          \\[-1.ex]
285       \end{aligned}
286
287    For further details regarding the algorithm we refer to `Adam: A Method for Stochastic Optimization`_.
288    """
289    + rf"""
290    Args:
291        params (iterable): iterable of parameters to optimize or dicts defining
292            parameter groups
293        lr (float, Tensor, optional): learning rate (default: 1e-3). A tensor LR
294            is not yet supported for all our implementations. Please use a float
295            LR if you are not also specifying fused=True or capturable=True.
296        betas (Tuple[float, float], optional): coefficients used for computing
297            running averages of gradient and its square (default: (0.9, 0.999))
298        eps (float, optional): term added to the denominator to improve
299            numerical stability (default: 1e-8)
300        weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
301        amsgrad (bool, optional): whether to use the AMSGrad variant of this
302            algorithm from the paper `On the Convergence of Adam and Beyond`_
303            (default: False)
304        {_foreach_doc}
305        {_maximize_doc}
306        {_capturable_doc}
307        {_differentiable_doc}
308        {_fused_doc}
309    .. Note::
310        A prototype implementation of Adam and AdamW for MPS supports `torch.float32` and `torch.float16`.
311    .. _Adam\: A Method for Stochastic Optimization:
312        https://arxiv.org/abs/1412.6980
313    .. _On the Convergence of Adam and Beyond:
314        https://openreview.net/forum?id=ryQu7f-RZ
315
316    """
317)
318
319
320def _single_tensor_adam(
321    params: List[Tensor],
322    grads: List[Tensor],
323    exp_avgs: List[Tensor],
324    exp_avg_sqs: List[Tensor],
325    max_exp_avg_sqs: List[Tensor],
326    state_steps: List[Tensor],
327    grad_scale: Optional[Tensor],
328    found_inf: Optional[Tensor],
329    *,
330    amsgrad: bool,
331    has_complex: bool,
332    beta1: float,
333    beta2: float,
334    lr: Union[float, Tensor],
335    weight_decay: float,
336    eps: float,
337    maximize: bool,
338    capturable: bool,
339    differentiable: bool,
340):
341    assert grad_scale is None and found_inf is None
342
343    if torch.jit.is_scripting():
344        # this assert is due to JIT being dumb and not realizing that the ops below
345        # have overloads to handle both float and Tensor lrs, so we just assert it's
346        # a float since most people using JIT are using floats
347        assert isinstance(lr, float)
348
349    for i, param in enumerate(params):
350        grad = grads[i] if not maximize else -grads[i]
351        exp_avg = exp_avgs[i]
352        exp_avg_sq = exp_avg_sqs[i]
353        step_t = state_steps[i]
354
355        # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
356        if not torch._utils.is_compiling() and capturable:
357            capturable_supported_devices = _get_capturable_supported_devices()
358            assert (
359                param.device.type == step_t.device.type
360                and param.device.type in capturable_supported_devices
361            ), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
362
363        # update step
364        step_t += 1
365
366        if weight_decay != 0:
367            grad = grad.add(param, alpha=weight_decay)
368
369        if torch.is_complex(param):
370            grad = torch.view_as_real(grad)
371            exp_avg = torch.view_as_real(exp_avg)
372            exp_avg_sq = torch.view_as_real(exp_avg_sq)
373            if amsgrad:
374                max_exp_avg_sqs[i] = torch.view_as_real(max_exp_avg_sqs[i])
375            param = torch.view_as_real(param)
376
377        # Decay the first and second moment running average coefficient
378        exp_avg.lerp_(grad, 1 - beta1)
379        exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2)
380
381        if capturable or differentiable:
382            step = step_t
383
384            bias_correction1 = 1 - beta1**step
385            bias_correction2 = 1 - beta2**step
386
387            step_size = lr / bias_correction1
388            step_size_neg = step_size.neg()
389
390            bias_correction2_sqrt = bias_correction2.sqrt()
391
392            if amsgrad:
393                # Maintains the maximum of all 2nd moment running avg. till now
394                if differentiable:
395                    max_exp_avg_sq = max_exp_avg_sqs[i].clone()
396                else:
397                    max_exp_avg_sq = max_exp_avg_sqs[i]
398
399                max_exp_avg_sqs[i].copy_(torch.maximum(max_exp_avg_sq, exp_avg_sq))
400
401                # Uses the max. for normalizing running avg. of gradient
402                # Folds in (admittedly ugly) 1-elem step_size math here to avoid extra param-set-sized read+write
403                # (can't fold it into addcdiv_ below because addcdiv_ requires value is a Number, not a Tensor)
404                denom = (
405                    max_exp_avg_sqs[i].sqrt() / (bias_correction2_sqrt * step_size_neg)
406                ).add_(eps / step_size_neg)
407            else:
408                denom = (
409                    exp_avg_sq.sqrt() / (bias_correction2_sqrt * step_size_neg)
410                ).add_(eps / step_size_neg)
411
412            param.addcdiv_(exp_avg, denom)
413        else:
414            step = _get_value(step_t)
415
416            bias_correction1 = 1 - beta1**step
417            bias_correction2 = 1 - beta2**step
418
419            step_size = lr / bias_correction1
420
421            bias_correction2_sqrt = bias_correction2**0.5
422
423            if amsgrad:
424                # Maintains the maximum of all 2nd moment running avg. till now
425                torch.maximum(max_exp_avg_sqs[i], exp_avg_sq, out=max_exp_avg_sqs[i])
426
427                # Use the max. for normalizing running avg. of gradient
428                denom = (max_exp_avg_sqs[i].sqrt() / bias_correction2_sqrt).add_(eps)
429            else:
430                denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps)
431
432            param.addcdiv_(exp_avg, denom, value=-step_size)
433
434        # Lastly, switch back to complex view
435        if amsgrad and torch.is_complex(params[i]):
436            max_exp_avg_sqs[i] = torch.view_as_complex(max_exp_avg_sqs[i])
437
438
439def _multi_tensor_adam(
440    params: List[Tensor],
441    grads: List[Tensor],
442    exp_avgs: List[Tensor],
443    exp_avg_sqs: List[Tensor],
444    max_exp_avg_sqs: List[Tensor],
445    state_steps: List[Tensor],
446    grad_scale: Optional[Tensor],
447    found_inf: Optional[Tensor],
448    *,
449    amsgrad: bool,
450    has_complex: bool,
451    beta1: float,
452    beta2: float,
453    lr: Union[float, Tensor],
454    weight_decay: float,
455    eps: float,
456    maximize: bool,
457    capturable: bool,
458    differentiable: bool,
459):
460    if len(params) == 0:
461        return
462
463    if isinstance(lr, Tensor) and not capturable:
464        raise RuntimeError(
465            "lr as a Tensor is not supported for capturable=False and foreach=True"
466        )
467
468    # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
469    if not torch._utils.is_compiling() and capturable:
470        capturable_supported_devices = _get_capturable_supported_devices(
471            supports_xla=False
472        )
473        assert all(
474            p.device.type == step.device.type
475            and p.device.type in capturable_supported_devices
476            for p, step in zip(params, state_steps)
477        ), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
478
479    assert grad_scale is None and found_inf is None
480
481    assert not differentiable, "_foreach ops don't support autograd"
482
483    grouped_tensors = Optimizer._group_tensors_by_device_and_dtype(
484        [params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps]  # type: ignore[list-item]
485    )
486    for (
487        device_params_,
488        device_grads_,
489        device_exp_avgs_,
490        device_exp_avg_sqs_,
491        device_max_exp_avg_sqs_,
492        device_state_steps_,
493    ), _ in grouped_tensors.values():
494        device_params = cast(List[Tensor], device_params_)
495        device_grads = cast(List[Tensor], device_grads_)
496        device_exp_avgs = cast(List[Tensor], device_exp_avgs_)
497        device_exp_avg_sqs = cast(List[Tensor], device_exp_avg_sqs_)
498        device_state_steps = cast(List[Tensor], device_state_steps_)
499
500        # Handle complex parameters
501        if has_complex:
502            if amsgrad:
503                device_max_exp_avg_sqs = cast(List[Tensor], device_max_exp_avg_sqs_)
504                _view_as_real(
505                    device_params,
506                    device_grads,
507                    device_exp_avgs,
508                    device_exp_avg_sqs,
509                    device_max_exp_avg_sqs,
510                )
511            else:
512                _view_as_real(
513                    device_params, device_grads, device_exp_avgs, device_exp_avg_sqs
514                )
515
516        if maximize:
517            device_grads = torch._foreach_neg(device_grads)  # type: ignore[assignment]
518
519        # Update steps
520        # If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over
521        # and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just
522        # wrapped it once now. The alpha is required to assure we go to the right overload.
523        if not torch._utils.is_compiling() and device_state_steps[0].is_cpu:
524            torch._foreach_add_(
525                device_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0
526            )
527        else:
528            torch._foreach_add_(device_state_steps, 1)
529
530        if weight_decay != 0:
531            # Re-use the intermediate memory (device_grads) already allocated for maximize
532            if maximize:
533                torch._foreach_add_(device_grads, device_params, alpha=weight_decay)
534            else:
535                device_grads = torch._foreach_add(  # type: ignore[assignment]
536                    device_grads, device_params, alpha=weight_decay
537                )
538
539        # Decay the first and second moment running average coefficient
540        torch._foreach_lerp_(device_exp_avgs, device_grads, 1 - beta1)
541
542        torch._foreach_mul_(device_exp_avg_sqs, beta2)
543        torch._foreach_addcmul_(
544            device_exp_avg_sqs, device_grads, device_grads, 1 - beta2
545        )
546
547        # Delete the local intermediate since it won't be used anymore to save on peak memory
548        del device_grads
549
550        bias_correction1: Union[Tuple[Tensor, ...], List[Tensor]]
551        bias_correction2: Union[Tuple[Tensor, ...], List[Tensor]]
552        bias_correction2_sqrt: Union[Tuple[Tensor, ...], List[Tensor]]
553
554        if capturable:
555            bias_correction1 = torch._foreach_pow(beta1, device_state_steps)
556            bias_correction2 = torch._foreach_pow(beta2, device_state_steps)
557            # foreach_sub doesn't allow a scalar as the first arg
558            torch._foreach_sub_(bias_correction1, 1)
559            torch._foreach_sub_(bias_correction2, 1)
560            # we do not negate bias_correction1 as it'll need to be negated later anyway
561            torch._foreach_neg_(bias_correction2)
562
563            # foreach_div doesn't allow a scalar as the first arg
564            torch._foreach_div_(bias_correction1, lr)
565            torch._foreach_reciprocal_(bias_correction1)
566
567            torch._foreach_sqrt_(bias_correction2)
568
569            # Re-assign for clarity as we maintain minimal intermediates: we'll have
570            # step_size = - lr / (1 - beta1 ^ t) where t = num_steps
571            # bias_correction2_sqrt = sqrt(1 - beta2 ^ t)
572            step_size = bias_correction1
573            bias_correction2_sqrt = bias_correction2
574
575            if amsgrad:
576                device_max_exp_avg_sqs = cast(List[Tensor], device_max_exp_avg_sqs_)
577                # Maintains the maximum of all 2nd moment running avg. till now
578                torch._foreach_maximum_(device_max_exp_avg_sqs, device_exp_avg_sqs)  # type: ignore[assignment]
579
580                # Set intermediate to the max. for normalizing running avg. of gradient when amsgrad
581                exp_avg_sq_sqrt = torch._foreach_sqrt(device_max_exp_avg_sqs)
582            else:
583                exp_avg_sq_sqrt = torch._foreach_sqrt(device_exp_avg_sqs)
584
585            torch._foreach_div_(exp_avg_sq_sqrt, bias_correction2_sqrt)
586            torch._foreach_add_(exp_avg_sq_sqrt, eps)
587            torch._foreach_div_(exp_avg_sq_sqrt, step_size)
588
589            # at this point, exp_avg_sq_sqrt = - (1 - beta^t) * [sqrt(exp_avg_sq / (1 - beta2^t)) + eps] / lr
590            torch._foreach_addcdiv_(device_params, device_exp_avgs, exp_avg_sq_sqrt)
591        else:
592            bias_correction1 = [
593                1 - beta1 ** _get_value(step) for step in device_state_steps
594            ]
595            bias_correction2 = [
596                1 - beta2 ** _get_value(step) for step in device_state_steps
597            ]
598
599            step_size = _stack_if_compiling([(lr / bc) * -1 for bc in bias_correction1])
600
601            bias_correction2_sqrt = [bc**0.5 for bc in bias_correction2]  # type: ignore[arg-type]
602
603            if amsgrad:
604                device_max_exp_avg_sqs = cast(List[Tensor], device_max_exp_avg_sqs_)
605                # Maintains the maximum of all 2nd moment running avg. till now
606                torch._foreach_maximum_(device_max_exp_avg_sqs, device_exp_avg_sqs)
607
608                # Use the max. for normalizing running avg. of gradient
609                exp_avg_sq_sqrt = torch._foreach_sqrt(device_max_exp_avg_sqs)
610            else:
611                exp_avg_sq_sqrt = torch._foreach_sqrt(device_exp_avg_sqs)
612
613            torch._foreach_div_(exp_avg_sq_sqrt, bias_correction2_sqrt)
614            torch._foreach_add_(exp_avg_sq_sqrt, eps)
615            torch._foreach_addcdiv_(
616                device_params, device_exp_avgs, exp_avg_sq_sqrt, step_size  # type: ignore[arg-type]
617            )
618
619
620def _fused_adam(
621    params: List[Tensor],
622    grads: List[Tensor],
623    exp_avgs: List[Tensor],
624    exp_avg_sqs: List[Tensor],
625    max_exp_avg_sqs: List[Tensor],
626    state_steps: List[Tensor],
627    grad_scale: Optional[Tensor],
628    found_inf: Optional[Tensor],
629    *,
630    amsgrad: bool,
631    has_complex: bool,  # Needed for consistency.
632    beta1: float,
633    beta2: float,
634    lr: Union[float, Tensor],
635    weight_decay: float,
636    eps: float,
637    maximize: bool,
638    capturable: bool,  # Needed for consistency.
639    differentiable: bool,
640) -> None:
641    if not params:
642        return
643    if differentiable:
644        raise RuntimeError("Adam with fused=True does not support differentiable=True")
645
646    grad_scale_dict: DeviceDict = (
647        {grad_scale.device: grad_scale} if grad_scale is not None else {}
648    )
649    found_inf_dict: DeviceDict = (
650        {found_inf.device: found_inf} if found_inf is not None else {}
651    )
652
653    # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer
654    # treating it as a scalar.
655    lr_dict: Optional[DeviceDict] = (
656        {lr.device: lr} if isinstance(lr, Tensor) and str(lr.device) != "cpu" else None
657    )
658    grouped_tensors = Optimizer._group_tensors_by_device_and_dtype(
659        [params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps]  # type: ignore[list-item]
660    )
661    for (device, _), (
662        (
663            device_params_,
664            device_grads_,
665            device_exp_avgs_,
666            device_exp_avg_sqs_,
667            device_max_exp_avg_sqs,
668            device_state_steps_,
669        ),
670        _,
671    ) in grouped_tensors.items():
672        device_params = cast(List[Tensor], device_params_)
673        device_grads = cast(List[Tensor], device_grads_)
674        device_exp_avgs = cast(List[Tensor], device_exp_avgs_)
675        device_exp_avg_sqs = cast(List[Tensor], device_exp_avg_sqs_)
676        device_state_steps = cast(List[Tensor], device_state_steps_)
677
678        if device.type == "mps":  # type: ignore[union-attr]
679            assert found_inf is None and grad_scale is None
680
681        device_grad_scale, device_found_inf = None, None
682        if grad_scale is not None:
683            device_grad_scale = grad_scale_dict.setdefault(
684                device, grad_scale.to(device, non_blocking=True)
685            )
686        if found_inf is not None:
687            device_found_inf = found_inf_dict.setdefault(
688                device, found_inf.to(device, non_blocking=True)
689            )
690        if lr_dict is not None and device not in lr_dict:
691            lr_dict[device] = lr.to(device=device, non_blocking=True)  # type: ignore[union-attr]
692            lr = lr_dict[device]
693        torch._foreach_add_(device_state_steps, 1)
694        torch._fused_adam_(
695            device_params,
696            device_grads,
697            device_exp_avgs,
698            device_exp_avg_sqs,
699            device_max_exp_avg_sqs,  # type: ignore[arg-type]
700            device_state_steps,
701            amsgrad=amsgrad,
702            lr=lr,  # type: ignore[arg-type]
703            beta1=beta1,
704            beta2=beta2,
705            weight_decay=weight_decay,
706            eps=eps,
707            maximize=maximize,
708            grad_scale=device_grad_scale,
709            found_inf=device_found_inf,
710        )
711        if device_found_inf is not None:
712            torch._foreach_sub_(
713                device_state_steps, [device_found_inf] * len(device_state_steps)
714            )
715
716
717@_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_adam)
718def adam(
719    params: List[Tensor],
720    grads: List[Tensor],
721    exp_avgs: List[Tensor],
722    exp_avg_sqs: List[Tensor],
723    max_exp_avg_sqs: List[Tensor],
724    state_steps: List[Tensor],
725    # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
726    # setting this as kwarg for now as functional API is compiled by torch/distributed/optim
727    foreach: Optional[bool] = None,
728    capturable: bool = False,
729    differentiable: bool = False,
730    fused: Optional[bool] = None,
731    grad_scale: Optional[Tensor] = None,
732    found_inf: Optional[Tensor] = None,
733    has_complex: bool = False,
734    *,
735    amsgrad: bool,
736    beta1: float,
737    beta2: float,
738    lr: Union[float, Tensor],
739    weight_decay: float,
740    eps: float,
741    maximize: bool,
742):
743    r"""Functional API that performs Adam algorithm computation.
744
745    See :class:`~torch.optim.Adam` for details.
746    """
747    # Respect when the user inputs False/True for foreach or fused. We only want to change
748    # the default when neither have been user-specified. Note that we default to foreach
749    # and pass False to use_fused. This is not a mistake--we want to give the fused impl
750    # bake-in time before making it the default, even if it is typically faster.
751    if fused is None and foreach is None:
752        _, foreach = _default_to_fused_or_foreach(
753            params, differentiable, use_fused=False
754        )
755        # Do not flip on foreach for the unsupported case where lr is a Tensor and capturable=False.
756        if foreach and isinstance(lr, Tensor) and not capturable:
757            foreach = False
758    if fused is None:
759        fused = False
760    if foreach is None:
761        foreach = False
762
763    # this check is slow during compilation, so we skip it
764    # if it's strictly needed we can add this check back in dynamo
765    if not torch._utils.is_compiling() and not all(
766        isinstance(t, torch.Tensor) for t in state_steps
767    ):
768        raise RuntimeError(
769            "API has changed, `state_steps` argument must contain a list of singleton tensors"
770        )
771
772    if foreach and torch.jit.is_scripting():
773        raise RuntimeError("torch.jit.script not supported with foreach optimizers")
774    if fused and torch.jit.is_scripting():
775        raise RuntimeError("torch.jit.script not supported with fused optimizers")
776
777    if fused and not torch.jit.is_scripting():
778        func = _fused_adam
779    elif foreach and not torch.jit.is_scripting():
780        func = _multi_tensor_adam
781    else:
782        func = _single_tensor_adam
783
784    func(
785        params,
786        grads,
787        exp_avgs,
788        exp_avg_sqs,
789        max_exp_avg_sqs,
790        state_steps,
791        amsgrad=amsgrad,
792        has_complex=has_complex,
793        beta1=beta1,
794        beta2=beta2,
795        lr=lr,
796        weight_decay=weight_decay,
797        eps=eps,
798        maximize=maximize,
799        capturable=capturable,
800        differentiable=differentiable,
801        grad_scale=grad_scale,
802        found_inf=found_inf,
803    )
804