xref: /aosp_15_r20/external/pytorch/torch/optim/nadam.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-decorators
2# mypy: allow-untyped-defs
3r"""Implementation for the NAdam algorithm."""
4from typing import cast, List, Optional, Tuple, Union
5
6import torch
7from torch import Tensor
8
9from .optimizer import (
10    _capturable_doc,
11    _default_to_fused_or_foreach,
12    _differentiable_doc,
13    _disable_dynamo_if_unsupported,
14    _foreach_doc,
15    _get_capturable_supported_devices,
16    _get_scalar_dtype,
17    _get_value,
18    _maximize_doc,
19    _stack_if_compiling,
20    _use_grad_for_differentiable,
21    _view_as_real,
22    Optimizer,
23    ParamsT,
24)
25
26
27__all__ = ["NAdam", "nadam"]
28
29
30class NAdam(Optimizer):  # noqa: D101
31    def __init__(
32        self,
33        params: ParamsT,
34        lr: Union[float, Tensor] = 2e-3,
35        betas: Tuple[float, float] = (0.9, 0.999),
36        eps: float = 1e-8,
37        weight_decay: float = 0,
38        momentum_decay: float = 4e-3,
39        decoupled_weight_decay: bool = False,
40        *,
41        foreach: Optional[bool] = None,
42        maximize: bool = False,
43        capturable: bool = False,
44        differentiable: bool = False,
45    ):  # noqa: D107
46        if isinstance(lr, Tensor) and lr.numel() != 1:
47            raise ValueError("Tensor lr must be 1-element")
48        if not 0.0 <= lr:
49            raise ValueError(f"Invalid learning rate: {lr}")
50        if not 0.0 <= eps:
51            raise ValueError(f"Invalid epsilon value: {eps}")
52        if not 0.0 <= betas[0] < 1.0:
53            raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
54        if not 0.0 <= betas[1] < 1.0:
55            raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
56        if not 0.0 <= weight_decay:
57            raise ValueError(f"Invalid weight_decay value: {weight_decay}")
58        if not 0.0 <= momentum_decay:
59            raise ValueError(f"Invalid momentum_decay value: {momentum_decay}")
60        defaults = dict(
61            lr=lr,
62            betas=betas,
63            eps=eps,
64            weight_decay=weight_decay,
65            momentum_decay=momentum_decay,
66            decoupled_weight_decay=decoupled_weight_decay,
67            maximize=maximize,
68            foreach=foreach,
69            capturable=capturable,
70            differentiable=differentiable,
71        )
72        super().__init__(params, defaults)
73
74    def __setstate__(self, state):  # noqa: D105
75        super().__setstate__(state)
76        for group in self.param_groups:
77            group.setdefault("maximize", False)
78            group.setdefault("foreach", None)
79            group.setdefault("capturable", False)
80            group.setdefault("differentiable", False)
81            group.setdefault("decoupled_weight_decay", False)
82            for p in group["params"]:
83                p_state = self.state.get(p, [])
84                if len(p_state) != 0:
85                    if not torch.is_tensor(p_state["step"]):
86                        step_val = float(p_state["step"])
87                        p_state["step"] = (
88                            torch.tensor(
89                                step_val, dtype=_get_scalar_dtype(), device=p.device
90                            )
91                            if group["capturable"]
92                            else torch.tensor(step_val, dtype=_get_scalar_dtype())
93                        )
94                    if not torch.is_tensor(p_state["mu_product"]):
95                        mu_prod_val = p_state["mu_product"]
96                        p_state["mu_product"] = (
97                            torch.tensor(
98                                mu_prod_val, dtype=_get_scalar_dtype(), device=p.device
99                            )
100                            if group["capturable"]
101                            else torch.tensor(mu_prod_val, dtype=_get_scalar_dtype())
102                        )
103
104    def _init_group(
105        self,
106        group,
107        params_with_grad,
108        grads,
109        exp_avgs,
110        exp_avg_sqs,
111        mu_products,
112        state_steps,
113    ):
114        has_complex = False
115        for p in group["params"]:
116            if p.grad is not None:
117                has_complex |= torch.is_complex(p)
118                params_with_grad.append(p)
119                if p.grad.is_sparse:
120                    raise RuntimeError("NAdam does not support sparse gradients")
121                grads.append(p.grad)
122
123                state = self.state[p]
124                # Lazy state initialization
125                if len(state) == 0:
126                    # note(crcrpar): [special device hosting for step]
127                    # Deliberately host `step` and `mu_product` on CPU if capturable is False.
128                    # This is because kernel launches are costly on CUDA and XLA.
129                    state["step"] = (
130                        torch.zeros((), dtype=_get_scalar_dtype(), device=p.device)
131                        if group["capturable"]
132                        else torch.tensor(0.0, dtype=_get_scalar_dtype())
133                    )
134                    state["mu_product"] = (
135                        torch.ones((), dtype=_get_scalar_dtype(), device=p.device)
136                        if group["capturable"]
137                        else torch.tensor(1.0, dtype=_get_scalar_dtype())
138                    )
139                    # Exponential moving average of gradient values
140                    state["exp_avg"] = torch.zeros_like(
141                        p, memory_format=torch.preserve_format
142                    )
143                    # Exponential moving average of squared gradient values
144                    state["exp_avg_sq"] = torch.zeros_like(
145                        p, memory_format=torch.preserve_format
146                    )
147
148                exp_avgs.append(state["exp_avg"])
149                exp_avg_sqs.append(state["exp_avg_sq"])
150                mu_products.append(state["mu_product"])
151                state_steps.append(state["step"])
152        return has_complex
153
154    @_use_grad_for_differentiable
155    def step(self, closure=None):
156        """Perform a single optimization step.
157
158        Args:
159            closure (Callable, optional): A closure that reevaluates the model
160                and returns the loss.
161        """
162        self._cuda_graph_capture_health_check()
163
164        loss = None
165        if closure is not None:
166            with torch.enable_grad():
167                loss = closure()
168
169        for group in self.param_groups:
170            params_with_grad: List[Tensor] = []
171            grads: List[Tensor] = []
172            exp_avgs: List[Tensor] = []
173            exp_avg_sqs: List[Tensor] = []
174            mu_products: List[Tensor] = []
175            state_steps: List[Tensor] = []
176            beta1, beta2 = cast(Tuple[float, float], group["betas"])
177
178            has_complex = self._init_group(
179                group,
180                params_with_grad,
181                grads,
182                exp_avgs,
183                exp_avg_sqs,
184                mu_products,
185                state_steps,
186            )
187
188            nadam(
189                params_with_grad,
190                grads,
191                exp_avgs,
192                exp_avg_sqs,
193                mu_products,
194                state_steps,
195                beta1=beta1,
196                beta2=beta2,
197                lr=group["lr"],
198                weight_decay=group["weight_decay"],
199                momentum_decay=group["momentum_decay"],
200                eps=group["eps"],
201                maximize=group["maximize"],
202                decoupled_weight_decay=group["decoupled_weight_decay"],
203                foreach=group["foreach"],
204                capturable=group["capturable"],
205                differentiable=group["differentiable"],
206                has_complex=has_complex,
207            )
208
209        return loss
210
211
212NAdam.__doc__ = (
213    r"""Implements NAdam algorithm.
214
215    .. math::
216       \begin{aligned}
217            &\rule{110mm}{0.4pt}                                                                 \\
218            &\textbf{input}      : \gamma_t \text{ (lr)}, \: \beta_1,\beta_2 \text{ (betas)},
219                \: \theta_0 \text{ (params)}, \: f(\theta) \text{ (objective)}                   \\
220            &\hspace{13mm} \: \lambda \text{ (weight decay)}, \:\psi \text{ (momentum decay)}    \\
221            &\hspace{13mm} \: \textit{decoupled\_weight\_decay}, \:\textit{maximize}             \\
222            &\textbf{initialize} :  m_0 \leftarrow 0 \text{ ( first moment)},
223                v_0 \leftarrow 0 \text{ ( second moment)}                                 \\[-1.ex]
224            &\rule{110mm}{0.4pt}                                                                 \\
225            &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do}                         \\
226            &\hspace{5mm}\textbf{if} \: \textit{maximize}:                                       \\
227            &\hspace{10mm}g_t           \leftarrow   -\nabla_{\theta} f_t (\theta_{t-1})         \\
228            &\hspace{5mm}\textbf{else}                                                           \\
229            &\hspace{10mm}g_t           \leftarrow   \nabla_{\theta} f_t (\theta_{t-1})          \\
230            &\hspace{5mm} \theta_t \leftarrow \theta_{t-1}                                       \\
231            &\hspace{5mm} \textbf{if} \: \lambda \neq 0                                          \\
232            &\hspace{10mm}\textbf{if} \: \textit{decoupled\_weight\_decay}                       \\
233            &\hspace{15mm} \theta_t \leftarrow \theta_{t-1} - \gamma \lambda \theta_{t-1}                    \\
234            &\hspace{10mm}\textbf{else}                                                          \\
235            &\hspace{15mm} g_t \leftarrow g_t + \lambda \theta_{t-1}                             \\
236            &\hspace{5mm} \mu_t \leftarrow \beta_1 \big(1 - \frac{1}{2}  0.96^{t \psi} \big)     \\
237            &\hspace{5mm} \mu_{t+1} \leftarrow \beta_1 \big(1 - \frac{1}{2} 0.96^{(t+1)\psi}\big)\\
238            &\hspace{5mm}m_t           \leftarrow   \beta_1 m_{t-1} + (1 - \beta_1) g_t          \\
239            &\hspace{5mm}v_t           \leftarrow   \beta_2 v_{t-1} + (1-\beta_2) g^2_t          \\
240            &\hspace{5mm}\widehat{m_t} \leftarrow \mu_{t+1} m_t/(1-\prod_{i=1}^{t+1}\mu_i)\\[-1.ex]
241            & \hspace{11mm} + (1-\mu_t) g_t /(1-\prod_{i=1}^{t} \mu_{i})                         \\
242            &\hspace{5mm}\widehat{v_t} \leftarrow   v_t/\big(1-\beta_2^t \big)                   \\
243            &\hspace{5mm}\theta_t \leftarrow \theta_t - \gamma \widehat{m_t}/
244                \big(\sqrt{\widehat{v_t}} + \epsilon \big)                                       \\
245            &\rule{110mm}{0.4pt}                                                          \\[-1.ex]
246            &\bf{return} \:  \theta_t                                                     \\[-1.ex]
247            &\rule{110mm}{0.4pt}                                                          \\[-1.ex]
248       \end{aligned}
249
250    For further details regarding the algorithm we refer to `Incorporating Nesterov Momentum into Adam`_.
251    """
252    + rf"""
253    Args:
254        params (iterable): iterable of parameters to optimize or dicts defining
255            parameter groups
256        lr (float, Tensor, optional): learning rate (default: 2e-3)
257        betas (Tuple[float, float], optional): coefficients used for computing
258            running averages of gradient and its square (default: (0.9, 0.999))
259        eps (float, optional): term added to the denominator to improve
260            numerical stability (default: 1e-8)
261        weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
262        momentum_decay (float, optional): momentum momentum_decay (default: 4e-3)
263        decoupled_weight_decay (bool, optional): whether to use decoupled weight
264            decay as in AdamW to obtain NAdamW (default: False)
265        {_foreach_doc}
266        {_maximize_doc}
267        {_capturable_doc}
268        {_differentiable_doc}
269
270    .. _Incorporating Nesterov Momentum into Adam:
271        https://openreview.net/forum?id=OM0jvwB8jIp57ZJjtNEZ
272    .. _Decoupled Weight Decay Regularization:
273        https://arxiv.org/abs/1711.05101
274
275    """
276)
277
278
279def _single_tensor_nadam(
280    params: List[Tensor],
281    grads: List[Tensor],
282    exp_avgs: List[Tensor],
283    exp_avg_sqs: List[Tensor],
284    mu_products: List[Tensor],
285    state_steps: List[Tensor],
286    *,
287    beta1: float,
288    beta2: float,
289    lr: float,
290    weight_decay: float,
291    momentum_decay: float,
292    eps: float,
293    decoupled_weight_decay: bool,
294    maximize: bool,
295    capturable: bool,
296    differentiable: bool,
297    has_complex: bool,
298):
299    for i, param in enumerate(params):
300        grad = grads[i] if not maximize else -grads[i]
301        exp_avg = exp_avgs[i]
302        exp_avg_sq = exp_avg_sqs[i]
303        mu_product = mu_products[i]
304        step_t = state_steps[i]
305
306        if torch.is_complex(param):
307            param = torch.view_as_real(param)
308            grad = torch.view_as_real(grad)
309            exp_avg = torch.view_as_real(exp_avg)
310            exp_avg_sq = torch.view_as_real(exp_avg_sq)
311
312        # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
313        if not torch._utils.is_compiling() and capturable:
314            capturable_supported_devices = _get_capturable_supported_devices()
315            assert (
316                param.device.type == mu_product.device.type == step_t.device.type
317                and param.device.type in capturable_supported_devices
318            ), (
319                f"If capturable=True, params, mu_products and state_steps must be "
320                f"on supported devices: {capturable_supported_devices}."
321            )
322
323        # update step
324        step_t += 1
325
326        if capturable:
327            step = step_t
328        else:
329            step = _get_value(step_t)
330
331        bias_correction2 = 1 - beta2**step
332
333        if weight_decay != 0:
334            if decoupled_weight_decay:
335                # Perform stepweight decay
336                param.mul_(1 - lr * weight_decay)
337            else:
338                grad = grad.add(param, alpha=weight_decay)
339
340        # calculate the momentum cache \mu^{t} and \mu^{t+1}
341        mu = beta1 * (1.0 - 0.5 * (0.96 ** (step * momentum_decay)))
342        mu_next = beta1 * (1.0 - 0.5 * (0.96 ** ((step + 1) * momentum_decay)))
343
344        # update mu_product
345        mu_product *= mu
346
347        # decay the first and second moment running average coefficient
348        exp_avg.lerp_(grad, 1 - beta1)
349        exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
350        denom = exp_avg_sq.div(bias_correction2).sqrt()
351
352        if differentiable or capturable:
353            denom = denom.add(eps)
354            # Make autograd track the operations
355            # by updating the grad and exp_avg directly and not using the
356            # scalar "value" argument of addcdiv.
357            mu_product_next = mu_product * mu_next
358            grad = grad * (-lr * (1.0 - mu) / (1.0 - mu_product))
359            exp_avg = exp_avg * (-lr * mu_next / (1.0 - mu_product_next))
360            param.addcdiv_(grad, denom)
361            param.addcdiv_(exp_avg, denom)
362        else:
363            mu_product_next = _get_value(mu_product) * mu_next
364            denom.add_(eps)
365            param.addcdiv_(
366                grad, denom, value=(-lr * (1.0 - mu) / (1.0 - _get_value(mu_product)))
367            )
368            param.addcdiv_(
369                exp_avg, denom, value=(-lr * mu_next) / (1.0 - mu_product_next)
370            )
371
372
373def _multi_tensor_nadam(
374    params: List[Tensor],
375    grads: List[Tensor],
376    exp_avgs: List[Tensor],
377    exp_avg_sqs: List[Tensor],
378    mu_products: List[Tensor],
379    state_steps: List[Tensor],
380    *,
381    beta1: float,
382    beta2: float,
383    lr: float,
384    weight_decay: float,
385    momentum_decay: float,
386    eps: float,
387    decoupled_weight_decay: bool,
388    maximize: bool,
389    capturable: bool,
390    differentiable: bool,
391    has_complex: bool,
392):
393    if len(params) == 0:
394        return
395
396    assert not differentiable, "_foreach ops don't support autograd"
397
398    # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
399    if not torch._utils.is_compiling() and capturable:
400        capturable_supported_devices = _get_capturable_supported_devices(
401            supports_xla=False
402        )
403        assert all(
404            p.device.type == mp.device.type == step.device.type
405            and p.device.type in capturable_supported_devices
406            for p, mp, step in zip(params, mu_products, state_steps)
407        ), f"If capturable=True, params, mu_products, and state_steps must be on supported devices: {capturable_supported_devices}."
408
409    grouped_tensors = Optimizer._group_tensors_by_device_and_dtype(
410        [params, grads, exp_avgs, exp_avg_sqs, mu_products, state_steps]  # type: ignore[list-item]
411    )
412    for (
413        grouped_params_,
414        grouped_grads_,
415        grouped_exp_avgs_,
416        grouped_exp_avg_sqs_,
417        grouped_mu_products_,
418        grouped_state_steps_,
419    ), _ in grouped_tensors.values():
420        grouped_params = cast(List[Tensor], grouped_params_)
421        grouped_grads = cast(List[Tensor], grouped_grads_)
422        grouped_exp_avgs = cast(List[Tensor], grouped_exp_avgs_)
423        grouped_exp_avg_sqs = cast(List[Tensor], grouped_exp_avg_sqs_)
424        grouped_mu_products = cast(List[Tensor], grouped_mu_products_)
425        grouped_state_steps = cast(List[Tensor], grouped_state_steps_)
426
427        # handle complex
428        if has_complex:
429            _view_as_real(
430                grouped_params, grouped_grads, grouped_exp_avgs, grouped_exp_avg_sqs
431            )
432
433        if maximize:
434            grouped_grads = torch._foreach_neg(grouped_grads)  # type: ignore[assignment]
435
436        # Update steps
437        # If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over
438        # and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just
439        # wrapped it once now. The alpha is required to assure we go to the right overload.
440        if not torch._utils.is_compiling() and grouped_state_steps[0].is_cpu:
441            torch._foreach_add_(
442                grouped_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0
443            )
444        else:
445            torch._foreach_add_(grouped_state_steps, 1)
446
447        if weight_decay != 0:
448            if decoupled_weight_decay:
449                # Perform stepweight decay
450                torch._foreach_mul_(grouped_params, 1 - lr * weight_decay)
451            else:
452                # Re-use the intermediate memory (grouped_grads) already allocated for maximize
453                if maximize:
454                    torch._foreach_add_(
455                        grouped_grads, grouped_params, alpha=weight_decay
456                    )
457                else:
458                    grouped_grads = torch._foreach_add(  # type: ignore[assignment]
459                        grouped_grads, grouped_params, alpha=weight_decay
460                    )
461
462        # Decay the first and second moment running average coefficient
463        torch._foreach_lerp_(grouped_exp_avgs, grouped_grads, 1 - beta1)
464
465        torch._foreach_mul_(grouped_exp_avg_sqs, beta2)
466        torch._foreach_addcmul_(
467            grouped_exp_avg_sqs, grouped_grads, grouped_grads, 1 - beta2
468        )
469
470        exp_avg_sq_sqrt = torch._foreach_sqrt(grouped_exp_avg_sqs)
471
472        bias_correction_sqrt: Union[Tuple[Tensor, ...], List[Tensor]]
473        mus: Union[Tuple[Tensor, ...], List[Tensor]]
474        mu_nexts: Union[Tuple[Tensor, ...], List[Tensor]]
475        if capturable:
476            # mus will be beta1 * (1 - 0.5 * 0.96 ** (step * momentum_decay))
477            exponent = torch._foreach_mul(grouped_state_steps, momentum_decay)
478            mus = torch._foreach_pow(0.96, exponent)
479            torch._foreach_mul_(mus, -0.5)
480            torch._foreach_add_(mus, 1.0)
481            torch._foreach_mul_(mus, beta1)
482
483            # mu_nexts will be beta1 * (1 - 0.5 * 0.96 ** ((step + 1) * momentum_decay))
484            torch._foreach_add_(exponent, momentum_decay)
485            mu_nexts = torch._foreach_pow(0.96, exponent)
486            torch._foreach_mul_(mu_nexts, -0.5)
487            torch._foreach_add_(mu_nexts, 1.0)
488            torch._foreach_mul_(mu_nexts, beta1)
489
490            # save peak memory as we don't need exponent anymore
491            del exponent
492
493            bias_correction_sqrt = torch._foreach_pow(beta2, grouped_state_steps)
494            # foreach_sub doesn't allow a scalar as the first arg
495            torch._foreach_sub_(bias_correction_sqrt, 1.0)
496            torch._foreach_neg_(bias_correction_sqrt)
497            torch._foreach_sqrt_(bias_correction_sqrt)
498        else:
499            bias_correction_sqrt = [
500                (1 - beta2 ** _get_value(step)) ** 0.5 for step in grouped_state_steps
501            ]
502            mus = [
503                beta1 * (1.0 - 0.5 * (0.96 ** (_get_value(step) * momentum_decay)))
504                for step in grouped_state_steps
505            ]
506            mu_nexts = [
507                beta1
508                * (1.0 - 0.5 * (0.96 ** ((_get_value(step) + 1) * momentum_decay)))
509                for step in grouped_state_steps
510            ]
511
512        # update mu_products
513        torch._foreach_mul_(grouped_mu_products, mus)
514
515        torch._foreach_div_(exp_avg_sq_sqrt, bias_correction_sqrt)
516        torch._foreach_add_(exp_avg_sq_sqrt, eps)
517
518        # explicitly delete bias_correction refs to save memory
519        del bias_correction_sqrt
520
521        if capturable:
522            # Build up the step_size multiplier for grad, reusing mus' memory
523            torch._foreach_sub_(mus, 1.0)
524            torch._foreach_mul_(mus, lr)
525            # foreach_sub doesn't allow a scalar as the first arg
526            denom = torch._foreach_sub(grouped_mu_products, 1.0)
527            torch._foreach_neg_(denom)
528            torch._foreach_div_(mus, denom)
529            # - lr * (1 - mu) / (1 - mu_product)
530            step_size_grads = mus
531            # explicitly delete denom to save memory
532            del denom
533
534            # Build up the step_size multiplier for exp_avg, reusing mu_nexts' memory
535            denom = torch._foreach_mul(grouped_mu_products, mu_nexts)
536            torch._foreach_mul_(mu_nexts, lr)
537            # foreach_sub doesn't allow a scalar as the first arg, but it's okay because
538            # we need a negative here anyway
539            torch._foreach_sub_(denom, 1.0)
540            torch._foreach_div_(mu_nexts, denom)
541            # - lr * mu_next / (1 - mu_product * mu_next)
542            step_size_expavg = mu_nexts
543            # explicitly delete denom to save memory
544            del denom
545
546            # we cannot inplace into step_size_grads cuz it is a list of ScalarTensors
547            # and mul'ing with grouped_grads will result in a list of bigger Tensors
548            numerator = torch._foreach_mul(step_size_grads, grouped_grads)
549            torch._foreach_addcmul_(numerator, step_size_expavg, grouped_exp_avgs)
550
551            # finally, update params
552            torch._foreach_addcdiv_(grouped_params, numerator, exp_avg_sq_sqrt)
553        else:
554            step_size_grads = _stack_if_compiling(
555                [
556                    (_get_value(lr) * (1.0 - mu) / (1.0 - _get_value(mu_product))) * -1
557                    for mu_product, mu in zip(grouped_mu_products, mus)
558                ]
559            )
560            step_size_expavg = _stack_if_compiling(
561                [
562                    (
563                        _get_value(lr)
564                        * mu_next
565                        / (1.0 - _get_value(mu_product) * mu_next)
566                    )
567                    * -1
568                    for mu_product, mu_next in zip(grouped_mu_products, mu_nexts)
569                ]
570            )
571
572            torch._foreach_addcdiv_(
573                grouped_params, grouped_grads, exp_avg_sq_sqrt, step_size_grads  # type: ignore[arg-type]
574            )
575            torch._foreach_addcdiv_(
576                grouped_params, grouped_exp_avgs, exp_avg_sq_sqrt, step_size_expavg  # type: ignore[arg-type]
577            )
578
579
580@_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_nadam)
581def nadam(
582    params: List[Tensor],
583    grads: List[Tensor],
584    exp_avgs: List[Tensor],
585    exp_avg_sqs: List[Tensor],
586    mu_products: List[Tensor],
587    state_steps: List[Tensor],
588    # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
589    # setting this as kwarg for now as functional API is compiled by torch/distributed/optim
590    decoupled_weight_decay: bool = False,
591    foreach: Optional[bool] = None,
592    capturable: bool = False,
593    differentiable: bool = False,
594    has_complex: bool = False,
595    maximize: bool = False,
596    *,
597    beta1: float,
598    beta2: float,
599    lr: float,
600    weight_decay: float,
601    momentum_decay: float,
602    eps: float,
603):
604    r"""Functional API that performs NAdam algorithm computation.
605
606    See :class:`~torch.optim.NAdam` for details.
607    """
608    if not all(isinstance(t, torch.Tensor) for t in state_steps):
609        raise RuntimeError(
610            "API has changed, `state_steps` argument must contain a list of singleton tensors"
611        )
612
613    if not all(isinstance(t, torch.Tensor) for t in mu_products):
614        raise RuntimeError(
615            "API has changed, `mu_products` argument must contain a list of singleton tensors"
616        )
617
618    if foreach is None:
619        _, foreach = _default_to_fused_or_foreach(
620            params, differentiable, use_fused=False
621        )
622
623    if foreach and torch.jit.is_scripting():
624        raise RuntimeError("torch.jit.script not supported with foreach optimizers")
625
626    if foreach and not torch.jit.is_scripting():
627        func = _multi_tensor_nadam
628    else:
629        func = _single_tensor_nadam
630
631    func(
632        params,
633        grads,
634        exp_avgs,
635        exp_avg_sqs,
636        mu_products,
637        state_steps,
638        beta1=beta1,
639        beta2=beta2,
640        lr=lr,
641        weight_decay=weight_decay,
642        momentum_decay=momentum_decay,
643        maximize=maximize,
644        decoupled_weight_decay=decoupled_weight_decay,
645        eps=eps,
646        capturable=capturable,
647        differentiable=differentiable,
648        has_complex=has_complex,
649    )
650