1# mypy: allow-untyped-decorators 2# mypy: allow-untyped-defs 3r"""Implementation for the Resilient backpropagation.""" 4from typing import cast, List, Optional, Tuple, Union 5 6import torch 7from torch import Tensor 8 9from .optimizer import ( 10 _capturable_doc, 11 _default_to_fused_or_foreach, 12 _differentiable_doc, 13 _disable_dynamo_if_unsupported, 14 _foreach_doc, 15 _get_capturable_supported_devices, 16 _get_scalar_dtype, 17 _maximize_doc, 18 _use_grad_for_differentiable, 19 _view_as_real, 20 Optimizer, 21 ParamsT, 22) 23 24 25__all__ = ["Rprop", "rprop"] 26 27 28class Rprop(Optimizer): # noqa: D101 29 def __init__( 30 self, 31 params: ParamsT, 32 lr: Union[float, Tensor] = 1e-2, 33 etas: Tuple[float, float] = (0.5, 1.2), 34 step_sizes: Tuple[float, float] = (1e-6, 50), 35 *, 36 capturable: bool = False, 37 foreach: Optional[bool] = None, 38 maximize: bool = False, 39 differentiable: bool = False, 40 ): # noqa: D107 41 if isinstance(lr, Tensor) and lr.numel() != 1: 42 raise ValueError("Tensor lr must be 1-element") 43 if not 0.0 <= lr: 44 raise ValueError(f"Invalid learning rate: {lr}") 45 if not 0.0 < etas[0] < 1.0 < etas[1]: 46 raise ValueError(f"Invalid eta values: {etas[0]}, {etas[1]}") 47 48 defaults = dict( 49 lr=lr, 50 etas=etas, 51 step_sizes=step_sizes, 52 foreach=foreach, 53 maximize=maximize, 54 differentiable=differentiable, 55 capturable=capturable, 56 ) 57 super().__init__(params, defaults) 58 59 def __setstate__(self, state): # noqa: D105 60 super().__setstate__(state) 61 for group in self.param_groups: 62 group.setdefault("foreach", None) 63 group.setdefault("maximize", False) 64 group.setdefault("differentiable", False) 65 group.setdefault("capturable", False) 66 for p in group["params"]: 67 p_state = self.state.get(p, []) 68 if len(p_state) != 0 and not torch.is_tensor(p_state["step"]): 69 step_val = float(p_state["step"]) 70 p_state["step"] = ( 71 torch.tensor( 72 step_val, dtype=_get_scalar_dtype(), device=p.device 73 ) 74 if group["capturable"] 75 else torch.tensor(step_val, dtype=_get_scalar_dtype()) 76 ) 77 78 def _init_group(self, group, params, grads, prevs, step_sizes, state_steps): 79 has_complex = False 80 for p in group["params"]: 81 if p.grad is None: 82 continue 83 has_complex |= torch.is_complex(p) 84 params.append(p) 85 grad = p.grad 86 if grad.is_sparse: 87 raise RuntimeError("Rprop does not support sparse gradients") 88 89 grads.append(grad) 90 state = self.state[p] 91 92 # State initialization 93 if len(state) == 0: 94 state["step"] = ( 95 torch.zeros((), dtype=_get_scalar_dtype(), device=p.device) 96 if group["capturable"] 97 else torch.zeros((), dtype=_get_scalar_dtype()) 98 ) 99 100 state["prev"] = torch.zeros_like(p, memory_format=torch.preserve_format) 101 if p.dtype.is_complex: 102 # Complex Number should be as if they are two independent real numbers. 103 # Hence the step_size shouldn't be zero for imaginary part. 104 state["step_size"] = torch.full_like( 105 grad, complex(group["lr"], group["lr"]) 106 ) 107 else: 108 state["step_size"] = torch.full_like(grad, group["lr"]) 109 110 prevs.append(state["prev"]) 111 step_sizes.append(state["step_size"]) 112 state_steps.append(state["step"]) 113 114 return has_complex 115 116 @_use_grad_for_differentiable 117 def step(self, closure=None): 118 """Perform a single optimization step. 119 120 Args: 121 closure (Callable, optional): A closure that reevaluates the model 122 and returns the loss. 123 """ 124 self._cuda_graph_capture_health_check() 125 126 loss = None 127 if closure is not None: 128 with torch.enable_grad(): 129 loss = closure() 130 131 for group in self.param_groups: 132 params: List[Tensor] = [] 133 grads: List[Tensor] = [] 134 prevs: List[Tensor] = [] 135 step_sizes: List[Tensor] = [] 136 state_steps: List[Tensor] = [] 137 138 etaminus, etaplus = group["etas"] 139 step_size_min, step_size_max = group["step_sizes"] 140 foreach = group["foreach"] 141 maximize = group["maximize"] 142 143 has_complex = self._init_group( 144 group, params, grads, prevs, step_sizes, state_steps 145 ) 146 147 rprop( 148 params, 149 grads, 150 prevs, 151 step_sizes, 152 state_steps, 153 step_size_min=step_size_min, 154 step_size_max=step_size_max, 155 etaminus=etaminus, 156 etaplus=etaplus, 157 foreach=foreach, 158 maximize=maximize, 159 differentiable=group["differentiable"], 160 capturable=group["capturable"], 161 has_complex=has_complex, 162 ) 163 164 return loss 165 166 167Rprop.__doc__ = ( 168 r"""Implements the resilient backpropagation algorithm. 169 170 .. math:: 171 \begin{aligned} 172 &\rule{110mm}{0.4pt} \\ 173 &\textbf{input} : \theta_0 \in \mathbf{R}^d \text{ (params)},f(\theta) 174 \text{ (objective)}, \\ 175 &\hspace{13mm} \eta_{+/-} \text{ (etaplus, etaminus)}, \Gamma_{max/min} 176 \text{ (step sizes)} \\ 177 &\textbf{initialize} : g^0_{prev} \leftarrow 0, 178 \: \eta_0 \leftarrow \text{lr (learning rate)} \\ 179 &\rule{110mm}{0.4pt} \\ 180 &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\ 181 &\hspace{5mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\ 182 &\hspace{5mm} \textbf{for} \text{ } i = 0, 1, \ldots, d-1 \: \mathbf{do} \\ 183 &\hspace{10mm} \textbf{if} \: g^i_{prev} g^i_t > 0 \\ 184 &\hspace{15mm} \eta^i_t \leftarrow \mathrm{min}(\eta^i_{t-1} \eta_{+}, 185 \Gamma_{max}) \\ 186 &\hspace{10mm} \textbf{else if} \: g^i_{prev} g^i_t < 0 \\ 187 &\hspace{15mm} \eta^i_t \leftarrow \mathrm{max}(\eta^i_{t-1} \eta_{-}, 188 \Gamma_{min}) \\ 189 &\hspace{15mm} g^i_t \leftarrow 0 \\ 190 &\hspace{10mm} \textbf{else} \: \\ 191 &\hspace{15mm} \eta^i_t \leftarrow \eta^i_{t-1} \\ 192 &\hspace{5mm}\theta_t \leftarrow \theta_{t-1}- \eta_t \mathrm{sign}(g_t) \\ 193 &\hspace{5mm}g_{prev} \leftarrow g_t \\ 194 &\rule{110mm}{0.4pt} \\[-1.ex] 195 &\bf{return} \: \theta_t \\[-1.ex] 196 &\rule{110mm}{0.4pt} \\[-1.ex] 197 \end{aligned} 198 199 For further details regarding the algorithm we refer to the paper 200 `A Direct Adaptive Method for Faster Backpropagation Learning: The RPROP Algorithm 201 <http://citeseerx.ist.psu.edu/viewdoc/summary?doi=10.1.1.21.1417>`_. 202 """ 203 + rf""" 204 Args: 205 params (iterable): iterable of parameters to optimize or dicts defining 206 parameter groups 207 lr (float, optional): learning rate (default: 1e-2) 208 etas (Tuple[float, float], optional): pair of (etaminus, etaplus), that 209 are multiplicative increase and decrease factors 210 (default: (0.5, 1.2)) 211 step_sizes (Tuple[float, float], optional): a pair of minimal and 212 maximal allowed step sizes (default: (1e-6, 50)) 213 {_foreach_doc} 214 {_capturable_doc} 215 {_maximize_doc} 216 {_differentiable_doc} 217 218 """ 219) 220 221 222def _single_tensor_rprop( 223 params: List[Tensor], 224 grads: List[Tensor], 225 prevs: List[Tensor], 226 step_sizes: List[Tensor], 227 state_steps: List[Tensor], 228 *, 229 step_size_min: float, 230 step_size_max: float, 231 etaminus: float, 232 etaplus: float, 233 maximize: bool, 234 capturable: bool, 235 differentiable: bool, 236 has_complex: bool, 237): 238 for i, param in enumerate(params): 239 grad = grads[i] 240 grad = grad if not maximize else -grad 241 prev = prevs[i] 242 step_size = step_sizes[i] 243 step = state_steps[i] 244 245 # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] 246 if not torch._utils.is_compiling() and capturable: 247 capturable_supported_devices = _get_capturable_supported_devices() 248 assert ( 249 param.device.type == step.device.type 250 and param.device.type in capturable_supported_devices 251 ), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}." 252 253 step += 1 254 255 if torch.is_complex(param): 256 grad = torch.view_as_real(grad) 257 prev = torch.view_as_real(prev) 258 param = torch.view_as_real(param) 259 step_size = torch.view_as_real(step_size) 260 if differentiable: 261 sign = grad.mul(prev.clone()).sign() 262 else: 263 sign = grad.mul(prev).sign() 264 265 if capturable: 266 sign.copy_(torch.where(sign.gt(0), etaplus, sign)) 267 sign.copy_(torch.where(sign.lt(0), etaminus, sign)) 268 sign.copy_(torch.where(sign.eq(0), 1, sign)) 269 else: 270 sign[sign.gt(0)] = etaplus 271 sign[sign.lt(0)] = etaminus 272 sign[sign.eq(0)] = 1 273 274 # update stepsizes with step size updates 275 step_size.mul_(sign).clamp_(step_size_min, step_size_max) 276 277 # for dir<0, dfdx=0 278 # for dir>=0 dfdx=dfdx 279 grad = grad.clone(memory_format=torch.preserve_format) 280 if capturable: 281 grad.copy_(torch.where(sign.eq(etaminus), 0, grad)) 282 else: 283 grad[sign.eq(etaminus)] = 0 284 285 # update parameters 286 param.addcmul_(grad.sign(), step_size, value=-1) 287 prev.copy_(grad) 288 289 290def _multi_tensor_rprop( 291 params: List[Tensor], 292 grads: List[Tensor], 293 prevs: List[Tensor], 294 step_sizes: List[Tensor], 295 state_steps: List[Tensor], 296 *, 297 step_size_min: float, 298 step_size_max: float, 299 etaminus: float, 300 etaplus: float, 301 maximize: bool, 302 capturable: bool, 303 differentiable: bool, 304 has_complex: bool, 305): 306 if len(params) == 0: 307 return 308 309 assert not differentiable, "_foreach ops don't support autograd" 310 311 # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] 312 if not torch._utils.is_compiling() and capturable: 313 capturable_supported_devices = _get_capturable_supported_devices() 314 assert all( 315 p.device.type == step.device.type 316 and p.device.type in capturable_supported_devices 317 for p, step in zip(params, state_steps) 318 ), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}." 319 320 grouped_tensors = Optimizer._group_tensors_by_device_and_dtype( 321 [params, grads, prevs, step_sizes, state_steps] # type: ignore[list-item] 322 ) 323 for ( 324 grouped_params_, 325 grouped_grads_, 326 grouped_prevs_, 327 grouped_step_sizes_, 328 grouped_state_steps_, 329 ), _ in grouped_tensors.values(): 330 grouped_params = cast(List[Tensor], grouped_params_) 331 grouped_grads = cast(List[Tensor], grouped_grads_) 332 grouped_prevs = cast(List[Tensor], grouped_prevs_) 333 grouped_step_sizes = cast(List[Tensor], grouped_step_sizes_) 334 grouped_state_steps = cast(List[Tensor], grouped_state_steps_) 335 336 # Update steps 337 # If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over 338 # and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just 339 # wrapped it once now. The alpha is required to assure we go to the right overload. 340 if not torch._utils.is_compiling() and grouped_state_steps[0].is_cpu: 341 torch._foreach_add_( 342 grouped_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0 343 ) 344 else: 345 torch._foreach_add_(grouped_state_steps, 1) 346 347 # Handle complex params 348 if has_complex: 349 _view_as_real( 350 grouped_params, grouped_grads, grouped_prevs, grouped_step_sizes 351 ) 352 353 signs = torch._foreach_mul(grouped_grads, grouped_prevs) 354 if maximize: 355 torch._foreach_neg_(signs) 356 357 # At the end of the step, grouped_prevs will contain the current grads, so we reuse 358 # grouped_prevs memory instead of creating a new buffer, but, for clarity, we reassign 359 # to keep referring to the buffer as grouped_grads. 360 torch._foreach_copy_(grouped_prevs, grouped_grads) 361 if maximize: 362 torch._foreach_neg_(grouped_prevs) 363 grouped_grads = grouped_prevs 364 365 torch._foreach_sign_(signs) 366 if capturable: 367 for sign in signs: 368 sign.copy_(torch.where(sign.gt(0), etaplus, sign)) 369 sign.copy_(torch.where(sign.lt(0), etaminus, sign)) 370 sign.copy_(torch.where(sign.eq(0), 1, sign)) 371 else: 372 for sign in signs: 373 sign[sign.gt(0)] = etaplus 374 sign[sign.lt(0)] = etaminus 375 sign[sign.eq(0)] = 1 376 377 # update stepsizes with step size updates 378 torch._foreach_mul_(grouped_step_sizes, signs) 379 for step_size in grouped_step_sizes: 380 step_size.clamp_(step_size_min, step_size_max) 381 382 # for dir<0, dfdx=0 383 # for dir>=0 dfdx=dfdx 384 grouped_grads = list(grouped_grads) 385 for i in range(len(grouped_grads)): 386 grouped_grads[i].copy_( 387 torch.where(signs[i].eq(etaminus), 0, grouped_grads[i]) 388 ) 389 390 # explicitly del signs as it's not used after here to save memory 391 del signs 392 393 # update parameters 394 grad_signs = [grad.sign() for grad in grouped_grads] 395 torch._foreach_addcmul_( 396 grouped_params, grad_signs, grouped_step_sizes, value=-1 397 ) 398 399 # Logically, you may expect grouped_prevs to get updated to grouped_grads, but that's 400 # basically already happened since we've been using grouped_prevs' memory to store 401 # updated grouped_grads! 402 403 404@_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_rprop) 405def rprop( 406 params: List[Tensor], 407 grads: List[Tensor], 408 prevs: List[Tensor], 409 step_sizes: List[Tensor], 410 state_steps: List[Tensor], 411 # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627 412 # setting this as kwarg for now as functional API is compiled by torch/distributed/optim 413 foreach: Optional[bool] = None, 414 capturable: bool = False, 415 maximize: bool = False, 416 differentiable: bool = False, 417 has_complex: bool = False, 418 *, 419 step_size_min: float, 420 step_size_max: float, 421 etaminus: float, 422 etaplus: float, 423): 424 r"""Functional API that performs rprop algorithm computation. 425 426 See :class:`~torch.optim.Rprop` for details. 427 """ 428 # this check is slow during compilation, so we skip it 429 # if it's strictly needed we can add this check back in dynamo 430 if not torch._utils.is_compiling() and not all( 431 isinstance(t, torch.Tensor) for t in state_steps 432 ): 433 raise RuntimeError( 434 "API has changed, `state_steps` argument must contain a list of singleton tensors" 435 ) 436 437 if foreach is None: 438 _, foreach = _default_to_fused_or_foreach( 439 params, differentiable, use_fused=False 440 ) 441 442 if foreach and torch.jit.is_scripting(): 443 raise RuntimeError("torch.jit.script not supported with foreach optimizers") 444 445 if foreach and not torch.jit.is_scripting(): 446 func = _multi_tensor_rprop 447 else: 448 func = _single_tensor_rprop 449 450 func( 451 params, 452 grads, 453 prevs, 454 step_sizes, 455 state_steps, 456 step_size_min=step_size_min, 457 step_size_max=step_size_max, 458 etaminus=etaminus, 459 etaplus=etaplus, 460 capturable=capturable, 461 maximize=maximize, 462 differentiable=differentiable, 463 has_complex=has_complex, 464 ) 465