xref: /aosp_15_r20/external/pytorch/torch/optim/adamw.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-decorators
2# mypy: allow-untyped-defs
3from typing import cast, List, Optional, Tuple, Union
4
5import torch
6from torch import Tensor
7
8from .optimizer import (
9    _capturable_doc,
10    _default_to_fused_or_foreach,
11    _device_dtype_check_for_fused,
12    _differentiable_doc,
13    _disable_dynamo_if_unsupported,
14    _foreach_doc,
15    _fused_doc,
16    _get_capturable_supported_devices,
17    _get_scalar_dtype,
18    _get_value,
19    _maximize_doc,
20    _stack_if_compiling,
21    _use_grad_for_differentiable,
22    _view_as_real,
23    DeviceDict,
24    Optimizer,
25    ParamsT,
26)
27
28
29__all__ = ["AdamW", "adamw"]
30
31
32class AdamW(Optimizer):
33    def __init__(
34        self,
35        params: ParamsT,
36        lr: Union[float, Tensor] = 1e-3,
37        betas: Tuple[float, float] = (0.9, 0.999),
38        eps: float = 1e-8,
39        weight_decay: float = 1e-2,
40        amsgrad: bool = False,
41        *,
42        maximize: bool = False,
43        foreach: Optional[bool] = None,
44        capturable: bool = False,
45        differentiable: bool = False,
46        fused: Optional[bool] = None,
47    ):
48        if isinstance(lr, Tensor):
49            if foreach and not capturable:
50                raise ValueError(
51                    "lr as a Tensor is not supported for capturable=False and foreach=True"
52                )
53            if lr.numel() != 1:
54                raise ValueError("Tensor lr must be 1-element")
55        if not 0.0 <= lr:
56            raise ValueError(f"Invalid learning rate: {lr}")
57        if not 0.0 <= eps:
58            raise ValueError(f"Invalid epsilon value: {eps}")
59        if not 0.0 <= betas[0] < 1.0:
60            raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
61        if not 0.0 <= betas[1] < 1.0:
62            raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
63        if not 0.0 <= weight_decay:
64            raise ValueError(f"Invalid weight_decay value: {weight_decay}")
65        defaults = dict(
66            lr=lr,
67            betas=betas,
68            eps=eps,
69            weight_decay=weight_decay,
70            amsgrad=amsgrad,
71            foreach=foreach,
72            maximize=maximize,
73            capturable=capturable,
74            differentiable=differentiable,
75            fused=fused,
76        )
77        super().__init__(params, defaults)
78
79        if fused:
80            if differentiable:
81                raise RuntimeError("`fused` does not support `differentiable`")
82            self._step_supports_amp_scaling = True
83            if foreach:
84                raise RuntimeError("`fused` and `foreach` cannot be `True` together.")
85
86    def __setstate__(self, state):
87        super().__setstate__(state)
88        for group in self.param_groups:
89            group.setdefault("amsgrad", False)
90            group.setdefault("maximize", False)
91            group.setdefault("foreach", None)
92            group.setdefault("capturable", False)
93            group.setdefault("differentiable", False)
94            fused = group.setdefault("fused", None)
95            for p in group["params"]:
96                p_state = self.state.get(p, [])
97                if len(p_state) != 0 and not torch.is_tensor(p_state["step"]):
98                    step_val = float(p_state["step"])
99                    p_state["step"] = (
100                        torch.tensor(
101                            step_val,
102                            dtype=_get_scalar_dtype(is_fused=fused),
103                            device=p.device,
104                        )
105                        if group["capturable"] or group["fused"]
106                        else torch.tensor(step_val, dtype=_get_scalar_dtype())
107                    )
108
109    def _init_group(
110        self,
111        group,
112        params_with_grad,
113        grads,
114        amsgrad,
115        exp_avgs,
116        exp_avg_sqs,
117        max_exp_avg_sqs,
118        state_steps,
119    ):
120        has_complex = False
121        for p in group["params"]:
122            if p.grad is None:
123                continue
124            has_complex |= torch.is_complex(p)
125            params_with_grad.append(p)
126            if p.grad.is_sparse:
127                raise RuntimeError("AdamW does not support sparse gradients")
128            grads.append(p.grad)
129
130            state = self.state[p]
131
132            # State initialization
133            if len(state) == 0:
134                if group["fused"]:
135                    _device_dtype_check_for_fused(p)
136                # note(crcrpar): Deliberately host `step` on CPU if both capturable and fused are off.
137                # This is because kernel launches are costly on CUDA and XLA.
138                state["step"] = (
139                    torch.zeros(
140                        (),
141                        dtype=_get_scalar_dtype(is_fused=group["fused"]),
142                        device=p.device,
143                    )
144                    if group["capturable"] or group["fused"]
145                    else torch.tensor(0.0, dtype=_get_scalar_dtype())
146                )
147                # Exponential moving average of gradient values
148                state["exp_avg"] = torch.zeros_like(
149                    p, memory_format=torch.preserve_format
150                )
151                # Exponential moving average of squared gradient values
152                state["exp_avg_sq"] = torch.zeros_like(
153                    p, memory_format=torch.preserve_format
154                )
155                if amsgrad:
156                    # Maintains max of all exp. moving avg. of sq. grad. values
157                    state["max_exp_avg_sq"] = torch.zeros_like(
158                        p, memory_format=torch.preserve_format
159                    )
160
161            exp_avgs.append(state["exp_avg"])
162            exp_avg_sqs.append(state["exp_avg_sq"])
163
164            if group["amsgrad"]:
165                max_exp_avg_sqs.append(state["max_exp_avg_sq"])
166            if group["differentiable"] and state["step"].requires_grad:
167                raise RuntimeError(
168                    "`requires_grad` is not supported for `step` in differentiable mode"
169                )
170
171            # Foreach without capturable does not support a tensor lr
172            if (
173                group["foreach"]
174                and isinstance(group["lr"], Tensor)
175                and not group["capturable"]
176            ):
177                raise RuntimeError(
178                    "lr as a Tensor is not supported for capturable=False and foreach=True"
179                )
180
181            state_steps.append(state["step"])
182        return has_complex
183
184    @_use_grad_for_differentiable
185    def step(self, closure=None):
186        """Perform a single optimization step.
187
188        Args:
189            closure (Callable, optional): A closure that reevaluates the model
190                and returns the loss.
191        """
192        self._cuda_graph_capture_health_check()
193
194        loss = None
195        if closure is not None:
196            with torch.enable_grad():
197                loss = closure()
198
199        for group in self.param_groups:
200            params_with_grad: List[Tensor] = []
201            grads: List[Tensor] = []
202            exp_avgs: List[Tensor] = []
203            exp_avg_sqs: List[Tensor] = []
204            max_exp_avg_sqs: List[Tensor] = []
205            state_steps: List[Tensor] = []
206            amsgrad: bool = group["amsgrad"]
207            beta1, beta2 = cast(Tuple[float, float], group["betas"])
208
209            has_complex = self._init_group(
210                group,
211                params_with_grad,
212                grads,
213                amsgrad,
214                exp_avgs,
215                exp_avg_sqs,
216                max_exp_avg_sqs,
217                state_steps,
218            )
219
220            adamw(
221                params_with_grad,
222                grads,
223                exp_avgs,
224                exp_avg_sqs,
225                max_exp_avg_sqs,
226                state_steps,
227                amsgrad=amsgrad,
228                beta1=beta1,
229                beta2=beta2,
230                lr=group["lr"],
231                weight_decay=group["weight_decay"],
232                eps=group["eps"],
233                maximize=group["maximize"],
234                foreach=group["foreach"],
235                capturable=group["capturable"],
236                differentiable=group["differentiable"],
237                fused=group["fused"],
238                grad_scale=getattr(self, "grad_scale", None),
239                found_inf=getattr(self, "found_inf", None),
240                has_complex=has_complex,
241            )
242
243        return loss
244
245
246AdamW.__doc__ = (
247    r"""Implements AdamW algorithm.
248
249    .. math::
250       \begin{aligned}
251            &\rule{110mm}{0.4pt}                                                                 \\
252            &\textbf{input}      : \gamma \text{(lr)}, \: \beta_1, \beta_2
253                \text{(betas)}, \: \theta_0 \text{(params)}, \: f(\theta) \text{(objective)},
254                \: \epsilon \text{ (epsilon)}                                                    \\
255            &\hspace{13mm}      \lambda \text{(weight decay)},  \: \textit{amsgrad},
256                \: \textit{maximize}                                                             \\
257            &\textbf{initialize} : m_0 \leftarrow 0 \text{ (first moment)}, v_0 \leftarrow 0
258                \text{ ( second moment)}, \: \widehat{v_0}^{max}\leftarrow 0              \\[-1.ex]
259            &\rule{110mm}{0.4pt}                                                                 \\
260            &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do}                         \\
261
262            &\hspace{5mm}\textbf{if} \: \textit{maximize}:                                       \\
263            &\hspace{10mm}g_t           \leftarrow   -\nabla_{\theta} f_t (\theta_{t-1})          \\
264            &\hspace{5mm}\textbf{else}                                                           \\
265            &\hspace{10mm}g_t           \leftarrow   \nabla_{\theta} f_t (\theta_{t-1})           \\
266            &\hspace{5mm} \theta_t \leftarrow \theta_{t-1} - \gamma \lambda \theta_{t-1}         \\
267            &\hspace{5mm}m_t           \leftarrow   \beta_1 m_{t-1} + (1 - \beta_1) g_t          \\
268            &\hspace{5mm}v_t           \leftarrow   \beta_2 v_{t-1} + (1-\beta_2) g^2_t          \\
269            &\hspace{5mm}\widehat{m_t} \leftarrow   m_t/\big(1-\beta_1^t \big)                   \\
270            &\hspace{5mm}\widehat{v_t} \leftarrow   v_t/\big(1-\beta_2^t \big)                   \\
271            &\hspace{5mm}\textbf{if} \: amsgrad                                                  \\
272            &\hspace{10mm}\widehat{v_t}^{max} \leftarrow \mathrm{max}(\widehat{v_t}^{max},
273                \widehat{v_t})                                                                   \\
274            &\hspace{10mm}\theta_t \leftarrow \theta_t - \gamma \widehat{m_t}/
275                \big(\sqrt{\widehat{v_t}^{max}} + \epsilon \big)                                 \\
276            &\hspace{5mm}\textbf{else}                                                           \\
277            &\hspace{10mm}\theta_t \leftarrow \theta_t - \gamma \widehat{m_t}/
278                \big(\sqrt{\widehat{v_t}} + \epsilon \big)                                       \\
279            &\rule{110mm}{0.4pt}                                                          \\[-1.ex]
280            &\bf{return} \:  \theta_t                                                     \\[-1.ex]
281            &\rule{110mm}{0.4pt}                                                          \\[-1.ex]
282       \end{aligned}
283
284    For further details regarding the algorithm we refer to `Decoupled Weight Decay Regularization`_.
285    """
286    + rf"""
287    Args:
288        params (iterable): iterable of parameters to optimize or dicts defining
289            parameter groups
290        lr (float, Tensor, optional): learning rate (default: 1e-3). A tensor LR
291            is not yet supported for all our implementations. Please use a float
292            LR if you are not also specifying fused=True or capturable=True.
293        betas (Tuple[float, float], optional): coefficients used for computing
294            running averages of gradient and its square (default: (0.9, 0.999))
295        eps (float, optional): term added to the denominator to improve
296            numerical stability (default: 1e-8)
297        weight_decay (float, optional): weight decay coefficient (default: 1e-2)
298        amsgrad (bool, optional): whether to use the AMSGrad variant of this
299            algorithm from the paper `On the Convergence of Adam and Beyond`_
300            (default: False)
301        {_maximize_doc}
302        {_foreach_doc}
303        {_capturable_doc}
304        {_differentiable_doc}
305        {_fused_doc}
306    .. Note::
307        A prototype implementation of Adam and AdamW for MPS supports `torch.float32` and `torch.float16`.
308    .. _Decoupled Weight Decay Regularization:
309        https://arxiv.org/abs/1711.05101
310    .. _On the Convergence of Adam and Beyond:
311        https://openreview.net/forum?id=ryQu7f-RZ
312
313    """
314)
315
316
317def _single_tensor_adamw(
318    params: List[Tensor],
319    grads: List[Tensor],
320    exp_avgs: List[Tensor],
321    exp_avg_sqs: List[Tensor],
322    max_exp_avg_sqs: List[Tensor],
323    state_steps: List[Tensor],
324    grad_scale: Optional[Tensor],
325    found_inf: Optional[Tensor],
326    *,
327    amsgrad: bool,
328    beta1: float,
329    beta2: float,
330    lr: Union[Tensor, float],
331    weight_decay: float,
332    eps: float,
333    maximize: bool,
334    capturable: bool,
335    differentiable: bool,
336    has_complex: bool,
337):
338    assert grad_scale is None and found_inf is None
339
340    if torch.jit.is_scripting():
341        # this assert is due to JIT being dumb and not realizing that the ops below
342        # have overloads to handle both float and Tensor lrs, so we just assert it's
343        # a float since most people using JIT are using floats
344        assert isinstance(lr, float)
345
346    for i, param in enumerate(params):
347        grad = grads[i] if not maximize else -grads[i]
348        exp_avg = exp_avgs[i]
349        exp_avg_sq = exp_avg_sqs[i]
350        step_t = state_steps[i]
351
352        # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
353        if not torch._utils.is_compiling() and capturable:
354            capturable_supported_devices = _get_capturable_supported_devices()
355            assert (
356                param.device.type == step_t.device.type
357                and param.device.type in capturable_supported_devices
358            ), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
359
360        if torch.is_complex(param):
361            grad = torch.view_as_real(grad)
362            exp_avg = torch.view_as_real(exp_avg)
363            exp_avg_sq = torch.view_as_real(exp_avg_sq)
364            if amsgrad:
365                max_exp_avg_sqs[i] = torch.view_as_real(max_exp_avg_sqs[i])
366            param = torch.view_as_real(param)
367
368        # update step
369        step_t += 1
370
371        # Perform stepweight decay
372        param.mul_(1 - lr * weight_decay)
373
374        # Decay the first and second moment running average coefficient
375        exp_avg.lerp_(grad, 1 - beta1)
376        exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
377
378        if capturable or differentiable:
379            step = step_t
380
381            bias_correction1 = 1 - beta1**step
382            bias_correction2 = 1 - beta2**step
383
384            step_size = lr / bias_correction1
385            step_size_neg = step_size.neg()
386
387            bias_correction2_sqrt = bias_correction2.sqrt()
388
389            if amsgrad:
390                # Maintains the maximum of all 2nd moment running avg. till now
391                if differentiable:
392                    max_exp_avg_sq = max_exp_avg_sqs[i].clone()
393                else:
394                    max_exp_avg_sq = max_exp_avg_sqs[i]
395
396                max_exp_avg_sqs[i].copy_(torch.maximum(max_exp_avg_sq, exp_avg_sq))
397
398                # Uses the max. for normalizing running avg. of gradient
399                # Folds in (admittedly ugly) 1-elem step_size math here to avoid extra param-set-sized read+write
400                # (can't fold it into addcdiv_ below because addcdiv_ requires value is a Number, not a Tensor)
401                denom = (
402                    max_exp_avg_sqs[i].sqrt() / (bias_correction2_sqrt * step_size_neg)
403                ).add_(eps / step_size_neg)
404            else:
405                denom = (
406                    exp_avg_sq.sqrt() / (bias_correction2_sqrt * step_size_neg)
407                ).add_(eps / step_size_neg)
408
409            param.addcdiv_(exp_avg, denom)
410        else:
411            step = _get_value(step_t)
412
413            bias_correction1 = 1 - beta1**step
414            bias_correction2 = 1 - beta2**step
415
416            step_size = lr / bias_correction1
417
418            bias_correction2_sqrt = bias_correction2**0.5
419
420            if amsgrad:
421                # Maintains the maximum of all 2nd moment running avg. till now
422                torch.maximum(max_exp_avg_sqs[i], exp_avg_sq, out=max_exp_avg_sqs[i])
423
424                # Use the max. for normalizing running avg. of gradient
425                denom = (max_exp_avg_sqs[i].sqrt() / bias_correction2_sqrt).add_(eps)
426            else:
427                denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps)
428
429            param.addcdiv_(exp_avg, denom, value=-step_size)
430
431        # Lastly, switch back to complex view
432        if amsgrad and torch.is_complex(params[i]):
433            max_exp_avg_sqs[i] = torch.view_as_complex(max_exp_avg_sqs[i])
434
435
436def _multi_tensor_adamw(
437    params: List[Tensor],
438    grads: List[Tensor],
439    exp_avgs: List[Tensor],
440    exp_avg_sqs: List[Tensor],
441    max_exp_avg_sqs: List[Tensor],
442    state_steps: List[Tensor],
443    grad_scale: Optional[Tensor],
444    found_inf: Optional[Tensor],
445    *,
446    amsgrad: bool,
447    beta1: float,
448    beta2: float,
449    lr: Union[Tensor, float],
450    weight_decay: float,
451    eps: float,
452    maximize: bool,
453    capturable: bool,
454    differentiable: bool,
455    has_complex: bool,
456):
457    if len(params) == 0:
458        return
459
460    if isinstance(lr, Tensor) and not capturable:
461        raise RuntimeError(
462            "lr as a Tensor is not supported for capturable=False and foreach=True"
463        )
464
465    # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
466    if not torch._utils.is_compiling() and capturable:
467        capturable_supported_devices = _get_capturable_supported_devices(
468            supports_xla=False
469        )
470        assert all(
471            p.device.type == step.device.type
472            and p.device.type in capturable_supported_devices
473            for p, step in zip(params, state_steps)
474        ), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
475
476    assert not differentiable, "_foreach ops don't support autograd"
477
478    assert grad_scale is None and found_inf is None
479
480    grouped_tensors = Optimizer._group_tensors_by_device_and_dtype(
481        [params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps]  # type: ignore[list-item]
482    )
483    for (
484        device_params_,
485        device_grads_,
486        device_exp_avgs_,
487        device_exp_avg_sqs_,
488        device_max_exp_avg_sqs_,
489        device_state_steps_,
490    ), _ in grouped_tensors.values():
491        device_params = cast(List[Tensor], device_params_)
492        device_grads = cast(List[Tensor], device_grads_)
493        device_exp_avgs = cast(List[Tensor], device_exp_avgs_)
494        device_exp_avg_sqs = cast(List[Tensor], device_exp_avg_sqs_)
495        device_state_steps = cast(List[Tensor], device_state_steps_)
496
497        if has_complex:
498            if amsgrad:
499                device_max_exp_avg_sqs = cast(List[Tensor], device_max_exp_avg_sqs_)
500                _view_as_real(
501                    device_params,
502                    device_grads,
503                    device_exp_avgs,
504                    device_exp_avg_sqs,
505                    device_max_exp_avg_sqs,
506                )
507            else:
508                _view_as_real(
509                    device_params, device_grads, device_exp_avgs, device_exp_avg_sqs
510                )
511
512        if maximize:
513            device_grads = torch._foreach_neg(device_grads)  # type: ignore[assignment]
514
515        # Update steps
516        # If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over
517        # and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just
518        # wrapped it once now. The alpha is required to assure we go to the right overload.
519        if not torch._utils.is_compiling() and device_state_steps[0].is_cpu:
520            torch._foreach_add_(
521                device_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0
522            )
523        else:
524            torch._foreach_add_(device_state_steps, 1)
525
526        # Perform stepweight decay
527        if weight_decay != 0:
528            torch._foreach_mul_(device_params, 1 - lr * weight_decay)
529
530        # Decay the first and second moment running average coefficient
531        torch._foreach_lerp_(device_exp_avgs, device_grads, 1 - beta1)
532
533        torch._foreach_mul_(device_exp_avg_sqs, beta2)
534        torch._foreach_addcmul_(
535            device_exp_avg_sqs, device_grads, device_grads, 1 - beta2
536        )
537
538        # Delete the local intermediate since it won't be used anymore to save on peak memory
539        del device_grads
540
541        bias_correction1: Union[Tuple[Tensor, ...], List[Tensor]]
542        bias_correction2: Union[Tuple[Tensor, ...], List[Tensor]]
543        bias_correction2_sqrt: Union[Tuple[Tensor, ...], List[Tensor]]
544
545        if capturable:
546            bias_correction1 = torch._foreach_pow(beta1, device_state_steps)
547            bias_correction2 = torch._foreach_pow(beta2, device_state_steps)
548            # foreach_sub doesn't allow a scalar as the first arg
549            torch._foreach_sub_(bias_correction1, 1)
550            torch._foreach_sub_(bias_correction2, 1)
551            # we do not negate bias_correction1 as it'll need to be negated later anyway
552            torch._foreach_neg_(bias_correction2)
553
554            # foreach_div doesn't allow a scalar as the first arg
555            torch._foreach_div_(bias_correction1, lr)
556            torch._foreach_reciprocal_(bias_correction1)
557
558            torch._foreach_sqrt_(bias_correction2)
559
560            # Re-assign for clarity as we maintain minimal intermediates: we'll have
561            # step_size = - lr / (1 - beta1 ^ t) where t = num_steps
562            # bias_correction2_sqrt = sqrt(1 - beta2 ^ t)
563            step_size = bias_correction1
564            bias_correction2_sqrt = bias_correction2
565
566            if amsgrad:
567                device_max_exp_avg_sqs = cast(List[Tensor], device_max_exp_avg_sqs_)
568
569                # Maintains the maximum of all 2nd moment running avg. till now
570                torch._foreach_maximum_(device_max_exp_avg_sqs, device_exp_avg_sqs)
571
572                # Use the max. for normalizing running avg. of gradient
573                exp_avg_sq_sqrt = torch._foreach_sqrt(device_max_exp_avg_sqs)
574            else:
575                exp_avg_sq_sqrt = torch._foreach_sqrt(device_exp_avg_sqs)
576
577            torch._foreach_div_(exp_avg_sq_sqrt, bias_correction2_sqrt)
578            torch._foreach_add_(exp_avg_sq_sqrt, eps)
579            torch._foreach_div_(exp_avg_sq_sqrt, step_size)
580
581            # at this point, exp_avg_sq_sqrt = - (1 - beta^t) * [sqrt(exp_avg_sq / (1 - beta2^t)) + eps] / lr
582            torch._foreach_addcdiv_(device_params, device_exp_avgs, exp_avg_sq_sqrt)
583        else:
584            bias_correction1 = [
585                1 - beta1 ** _get_value(step) for step in device_state_steps
586            ]
587            bias_correction2 = [
588                1 - beta2 ** _get_value(step) for step in device_state_steps
589            ]
590
591            step_size = _stack_if_compiling([(lr / bc) * -1 for bc in bias_correction1])
592
593            bias_correction2_sqrt = [
594                bc**0.5 for bc in bias_correction2  # type: ignore[arg-type]
595            ]
596
597            if amsgrad:
598                device_max_exp_avg_sqs = cast(List[Tensor], device_max_exp_avg_sqs_)
599
600                # Maintains the maximum of all 2nd moment running avg. till now
601                torch._foreach_maximum_(device_max_exp_avg_sqs, device_exp_avg_sqs)
602
603                # Use the max. for normalizing running avg. of gradient
604                exp_avg_sq_sqrt = torch._foreach_sqrt(device_max_exp_avg_sqs)
605            else:
606                exp_avg_sq_sqrt = torch._foreach_sqrt(device_exp_avg_sqs)
607
608            torch._foreach_div_(exp_avg_sq_sqrt, bias_correction2_sqrt)
609            torch._foreach_add_(exp_avg_sq_sqrt, eps)
610            torch._foreach_addcdiv_(
611                device_params,
612                device_exp_avgs,
613                exp_avg_sq_sqrt,
614                step_size,  # type: ignore[arg-type]
615            )
616
617
618def _fused_adamw(
619    params: List[Tensor],
620    grads: List[Tensor],
621    exp_avgs: List[Tensor],
622    exp_avg_sqs: List[Tensor],
623    max_exp_avg_sqs: List[Tensor],
624    state_steps: List[Tensor],
625    grad_scale: Optional[Tensor],
626    found_inf: Optional[Tensor],
627    *,
628    amsgrad: bool,
629    beta1: float,
630    beta2: float,
631    lr: Union[Tensor, float],
632    weight_decay: float,
633    eps: float,
634    maximize: bool,
635    capturable: bool,  # Needed for consistency.
636    differentiable: bool,
637    has_complex: bool,  # Needed for consistency.
638) -> None:
639    if not params:
640        return
641    if differentiable:
642        raise RuntimeError("Adam with fused=True does not support differentiable=True")
643
644    grad_scale_dict: DeviceDict = (
645        {grad_scale.device: grad_scale} if grad_scale is not None else {}
646    )
647    found_inf_dict: DeviceDict = (
648        {found_inf.device: found_inf} if found_inf is not None else {}
649    )
650
651    # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer
652    # treating it as a scalar.
653    lr_dict: Optional[DeviceDict] = (
654        {lr.device: lr} if isinstance(lr, Tensor) and str(lr.device) != "cpu" else None
655    )
656
657    grouped_tensors = Optimizer._group_tensors_by_device_and_dtype(
658        [params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps]  # type: ignore[list-item]
659    )
660    for (device, _), (
661        (
662            device_params_,
663            device_grads_,
664            device_exp_avgs_,
665            device_exp_avg_sqs_,
666            device_max_exp_avg_sqs,
667            device_state_steps_,
668        ),
669        _,
670    ) in grouped_tensors.items():
671        device_params = cast(List[Tensor], device_params_)
672        device_grads = cast(List[Tensor], device_grads_)
673        device_exp_avgs = cast(List[Tensor], device_exp_avgs_)
674        device_exp_avg_sqs = cast(List[Tensor], device_exp_avg_sqs_)
675        device_state_steps = cast(List[Tensor], device_state_steps_)
676
677        if device.type == "mps":  # type: ignore[union-attr]
678            assert found_inf is None and grad_scale is None
679
680        device_grad_scale, device_found_inf = None, None
681        if grad_scale is not None:
682            device_grad_scale = grad_scale_dict.setdefault(
683                device, grad_scale.to(device, non_blocking=True)
684            )
685        if found_inf is not None:
686            device_found_inf = found_inf_dict.setdefault(
687                device, found_inf.to(device, non_blocking=True)
688            )
689        if lr_dict is not None and device not in lr_dict:
690            lr = lr_dict.setdefault(
691                device, lr.to(device=device, non_blocking=True)  # type: ignore[union-attr]
692            )
693        torch._foreach_add_(device_state_steps, 1)
694        torch._fused_adamw_(
695            device_params,
696            device_grads,
697            device_exp_avgs,
698            device_exp_avg_sqs,
699            device_max_exp_avg_sqs,  # type: ignore[arg-type]
700            device_state_steps,
701            amsgrad=amsgrad,
702            lr=lr,  # type: ignore[arg-type]
703            beta1=beta1,
704            beta2=beta2,
705            weight_decay=weight_decay,
706            eps=eps,
707            maximize=maximize,
708            grad_scale=device_grad_scale,
709            found_inf=device_found_inf,
710        )
711        if device_found_inf is not None:
712            torch._foreach_sub_(
713                device_state_steps, [device_found_inf] * len(device_state_steps)
714            )
715
716
717@_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_adamw)
718def adamw(
719    params: List[Tensor],
720    grads: List[Tensor],
721    exp_avgs: List[Tensor],
722    exp_avg_sqs: List[Tensor],
723    max_exp_avg_sqs: List[Tensor],
724    state_steps: List[Tensor],
725    # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
726    # setting this as kwarg for now as functional API is compiled by torch/distributed/optim
727    foreach: Optional[bool] = None,
728    capturable: bool = False,
729    differentiable: bool = False,
730    fused: Optional[bool] = None,
731    grad_scale: Optional[Tensor] = None,
732    found_inf: Optional[Tensor] = None,
733    has_complex: bool = False,
734    *,
735    amsgrad: bool,
736    beta1: float,
737    beta2: float,
738    lr: Union[float, Tensor],
739    weight_decay: float,
740    eps: float,
741    maximize: bool,
742):
743    r"""Functional API that performs AdamW algorithm computation.
744
745    See :class:`~torch.optim.AdamW` for details.
746    """
747    if not torch._utils.is_compiling() and not all(
748        isinstance(t, torch.Tensor) for t in state_steps
749    ):
750        raise RuntimeError(
751            "API has changed, `state_steps` argument must contain a list of singleton tensors"
752        )
753
754    # Respect when the user inputs False/True for foreach or fused. We only want to change
755    # the default when neither have been user-specified. Note that we default to foreach
756    # and pass False to use_fused. This is not a mistake--we want to give the fused impl
757    # bake-in time before making it the default, even if it is typically faster.
758    if fused is None and foreach is None:
759        _, foreach = _default_to_fused_or_foreach(
760            params, differentiable, use_fused=False
761        )
762        # Do not flip on foreach for the unsupported case where lr is a Tensor and capturable=False.
763        if foreach and isinstance(lr, Tensor) and not capturable:
764            foreach = False
765    if fused is None:
766        fused = False
767    if foreach is None:
768        foreach = False
769
770    if foreach and torch.jit.is_scripting():
771        raise RuntimeError("torch.jit.script not supported with foreach optimizers")
772    if fused and torch.jit.is_scripting():
773        raise RuntimeError("torch.jit.script not supported with fused optimizers")
774
775    if fused and not torch.jit.is_scripting():
776        func = _fused_adamw
777    elif foreach and not torch.jit.is_scripting():
778        func = _multi_tensor_adamw
779    else:
780        func = _single_tensor_adamw
781
782    func(
783        params,
784        grads,
785        exp_avgs,
786        exp_avg_sqs,
787        max_exp_avg_sqs,
788        state_steps,
789        amsgrad=amsgrad,
790        beta1=beta1,
791        beta2=beta2,
792        lr=lr,
793        weight_decay=weight_decay,
794        eps=eps,
795        maximize=maximize,
796        capturable=capturable,
797        differentiable=differentiable,
798        grad_scale=grad_scale,
799        found_inf=found_inf,
800        has_complex=has_complex,
801    )
802