1# mypy: allow-untyped-decorators 2# mypy: allow-untyped-defs 3r"""Implementation for the NAdam algorithm.""" 4from typing import cast, List, Optional, Tuple, Union 5 6import torch 7from torch import Tensor 8 9from .optimizer import ( 10 _capturable_doc, 11 _default_to_fused_or_foreach, 12 _differentiable_doc, 13 _disable_dynamo_if_unsupported, 14 _foreach_doc, 15 _get_capturable_supported_devices, 16 _get_scalar_dtype, 17 _get_value, 18 _maximize_doc, 19 _stack_if_compiling, 20 _use_grad_for_differentiable, 21 _view_as_real, 22 Optimizer, 23 ParamsT, 24) 25 26 27__all__ = ["NAdam", "nadam"] 28 29 30class NAdam(Optimizer): # noqa: D101 31 def __init__( 32 self, 33 params: ParamsT, 34 lr: Union[float, Tensor] = 2e-3, 35 betas: Tuple[float, float] = (0.9, 0.999), 36 eps: float = 1e-8, 37 weight_decay: float = 0, 38 momentum_decay: float = 4e-3, 39 decoupled_weight_decay: bool = False, 40 *, 41 foreach: Optional[bool] = None, 42 maximize: bool = False, 43 capturable: bool = False, 44 differentiable: bool = False, 45 ): # noqa: D107 46 if isinstance(lr, Tensor) and lr.numel() != 1: 47 raise ValueError("Tensor lr must be 1-element") 48 if not 0.0 <= lr: 49 raise ValueError(f"Invalid learning rate: {lr}") 50 if not 0.0 <= eps: 51 raise ValueError(f"Invalid epsilon value: {eps}") 52 if not 0.0 <= betas[0] < 1.0: 53 raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}") 54 if not 0.0 <= betas[1] < 1.0: 55 raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}") 56 if not 0.0 <= weight_decay: 57 raise ValueError(f"Invalid weight_decay value: {weight_decay}") 58 if not 0.0 <= momentum_decay: 59 raise ValueError(f"Invalid momentum_decay value: {momentum_decay}") 60 defaults = dict( 61 lr=lr, 62 betas=betas, 63 eps=eps, 64 weight_decay=weight_decay, 65 momentum_decay=momentum_decay, 66 decoupled_weight_decay=decoupled_weight_decay, 67 maximize=maximize, 68 foreach=foreach, 69 capturable=capturable, 70 differentiable=differentiable, 71 ) 72 super().__init__(params, defaults) 73 74 def __setstate__(self, state): # noqa: D105 75 super().__setstate__(state) 76 for group in self.param_groups: 77 group.setdefault("maximize", False) 78 group.setdefault("foreach", None) 79 group.setdefault("capturable", False) 80 group.setdefault("differentiable", False) 81 group.setdefault("decoupled_weight_decay", False) 82 for p in group["params"]: 83 p_state = self.state.get(p, []) 84 if len(p_state) != 0: 85 if not torch.is_tensor(p_state["step"]): 86 step_val = float(p_state["step"]) 87 p_state["step"] = ( 88 torch.tensor( 89 step_val, dtype=_get_scalar_dtype(), device=p.device 90 ) 91 if group["capturable"] 92 else torch.tensor(step_val, dtype=_get_scalar_dtype()) 93 ) 94 if not torch.is_tensor(p_state["mu_product"]): 95 mu_prod_val = p_state["mu_product"] 96 p_state["mu_product"] = ( 97 torch.tensor( 98 mu_prod_val, dtype=_get_scalar_dtype(), device=p.device 99 ) 100 if group["capturable"] 101 else torch.tensor(mu_prod_val, dtype=_get_scalar_dtype()) 102 ) 103 104 def _init_group( 105 self, 106 group, 107 params_with_grad, 108 grads, 109 exp_avgs, 110 exp_avg_sqs, 111 mu_products, 112 state_steps, 113 ): 114 has_complex = False 115 for p in group["params"]: 116 if p.grad is not None: 117 has_complex |= torch.is_complex(p) 118 params_with_grad.append(p) 119 if p.grad.is_sparse: 120 raise RuntimeError("NAdam does not support sparse gradients") 121 grads.append(p.grad) 122 123 state = self.state[p] 124 # Lazy state initialization 125 if len(state) == 0: 126 # note(crcrpar): [special device hosting for step] 127 # Deliberately host `step` and `mu_product` on CPU if capturable is False. 128 # This is because kernel launches are costly on CUDA and XLA. 129 state["step"] = ( 130 torch.zeros((), dtype=_get_scalar_dtype(), device=p.device) 131 if group["capturable"] 132 else torch.tensor(0.0, dtype=_get_scalar_dtype()) 133 ) 134 state["mu_product"] = ( 135 torch.ones((), dtype=_get_scalar_dtype(), device=p.device) 136 if group["capturable"] 137 else torch.tensor(1.0, dtype=_get_scalar_dtype()) 138 ) 139 # Exponential moving average of gradient values 140 state["exp_avg"] = torch.zeros_like( 141 p, memory_format=torch.preserve_format 142 ) 143 # Exponential moving average of squared gradient values 144 state["exp_avg_sq"] = torch.zeros_like( 145 p, memory_format=torch.preserve_format 146 ) 147 148 exp_avgs.append(state["exp_avg"]) 149 exp_avg_sqs.append(state["exp_avg_sq"]) 150 mu_products.append(state["mu_product"]) 151 state_steps.append(state["step"]) 152 return has_complex 153 154 @_use_grad_for_differentiable 155 def step(self, closure=None): 156 """Perform a single optimization step. 157 158 Args: 159 closure (Callable, optional): A closure that reevaluates the model 160 and returns the loss. 161 """ 162 self._cuda_graph_capture_health_check() 163 164 loss = None 165 if closure is not None: 166 with torch.enable_grad(): 167 loss = closure() 168 169 for group in self.param_groups: 170 params_with_grad: List[Tensor] = [] 171 grads: List[Tensor] = [] 172 exp_avgs: List[Tensor] = [] 173 exp_avg_sqs: List[Tensor] = [] 174 mu_products: List[Tensor] = [] 175 state_steps: List[Tensor] = [] 176 beta1, beta2 = cast(Tuple[float, float], group["betas"]) 177 178 has_complex = self._init_group( 179 group, 180 params_with_grad, 181 grads, 182 exp_avgs, 183 exp_avg_sqs, 184 mu_products, 185 state_steps, 186 ) 187 188 nadam( 189 params_with_grad, 190 grads, 191 exp_avgs, 192 exp_avg_sqs, 193 mu_products, 194 state_steps, 195 beta1=beta1, 196 beta2=beta2, 197 lr=group["lr"], 198 weight_decay=group["weight_decay"], 199 momentum_decay=group["momentum_decay"], 200 eps=group["eps"], 201 maximize=group["maximize"], 202 decoupled_weight_decay=group["decoupled_weight_decay"], 203 foreach=group["foreach"], 204 capturable=group["capturable"], 205 differentiable=group["differentiable"], 206 has_complex=has_complex, 207 ) 208 209 return loss 210 211 212NAdam.__doc__ = ( 213 r"""Implements NAdam algorithm. 214 215 .. math:: 216 \begin{aligned} 217 &\rule{110mm}{0.4pt} \\ 218 &\textbf{input} : \gamma_t \text{ (lr)}, \: \beta_1,\beta_2 \text{ (betas)}, 219 \: \theta_0 \text{ (params)}, \: f(\theta) \text{ (objective)} \\ 220 &\hspace{13mm} \: \lambda \text{ (weight decay)}, \:\psi \text{ (momentum decay)} \\ 221 &\hspace{13mm} \: \textit{decoupled\_weight\_decay}, \:\textit{maximize} \\ 222 &\textbf{initialize} : m_0 \leftarrow 0 \text{ ( first moment)}, 223 v_0 \leftarrow 0 \text{ ( second moment)} \\[-1.ex] 224 &\rule{110mm}{0.4pt} \\ 225 &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\ 226 &\hspace{5mm}\textbf{if} \: \textit{maximize}: \\ 227 &\hspace{10mm}g_t \leftarrow -\nabla_{\theta} f_t (\theta_{t-1}) \\ 228 &\hspace{5mm}\textbf{else} \\ 229 &\hspace{10mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\ 230 &\hspace{5mm} \theta_t \leftarrow \theta_{t-1} \\ 231 &\hspace{5mm} \textbf{if} \: \lambda \neq 0 \\ 232 &\hspace{10mm}\textbf{if} \: \textit{decoupled\_weight\_decay} \\ 233 &\hspace{15mm} \theta_t \leftarrow \theta_{t-1} - \gamma \lambda \theta_{t-1} \\ 234 &\hspace{10mm}\textbf{else} \\ 235 &\hspace{15mm} g_t \leftarrow g_t + \lambda \theta_{t-1} \\ 236 &\hspace{5mm} \mu_t \leftarrow \beta_1 \big(1 - \frac{1}{2} 0.96^{t \psi} \big) \\ 237 &\hspace{5mm} \mu_{t+1} \leftarrow \beta_1 \big(1 - \frac{1}{2} 0.96^{(t+1)\psi}\big)\\ 238 &\hspace{5mm}m_t \leftarrow \beta_1 m_{t-1} + (1 - \beta_1) g_t \\ 239 &\hspace{5mm}v_t \leftarrow \beta_2 v_{t-1} + (1-\beta_2) g^2_t \\ 240 &\hspace{5mm}\widehat{m_t} \leftarrow \mu_{t+1} m_t/(1-\prod_{i=1}^{t+1}\mu_i)\\[-1.ex] 241 & \hspace{11mm} + (1-\mu_t) g_t /(1-\prod_{i=1}^{t} \mu_{i}) \\ 242 &\hspace{5mm}\widehat{v_t} \leftarrow v_t/\big(1-\beta_2^t \big) \\ 243 &\hspace{5mm}\theta_t \leftarrow \theta_t - \gamma \widehat{m_t}/ 244 \big(\sqrt{\widehat{v_t}} + \epsilon \big) \\ 245 &\rule{110mm}{0.4pt} \\[-1.ex] 246 &\bf{return} \: \theta_t \\[-1.ex] 247 &\rule{110mm}{0.4pt} \\[-1.ex] 248 \end{aligned} 249 250 For further details regarding the algorithm we refer to `Incorporating Nesterov Momentum into Adam`_. 251 """ 252 + rf""" 253 Args: 254 params (iterable): iterable of parameters to optimize or dicts defining 255 parameter groups 256 lr (float, Tensor, optional): learning rate (default: 2e-3) 257 betas (Tuple[float, float], optional): coefficients used for computing 258 running averages of gradient and its square (default: (0.9, 0.999)) 259 eps (float, optional): term added to the denominator to improve 260 numerical stability (default: 1e-8) 261 weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 262 momentum_decay (float, optional): momentum momentum_decay (default: 4e-3) 263 decoupled_weight_decay (bool, optional): whether to use decoupled weight 264 decay as in AdamW to obtain NAdamW (default: False) 265 {_foreach_doc} 266 {_maximize_doc} 267 {_capturable_doc} 268 {_differentiable_doc} 269 270 .. _Incorporating Nesterov Momentum into Adam: 271 https://openreview.net/forum?id=OM0jvwB8jIp57ZJjtNEZ 272 .. _Decoupled Weight Decay Regularization: 273 https://arxiv.org/abs/1711.05101 274 275 """ 276) 277 278 279def _single_tensor_nadam( 280 params: List[Tensor], 281 grads: List[Tensor], 282 exp_avgs: List[Tensor], 283 exp_avg_sqs: List[Tensor], 284 mu_products: List[Tensor], 285 state_steps: List[Tensor], 286 *, 287 beta1: float, 288 beta2: float, 289 lr: float, 290 weight_decay: float, 291 momentum_decay: float, 292 eps: float, 293 decoupled_weight_decay: bool, 294 maximize: bool, 295 capturable: bool, 296 differentiable: bool, 297 has_complex: bool, 298): 299 for i, param in enumerate(params): 300 grad = grads[i] if not maximize else -grads[i] 301 exp_avg = exp_avgs[i] 302 exp_avg_sq = exp_avg_sqs[i] 303 mu_product = mu_products[i] 304 step_t = state_steps[i] 305 306 if torch.is_complex(param): 307 param = torch.view_as_real(param) 308 grad = torch.view_as_real(grad) 309 exp_avg = torch.view_as_real(exp_avg) 310 exp_avg_sq = torch.view_as_real(exp_avg_sq) 311 312 # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] 313 if not torch._utils.is_compiling() and capturable: 314 capturable_supported_devices = _get_capturable_supported_devices() 315 assert ( 316 param.device.type == mu_product.device.type == step_t.device.type 317 and param.device.type in capturable_supported_devices 318 ), ( 319 f"If capturable=True, params, mu_products and state_steps must be " 320 f"on supported devices: {capturable_supported_devices}." 321 ) 322 323 # update step 324 step_t += 1 325 326 if capturable: 327 step = step_t 328 else: 329 step = _get_value(step_t) 330 331 bias_correction2 = 1 - beta2**step 332 333 if weight_decay != 0: 334 if decoupled_weight_decay: 335 # Perform stepweight decay 336 param.mul_(1 - lr * weight_decay) 337 else: 338 grad = grad.add(param, alpha=weight_decay) 339 340 # calculate the momentum cache \mu^{t} and \mu^{t+1} 341 mu = beta1 * (1.0 - 0.5 * (0.96 ** (step * momentum_decay))) 342 mu_next = beta1 * (1.0 - 0.5 * (0.96 ** ((step + 1) * momentum_decay))) 343 344 # update mu_product 345 mu_product *= mu 346 347 # decay the first and second moment running average coefficient 348 exp_avg.lerp_(grad, 1 - beta1) 349 exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) 350 denom = exp_avg_sq.div(bias_correction2).sqrt() 351 352 if differentiable or capturable: 353 denom = denom.add(eps) 354 # Make autograd track the operations 355 # by updating the grad and exp_avg directly and not using the 356 # scalar "value" argument of addcdiv. 357 mu_product_next = mu_product * mu_next 358 grad = grad * (-lr * (1.0 - mu) / (1.0 - mu_product)) 359 exp_avg = exp_avg * (-lr * mu_next / (1.0 - mu_product_next)) 360 param.addcdiv_(grad, denom) 361 param.addcdiv_(exp_avg, denom) 362 else: 363 mu_product_next = _get_value(mu_product) * mu_next 364 denom.add_(eps) 365 param.addcdiv_( 366 grad, denom, value=(-lr * (1.0 - mu) / (1.0 - _get_value(mu_product))) 367 ) 368 param.addcdiv_( 369 exp_avg, denom, value=(-lr * mu_next) / (1.0 - mu_product_next) 370 ) 371 372 373def _multi_tensor_nadam( 374 params: List[Tensor], 375 grads: List[Tensor], 376 exp_avgs: List[Tensor], 377 exp_avg_sqs: List[Tensor], 378 mu_products: List[Tensor], 379 state_steps: List[Tensor], 380 *, 381 beta1: float, 382 beta2: float, 383 lr: float, 384 weight_decay: float, 385 momentum_decay: float, 386 eps: float, 387 decoupled_weight_decay: bool, 388 maximize: bool, 389 capturable: bool, 390 differentiable: bool, 391 has_complex: bool, 392): 393 if len(params) == 0: 394 return 395 396 assert not differentiable, "_foreach ops don't support autograd" 397 398 # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] 399 if not torch._utils.is_compiling() and capturable: 400 capturable_supported_devices = _get_capturable_supported_devices( 401 supports_xla=False 402 ) 403 assert all( 404 p.device.type == mp.device.type == step.device.type 405 and p.device.type in capturable_supported_devices 406 for p, mp, step in zip(params, mu_products, state_steps) 407 ), f"If capturable=True, params, mu_products, and state_steps must be on supported devices: {capturable_supported_devices}." 408 409 grouped_tensors = Optimizer._group_tensors_by_device_and_dtype( 410 [params, grads, exp_avgs, exp_avg_sqs, mu_products, state_steps] # type: ignore[list-item] 411 ) 412 for ( 413 grouped_params_, 414 grouped_grads_, 415 grouped_exp_avgs_, 416 grouped_exp_avg_sqs_, 417 grouped_mu_products_, 418 grouped_state_steps_, 419 ), _ in grouped_tensors.values(): 420 grouped_params = cast(List[Tensor], grouped_params_) 421 grouped_grads = cast(List[Tensor], grouped_grads_) 422 grouped_exp_avgs = cast(List[Tensor], grouped_exp_avgs_) 423 grouped_exp_avg_sqs = cast(List[Tensor], grouped_exp_avg_sqs_) 424 grouped_mu_products = cast(List[Tensor], grouped_mu_products_) 425 grouped_state_steps = cast(List[Tensor], grouped_state_steps_) 426 427 # handle complex 428 if has_complex: 429 _view_as_real( 430 grouped_params, grouped_grads, grouped_exp_avgs, grouped_exp_avg_sqs 431 ) 432 433 if maximize: 434 grouped_grads = torch._foreach_neg(grouped_grads) # type: ignore[assignment] 435 436 # Update steps 437 # If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over 438 # and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just 439 # wrapped it once now. The alpha is required to assure we go to the right overload. 440 if not torch._utils.is_compiling() and grouped_state_steps[0].is_cpu: 441 torch._foreach_add_( 442 grouped_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0 443 ) 444 else: 445 torch._foreach_add_(grouped_state_steps, 1) 446 447 if weight_decay != 0: 448 if decoupled_weight_decay: 449 # Perform stepweight decay 450 torch._foreach_mul_(grouped_params, 1 - lr * weight_decay) 451 else: 452 # Re-use the intermediate memory (grouped_grads) already allocated for maximize 453 if maximize: 454 torch._foreach_add_( 455 grouped_grads, grouped_params, alpha=weight_decay 456 ) 457 else: 458 grouped_grads = torch._foreach_add( # type: ignore[assignment] 459 grouped_grads, grouped_params, alpha=weight_decay 460 ) 461 462 # Decay the first and second moment running average coefficient 463 torch._foreach_lerp_(grouped_exp_avgs, grouped_grads, 1 - beta1) 464 465 torch._foreach_mul_(grouped_exp_avg_sqs, beta2) 466 torch._foreach_addcmul_( 467 grouped_exp_avg_sqs, grouped_grads, grouped_grads, 1 - beta2 468 ) 469 470 exp_avg_sq_sqrt = torch._foreach_sqrt(grouped_exp_avg_sqs) 471 472 bias_correction_sqrt: Union[Tuple[Tensor, ...], List[Tensor]] 473 mus: Union[Tuple[Tensor, ...], List[Tensor]] 474 mu_nexts: Union[Tuple[Tensor, ...], List[Tensor]] 475 if capturable: 476 # mus will be beta1 * (1 - 0.5 * 0.96 ** (step * momentum_decay)) 477 exponent = torch._foreach_mul(grouped_state_steps, momentum_decay) 478 mus = torch._foreach_pow(0.96, exponent) 479 torch._foreach_mul_(mus, -0.5) 480 torch._foreach_add_(mus, 1.0) 481 torch._foreach_mul_(mus, beta1) 482 483 # mu_nexts will be beta1 * (1 - 0.5 * 0.96 ** ((step + 1) * momentum_decay)) 484 torch._foreach_add_(exponent, momentum_decay) 485 mu_nexts = torch._foreach_pow(0.96, exponent) 486 torch._foreach_mul_(mu_nexts, -0.5) 487 torch._foreach_add_(mu_nexts, 1.0) 488 torch._foreach_mul_(mu_nexts, beta1) 489 490 # save peak memory as we don't need exponent anymore 491 del exponent 492 493 bias_correction_sqrt = torch._foreach_pow(beta2, grouped_state_steps) 494 # foreach_sub doesn't allow a scalar as the first arg 495 torch._foreach_sub_(bias_correction_sqrt, 1.0) 496 torch._foreach_neg_(bias_correction_sqrt) 497 torch._foreach_sqrt_(bias_correction_sqrt) 498 else: 499 bias_correction_sqrt = [ 500 (1 - beta2 ** _get_value(step)) ** 0.5 for step in grouped_state_steps 501 ] 502 mus = [ 503 beta1 * (1.0 - 0.5 * (0.96 ** (_get_value(step) * momentum_decay))) 504 for step in grouped_state_steps 505 ] 506 mu_nexts = [ 507 beta1 508 * (1.0 - 0.5 * (0.96 ** ((_get_value(step) + 1) * momentum_decay))) 509 for step in grouped_state_steps 510 ] 511 512 # update mu_products 513 torch._foreach_mul_(grouped_mu_products, mus) 514 515 torch._foreach_div_(exp_avg_sq_sqrt, bias_correction_sqrt) 516 torch._foreach_add_(exp_avg_sq_sqrt, eps) 517 518 # explicitly delete bias_correction refs to save memory 519 del bias_correction_sqrt 520 521 if capturable: 522 # Build up the step_size multiplier for grad, reusing mus' memory 523 torch._foreach_sub_(mus, 1.0) 524 torch._foreach_mul_(mus, lr) 525 # foreach_sub doesn't allow a scalar as the first arg 526 denom = torch._foreach_sub(grouped_mu_products, 1.0) 527 torch._foreach_neg_(denom) 528 torch._foreach_div_(mus, denom) 529 # - lr * (1 - mu) / (1 - mu_product) 530 step_size_grads = mus 531 # explicitly delete denom to save memory 532 del denom 533 534 # Build up the step_size multiplier for exp_avg, reusing mu_nexts' memory 535 denom = torch._foreach_mul(grouped_mu_products, mu_nexts) 536 torch._foreach_mul_(mu_nexts, lr) 537 # foreach_sub doesn't allow a scalar as the first arg, but it's okay because 538 # we need a negative here anyway 539 torch._foreach_sub_(denom, 1.0) 540 torch._foreach_div_(mu_nexts, denom) 541 # - lr * mu_next / (1 - mu_product * mu_next) 542 step_size_expavg = mu_nexts 543 # explicitly delete denom to save memory 544 del denom 545 546 # we cannot inplace into step_size_grads cuz it is a list of ScalarTensors 547 # and mul'ing with grouped_grads will result in a list of bigger Tensors 548 numerator = torch._foreach_mul(step_size_grads, grouped_grads) 549 torch._foreach_addcmul_(numerator, step_size_expavg, grouped_exp_avgs) 550 551 # finally, update params 552 torch._foreach_addcdiv_(grouped_params, numerator, exp_avg_sq_sqrt) 553 else: 554 step_size_grads = _stack_if_compiling( 555 [ 556 (_get_value(lr) * (1.0 - mu) / (1.0 - _get_value(mu_product))) * -1 557 for mu_product, mu in zip(grouped_mu_products, mus) 558 ] 559 ) 560 step_size_expavg = _stack_if_compiling( 561 [ 562 ( 563 _get_value(lr) 564 * mu_next 565 / (1.0 - _get_value(mu_product) * mu_next) 566 ) 567 * -1 568 for mu_product, mu_next in zip(grouped_mu_products, mu_nexts) 569 ] 570 ) 571 572 torch._foreach_addcdiv_( 573 grouped_params, grouped_grads, exp_avg_sq_sqrt, step_size_grads # type: ignore[arg-type] 574 ) 575 torch._foreach_addcdiv_( 576 grouped_params, grouped_exp_avgs, exp_avg_sq_sqrt, step_size_expavg # type: ignore[arg-type] 577 ) 578 579 580@_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_nadam) 581def nadam( 582 params: List[Tensor], 583 grads: List[Tensor], 584 exp_avgs: List[Tensor], 585 exp_avg_sqs: List[Tensor], 586 mu_products: List[Tensor], 587 state_steps: List[Tensor], 588 # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627 589 # setting this as kwarg for now as functional API is compiled by torch/distributed/optim 590 decoupled_weight_decay: bool = False, 591 foreach: Optional[bool] = None, 592 capturable: bool = False, 593 differentiable: bool = False, 594 has_complex: bool = False, 595 maximize: bool = False, 596 *, 597 beta1: float, 598 beta2: float, 599 lr: float, 600 weight_decay: float, 601 momentum_decay: float, 602 eps: float, 603): 604 r"""Functional API that performs NAdam algorithm computation. 605 606 See :class:`~torch.optim.NAdam` for details. 607 """ 608 if not all(isinstance(t, torch.Tensor) for t in state_steps): 609 raise RuntimeError( 610 "API has changed, `state_steps` argument must contain a list of singleton tensors" 611 ) 612 613 if not all(isinstance(t, torch.Tensor) for t in mu_products): 614 raise RuntimeError( 615 "API has changed, `mu_products` argument must contain a list of singleton tensors" 616 ) 617 618 if foreach is None: 619 _, foreach = _default_to_fused_or_foreach( 620 params, differentiable, use_fused=False 621 ) 622 623 if foreach and torch.jit.is_scripting(): 624 raise RuntimeError("torch.jit.script not supported with foreach optimizers") 625 626 if foreach and not torch.jit.is_scripting(): 627 func = _multi_tensor_nadam 628 else: 629 func = _single_tensor_nadam 630 631 func( 632 params, 633 grads, 634 exp_avgs, 635 exp_avg_sqs, 636 mu_products, 637 state_steps, 638 beta1=beta1, 639 beta2=beta2, 640 lr=lr, 641 weight_decay=weight_decay, 642 momentum_decay=momentum_decay, 643 maximize=maximize, 644 decoupled_weight_decay=decoupled_weight_decay, 645 eps=eps, 646 capturable=capturable, 647 differentiable=differentiable, 648 has_complex=has_complex, 649 ) 650