1# mypy: allow-untyped-decorators 2# mypy: allow-untyped-defs 3from typing import cast, List, Optional, Tuple, Union 4 5import torch 6from torch import Tensor 7 8from .optimizer import ( 9 _capturable_doc, 10 _default_to_fused_or_foreach, 11 _device_dtype_check_for_fused, 12 _differentiable_doc, 13 _disable_dynamo_if_unsupported, 14 _foreach_doc, 15 _fused_doc, 16 _get_capturable_supported_devices, 17 _get_scalar_dtype, 18 _get_value, 19 _maximize_doc, 20 _stack_if_compiling, 21 _use_grad_for_differentiable, 22 _view_as_real, 23 DeviceDict, 24 Optimizer, 25 ParamsT, 26) 27 28 29__all__ = ["Adam", "adam"] 30 31 32class Adam(Optimizer): 33 def __init__( 34 self, 35 params: ParamsT, 36 lr: Union[float, Tensor] = 1e-3, 37 betas: Tuple[float, float] = (0.9, 0.999), 38 eps: float = 1e-8, 39 weight_decay: float = 0, 40 amsgrad: bool = False, 41 *, 42 foreach: Optional[bool] = None, 43 maximize: bool = False, 44 capturable: bool = False, 45 differentiable: bool = False, 46 fused: Optional[bool] = None, 47 ): 48 if isinstance(lr, Tensor): 49 if foreach and not capturable: 50 raise ValueError( 51 "lr as a Tensor is not supported for capturable=False and foreach=True" 52 ) 53 if lr.numel() != 1: 54 raise ValueError("Tensor lr must be 1-element") 55 if not 0.0 <= lr: 56 raise ValueError(f"Invalid learning rate: {lr}") 57 if not 0.0 <= eps: 58 raise ValueError(f"Invalid epsilon value: {eps}") 59 if not 0.0 <= betas[0] < 1.0: 60 raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}") 61 if not 0.0 <= betas[1] < 1.0: 62 raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}") 63 if not 0.0 <= weight_decay: 64 raise ValueError(f"Invalid weight_decay value: {weight_decay}") 65 66 defaults = dict( 67 lr=lr, 68 betas=betas, 69 eps=eps, 70 weight_decay=weight_decay, 71 amsgrad=amsgrad, 72 maximize=maximize, 73 foreach=foreach, 74 capturable=capturable, 75 differentiable=differentiable, 76 fused=fused, 77 ) 78 super().__init__(params, defaults) 79 80 if fused: 81 if differentiable: 82 raise RuntimeError("`fused` does not support `differentiable`") 83 self._step_supports_amp_scaling = True 84 # TODO(crcrpar): [low prec params & their higher prec copy] 85 # Support AMP with FP16/BF16 model params which would need 86 # higher prec copy of params to do update math in higher prec to 87 # alleviate the loss of information. 88 if foreach: 89 raise RuntimeError("`fused` and `foreach` cannot be `True` together.") 90 91 def __setstate__(self, state): 92 super().__setstate__(state) 93 for group in self.param_groups: 94 group.setdefault("amsgrad", False) 95 group.setdefault("maximize", False) 96 group.setdefault("foreach", None) 97 group.setdefault("capturable", False) 98 group.setdefault("differentiable", False) 99 fused = group.setdefault("fused", None) 100 for p in group["params"]: 101 p_state = self.state.get(p, []) 102 if len(p_state) != 0 and not torch.is_tensor(p_state["step"]): 103 step_val = float(p_state["step"]) 104 p_state["step"] = ( 105 torch.tensor( 106 step_val, 107 dtype=_get_scalar_dtype(is_fused=fused), 108 device=p.device, 109 ) 110 if group["capturable"] or group["fused"] 111 else torch.tensor(step_val, dtype=_get_scalar_dtype()) 112 ) 113 114 def _init_group( 115 self, 116 group, 117 params_with_grad, 118 grads, 119 exp_avgs, 120 exp_avg_sqs, 121 max_exp_avg_sqs, 122 state_steps, 123 ): 124 has_complex = False 125 for p in group["params"]: 126 if p.grad is not None: 127 has_complex |= torch.is_complex(p) 128 params_with_grad.append(p) 129 if p.grad.is_sparse: 130 raise RuntimeError( 131 "Adam does not support sparse gradients, please consider SparseAdam instead" 132 ) 133 grads.append(p.grad) 134 135 state = self.state[p] 136 # Lazy state initialization 137 if len(state) == 0: 138 if group["fused"]: 139 _device_dtype_check_for_fused(p) 140 # note(crcrpar): [special device hosting for step] 141 # Deliberately host `step` on CPU if both capturable and fused are off. 142 # This is because kernel launches are costly on CUDA and XLA. 143 state["step"] = ( 144 torch.zeros( 145 (), 146 dtype=_get_scalar_dtype(is_fused=group["fused"]), 147 device=p.device, 148 ) 149 if group["capturable"] or group["fused"] 150 else torch.tensor(0.0, dtype=_get_scalar_dtype()) 151 ) 152 # Exponential moving average of gradient values 153 state["exp_avg"] = torch.zeros_like( 154 p, memory_format=torch.preserve_format 155 ) 156 # Exponential moving average of squared gradient values 157 state["exp_avg_sq"] = torch.zeros_like( 158 p, memory_format=torch.preserve_format 159 ) 160 if group["amsgrad"]: 161 # Maintains max of all exp. moving avg. of sq. grad. values 162 state["max_exp_avg_sq"] = torch.zeros_like( 163 p, memory_format=torch.preserve_format 164 ) 165 166 exp_avgs.append(state["exp_avg"]) 167 exp_avg_sqs.append(state["exp_avg_sq"]) 168 169 if group["amsgrad"]: 170 max_exp_avg_sqs.append(state["max_exp_avg_sq"]) 171 if group["differentiable"] and state["step"].requires_grad: 172 raise RuntimeError( 173 "`requires_grad` is not supported for `step` in differentiable mode" 174 ) 175 176 # Foreach without capturable does not support a tensor lr 177 if ( 178 group["foreach"] 179 and torch.is_tensor(group["lr"]) 180 and not group["capturable"] 181 ): 182 raise RuntimeError( 183 "lr as a Tensor is not supported for capturable=False and foreach=True" 184 ) 185 186 state_steps.append(state["step"]) 187 return has_complex 188 189 @_use_grad_for_differentiable 190 def step(self, closure=None): 191 """Perform a single optimization step. 192 193 Args: 194 closure (Callable, optional): A closure that reevaluates the model 195 and returns the loss. 196 """ 197 self._cuda_graph_capture_health_check() 198 199 loss = None 200 if closure is not None: 201 with torch.enable_grad(): 202 loss = closure() 203 204 for group in self.param_groups: 205 params_with_grad: List[Tensor] = [] 206 grads: List[Tensor] = [] 207 exp_avgs: List[Tensor] = [] 208 exp_avg_sqs: List[Tensor] = [] 209 max_exp_avg_sqs: List[Tensor] = [] 210 state_steps: List[Tensor] = [] 211 beta1, beta2 = group["betas"] 212 213 has_complex = self._init_group( 214 group, 215 params_with_grad, 216 grads, 217 exp_avgs, 218 exp_avg_sqs, 219 max_exp_avg_sqs, 220 state_steps, 221 ) 222 223 adam( 224 params_with_grad, 225 grads, 226 exp_avgs, 227 exp_avg_sqs, 228 max_exp_avg_sqs, 229 state_steps, 230 amsgrad=group["amsgrad"], 231 has_complex=has_complex, 232 beta1=beta1, 233 beta2=beta2, 234 lr=group["lr"], 235 weight_decay=group["weight_decay"], 236 eps=group["eps"], 237 maximize=group["maximize"], 238 foreach=group["foreach"], 239 capturable=group["capturable"], 240 differentiable=group["differentiable"], 241 fused=group["fused"], 242 grad_scale=getattr(self, "grad_scale", None), 243 found_inf=getattr(self, "found_inf", None), 244 ) 245 246 return loss 247 248 249Adam.__doc__ = ( 250 r"""Implements Adam algorithm. 251 252 .. math:: 253 \begin{aligned} 254 &\rule{110mm}{0.4pt} \\ 255 &\textbf{input} : \gamma \text{ (lr)}, \beta_1, \beta_2 256 \text{ (betas)},\theta_0 \text{ (params)},f(\theta) \text{ (objective)} \\ 257 &\hspace{13mm} \lambda \text{ (weight decay)}, \: \textit{amsgrad}, 258 \:\textit{maximize} \\ 259 &\textbf{initialize} : m_0 \leftarrow 0 \text{ ( first moment)}, 260 v_0\leftarrow 0 \text{ (second moment)},\: \widehat{v_0}^{max}\leftarrow 0\\[-1.ex] 261 &\rule{110mm}{0.4pt} \\ 262 &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\ 263 264 &\hspace{5mm}\textbf{if} \: \textit{maximize}: \\ 265 &\hspace{10mm}g_t \leftarrow -\nabla_{\theta} f_t (\theta_{t-1}) \\ 266 &\hspace{5mm}\textbf{else} \\ 267 &\hspace{10mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\ 268 &\hspace{5mm}\textbf{if} \: \lambda \neq 0 \\ 269 &\hspace{10mm} g_t \leftarrow g_t + \lambda \theta_{t-1} \\ 270 &\hspace{5mm}m_t \leftarrow \beta_1 m_{t-1} + (1 - \beta_1) g_t \\ 271 &\hspace{5mm}v_t \leftarrow \beta_2 v_{t-1} + (1-\beta_2) g^2_t \\ 272 &\hspace{5mm}\widehat{m_t} \leftarrow m_t/\big(1-\beta_1^t \big) \\ 273 &\hspace{5mm}\widehat{v_t} \leftarrow v_t/\big(1-\beta_2^t \big) \\ 274 &\hspace{5mm}\textbf{if} \: amsgrad \\ 275 &\hspace{10mm}\widehat{v_t}^{max} \leftarrow \mathrm{max}(\widehat{v_t}^{max}, 276 \widehat{v_t}) \\ 277 &\hspace{10mm}\theta_t \leftarrow \theta_{t-1} - \gamma \widehat{m_t}/ 278 \big(\sqrt{\widehat{v_t}^{max}} + \epsilon \big) \\ 279 &\hspace{5mm}\textbf{else} \\ 280 &\hspace{10mm}\theta_t \leftarrow \theta_{t-1} - \gamma \widehat{m_t}/ 281 \big(\sqrt{\widehat{v_t}} + \epsilon \big) \\ 282 &\rule{110mm}{0.4pt} \\[-1.ex] 283 &\bf{return} \: \theta_t \\[-1.ex] 284 &\rule{110mm}{0.4pt} \\[-1.ex] 285 \end{aligned} 286 287 For further details regarding the algorithm we refer to `Adam: A Method for Stochastic Optimization`_. 288 """ 289 + rf""" 290 Args: 291 params (iterable): iterable of parameters to optimize or dicts defining 292 parameter groups 293 lr (float, Tensor, optional): learning rate (default: 1e-3). A tensor LR 294 is not yet supported for all our implementations. Please use a float 295 LR if you are not also specifying fused=True or capturable=True. 296 betas (Tuple[float, float], optional): coefficients used for computing 297 running averages of gradient and its square (default: (0.9, 0.999)) 298 eps (float, optional): term added to the denominator to improve 299 numerical stability (default: 1e-8) 300 weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 301 amsgrad (bool, optional): whether to use the AMSGrad variant of this 302 algorithm from the paper `On the Convergence of Adam and Beyond`_ 303 (default: False) 304 {_foreach_doc} 305 {_maximize_doc} 306 {_capturable_doc} 307 {_differentiable_doc} 308 {_fused_doc} 309 .. Note:: 310 A prototype implementation of Adam and AdamW for MPS supports `torch.float32` and `torch.float16`. 311 .. _Adam\: A Method for Stochastic Optimization: 312 https://arxiv.org/abs/1412.6980 313 .. _On the Convergence of Adam and Beyond: 314 https://openreview.net/forum?id=ryQu7f-RZ 315 316 """ 317) 318 319 320def _single_tensor_adam( 321 params: List[Tensor], 322 grads: List[Tensor], 323 exp_avgs: List[Tensor], 324 exp_avg_sqs: List[Tensor], 325 max_exp_avg_sqs: List[Tensor], 326 state_steps: List[Tensor], 327 grad_scale: Optional[Tensor], 328 found_inf: Optional[Tensor], 329 *, 330 amsgrad: bool, 331 has_complex: bool, 332 beta1: float, 333 beta2: float, 334 lr: Union[float, Tensor], 335 weight_decay: float, 336 eps: float, 337 maximize: bool, 338 capturable: bool, 339 differentiable: bool, 340): 341 assert grad_scale is None and found_inf is None 342 343 if torch.jit.is_scripting(): 344 # this assert is due to JIT being dumb and not realizing that the ops below 345 # have overloads to handle both float and Tensor lrs, so we just assert it's 346 # a float since most people using JIT are using floats 347 assert isinstance(lr, float) 348 349 for i, param in enumerate(params): 350 grad = grads[i] if not maximize else -grads[i] 351 exp_avg = exp_avgs[i] 352 exp_avg_sq = exp_avg_sqs[i] 353 step_t = state_steps[i] 354 355 # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] 356 if not torch._utils.is_compiling() and capturable: 357 capturable_supported_devices = _get_capturable_supported_devices() 358 assert ( 359 param.device.type == step_t.device.type 360 and param.device.type in capturable_supported_devices 361 ), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}." 362 363 # update step 364 step_t += 1 365 366 if weight_decay != 0: 367 grad = grad.add(param, alpha=weight_decay) 368 369 if torch.is_complex(param): 370 grad = torch.view_as_real(grad) 371 exp_avg = torch.view_as_real(exp_avg) 372 exp_avg_sq = torch.view_as_real(exp_avg_sq) 373 if amsgrad: 374 max_exp_avg_sqs[i] = torch.view_as_real(max_exp_avg_sqs[i]) 375 param = torch.view_as_real(param) 376 377 # Decay the first and second moment running average coefficient 378 exp_avg.lerp_(grad, 1 - beta1) 379 exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2) 380 381 if capturable or differentiable: 382 step = step_t 383 384 bias_correction1 = 1 - beta1**step 385 bias_correction2 = 1 - beta2**step 386 387 step_size = lr / bias_correction1 388 step_size_neg = step_size.neg() 389 390 bias_correction2_sqrt = bias_correction2.sqrt() 391 392 if amsgrad: 393 # Maintains the maximum of all 2nd moment running avg. till now 394 if differentiable: 395 max_exp_avg_sq = max_exp_avg_sqs[i].clone() 396 else: 397 max_exp_avg_sq = max_exp_avg_sqs[i] 398 399 max_exp_avg_sqs[i].copy_(torch.maximum(max_exp_avg_sq, exp_avg_sq)) 400 401 # Uses the max. for normalizing running avg. of gradient 402 # Folds in (admittedly ugly) 1-elem step_size math here to avoid extra param-set-sized read+write 403 # (can't fold it into addcdiv_ below because addcdiv_ requires value is a Number, not a Tensor) 404 denom = ( 405 max_exp_avg_sqs[i].sqrt() / (bias_correction2_sqrt * step_size_neg) 406 ).add_(eps / step_size_neg) 407 else: 408 denom = ( 409 exp_avg_sq.sqrt() / (bias_correction2_sqrt * step_size_neg) 410 ).add_(eps / step_size_neg) 411 412 param.addcdiv_(exp_avg, denom) 413 else: 414 step = _get_value(step_t) 415 416 bias_correction1 = 1 - beta1**step 417 bias_correction2 = 1 - beta2**step 418 419 step_size = lr / bias_correction1 420 421 bias_correction2_sqrt = bias_correction2**0.5 422 423 if amsgrad: 424 # Maintains the maximum of all 2nd moment running avg. till now 425 torch.maximum(max_exp_avg_sqs[i], exp_avg_sq, out=max_exp_avg_sqs[i]) 426 427 # Use the max. for normalizing running avg. of gradient 428 denom = (max_exp_avg_sqs[i].sqrt() / bias_correction2_sqrt).add_(eps) 429 else: 430 denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps) 431 432 param.addcdiv_(exp_avg, denom, value=-step_size) 433 434 # Lastly, switch back to complex view 435 if amsgrad and torch.is_complex(params[i]): 436 max_exp_avg_sqs[i] = torch.view_as_complex(max_exp_avg_sqs[i]) 437 438 439def _multi_tensor_adam( 440 params: List[Tensor], 441 grads: List[Tensor], 442 exp_avgs: List[Tensor], 443 exp_avg_sqs: List[Tensor], 444 max_exp_avg_sqs: List[Tensor], 445 state_steps: List[Tensor], 446 grad_scale: Optional[Tensor], 447 found_inf: Optional[Tensor], 448 *, 449 amsgrad: bool, 450 has_complex: bool, 451 beta1: float, 452 beta2: float, 453 lr: Union[float, Tensor], 454 weight_decay: float, 455 eps: float, 456 maximize: bool, 457 capturable: bool, 458 differentiable: bool, 459): 460 if len(params) == 0: 461 return 462 463 if isinstance(lr, Tensor) and not capturable: 464 raise RuntimeError( 465 "lr as a Tensor is not supported for capturable=False and foreach=True" 466 ) 467 468 # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] 469 if not torch._utils.is_compiling() and capturable: 470 capturable_supported_devices = _get_capturable_supported_devices( 471 supports_xla=False 472 ) 473 assert all( 474 p.device.type == step.device.type 475 and p.device.type in capturable_supported_devices 476 for p, step in zip(params, state_steps) 477 ), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}." 478 479 assert grad_scale is None and found_inf is None 480 481 assert not differentiable, "_foreach ops don't support autograd" 482 483 grouped_tensors = Optimizer._group_tensors_by_device_and_dtype( 484 [params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps] # type: ignore[list-item] 485 ) 486 for ( 487 device_params_, 488 device_grads_, 489 device_exp_avgs_, 490 device_exp_avg_sqs_, 491 device_max_exp_avg_sqs_, 492 device_state_steps_, 493 ), _ in grouped_tensors.values(): 494 device_params = cast(List[Tensor], device_params_) 495 device_grads = cast(List[Tensor], device_grads_) 496 device_exp_avgs = cast(List[Tensor], device_exp_avgs_) 497 device_exp_avg_sqs = cast(List[Tensor], device_exp_avg_sqs_) 498 device_state_steps = cast(List[Tensor], device_state_steps_) 499 500 # Handle complex parameters 501 if has_complex: 502 if amsgrad: 503 device_max_exp_avg_sqs = cast(List[Tensor], device_max_exp_avg_sqs_) 504 _view_as_real( 505 device_params, 506 device_grads, 507 device_exp_avgs, 508 device_exp_avg_sqs, 509 device_max_exp_avg_sqs, 510 ) 511 else: 512 _view_as_real( 513 device_params, device_grads, device_exp_avgs, device_exp_avg_sqs 514 ) 515 516 if maximize: 517 device_grads = torch._foreach_neg(device_grads) # type: ignore[assignment] 518 519 # Update steps 520 # If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over 521 # and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just 522 # wrapped it once now. The alpha is required to assure we go to the right overload. 523 if not torch._utils.is_compiling() and device_state_steps[0].is_cpu: 524 torch._foreach_add_( 525 device_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0 526 ) 527 else: 528 torch._foreach_add_(device_state_steps, 1) 529 530 if weight_decay != 0: 531 # Re-use the intermediate memory (device_grads) already allocated for maximize 532 if maximize: 533 torch._foreach_add_(device_grads, device_params, alpha=weight_decay) 534 else: 535 device_grads = torch._foreach_add( # type: ignore[assignment] 536 device_grads, device_params, alpha=weight_decay 537 ) 538 539 # Decay the first and second moment running average coefficient 540 torch._foreach_lerp_(device_exp_avgs, device_grads, 1 - beta1) 541 542 torch._foreach_mul_(device_exp_avg_sqs, beta2) 543 torch._foreach_addcmul_( 544 device_exp_avg_sqs, device_grads, device_grads, 1 - beta2 545 ) 546 547 # Delete the local intermediate since it won't be used anymore to save on peak memory 548 del device_grads 549 550 bias_correction1: Union[Tuple[Tensor, ...], List[Tensor]] 551 bias_correction2: Union[Tuple[Tensor, ...], List[Tensor]] 552 bias_correction2_sqrt: Union[Tuple[Tensor, ...], List[Tensor]] 553 554 if capturable: 555 bias_correction1 = torch._foreach_pow(beta1, device_state_steps) 556 bias_correction2 = torch._foreach_pow(beta2, device_state_steps) 557 # foreach_sub doesn't allow a scalar as the first arg 558 torch._foreach_sub_(bias_correction1, 1) 559 torch._foreach_sub_(bias_correction2, 1) 560 # we do not negate bias_correction1 as it'll need to be negated later anyway 561 torch._foreach_neg_(bias_correction2) 562 563 # foreach_div doesn't allow a scalar as the first arg 564 torch._foreach_div_(bias_correction1, lr) 565 torch._foreach_reciprocal_(bias_correction1) 566 567 torch._foreach_sqrt_(bias_correction2) 568 569 # Re-assign for clarity as we maintain minimal intermediates: we'll have 570 # step_size = - lr / (1 - beta1 ^ t) where t = num_steps 571 # bias_correction2_sqrt = sqrt(1 - beta2 ^ t) 572 step_size = bias_correction1 573 bias_correction2_sqrt = bias_correction2 574 575 if amsgrad: 576 device_max_exp_avg_sqs = cast(List[Tensor], device_max_exp_avg_sqs_) 577 # Maintains the maximum of all 2nd moment running avg. till now 578 torch._foreach_maximum_(device_max_exp_avg_sqs, device_exp_avg_sqs) # type: ignore[assignment] 579 580 # Set intermediate to the max. for normalizing running avg. of gradient when amsgrad 581 exp_avg_sq_sqrt = torch._foreach_sqrt(device_max_exp_avg_sqs) 582 else: 583 exp_avg_sq_sqrt = torch._foreach_sqrt(device_exp_avg_sqs) 584 585 torch._foreach_div_(exp_avg_sq_sqrt, bias_correction2_sqrt) 586 torch._foreach_add_(exp_avg_sq_sqrt, eps) 587 torch._foreach_div_(exp_avg_sq_sqrt, step_size) 588 589 # at this point, exp_avg_sq_sqrt = - (1 - beta^t) * [sqrt(exp_avg_sq / (1 - beta2^t)) + eps] / lr 590 torch._foreach_addcdiv_(device_params, device_exp_avgs, exp_avg_sq_sqrt) 591 else: 592 bias_correction1 = [ 593 1 - beta1 ** _get_value(step) for step in device_state_steps 594 ] 595 bias_correction2 = [ 596 1 - beta2 ** _get_value(step) for step in device_state_steps 597 ] 598 599 step_size = _stack_if_compiling([(lr / bc) * -1 for bc in bias_correction1]) 600 601 bias_correction2_sqrt = [bc**0.5 for bc in bias_correction2] # type: ignore[arg-type] 602 603 if amsgrad: 604 device_max_exp_avg_sqs = cast(List[Tensor], device_max_exp_avg_sqs_) 605 # Maintains the maximum of all 2nd moment running avg. till now 606 torch._foreach_maximum_(device_max_exp_avg_sqs, device_exp_avg_sqs) 607 608 # Use the max. for normalizing running avg. of gradient 609 exp_avg_sq_sqrt = torch._foreach_sqrt(device_max_exp_avg_sqs) 610 else: 611 exp_avg_sq_sqrt = torch._foreach_sqrt(device_exp_avg_sqs) 612 613 torch._foreach_div_(exp_avg_sq_sqrt, bias_correction2_sqrt) 614 torch._foreach_add_(exp_avg_sq_sqrt, eps) 615 torch._foreach_addcdiv_( 616 device_params, device_exp_avgs, exp_avg_sq_sqrt, step_size # type: ignore[arg-type] 617 ) 618 619 620def _fused_adam( 621 params: List[Tensor], 622 grads: List[Tensor], 623 exp_avgs: List[Tensor], 624 exp_avg_sqs: List[Tensor], 625 max_exp_avg_sqs: List[Tensor], 626 state_steps: List[Tensor], 627 grad_scale: Optional[Tensor], 628 found_inf: Optional[Tensor], 629 *, 630 amsgrad: bool, 631 has_complex: bool, # Needed for consistency. 632 beta1: float, 633 beta2: float, 634 lr: Union[float, Tensor], 635 weight_decay: float, 636 eps: float, 637 maximize: bool, 638 capturable: bool, # Needed for consistency. 639 differentiable: bool, 640) -> None: 641 if not params: 642 return 643 if differentiable: 644 raise RuntimeError("Adam with fused=True does not support differentiable=True") 645 646 grad_scale_dict: DeviceDict = ( 647 {grad_scale.device: grad_scale} if grad_scale is not None else {} 648 ) 649 found_inf_dict: DeviceDict = ( 650 {found_inf.device: found_inf} if found_inf is not None else {} 651 ) 652 653 # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer 654 # treating it as a scalar. 655 lr_dict: Optional[DeviceDict] = ( 656 {lr.device: lr} if isinstance(lr, Tensor) and str(lr.device) != "cpu" else None 657 ) 658 grouped_tensors = Optimizer._group_tensors_by_device_and_dtype( 659 [params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps] # type: ignore[list-item] 660 ) 661 for (device, _), ( 662 ( 663 device_params_, 664 device_grads_, 665 device_exp_avgs_, 666 device_exp_avg_sqs_, 667 device_max_exp_avg_sqs, 668 device_state_steps_, 669 ), 670 _, 671 ) in grouped_tensors.items(): 672 device_params = cast(List[Tensor], device_params_) 673 device_grads = cast(List[Tensor], device_grads_) 674 device_exp_avgs = cast(List[Tensor], device_exp_avgs_) 675 device_exp_avg_sqs = cast(List[Tensor], device_exp_avg_sqs_) 676 device_state_steps = cast(List[Tensor], device_state_steps_) 677 678 if device.type == "mps": # type: ignore[union-attr] 679 assert found_inf is None and grad_scale is None 680 681 device_grad_scale, device_found_inf = None, None 682 if grad_scale is not None: 683 device_grad_scale = grad_scale_dict.setdefault( 684 device, grad_scale.to(device, non_blocking=True) 685 ) 686 if found_inf is not None: 687 device_found_inf = found_inf_dict.setdefault( 688 device, found_inf.to(device, non_blocking=True) 689 ) 690 if lr_dict is not None and device not in lr_dict: 691 lr_dict[device] = lr.to(device=device, non_blocking=True) # type: ignore[union-attr] 692 lr = lr_dict[device] 693 torch._foreach_add_(device_state_steps, 1) 694 torch._fused_adam_( 695 device_params, 696 device_grads, 697 device_exp_avgs, 698 device_exp_avg_sqs, 699 device_max_exp_avg_sqs, # type: ignore[arg-type] 700 device_state_steps, 701 amsgrad=amsgrad, 702 lr=lr, # type: ignore[arg-type] 703 beta1=beta1, 704 beta2=beta2, 705 weight_decay=weight_decay, 706 eps=eps, 707 maximize=maximize, 708 grad_scale=device_grad_scale, 709 found_inf=device_found_inf, 710 ) 711 if device_found_inf is not None: 712 torch._foreach_sub_( 713 device_state_steps, [device_found_inf] * len(device_state_steps) 714 ) 715 716 717@_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_adam) 718def adam( 719 params: List[Tensor], 720 grads: List[Tensor], 721 exp_avgs: List[Tensor], 722 exp_avg_sqs: List[Tensor], 723 max_exp_avg_sqs: List[Tensor], 724 state_steps: List[Tensor], 725 # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627 726 # setting this as kwarg for now as functional API is compiled by torch/distributed/optim 727 foreach: Optional[bool] = None, 728 capturable: bool = False, 729 differentiable: bool = False, 730 fused: Optional[bool] = None, 731 grad_scale: Optional[Tensor] = None, 732 found_inf: Optional[Tensor] = None, 733 has_complex: bool = False, 734 *, 735 amsgrad: bool, 736 beta1: float, 737 beta2: float, 738 lr: Union[float, Tensor], 739 weight_decay: float, 740 eps: float, 741 maximize: bool, 742): 743 r"""Functional API that performs Adam algorithm computation. 744 745 See :class:`~torch.optim.Adam` for details. 746 """ 747 # Respect when the user inputs False/True for foreach or fused. We only want to change 748 # the default when neither have been user-specified. Note that we default to foreach 749 # and pass False to use_fused. This is not a mistake--we want to give the fused impl 750 # bake-in time before making it the default, even if it is typically faster. 751 if fused is None and foreach is None: 752 _, foreach = _default_to_fused_or_foreach( 753 params, differentiable, use_fused=False 754 ) 755 # Do not flip on foreach for the unsupported case where lr is a Tensor and capturable=False. 756 if foreach and isinstance(lr, Tensor) and not capturable: 757 foreach = False 758 if fused is None: 759 fused = False 760 if foreach is None: 761 foreach = False 762 763 # this check is slow during compilation, so we skip it 764 # if it's strictly needed we can add this check back in dynamo 765 if not torch._utils.is_compiling() and not all( 766 isinstance(t, torch.Tensor) for t in state_steps 767 ): 768 raise RuntimeError( 769 "API has changed, `state_steps` argument must contain a list of singleton tensors" 770 ) 771 772 if foreach and torch.jit.is_scripting(): 773 raise RuntimeError("torch.jit.script not supported with foreach optimizers") 774 if fused and torch.jit.is_scripting(): 775 raise RuntimeError("torch.jit.script not supported with fused optimizers") 776 777 if fused and not torch.jit.is_scripting(): 778 func = _fused_adam 779 elif foreach and not torch.jit.is_scripting(): 780 func = _multi_tensor_adam 781 else: 782 func = _single_tensor_adam 783 784 func( 785 params, 786 grads, 787 exp_avgs, 788 exp_avg_sqs, 789 max_exp_avg_sqs, 790 state_steps, 791 amsgrad=amsgrad, 792 has_complex=has_complex, 793 beta1=beta1, 794 beta2=beta2, 795 lr=lr, 796 weight_decay=weight_decay, 797 eps=eps, 798 maximize=maximize, 799 capturable=capturable, 800 differentiable=differentiable, 801 grad_scale=grad_scale, 802 found_inf=found_inf, 803 ) 804