1# mypy: allow-untyped-decorators 2# mypy: allow-untyped-defs 3r"""Implementation for the RAdam algorithm.""" 4from typing import cast, List, Optional, Tuple, Union 5 6import torch 7from torch import Tensor 8 9from .optimizer import ( 10 _capturable_doc, 11 _default_to_fused_or_foreach, 12 _differentiable_doc, 13 _disable_dynamo_if_unsupported, 14 _foreach_doc, 15 _get_capturable_supported_devices, 16 _get_scalar_dtype, 17 _get_value, 18 _maximize_doc, 19 _use_grad_for_differentiable, 20 _view_as_real, 21 Optimizer, 22 ParamsT, 23) 24 25 26__all__ = ["RAdam", "radam"] 27 28 29class RAdam(Optimizer): # noqa: D101 30 def __init__( 31 self, 32 params: ParamsT, 33 lr: Union[float, Tensor] = 1e-3, 34 betas: Tuple[float, float] = (0.9, 0.999), 35 eps: float = 1e-8, 36 weight_decay: float = 0, 37 decoupled_weight_decay: bool = False, 38 *, 39 foreach: Optional[bool] = None, 40 maximize: bool = False, 41 capturable: bool = False, 42 differentiable: bool = False, 43 ): # noqa: D107 44 if isinstance(lr, Tensor) and lr.numel() != 1: 45 raise ValueError("Tensor lr must be 1-element") 46 if not 0.0 <= lr: 47 raise ValueError(f"Invalid learning rate: {lr}") 48 if not 0.0 <= eps: 49 raise ValueError(f"Invalid epsilon value: {eps}") 50 if not 0.0 <= betas[0] < 1.0: 51 raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}") 52 if not 0.0 <= betas[1] < 1.0: 53 raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}") 54 if not 0.0 <= weight_decay: 55 raise ValueError(f"Invalid weight_decay value: {weight_decay}") 56 57 defaults = dict( 58 lr=lr, 59 betas=betas, 60 eps=eps, 61 weight_decay=weight_decay, 62 maximize=maximize, 63 foreach=foreach, 64 capturable=capturable, 65 decoupled_weight_decay=decoupled_weight_decay, 66 differentiable=differentiable, 67 ) 68 super().__init__(params, defaults) 69 70 def __setstate__(self, state): # noqa: D105 71 super().__setstate__(state) 72 for group in self.param_groups: 73 group.setdefault("foreach", None) 74 group.setdefault("maximize", False) 75 group.setdefault("differentiable", False) 76 group.setdefault("decoupled_weight_decay", False) 77 group.setdefault("capturable", False) 78 for p in group["params"]: 79 p_state = self.state.get(p, []) 80 if len(p_state) != 0 and not torch.is_tensor(p_state["step"]): 81 step_val = float(p_state["step"]) 82 p_state["step"] = ( 83 torch.tensor( 84 step_val, dtype=_get_scalar_dtype(), device=p.device 85 ) 86 if group["capturable"] 87 else torch.tensor(step_val, dtype=_get_scalar_dtype()) 88 ) 89 90 def _init_group( 91 self, group, params_with_grad, grads, exp_avgs, exp_avg_sqs, state_steps 92 ): 93 has_complex = False 94 for p in group["params"]: 95 if p.grad is not None: 96 has_complex |= torch.is_complex(p) 97 params_with_grad.append(p) 98 if p.grad.is_sparse: 99 raise RuntimeError("RAdam does not support sparse gradients") 100 grads.append(p.grad) 101 102 state = self.state[p] 103 # Lazy state initialization 104 if len(state) == 0: 105 state["step"] = ( 106 torch.zeros((), dtype=_get_scalar_dtype(), device=p.device) 107 if group["capturable"] 108 else torch.tensor(0.0, dtype=_get_scalar_dtype()) 109 ) 110 # Exponential moving average of gradient values 111 state["exp_avg"] = torch.zeros_like( 112 p, memory_format=torch.preserve_format 113 ) 114 # Exponential moving average of squared gradient values 115 state["exp_avg_sq"] = torch.zeros_like( 116 p, memory_format=torch.preserve_format 117 ) 118 119 exp_avgs.append(state["exp_avg"]) 120 exp_avg_sqs.append(state["exp_avg_sq"]) 121 state_steps.append(state["step"]) 122 123 return has_complex 124 125 @_use_grad_for_differentiable 126 def step(self, closure=None): 127 """Perform a single optimization step. 128 129 Args: 130 closure (Callable, optional): A closure that reevaluates the model 131 and returns the loss. 132 """ 133 self._cuda_graph_capture_health_check() 134 135 loss = None 136 if closure is not None: 137 with torch.enable_grad(): 138 loss = closure() 139 140 for group in self.param_groups: 141 params_with_grad: List[Tensor] = [] 142 grads: List[Tensor] = [] 143 exp_avgs: List[Tensor] = [] 144 exp_avg_sqs: List[Tensor] = [] 145 state_steps: List[Tensor] = [] 146 beta1, beta2 = cast(Tuple[float, float], group["betas"]) 147 148 has_complex = self._init_group( 149 group, params_with_grad, grads, exp_avgs, exp_avg_sqs, state_steps 150 ) 151 152 radam( 153 params_with_grad, 154 grads, 155 exp_avgs, 156 exp_avg_sqs, 157 state_steps, 158 beta1=beta1, 159 beta2=beta2, 160 lr=group["lr"], 161 weight_decay=group["weight_decay"], 162 eps=group["eps"], 163 maximize=group["maximize"], 164 foreach=group["foreach"], 165 capturable=group["capturable"], 166 differentiable=group["differentiable"], 167 decoupled_weight_decay=group["decoupled_weight_decay"], 168 has_complex=has_complex, 169 ) 170 171 return loss 172 173 174RAdam.__doc__ = ( 175 r"""Implements RAdam algorithm. 176 177 .. math:: 178 \begin{aligned} 179 &\rule{110mm}{0.4pt} \\ 180 &\textbf{input} : \gamma \text{ (lr)}, \: \beta_1, \beta_2 181 \text{ (betas)}, \: \theta_0 \text{ (params)}, \:f(\theta) \text{ (objective)}, \: 182 \lambda \text{ (weightdecay)}, \:\textit{maximize} \\ 183 &\hspace{13mm} \epsilon \text{ (epsilon)}, \textit{decoupled\_weight\_decay} \\ 184 &\textbf{initialize} : m_0 \leftarrow 0 \text{ ( first moment)}, 185 v_0 \leftarrow 0 \text{ ( second moment)}, \\ 186 &\hspace{18mm} \rho_{\infty} \leftarrow 2/(1-\beta_2) -1 \\[-1.ex] 187 &\rule{110mm}{0.4pt} \\ 188 &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\ 189 &\hspace{6mm}\textbf{if} \: \textit{maximize}: \\ 190 &\hspace{12mm}g_t \leftarrow -\nabla_{\theta} f_t (\theta_{t-1}) \\ 191 &\hspace{6mm}\textbf{else} \\ 192 &\hspace{12mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\ 193 &\hspace{6mm} \theta_t \leftarrow \theta_{t-1} \\ 194 &\hspace{6mm} \textbf{if} \: \lambda \neq 0 \\ 195 &\hspace{12mm}\textbf{if} \: \textit{decoupled\_weight\_decay} \\ 196 &\hspace{18mm} \theta_t \leftarrow \theta_{t} - \gamma \lambda \theta_{t} \\ 197 &\hspace{12mm}\textbf{else} \\ 198 &\hspace{18mm} g_t \leftarrow g_t + \lambda \theta_{t} \\ 199 &\hspace{6mm}m_t \leftarrow \beta_1 m_{t-1} + (1 - \beta_1) g_t \\ 200 &\hspace{6mm}v_t \leftarrow \beta_2 v_{t-1} + (1-\beta_2) g^2_t \\ 201 &\hspace{6mm}\widehat{m_t} \leftarrow m_t/\big(1-\beta_1^t \big) \\ 202 &\hspace{6mm}\rho_t \leftarrow \rho_{\infty} - 203 2 t \beta^t_2 /\big(1-\beta_2^t \big) \\[0.1.ex] 204 &\hspace{6mm}\textbf{if} \: \rho_t > 5 \\ 205 &\hspace{12mm} l_t \leftarrow \frac{\sqrt{ (1-\beta^t_2) }}{ \sqrt{v_t} +\epsilon } \\ 206 &\hspace{12mm} r_t \leftarrow 207 \sqrt{\frac{(\rho_t-4)(\rho_t-2)\rho_{\infty}}{(\rho_{\infty}-4)(\rho_{\infty}-2) \rho_t}} \\ 208 &\hspace{12mm}\theta_t \leftarrow \theta_t - \gamma \widehat{m_t} r_t l_t \\ 209 &\hspace{6mm}\textbf{else} \\ 210 &\hspace{12mm}\theta_t \leftarrow \theta_t - \gamma \widehat{m_t} \\ 211 &\rule{110mm}{0.4pt} \\[-1.ex] 212 &\bf{return} \: \theta_t \\[-1.ex] 213 &\rule{110mm}{0.4pt} \\[-1.ex] 214 \end{aligned} 215 216 For further details regarding the algorithm we refer to `On the variance of the adaptive learning rate and beyond`_. 217 218 This implementation provides an option to use either the original weight_decay implementation as in Adam 219 (where the weight_decay is applied to the gradient) or the one from AdamW (where weight_decay is applied 220 to the weight) through the decoupled_weight_decay option. When decoupled_weight_decay is set to False 221 (default), it uses the original Adam style weight decay, otherwise, it uses the AdamW style which 222 corresponds more closely to the `author's implementation`_ in the RAdam paper. Further information 223 about decoupled weight decay can be found in `Decoupled Weight Decay Regularization`_. 224 225 """ 226 + rf""" 227 Args: 228 params (iterable): iterable of parameters to optimize or dicts defining 229 parameter groups 230 lr (float, Tensor, optional): learning rate (default: 1e-3) 231 betas (Tuple[float, float], optional): coefficients used for computing 232 running averages of gradient and its square (default: (0.9, 0.999)) 233 eps (float, optional): term added to the denominator to improve 234 numerical stability (default: 1e-8) 235 weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 236 decoupled_weight_decay (bool, optional): whether to use decoupled weight 237 decay as in AdamW to obtain RAdamW (default: False) 238 {_foreach_doc} 239 {_maximize_doc} 240 {_differentiable_doc} 241 {_capturable_doc} 242 243 .. _On the variance of the adaptive learning rate and beyond: 244 https://arxiv.org/abs/1908.03265 245 .. _author's implementation: 246 https://github.com/LiyuanLucasLiu/RAdam 247 .. _Decoupled Weight Decay Regularization: 248 https://arxiv.org/abs/1711.05101 249 250 """ 251) 252 253 254def _single_tensor_radam( 255 params: List[Tensor], 256 grads: List[Tensor], 257 exp_avgs: List[Tensor], 258 exp_avg_sqs: List[Tensor], 259 state_steps: List[Tensor], 260 *, 261 beta1: float, 262 beta2: float, 263 lr: float, 264 weight_decay: float, 265 eps: float, 266 decoupled_weight_decay: bool, 267 differentiable: bool, 268 maximize: bool, 269 capturable: bool, 270 has_complex: bool, 271): 272 for i, param in enumerate(params): 273 grad = grads[i] if not maximize else -grads[i] 274 exp_avg = exp_avgs[i] 275 exp_avg_sq = exp_avg_sqs[i] 276 step_t = state_steps[i] 277 278 # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] 279 if not torch._utils.is_compiling() and capturable: 280 capturable_supported_devices = _get_capturable_supported_devices() 281 assert ( 282 param.device.type == step_t.device.type 283 and param.device.type in capturable_supported_devices 284 ), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}." 285 286 if torch.is_complex(param): 287 param = torch.view_as_real(param) 288 grad = torch.view_as_real(grad) 289 exp_avg = torch.view_as_real(exp_avg) 290 exp_avg_sq = torch.view_as_real(exp_avg_sq) 291 292 # update step 293 step_t += 1 294 step = step_t if capturable else _get_value(step_t) 295 296 if weight_decay != 0: 297 if decoupled_weight_decay: 298 param.mul_(1 - lr * weight_decay) 299 else: 300 grad = grad.add(param, alpha=weight_decay) 301 302 # Decay the first and second moment running average coefficient 303 exp_avg.lerp_(grad, 1 - beta1) 304 exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) 305 306 bias_correction1 = 1 - beta1**step 307 bias_correction2 = 1 - beta2**step 308 309 # correcting bias for the first moving moment 310 bias_corrected_exp_avg = exp_avg / bias_correction1 311 312 # maximum length of the approximated SMA 313 rho_inf = 2 / (1 - beta2) - 1 314 # compute the length of the approximated SMA 315 rho_t = rho_inf - 2 * step * (beta2**step) / bias_correction2 316 317 def _compute_rect(): 318 return ( 319 (rho_t - 4) 320 * (rho_t - 2) 321 * rho_inf 322 / ((rho_inf - 4) * (rho_inf - 2) * rho_t) 323 ) ** 0.5 324 325 def _compute_adaptive_lr(): 326 exp_avg_sq_sqrt = exp_avg_sq.sqrt() 327 if differentiable: 328 exp_avg_sq_sqrt = exp_avg_sq_sqrt.add(eps) 329 else: 330 exp_avg_sq_sqrt = exp_avg_sq_sqrt.add_(eps) 331 332 return (bias_correction2**0.5) / exp_avg_sq_sqrt 333 334 # Compute the variance rectification term and update parameters accordingly 335 if capturable: 336 update = torch.where( 337 rho_t > 5.0, _compute_rect() * _compute_adaptive_lr(), 1.0 338 ) 339 param.add_(bias_corrected_exp_avg * lr * update, alpha=-1.0) 340 else: 341 if rho_t > 5.0: 342 param.add_( 343 bias_corrected_exp_avg 344 * lr 345 * _compute_adaptive_lr() 346 * _compute_rect(), 347 alpha=-1.0, 348 ) 349 else: 350 param.add_(bias_corrected_exp_avg * lr, alpha=-1.0) 351 352 353def _multi_tensor_radam( 354 params: List[Tensor], 355 grads: List[Tensor], 356 exp_avgs: List[Tensor], 357 exp_avg_sqs: List[Tensor], 358 state_steps: List[Tensor], 359 *, 360 beta1: float, 361 beta2: float, 362 lr: float, 363 weight_decay: float, 364 eps: float, 365 decoupled_weight_decay: bool, 366 differentiable: bool, 367 maximize: bool, 368 capturable: bool, 369 has_complex: bool, 370): 371 if len(params) == 0: 372 return 373 374 assert not differentiable, "_foreach ops don't support autograd" 375 376 # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] 377 if not torch._utils.is_compiling() and capturable: 378 capturable_supported_devices = _get_capturable_supported_devices( 379 supports_xla=False 380 ) 381 assert all( 382 p.device.type == step.device.type 383 and p.device.type in capturable_supported_devices 384 for p, step in zip(params, state_steps) 385 ), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}." 386 387 grouped_tensors = Optimizer._group_tensors_by_device_and_dtype( 388 [params, grads, exp_avgs, exp_avg_sqs, state_steps] # type: ignore[list-item] 389 ) 390 for ( 391 grouped_params_, 392 grouped_grads_, 393 grouped_exp_avgs_, 394 grouped_exp_avg_sqs_, 395 grouped_state_steps_, 396 ), _ in grouped_tensors.values(): 397 grouped_params = cast(List[Tensor], grouped_params_) 398 grouped_grads = cast(List[Tensor], grouped_grads_) 399 grouped_exp_avgs = cast(List[Tensor], grouped_exp_avgs_) 400 grouped_exp_avg_sqs = cast(List[Tensor], grouped_exp_avg_sqs_) 401 grouped_state_steps = cast(List[Tensor], grouped_state_steps_) 402 403 # Update steps 404 # If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over 405 # and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just 406 # wrapped it once now. The alpha is required to assure we go to the right overload. 407 if not torch._utils.is_compiling() and grouped_state_steps[0].is_cpu: 408 torch._foreach_add_( 409 grouped_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0 410 ) 411 else: 412 torch._foreach_add_(grouped_state_steps, 1) 413 414 if has_complex: 415 _view_as_real( 416 grouped_params, grouped_grads, grouped_exp_avgs, grouped_exp_avg_sqs 417 ) 418 419 if maximize: 420 grouped_grads = torch._foreach_neg(grouped_grads) # type: ignore[assignment] 421 422 # maximum length of the approximated SMA 423 rho_inf = 2 / (1 - beta2) - 1 424 # compute the length of the approximated SMA 425 bias_correction1: Union[Tuple[Tensor, ...], List[Tensor]] 426 bias_correction2: Union[Tuple[Tensor, ...], List[Tensor]] 427 rho_t_list: Union[Tuple[Tensor, ...], List[Tensor]] 428 if capturable: 429 bias_correction1 = torch._foreach_pow(beta2, grouped_state_steps) 430 torch._foreach_neg_(bias_correction1) 431 torch._foreach_add_(bias_correction1, 1) 432 bias_correction2 = torch._foreach_pow(beta2, grouped_state_steps) 433 torch._foreach_mul_(bias_correction2, grouped_state_steps) 434 torch._foreach_mul_(bias_correction2, 2) 435 torch._foreach_div_(bias_correction2, bias_correction1) 436 torch._foreach_neg_(bias_correction2) 437 torch._foreach_add_(bias_correction2, rho_inf) 438 rho_t_list = bias_correction2 439 else: 440 rho_t_list = [ 441 rho_inf 442 - 2 443 * _get_value(step) 444 * (beta2 ** _get_value(step)) 445 / (1 - beta2 ** _get_value(step)) 446 for step in grouped_state_steps 447 ] 448 449 if weight_decay != 0: 450 if decoupled_weight_decay: 451 torch._foreach_mul_(grouped_params, 1 - lr * weight_decay) 452 else: 453 # Re-use the intermediate memory (grouped_grads) already allocated for maximize 454 if maximize: 455 torch._foreach_add_( 456 grouped_grads, grouped_params, alpha=weight_decay 457 ) 458 else: 459 grouped_grads = torch._foreach_add( # type: ignore[assignment] 460 grouped_grads, grouped_params, alpha=weight_decay 461 ) 462 463 # Decay the first and second moment running average coefficient 464 torch._foreach_lerp_(grouped_exp_avgs, grouped_grads, 1 - beta1) 465 466 torch._foreach_mul_(grouped_exp_avg_sqs, beta2) 467 torch._foreach_addcmul_( 468 grouped_exp_avg_sqs, grouped_grads, grouped_grads, 1 - beta2 469 ) 470 471 # Delete the local intermediate since it won't be used anymore to save on peak memory 472 del grouped_grads 473 474 if capturable: 475 num = torch._foreach_sub(rho_t_list, 4) 476 sub2 = torch._foreach_sub(rho_t_list, 2) 477 torch._foreach_mul_(num, sub2) 478 del sub2 479 torch._foreach_mul_(num, rho_inf) 480 rho_inf = (rho_inf - 4) * (rho_inf - 2) 481 denom = torch._foreach_mul(rho_t_list, rho_inf) 482 torch._foreach_div_(num, denom) 483 del denom 484 torch._foreach_sqrt_(num) 485 486 # TODO(mlazos): we should try and get a foreach_where op https://github.com/pytorch/pytorch/issues/117884 487 rect = [ 488 torch.where(rho_t > 5.0, n, 0.0) for n, rho_t in zip(num, rho_t_list) 489 ] 490 del num 491 del rho_t_list 492 unrect_step_size = [torch.where(rect > 0, 0.0, 1.0) for rect in rect] 493 torch._foreach_mul_(unrect_step_size, lr) 494 495 bias_correction1 = torch._foreach_pow(beta1, grouped_state_steps) 496 torch._foreach_neg_(bias_correction1) 497 torch._foreach_add_(bias_correction1, 1) 498 499 torch._foreach_div_(unrect_step_size, bias_correction1) 500 torch._foreach_neg_(unrect_step_size) 501 502 bias_correction2 = torch._foreach_pow(beta2, grouped_state_steps) 503 torch._foreach_neg_(bias_correction2) 504 torch._foreach_add_(bias_correction2, 1) 505 torch._foreach_sqrt_(bias_correction2) 506 torch._foreach_mul_(bias_correction2, lr) 507 torch._foreach_mul_(bias_correction2, rect) 508 del rect 509 torch._foreach_neg_(bias_correction2) 510 torch._foreach_div_(bias_correction2, bias_correction1) 511 del bias_correction1 512 else: 513 rect = [ 514 ( 515 (rho_t - 4) # type: ignore[arg-type] 516 * (rho_t - 2) 517 * rho_inf 518 / ((rho_inf - 4) * (rho_inf - 2) * rho_t) 519 ) 520 ** 0.5 521 if rho_t > 5 522 else 0 523 for rho_t in rho_t_list 524 ] 525 unrectified = [0 if rect > 0 else 1.0 for rect in rect] 526 527 bias_correction1 = [ 528 1 - beta1 ** _get_value(step) for step in grouped_state_steps 529 ] 530 unrect_step_size = [ 531 (lr * rect / bc) * -1 for rect, bc in zip(unrectified, bias_correction1) 532 ] 533 bias_correction2 = [ 534 ((1 - beta2 ** _get_value(step)) ** 0.5) * (lr * rect / bc) * -1 535 for step, rect, bc in zip(grouped_state_steps, rect, bias_correction1) 536 ] 537 538 buffer = torch._foreach_sqrt(grouped_exp_avg_sqs) 539 torch._foreach_add_(buffer, eps) 540 torch._foreach_div_(buffer, bias_correction2) 541 torch._foreach_reciprocal_(buffer) 542 torch._foreach_add_(buffer, unrect_step_size) 543 544 # Here, buffer = sqrt(1 - beta2^t) * rect_step_size / (sqrt(v) + eps) + unrect_step_size 545 torch._foreach_addcmul_(grouped_params, grouped_exp_avgs, buffer) 546 547 548@_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_radam) 549def radam( 550 params: List[Tensor], 551 grads: List[Tensor], 552 exp_avgs: List[Tensor], 553 exp_avg_sqs: List[Tensor], 554 state_steps: List[Tensor], 555 # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627 556 # setting this as kwarg for now as functional API is compiled by torch/distributed/optim 557 decoupled_weight_decay: bool = False, 558 foreach: Optional[bool] = None, 559 differentiable: bool = False, 560 capturable: bool = False, 561 has_complex: bool = False, 562 maximize: bool = False, 563 *, 564 beta1: float, 565 beta2: float, 566 lr: float, 567 weight_decay: float, 568 eps: float, 569): 570 r"""Functional API that performs RAdam algorithm computation. 571 572 See :class:`~torch.optim.RAdam` for details. 573 """ 574 if not all(isinstance(t, torch.Tensor) for t in state_steps): 575 raise RuntimeError( 576 "API has changed, `state_steps` argument must contain a list of singleton tensors" 577 ) 578 579 if foreach is None: 580 _, foreach = _default_to_fused_or_foreach( 581 params, differentiable, use_fused=False 582 ) 583 584 if foreach and torch.jit.is_scripting(): 585 raise RuntimeError("torch.jit.script not supported with foreach optimizers") 586 587 if foreach and not torch.jit.is_scripting(): 588 func = _multi_tensor_radam 589 else: 590 func = _single_tensor_radam 591 592 func( 593 params, 594 grads, 595 exp_avgs, 596 exp_avg_sqs, 597 state_steps, 598 beta1=beta1, 599 beta2=beta2, 600 lr=lr, 601 weight_decay=weight_decay, 602 eps=eps, 603 maximize=maximize, 604 decoupled_weight_decay=decoupled_weight_decay, 605 differentiable=differentiable, 606 capturable=capturable, 607 has_complex=has_complex, 608 ) 609