1# mypy: allow-untyped-decorators 2# mypy: allow-untyped-defs 3from typing import cast, List, Optional, Tuple, Union 4 5import torch 6from torch import Tensor 7 8from .optimizer import ( 9 _capturable_doc, 10 _default_to_fused_or_foreach, 11 _device_dtype_check_for_fused, 12 _differentiable_doc, 13 _disable_dynamo_if_unsupported, 14 _foreach_doc, 15 _fused_doc, 16 _get_capturable_supported_devices, 17 _get_scalar_dtype, 18 _get_value, 19 _maximize_doc, 20 _stack_if_compiling, 21 _use_grad_for_differentiable, 22 _view_as_real, 23 DeviceDict, 24 Optimizer, 25 ParamsT, 26) 27 28 29__all__ = ["AdamW", "adamw"] 30 31 32class AdamW(Optimizer): 33 def __init__( 34 self, 35 params: ParamsT, 36 lr: Union[float, Tensor] = 1e-3, 37 betas: Tuple[float, float] = (0.9, 0.999), 38 eps: float = 1e-8, 39 weight_decay: float = 1e-2, 40 amsgrad: bool = False, 41 *, 42 maximize: bool = False, 43 foreach: Optional[bool] = None, 44 capturable: bool = False, 45 differentiable: bool = False, 46 fused: Optional[bool] = None, 47 ): 48 if isinstance(lr, Tensor): 49 if foreach and not capturable: 50 raise ValueError( 51 "lr as a Tensor is not supported for capturable=False and foreach=True" 52 ) 53 if lr.numel() != 1: 54 raise ValueError("Tensor lr must be 1-element") 55 if not 0.0 <= lr: 56 raise ValueError(f"Invalid learning rate: {lr}") 57 if not 0.0 <= eps: 58 raise ValueError(f"Invalid epsilon value: {eps}") 59 if not 0.0 <= betas[0] < 1.0: 60 raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}") 61 if not 0.0 <= betas[1] < 1.0: 62 raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}") 63 if not 0.0 <= weight_decay: 64 raise ValueError(f"Invalid weight_decay value: {weight_decay}") 65 defaults = dict( 66 lr=lr, 67 betas=betas, 68 eps=eps, 69 weight_decay=weight_decay, 70 amsgrad=amsgrad, 71 foreach=foreach, 72 maximize=maximize, 73 capturable=capturable, 74 differentiable=differentiable, 75 fused=fused, 76 ) 77 super().__init__(params, defaults) 78 79 if fused: 80 if differentiable: 81 raise RuntimeError("`fused` does not support `differentiable`") 82 self._step_supports_amp_scaling = True 83 if foreach: 84 raise RuntimeError("`fused` and `foreach` cannot be `True` together.") 85 86 def __setstate__(self, state): 87 super().__setstate__(state) 88 for group in self.param_groups: 89 group.setdefault("amsgrad", False) 90 group.setdefault("maximize", False) 91 group.setdefault("foreach", None) 92 group.setdefault("capturable", False) 93 group.setdefault("differentiable", False) 94 fused = group.setdefault("fused", None) 95 for p in group["params"]: 96 p_state = self.state.get(p, []) 97 if len(p_state) != 0 and not torch.is_tensor(p_state["step"]): 98 step_val = float(p_state["step"]) 99 p_state["step"] = ( 100 torch.tensor( 101 step_val, 102 dtype=_get_scalar_dtype(is_fused=fused), 103 device=p.device, 104 ) 105 if group["capturable"] or group["fused"] 106 else torch.tensor(step_val, dtype=_get_scalar_dtype()) 107 ) 108 109 def _init_group( 110 self, 111 group, 112 params_with_grad, 113 grads, 114 amsgrad, 115 exp_avgs, 116 exp_avg_sqs, 117 max_exp_avg_sqs, 118 state_steps, 119 ): 120 has_complex = False 121 for p in group["params"]: 122 if p.grad is None: 123 continue 124 has_complex |= torch.is_complex(p) 125 params_with_grad.append(p) 126 if p.grad.is_sparse: 127 raise RuntimeError("AdamW does not support sparse gradients") 128 grads.append(p.grad) 129 130 state = self.state[p] 131 132 # State initialization 133 if len(state) == 0: 134 if group["fused"]: 135 _device_dtype_check_for_fused(p) 136 # note(crcrpar): Deliberately host `step` on CPU if both capturable and fused are off. 137 # This is because kernel launches are costly on CUDA and XLA. 138 state["step"] = ( 139 torch.zeros( 140 (), 141 dtype=_get_scalar_dtype(is_fused=group["fused"]), 142 device=p.device, 143 ) 144 if group["capturable"] or group["fused"] 145 else torch.tensor(0.0, dtype=_get_scalar_dtype()) 146 ) 147 # Exponential moving average of gradient values 148 state["exp_avg"] = torch.zeros_like( 149 p, memory_format=torch.preserve_format 150 ) 151 # Exponential moving average of squared gradient values 152 state["exp_avg_sq"] = torch.zeros_like( 153 p, memory_format=torch.preserve_format 154 ) 155 if amsgrad: 156 # Maintains max of all exp. moving avg. of sq. grad. values 157 state["max_exp_avg_sq"] = torch.zeros_like( 158 p, memory_format=torch.preserve_format 159 ) 160 161 exp_avgs.append(state["exp_avg"]) 162 exp_avg_sqs.append(state["exp_avg_sq"]) 163 164 if group["amsgrad"]: 165 max_exp_avg_sqs.append(state["max_exp_avg_sq"]) 166 if group["differentiable"] and state["step"].requires_grad: 167 raise RuntimeError( 168 "`requires_grad` is not supported for `step` in differentiable mode" 169 ) 170 171 # Foreach without capturable does not support a tensor lr 172 if ( 173 group["foreach"] 174 and isinstance(group["lr"], Tensor) 175 and not group["capturable"] 176 ): 177 raise RuntimeError( 178 "lr as a Tensor is not supported for capturable=False and foreach=True" 179 ) 180 181 state_steps.append(state["step"]) 182 return has_complex 183 184 @_use_grad_for_differentiable 185 def step(self, closure=None): 186 """Perform a single optimization step. 187 188 Args: 189 closure (Callable, optional): A closure that reevaluates the model 190 and returns the loss. 191 """ 192 self._cuda_graph_capture_health_check() 193 194 loss = None 195 if closure is not None: 196 with torch.enable_grad(): 197 loss = closure() 198 199 for group in self.param_groups: 200 params_with_grad: List[Tensor] = [] 201 grads: List[Tensor] = [] 202 exp_avgs: List[Tensor] = [] 203 exp_avg_sqs: List[Tensor] = [] 204 max_exp_avg_sqs: List[Tensor] = [] 205 state_steps: List[Tensor] = [] 206 amsgrad: bool = group["amsgrad"] 207 beta1, beta2 = cast(Tuple[float, float], group["betas"]) 208 209 has_complex = self._init_group( 210 group, 211 params_with_grad, 212 grads, 213 amsgrad, 214 exp_avgs, 215 exp_avg_sqs, 216 max_exp_avg_sqs, 217 state_steps, 218 ) 219 220 adamw( 221 params_with_grad, 222 grads, 223 exp_avgs, 224 exp_avg_sqs, 225 max_exp_avg_sqs, 226 state_steps, 227 amsgrad=amsgrad, 228 beta1=beta1, 229 beta2=beta2, 230 lr=group["lr"], 231 weight_decay=group["weight_decay"], 232 eps=group["eps"], 233 maximize=group["maximize"], 234 foreach=group["foreach"], 235 capturable=group["capturable"], 236 differentiable=group["differentiable"], 237 fused=group["fused"], 238 grad_scale=getattr(self, "grad_scale", None), 239 found_inf=getattr(self, "found_inf", None), 240 has_complex=has_complex, 241 ) 242 243 return loss 244 245 246AdamW.__doc__ = ( 247 r"""Implements AdamW algorithm. 248 249 .. math:: 250 \begin{aligned} 251 &\rule{110mm}{0.4pt} \\ 252 &\textbf{input} : \gamma \text{(lr)}, \: \beta_1, \beta_2 253 \text{(betas)}, \: \theta_0 \text{(params)}, \: f(\theta) \text{(objective)}, 254 \: \epsilon \text{ (epsilon)} \\ 255 &\hspace{13mm} \lambda \text{(weight decay)}, \: \textit{amsgrad}, 256 \: \textit{maximize} \\ 257 &\textbf{initialize} : m_0 \leftarrow 0 \text{ (first moment)}, v_0 \leftarrow 0 258 \text{ ( second moment)}, \: \widehat{v_0}^{max}\leftarrow 0 \\[-1.ex] 259 &\rule{110mm}{0.4pt} \\ 260 &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\ 261 262 &\hspace{5mm}\textbf{if} \: \textit{maximize}: \\ 263 &\hspace{10mm}g_t \leftarrow -\nabla_{\theta} f_t (\theta_{t-1}) \\ 264 &\hspace{5mm}\textbf{else} \\ 265 &\hspace{10mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\ 266 &\hspace{5mm} \theta_t \leftarrow \theta_{t-1} - \gamma \lambda \theta_{t-1} \\ 267 &\hspace{5mm}m_t \leftarrow \beta_1 m_{t-1} + (1 - \beta_1) g_t \\ 268 &\hspace{5mm}v_t \leftarrow \beta_2 v_{t-1} + (1-\beta_2) g^2_t \\ 269 &\hspace{5mm}\widehat{m_t} \leftarrow m_t/\big(1-\beta_1^t \big) \\ 270 &\hspace{5mm}\widehat{v_t} \leftarrow v_t/\big(1-\beta_2^t \big) \\ 271 &\hspace{5mm}\textbf{if} \: amsgrad \\ 272 &\hspace{10mm}\widehat{v_t}^{max} \leftarrow \mathrm{max}(\widehat{v_t}^{max}, 273 \widehat{v_t}) \\ 274 &\hspace{10mm}\theta_t \leftarrow \theta_t - \gamma \widehat{m_t}/ 275 \big(\sqrt{\widehat{v_t}^{max}} + \epsilon \big) \\ 276 &\hspace{5mm}\textbf{else} \\ 277 &\hspace{10mm}\theta_t \leftarrow \theta_t - \gamma \widehat{m_t}/ 278 \big(\sqrt{\widehat{v_t}} + \epsilon \big) \\ 279 &\rule{110mm}{0.4pt} \\[-1.ex] 280 &\bf{return} \: \theta_t \\[-1.ex] 281 &\rule{110mm}{0.4pt} \\[-1.ex] 282 \end{aligned} 283 284 For further details regarding the algorithm we refer to `Decoupled Weight Decay Regularization`_. 285 """ 286 + rf""" 287 Args: 288 params (iterable): iterable of parameters to optimize or dicts defining 289 parameter groups 290 lr (float, Tensor, optional): learning rate (default: 1e-3). A tensor LR 291 is not yet supported for all our implementations. Please use a float 292 LR if you are not also specifying fused=True or capturable=True. 293 betas (Tuple[float, float], optional): coefficients used for computing 294 running averages of gradient and its square (default: (0.9, 0.999)) 295 eps (float, optional): term added to the denominator to improve 296 numerical stability (default: 1e-8) 297 weight_decay (float, optional): weight decay coefficient (default: 1e-2) 298 amsgrad (bool, optional): whether to use the AMSGrad variant of this 299 algorithm from the paper `On the Convergence of Adam and Beyond`_ 300 (default: False) 301 {_maximize_doc} 302 {_foreach_doc} 303 {_capturable_doc} 304 {_differentiable_doc} 305 {_fused_doc} 306 .. Note:: 307 A prototype implementation of Adam and AdamW for MPS supports `torch.float32` and `torch.float16`. 308 .. _Decoupled Weight Decay Regularization: 309 https://arxiv.org/abs/1711.05101 310 .. _On the Convergence of Adam and Beyond: 311 https://openreview.net/forum?id=ryQu7f-RZ 312 313 """ 314) 315 316 317def _single_tensor_adamw( 318 params: List[Tensor], 319 grads: List[Tensor], 320 exp_avgs: List[Tensor], 321 exp_avg_sqs: List[Tensor], 322 max_exp_avg_sqs: List[Tensor], 323 state_steps: List[Tensor], 324 grad_scale: Optional[Tensor], 325 found_inf: Optional[Tensor], 326 *, 327 amsgrad: bool, 328 beta1: float, 329 beta2: float, 330 lr: Union[Tensor, float], 331 weight_decay: float, 332 eps: float, 333 maximize: bool, 334 capturable: bool, 335 differentiable: bool, 336 has_complex: bool, 337): 338 assert grad_scale is None and found_inf is None 339 340 if torch.jit.is_scripting(): 341 # this assert is due to JIT being dumb and not realizing that the ops below 342 # have overloads to handle both float and Tensor lrs, so we just assert it's 343 # a float since most people using JIT are using floats 344 assert isinstance(lr, float) 345 346 for i, param in enumerate(params): 347 grad = grads[i] if not maximize else -grads[i] 348 exp_avg = exp_avgs[i] 349 exp_avg_sq = exp_avg_sqs[i] 350 step_t = state_steps[i] 351 352 # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] 353 if not torch._utils.is_compiling() and capturable: 354 capturable_supported_devices = _get_capturable_supported_devices() 355 assert ( 356 param.device.type == step_t.device.type 357 and param.device.type in capturable_supported_devices 358 ), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}." 359 360 if torch.is_complex(param): 361 grad = torch.view_as_real(grad) 362 exp_avg = torch.view_as_real(exp_avg) 363 exp_avg_sq = torch.view_as_real(exp_avg_sq) 364 if amsgrad: 365 max_exp_avg_sqs[i] = torch.view_as_real(max_exp_avg_sqs[i]) 366 param = torch.view_as_real(param) 367 368 # update step 369 step_t += 1 370 371 # Perform stepweight decay 372 param.mul_(1 - lr * weight_decay) 373 374 # Decay the first and second moment running average coefficient 375 exp_avg.lerp_(grad, 1 - beta1) 376 exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) 377 378 if capturable or differentiable: 379 step = step_t 380 381 bias_correction1 = 1 - beta1**step 382 bias_correction2 = 1 - beta2**step 383 384 step_size = lr / bias_correction1 385 step_size_neg = step_size.neg() 386 387 bias_correction2_sqrt = bias_correction2.sqrt() 388 389 if amsgrad: 390 # Maintains the maximum of all 2nd moment running avg. till now 391 if differentiable: 392 max_exp_avg_sq = max_exp_avg_sqs[i].clone() 393 else: 394 max_exp_avg_sq = max_exp_avg_sqs[i] 395 396 max_exp_avg_sqs[i].copy_(torch.maximum(max_exp_avg_sq, exp_avg_sq)) 397 398 # Uses the max. for normalizing running avg. of gradient 399 # Folds in (admittedly ugly) 1-elem step_size math here to avoid extra param-set-sized read+write 400 # (can't fold it into addcdiv_ below because addcdiv_ requires value is a Number, not a Tensor) 401 denom = ( 402 max_exp_avg_sqs[i].sqrt() / (bias_correction2_sqrt * step_size_neg) 403 ).add_(eps / step_size_neg) 404 else: 405 denom = ( 406 exp_avg_sq.sqrt() / (bias_correction2_sqrt * step_size_neg) 407 ).add_(eps / step_size_neg) 408 409 param.addcdiv_(exp_avg, denom) 410 else: 411 step = _get_value(step_t) 412 413 bias_correction1 = 1 - beta1**step 414 bias_correction2 = 1 - beta2**step 415 416 step_size = lr / bias_correction1 417 418 bias_correction2_sqrt = bias_correction2**0.5 419 420 if amsgrad: 421 # Maintains the maximum of all 2nd moment running avg. till now 422 torch.maximum(max_exp_avg_sqs[i], exp_avg_sq, out=max_exp_avg_sqs[i]) 423 424 # Use the max. for normalizing running avg. of gradient 425 denom = (max_exp_avg_sqs[i].sqrt() / bias_correction2_sqrt).add_(eps) 426 else: 427 denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps) 428 429 param.addcdiv_(exp_avg, denom, value=-step_size) 430 431 # Lastly, switch back to complex view 432 if amsgrad and torch.is_complex(params[i]): 433 max_exp_avg_sqs[i] = torch.view_as_complex(max_exp_avg_sqs[i]) 434 435 436def _multi_tensor_adamw( 437 params: List[Tensor], 438 grads: List[Tensor], 439 exp_avgs: List[Tensor], 440 exp_avg_sqs: List[Tensor], 441 max_exp_avg_sqs: List[Tensor], 442 state_steps: List[Tensor], 443 grad_scale: Optional[Tensor], 444 found_inf: Optional[Tensor], 445 *, 446 amsgrad: bool, 447 beta1: float, 448 beta2: float, 449 lr: Union[Tensor, float], 450 weight_decay: float, 451 eps: float, 452 maximize: bool, 453 capturable: bool, 454 differentiable: bool, 455 has_complex: bool, 456): 457 if len(params) == 0: 458 return 459 460 if isinstance(lr, Tensor) and not capturable: 461 raise RuntimeError( 462 "lr as a Tensor is not supported for capturable=False and foreach=True" 463 ) 464 465 # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] 466 if not torch._utils.is_compiling() and capturable: 467 capturable_supported_devices = _get_capturable_supported_devices( 468 supports_xla=False 469 ) 470 assert all( 471 p.device.type == step.device.type 472 and p.device.type in capturable_supported_devices 473 for p, step in zip(params, state_steps) 474 ), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}." 475 476 assert not differentiable, "_foreach ops don't support autograd" 477 478 assert grad_scale is None and found_inf is None 479 480 grouped_tensors = Optimizer._group_tensors_by_device_and_dtype( 481 [params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps] # type: ignore[list-item] 482 ) 483 for ( 484 device_params_, 485 device_grads_, 486 device_exp_avgs_, 487 device_exp_avg_sqs_, 488 device_max_exp_avg_sqs_, 489 device_state_steps_, 490 ), _ in grouped_tensors.values(): 491 device_params = cast(List[Tensor], device_params_) 492 device_grads = cast(List[Tensor], device_grads_) 493 device_exp_avgs = cast(List[Tensor], device_exp_avgs_) 494 device_exp_avg_sqs = cast(List[Tensor], device_exp_avg_sqs_) 495 device_state_steps = cast(List[Tensor], device_state_steps_) 496 497 if has_complex: 498 if amsgrad: 499 device_max_exp_avg_sqs = cast(List[Tensor], device_max_exp_avg_sqs_) 500 _view_as_real( 501 device_params, 502 device_grads, 503 device_exp_avgs, 504 device_exp_avg_sqs, 505 device_max_exp_avg_sqs, 506 ) 507 else: 508 _view_as_real( 509 device_params, device_grads, device_exp_avgs, device_exp_avg_sqs 510 ) 511 512 if maximize: 513 device_grads = torch._foreach_neg(device_grads) # type: ignore[assignment] 514 515 # Update steps 516 # If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over 517 # and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just 518 # wrapped it once now. The alpha is required to assure we go to the right overload. 519 if not torch._utils.is_compiling() and device_state_steps[0].is_cpu: 520 torch._foreach_add_( 521 device_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0 522 ) 523 else: 524 torch._foreach_add_(device_state_steps, 1) 525 526 # Perform stepweight decay 527 if weight_decay != 0: 528 torch._foreach_mul_(device_params, 1 - lr * weight_decay) 529 530 # Decay the first and second moment running average coefficient 531 torch._foreach_lerp_(device_exp_avgs, device_grads, 1 - beta1) 532 533 torch._foreach_mul_(device_exp_avg_sqs, beta2) 534 torch._foreach_addcmul_( 535 device_exp_avg_sqs, device_grads, device_grads, 1 - beta2 536 ) 537 538 # Delete the local intermediate since it won't be used anymore to save on peak memory 539 del device_grads 540 541 bias_correction1: Union[Tuple[Tensor, ...], List[Tensor]] 542 bias_correction2: Union[Tuple[Tensor, ...], List[Tensor]] 543 bias_correction2_sqrt: Union[Tuple[Tensor, ...], List[Tensor]] 544 545 if capturable: 546 bias_correction1 = torch._foreach_pow(beta1, device_state_steps) 547 bias_correction2 = torch._foreach_pow(beta2, device_state_steps) 548 # foreach_sub doesn't allow a scalar as the first arg 549 torch._foreach_sub_(bias_correction1, 1) 550 torch._foreach_sub_(bias_correction2, 1) 551 # we do not negate bias_correction1 as it'll need to be negated later anyway 552 torch._foreach_neg_(bias_correction2) 553 554 # foreach_div doesn't allow a scalar as the first arg 555 torch._foreach_div_(bias_correction1, lr) 556 torch._foreach_reciprocal_(bias_correction1) 557 558 torch._foreach_sqrt_(bias_correction2) 559 560 # Re-assign for clarity as we maintain minimal intermediates: we'll have 561 # step_size = - lr / (1 - beta1 ^ t) where t = num_steps 562 # bias_correction2_sqrt = sqrt(1 - beta2 ^ t) 563 step_size = bias_correction1 564 bias_correction2_sqrt = bias_correction2 565 566 if amsgrad: 567 device_max_exp_avg_sqs = cast(List[Tensor], device_max_exp_avg_sqs_) 568 569 # Maintains the maximum of all 2nd moment running avg. till now 570 torch._foreach_maximum_(device_max_exp_avg_sqs, device_exp_avg_sqs) 571 572 # Use the max. for normalizing running avg. of gradient 573 exp_avg_sq_sqrt = torch._foreach_sqrt(device_max_exp_avg_sqs) 574 else: 575 exp_avg_sq_sqrt = torch._foreach_sqrt(device_exp_avg_sqs) 576 577 torch._foreach_div_(exp_avg_sq_sqrt, bias_correction2_sqrt) 578 torch._foreach_add_(exp_avg_sq_sqrt, eps) 579 torch._foreach_div_(exp_avg_sq_sqrt, step_size) 580 581 # at this point, exp_avg_sq_sqrt = - (1 - beta^t) * [sqrt(exp_avg_sq / (1 - beta2^t)) + eps] / lr 582 torch._foreach_addcdiv_(device_params, device_exp_avgs, exp_avg_sq_sqrt) 583 else: 584 bias_correction1 = [ 585 1 - beta1 ** _get_value(step) for step in device_state_steps 586 ] 587 bias_correction2 = [ 588 1 - beta2 ** _get_value(step) for step in device_state_steps 589 ] 590 591 step_size = _stack_if_compiling([(lr / bc) * -1 for bc in bias_correction1]) 592 593 bias_correction2_sqrt = [ 594 bc**0.5 for bc in bias_correction2 # type: ignore[arg-type] 595 ] 596 597 if amsgrad: 598 device_max_exp_avg_sqs = cast(List[Tensor], device_max_exp_avg_sqs_) 599 600 # Maintains the maximum of all 2nd moment running avg. till now 601 torch._foreach_maximum_(device_max_exp_avg_sqs, device_exp_avg_sqs) 602 603 # Use the max. for normalizing running avg. of gradient 604 exp_avg_sq_sqrt = torch._foreach_sqrt(device_max_exp_avg_sqs) 605 else: 606 exp_avg_sq_sqrt = torch._foreach_sqrt(device_exp_avg_sqs) 607 608 torch._foreach_div_(exp_avg_sq_sqrt, bias_correction2_sqrt) 609 torch._foreach_add_(exp_avg_sq_sqrt, eps) 610 torch._foreach_addcdiv_( 611 device_params, 612 device_exp_avgs, 613 exp_avg_sq_sqrt, 614 step_size, # type: ignore[arg-type] 615 ) 616 617 618def _fused_adamw( 619 params: List[Tensor], 620 grads: List[Tensor], 621 exp_avgs: List[Tensor], 622 exp_avg_sqs: List[Tensor], 623 max_exp_avg_sqs: List[Tensor], 624 state_steps: List[Tensor], 625 grad_scale: Optional[Tensor], 626 found_inf: Optional[Tensor], 627 *, 628 amsgrad: bool, 629 beta1: float, 630 beta2: float, 631 lr: Union[Tensor, float], 632 weight_decay: float, 633 eps: float, 634 maximize: bool, 635 capturable: bool, # Needed for consistency. 636 differentiable: bool, 637 has_complex: bool, # Needed for consistency. 638) -> None: 639 if not params: 640 return 641 if differentiable: 642 raise RuntimeError("Adam with fused=True does not support differentiable=True") 643 644 grad_scale_dict: DeviceDict = ( 645 {grad_scale.device: grad_scale} if grad_scale is not None else {} 646 ) 647 found_inf_dict: DeviceDict = ( 648 {found_inf.device: found_inf} if found_inf is not None else {} 649 ) 650 651 # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer 652 # treating it as a scalar. 653 lr_dict: Optional[DeviceDict] = ( 654 {lr.device: lr} if isinstance(lr, Tensor) and str(lr.device) != "cpu" else None 655 ) 656 657 grouped_tensors = Optimizer._group_tensors_by_device_and_dtype( 658 [params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps] # type: ignore[list-item] 659 ) 660 for (device, _), ( 661 ( 662 device_params_, 663 device_grads_, 664 device_exp_avgs_, 665 device_exp_avg_sqs_, 666 device_max_exp_avg_sqs, 667 device_state_steps_, 668 ), 669 _, 670 ) in grouped_tensors.items(): 671 device_params = cast(List[Tensor], device_params_) 672 device_grads = cast(List[Tensor], device_grads_) 673 device_exp_avgs = cast(List[Tensor], device_exp_avgs_) 674 device_exp_avg_sqs = cast(List[Tensor], device_exp_avg_sqs_) 675 device_state_steps = cast(List[Tensor], device_state_steps_) 676 677 if device.type == "mps": # type: ignore[union-attr] 678 assert found_inf is None and grad_scale is None 679 680 device_grad_scale, device_found_inf = None, None 681 if grad_scale is not None: 682 device_grad_scale = grad_scale_dict.setdefault( 683 device, grad_scale.to(device, non_blocking=True) 684 ) 685 if found_inf is not None: 686 device_found_inf = found_inf_dict.setdefault( 687 device, found_inf.to(device, non_blocking=True) 688 ) 689 if lr_dict is not None and device not in lr_dict: 690 lr = lr_dict.setdefault( 691 device, lr.to(device=device, non_blocking=True) # type: ignore[union-attr] 692 ) 693 torch._foreach_add_(device_state_steps, 1) 694 torch._fused_adamw_( 695 device_params, 696 device_grads, 697 device_exp_avgs, 698 device_exp_avg_sqs, 699 device_max_exp_avg_sqs, # type: ignore[arg-type] 700 device_state_steps, 701 amsgrad=amsgrad, 702 lr=lr, # type: ignore[arg-type] 703 beta1=beta1, 704 beta2=beta2, 705 weight_decay=weight_decay, 706 eps=eps, 707 maximize=maximize, 708 grad_scale=device_grad_scale, 709 found_inf=device_found_inf, 710 ) 711 if device_found_inf is not None: 712 torch._foreach_sub_( 713 device_state_steps, [device_found_inf] * len(device_state_steps) 714 ) 715 716 717@_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_adamw) 718def adamw( 719 params: List[Tensor], 720 grads: List[Tensor], 721 exp_avgs: List[Tensor], 722 exp_avg_sqs: List[Tensor], 723 max_exp_avg_sqs: List[Tensor], 724 state_steps: List[Tensor], 725 # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627 726 # setting this as kwarg for now as functional API is compiled by torch/distributed/optim 727 foreach: Optional[bool] = None, 728 capturable: bool = False, 729 differentiable: bool = False, 730 fused: Optional[bool] = None, 731 grad_scale: Optional[Tensor] = None, 732 found_inf: Optional[Tensor] = None, 733 has_complex: bool = False, 734 *, 735 amsgrad: bool, 736 beta1: float, 737 beta2: float, 738 lr: Union[float, Tensor], 739 weight_decay: float, 740 eps: float, 741 maximize: bool, 742): 743 r"""Functional API that performs AdamW algorithm computation. 744 745 See :class:`~torch.optim.AdamW` for details. 746 """ 747 if not torch._utils.is_compiling() and not all( 748 isinstance(t, torch.Tensor) for t in state_steps 749 ): 750 raise RuntimeError( 751 "API has changed, `state_steps` argument must contain a list of singleton tensors" 752 ) 753 754 # Respect when the user inputs False/True for foreach or fused. We only want to change 755 # the default when neither have been user-specified. Note that we default to foreach 756 # and pass False to use_fused. This is not a mistake--we want to give the fused impl 757 # bake-in time before making it the default, even if it is typically faster. 758 if fused is None and foreach is None: 759 _, foreach = _default_to_fused_or_foreach( 760 params, differentiable, use_fused=False 761 ) 762 # Do not flip on foreach for the unsupported case where lr is a Tensor and capturable=False. 763 if foreach and isinstance(lr, Tensor) and not capturable: 764 foreach = False 765 if fused is None: 766 fused = False 767 if foreach is None: 768 foreach = False 769 770 if foreach and torch.jit.is_scripting(): 771 raise RuntimeError("torch.jit.script not supported with foreach optimizers") 772 if fused and torch.jit.is_scripting(): 773 raise RuntimeError("torch.jit.script not supported with fused optimizers") 774 775 if fused and not torch.jit.is_scripting(): 776 func = _fused_adamw 777 elif foreach and not torch.jit.is_scripting(): 778 func = _multi_tensor_adamw 779 else: 780 func = _single_tensor_adamw 781 782 func( 783 params, 784 grads, 785 exp_avgs, 786 exp_avg_sqs, 787 max_exp_avg_sqs, 788 state_steps, 789 amsgrad=amsgrad, 790 beta1=beta1, 791 beta2=beta2, 792 lr=lr, 793 weight_decay=weight_decay, 794 eps=eps, 795 maximize=maximize, 796 capturable=capturable, 797 differentiable=differentiable, 798 grad_scale=grad_scale, 799 found_inf=found_inf, 800 has_complex=has_complex, 801 ) 802