1# mypy: allow-untyped-decorators 2# mypy: allow-untyped-defs 3r"""Implementation for the RMSprop algorithm.""" 4from typing import cast, List, Optional, Union 5 6import torch 7from torch import Tensor 8 9from .optimizer import ( 10 _capturable_doc, 11 _default_to_fused_or_foreach, 12 _differentiable_doc, 13 _disable_dynamo_if_unsupported, 14 _foreach_doc, 15 _get_capturable_supported_devices, 16 _get_scalar_dtype, 17 _maximize_doc, 18 _use_grad_for_differentiable, 19 _view_as_real, 20 Optimizer, 21 ParamsT, 22) 23 24 25__all__ = ["RMSprop", "rmsprop"] 26 27 28class RMSprop(Optimizer): # noqa: D101 29 def __init__( 30 self, 31 params: ParamsT, 32 lr: Union[float, Tensor] = 1e-2, 33 alpha: float = 0.99, 34 eps: float = 1e-8, 35 weight_decay: float = 0, 36 momentum: float = 0, 37 centered=False, 38 capturable=False, 39 foreach: Optional[bool] = None, 40 maximize: bool = False, 41 differentiable: bool = False, 42 ): # noqa: D107 43 if isinstance(lr, Tensor) and lr.numel() != 1: 44 raise ValueError("Tensor lr must be 1-element") 45 if not 0.0 <= lr: 46 raise ValueError(f"Invalid learning rate: {lr}") 47 if not 0.0 <= eps: 48 raise ValueError(f"Invalid epsilon value: {eps}") 49 if not 0.0 <= momentum: 50 raise ValueError(f"Invalid momentum value: {momentum}") 51 if not 0.0 <= weight_decay: 52 raise ValueError(f"Invalid weight_decay value: {weight_decay}") 53 if not 0.0 <= alpha: 54 raise ValueError(f"Invalid alpha value: {alpha}") 55 56 defaults = dict( 57 lr=lr, 58 momentum=momentum, 59 alpha=alpha, 60 eps=eps, 61 centered=centered, 62 weight_decay=weight_decay, 63 capturable=capturable, 64 foreach=foreach, 65 maximize=maximize, 66 differentiable=differentiable, 67 ) 68 super().__init__(params, defaults) 69 70 def __setstate__(self, state): # noqa: D105 71 super().__setstate__(state) 72 for group in self.param_groups: 73 group.setdefault("momentum", 0) 74 group.setdefault("centered", False) 75 group.setdefault("foreach", None) 76 group.setdefault("maximize", False) 77 group.setdefault("differentiable", False) 78 group.setdefault("capturable", False) 79 for p in group["params"]: 80 p_state = self.state.get(p, []) 81 if len(p_state) != 0 and not torch.is_tensor(p_state["step"]): 82 step_val = float(p_state["step"]) 83 p_state["step"] = ( 84 torch.tensor( 85 step_val, dtype=_get_scalar_dtype(), device=p.device 86 ) 87 if group["capturable"] 88 else torch.tensor(step_val, dtype=_get_scalar_dtype()) 89 ) 90 91 def _init_group( 92 self, 93 group, 94 params_with_grad, 95 grads, 96 square_avgs, 97 momentum_buffer_list, 98 grad_avgs, 99 state_steps, 100 ): 101 has_complex = False 102 for p in group["params"]: 103 if p.grad is None: 104 continue 105 has_complex |= torch.is_complex(p) 106 params_with_grad.append(p) 107 108 if p.grad.is_sparse: 109 raise RuntimeError("RMSprop does not support sparse gradients") 110 grads.append(p.grad) 111 112 state = self.state[p] 113 114 # State initialization 115 if len(state) == 0: 116 state["step"] = ( 117 torch.zeros((), dtype=_get_scalar_dtype(), device=p.device) 118 if group["capturable"] 119 else torch.zeros((), dtype=_get_scalar_dtype()) 120 ) 121 state["square_avg"] = torch.zeros_like( 122 p, memory_format=torch.preserve_format 123 ) 124 if group["momentum"] > 0: 125 state["momentum_buffer"] = torch.zeros_like( 126 p, memory_format=torch.preserve_format 127 ) 128 if group["centered"]: 129 state["grad_avg"] = torch.zeros_like( 130 p, memory_format=torch.preserve_format 131 ) 132 square_avgs.append(state["square_avg"]) 133 state_steps.append(state["step"]) 134 135 if group["momentum"] > 0: 136 momentum_buffer_list.append(state["momentum_buffer"]) 137 if group["centered"]: 138 grad_avgs.append(state["grad_avg"]) 139 140 return has_complex 141 142 @_use_grad_for_differentiable 143 def step(self, closure=None): 144 """Perform a single optimization step. 145 146 Args: 147 closure (Callable, optional): A closure that reevaluates the model 148 and returns the loss. 149 """ 150 self._cuda_graph_capture_health_check() 151 152 loss = None 153 if closure is not None: 154 with torch.enable_grad(): 155 loss = closure() 156 157 for group in self.param_groups: 158 params_with_grad: List[Tensor] = [] 159 grads: List[Tensor] = [] 160 square_avgs: List[Tensor] = [] 161 grad_avgs: List[Tensor] = [] 162 momentum_buffer_list: List[Tensor] = [] 163 state_steps: List[Tensor] = [] 164 165 has_complex = self._init_group( 166 group, 167 params_with_grad, 168 grads, 169 square_avgs, 170 momentum_buffer_list, 171 grad_avgs, 172 state_steps, 173 ) 174 175 rmsprop( 176 params_with_grad, 177 grads, 178 square_avgs, 179 grad_avgs, 180 momentum_buffer_list, 181 state_steps, 182 lr=group["lr"], 183 alpha=group["alpha"], 184 eps=group["eps"], 185 weight_decay=group["weight_decay"], 186 momentum=group["momentum"], 187 centered=group["centered"], 188 foreach=group["foreach"], 189 maximize=group["maximize"], 190 differentiable=group["differentiable"], 191 capturable=group["capturable"], 192 has_complex=has_complex, 193 ) 194 195 return loss 196 197 198RMSprop.__doc__ = ( 199 r"""Implements RMSprop algorithm. 200 201 .. math:: 202 \begin{aligned} 203 &\rule{110mm}{0.4pt} \\ 204 &\textbf{input} : \alpha \text{ (alpha)},\: \gamma \text{ (lr)}, 205 \: \theta_0 \text{ (params)}, \: f(\theta) \text{ (objective)} \\ 206 &\hspace{13mm} \lambda \text{ (weight decay)},\: \mu \text{ (momentum)},\: centered\\ 207 &\textbf{initialize} : v_0 \leftarrow 0 \text{ (square average)}, \: 208 \textbf{b}_0 \leftarrow 0 \text{ (buffer)}, \: g^{ave}_0 \leftarrow 0 \\[-1.ex] 209 &\rule{110mm}{0.4pt} \\ 210 &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\ 211 &\hspace{5mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\ 212 &\hspace{5mm}if \: \lambda \neq 0 \\ 213 &\hspace{10mm} g_t \leftarrow g_t + \lambda \theta_{t-1} \\ 214 &\hspace{5mm}v_t \leftarrow \alpha v_{t-1} + (1 - \alpha) g^2_t 215 \hspace{8mm} \\ 216 &\hspace{5mm} \tilde{v_t} \leftarrow v_t \\ 217 &\hspace{5mm}if \: centered \\ 218 &\hspace{10mm} g^{ave}_t \leftarrow g^{ave}_{t-1} \alpha + (1-\alpha) g_t \\ 219 &\hspace{10mm} \tilde{v_t} \leftarrow \tilde{v_t} - \big(g^{ave}_{t} \big)^2 \\ 220 &\hspace{5mm}if \: \mu > 0 \\ 221 &\hspace{10mm} \textbf{b}_t\leftarrow \mu \textbf{b}_{t-1} + 222 g_t/ \big(\sqrt{\tilde{v_t}} + \epsilon \big) \\ 223 &\hspace{10mm} \theta_t \leftarrow \theta_{t-1} - \gamma \textbf{b}_t \\ 224 &\hspace{5mm} else \\ 225 &\hspace{10mm}\theta_t \leftarrow \theta_{t-1} - 226 \gamma g_t/ \big(\sqrt{\tilde{v_t}} + \epsilon \big) \hspace{3mm} \\ 227 &\rule{110mm}{0.4pt} \\[-1.ex] 228 &\bf{return} \: \theta_t \\[-1.ex] 229 &\rule{110mm}{0.4pt} \\[-1.ex] 230 \end{aligned} 231 232 For further details regarding the algorithm we refer to 233 `lecture notes <https://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf>`_ by G. Hinton. 234 and centered version `Generating Sequences 235 With Recurrent Neural Networks <https://arxiv.org/pdf/1308.0850v5.pdf>`_. 236 The implementation here takes the square root of the gradient average before 237 adding epsilon (note that TensorFlow interchanges these two operations). The effective 238 learning rate is thus :math:`\gamma/(\sqrt{v} + \epsilon)` where :math:`\gamma` 239 is the scheduled learning rate and :math:`v` is the weighted moving average 240 of the squared gradient. 241 """ 242 + rf""" 243 Args: 244 params (iterable): iterable of parameters to optimize or dicts defining 245 parameter groups 246 lr (float, Tensor, optional): learning rate (default: 1e-2) 247 momentum (float, optional): momentum factor (default: 0) 248 alpha (float, optional): smoothing constant (default: 0.99) 249 eps (float, optional): term added to the denominator to improve 250 numerical stability (default: 1e-8) 251 centered (bool, optional) : if ``True``, compute the centered RMSProp, 252 the gradient is normalized by an estimation of its variance 253 weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 254 {_foreach_doc} 255 {_maximize_doc} 256 {_capturable_doc} 257 {_differentiable_doc} 258 259 """ 260) 261 262 263def _single_tensor_rmsprop( 264 params: List[Tensor], 265 grads: List[Tensor], 266 square_avgs: List[Tensor], 267 grad_avgs: List[Tensor], 268 momentum_buffer_list: List[Tensor], 269 state_steps: List[Tensor], 270 *, 271 lr: float, 272 alpha: float, 273 eps: float, 274 weight_decay: float, 275 momentum: float, 276 centered: bool, 277 maximize: bool, 278 differentiable: bool, 279 capturable: bool, 280 has_complex: bool, 281): 282 for i, param in enumerate(params): 283 step = state_steps[i] 284 285 # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] 286 if not torch._utils.is_compiling() and capturable: 287 capturable_supported_devices = _get_capturable_supported_devices() 288 assert ( 289 param.device.type == step.device.type 290 and param.device.type in capturable_supported_devices 291 ), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}." 292 293 grad = grads[i] 294 grad = grad if not maximize else -grad 295 square_avg = square_avgs[i] 296 297 step += 1 298 299 if weight_decay != 0: 300 grad = grad.add(param, alpha=weight_decay) 301 302 is_complex_param = torch.is_complex(param) 303 if is_complex_param: 304 param = torch.view_as_real(param) 305 grad = torch.view_as_real(grad) 306 square_avg = torch.view_as_real(square_avg) 307 308 square_avg.mul_(alpha).addcmul_(grad, grad, value=1 - alpha) 309 310 if centered: 311 grad_avg = grad_avgs[i] 312 if is_complex_param: 313 grad_avg = torch.view_as_real(grad_avg) 314 grad_avg.lerp_(grad, 1 - alpha) 315 avg = square_avg.addcmul(grad_avg, grad_avg, value=-1).sqrt_() 316 else: 317 avg = square_avg.sqrt() 318 319 if differentiable: 320 avg = avg.add(eps) 321 else: 322 avg = avg.add_(eps) 323 324 if momentum > 0: 325 buf = momentum_buffer_list[i] 326 if is_complex_param: 327 buf = torch.view_as_real(buf) 328 buf.mul_(momentum).addcdiv_(grad, avg) 329 param.add_(buf, alpha=-lr) 330 else: 331 param.addcdiv_(grad, avg, value=-lr) 332 333 334def _multi_tensor_rmsprop( 335 params: List[Tensor], 336 grads: List[Tensor], 337 square_avgs: List[Tensor], 338 grad_avgs: List[Tensor], 339 momentum_buffer_list: List[Tensor], 340 state_steps: List[Tensor], 341 *, 342 lr: float, 343 alpha: float, 344 eps: float, 345 weight_decay: float, 346 momentum: float, 347 centered: bool, 348 maximize: bool, 349 differentiable: bool, 350 capturable: bool, 351 has_complex: bool, 352): 353 if len(params) == 0: 354 return 355 356 assert not differentiable, "_foreach ops don't support autograd" 357 358 # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] 359 if not torch._utils.is_compiling() and capturable: 360 capturable_supported_devices = _get_capturable_supported_devices() 361 assert all( 362 p.device.type == step.device.type 363 and p.device.type in capturable_supported_devices 364 for p, step in zip(params, state_steps) 365 ), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}." 366 367 grouped_tensors = Optimizer._group_tensors_by_device_and_dtype( 368 [params, grads, square_avgs, grad_avgs, momentum_buffer_list, state_steps] # type: ignore[list-item] 369 ) 370 for ( 371 ( 372 grouped_params_, 373 grouped_grads_, 374 grouped_square_avgs_, 375 grouped_grad_avgs_, 376 grouped_momentum_buffer_list_, 377 grouped_state_steps_, 378 ) 379 ), _ in grouped_tensors.values(): 380 grouped_params = cast(List[Tensor], grouped_params_) 381 grouped_grads = cast(List[Tensor], grouped_grads_) 382 grouped_square_avgs = cast(List[Tensor], grouped_square_avgs_) 383 grouped_state_steps = cast(List[Tensor], grouped_state_steps_) 384 385 if has_complex: 386 state_and_grads = [grouped_grads, grouped_square_avgs] 387 if momentum > 0: 388 grouped_momentum_buffer_list = cast( 389 List[Tensor], grouped_momentum_buffer_list_ 390 ) 391 state_and_grads.append(grouped_momentum_buffer_list) 392 if centered: 393 grouped_grad_avgs = cast(List[Tensor], grouped_grad_avgs_) 394 state_and_grads.append(grouped_grad_avgs) 395 _view_as_real(grouped_params, *state_and_grads) 396 397 if maximize: 398 grouped_grads = torch._foreach_neg(grouped_grads) # type: ignore[assignment] 399 400 # Update steps 401 # If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over 402 # and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just 403 # wrapped it once now. The alpha is required to assure we go to the right overload. 404 if not torch._utils.is_compiling() and grouped_state_steps[0].is_cpu: 405 torch._foreach_add_( 406 grouped_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0 407 ) 408 else: 409 torch._foreach_add_(grouped_state_steps, 1) 410 411 if weight_decay != 0: 412 # Re-use the intermediate memory (grouped_grads) already allocated for maximize 413 if maximize: 414 torch._foreach_add_(grouped_grads, grouped_params, alpha=weight_decay) 415 else: 416 grouped_grads = torch._foreach_add( # type: ignore[assignment] 417 grouped_grads, grouped_params, alpha=weight_decay 418 ) 419 420 torch._foreach_mul_(grouped_square_avgs, alpha) 421 torch._foreach_addcmul_( 422 grouped_square_avgs, grouped_grads, grouped_grads, value=1 - alpha 423 ) 424 425 if centered: 426 grouped_grad_avgs = cast(List[Tensor], grouped_grad_avgs_) 427 torch._foreach_lerp_(grouped_grad_avgs, grouped_grads, 1 - alpha) 428 avg = torch._foreach_addcmul( 429 grouped_square_avgs, grouped_grad_avgs, grouped_grad_avgs, value=-1 430 ) 431 torch._foreach_sqrt_(avg) 432 torch._foreach_add_(avg, eps) 433 else: 434 avg = torch._foreach_sqrt(grouped_square_avgs) 435 torch._foreach_add_(avg, eps) 436 437 if momentum > 0: 438 grouped_momentum_buffer_list = cast( 439 List[Tensor], grouped_momentum_buffer_list_ 440 ) 441 torch._foreach_mul_(grouped_momentum_buffer_list, momentum) 442 torch._foreach_addcdiv_(grouped_momentum_buffer_list, grouped_grads, avg) 443 # If LR is a tensor, the else branch will internally call item() 444 # which will cause silent incorrectness if we are capturing 445 if capturable and isinstance(lr, torch.Tensor): 446 momentum_lr = torch._foreach_mul(grouped_momentum_buffer_list, -lr) 447 torch._foreach_add_(grouped_params, momentum_lr) 448 else: 449 torch._foreach_add_( 450 grouped_params, grouped_momentum_buffer_list, alpha=-lr 451 ) 452 else: 453 # If LR is a tensor, the else branch will internally call item() 454 # which will cause silent incorrectness if we are capturing 455 if capturable and isinstance(lr, torch.Tensor): 456 torch._foreach_div_(avg, -lr) 457 torch._foreach_addcdiv_(grouped_params, grouped_grads, avg) 458 else: 459 torch._foreach_addcdiv_(grouped_params, grouped_grads, avg, value=-lr) 460 461 462@_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_rmsprop) 463def rmsprop( 464 params: List[Tensor], 465 grads: List[Tensor], 466 square_avgs: List[Tensor], 467 grad_avgs: List[Tensor], 468 momentum_buffer_list: List[Tensor], 469 state_steps: List[Tensor], 470 # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627 471 # setting this as kwarg for now as functional API is compiled by torch/distributed/optim 472 foreach: Optional[bool] = None, 473 maximize: bool = False, 474 differentiable: bool = False, 475 capturable: bool = False, 476 has_complex: bool = False, 477 *, 478 lr: float, 479 alpha: float, 480 eps: float, 481 weight_decay: float, 482 momentum: float, 483 centered: bool, 484): 485 r"""Functional API that performs rmsprop algorithm computation. 486 487 See :class:`~torch.optim.RMSProp` for details. 488 """ 489 # this check is slow during compilation, so we skip it 490 # if it's strictly needed we can add this check back in dynamo 491 if not torch._utils.is_compiling() and not all( 492 isinstance(t, torch.Tensor) for t in state_steps 493 ): 494 raise RuntimeError( 495 "API has changed, `state_steps` argument must contain a list of singleton tensors" 496 ) 497 498 if foreach is None: 499 _, foreach = _default_to_fused_or_foreach( 500 params, differentiable, use_fused=False 501 ) 502 503 if foreach and torch.jit.is_scripting(): 504 raise RuntimeError("torch.jit.script not supported with foreach optimizers") 505 506 if foreach and not torch.jit.is_scripting(): 507 func = _multi_tensor_rmsprop 508 else: 509 func = _single_tensor_rmsprop 510 511 func( 512 params, 513 grads, 514 square_avgs, 515 grad_avgs, 516 momentum_buffer_list, 517 state_steps, 518 lr=lr, 519 alpha=alpha, 520 eps=eps, 521 weight_decay=weight_decay, 522 momentum=momentum, 523 centered=centered, 524 maximize=maximize, 525 capturable=capturable, 526 differentiable=differentiable, 527 has_complex=has_complex, 528 ) 529