1# mypy: allow-untyped-defs 2from typing import cast, List, Optional, Union 3 4import torch 5from torch import Tensor 6 7from .optimizer import ( 8 _default_to_fused_or_foreach, 9 _device_dtype_check_for_fused, 10 _differentiable_doc, 11 _foreach_doc, 12 _get_scalar_dtype, 13 _get_value, 14 _maximize_doc, 15 _use_grad_for_differentiable, 16 _view_as_real, 17 Optimizer, 18 ParamsT, 19) 20 21 22__all__ = ["Adagrad", "adagrad"] 23 24 25class Adagrad(Optimizer): 26 def __init__( 27 self, 28 params: ParamsT, 29 lr: Union[float, Tensor] = 1e-2, 30 lr_decay: float = 0, 31 weight_decay: float = 0, 32 initial_accumulator_value: float = 0, 33 eps: float = 1e-10, 34 foreach: Optional[bool] = None, 35 *, 36 maximize: bool = False, 37 differentiable: bool = False, 38 fused: Optional[bool] = None, 39 ): 40 if isinstance(lr, Tensor) and lr.numel() != 1: 41 raise ValueError("Tensor lr must be 1-element") 42 if not 0.0 <= lr: 43 raise ValueError(f"Invalid learning rate: {lr}") 44 if not 0.0 <= lr_decay: 45 raise ValueError(f"Invalid lr_decay value: {lr_decay}") 46 if not 0.0 <= weight_decay: 47 raise ValueError(f"Invalid weight_decay value: {weight_decay}") 48 if not 0.0 <= initial_accumulator_value: 49 raise ValueError( 50 f"Invalid initial_accumulator_value value: {initial_accumulator_value}" 51 ) 52 if not 0.0 <= eps: 53 raise ValueError(f"Invalid epsilon value: {eps}") 54 55 defaults = dict( 56 lr=lr, 57 lr_decay=lr_decay, 58 eps=eps, 59 weight_decay=weight_decay, 60 initial_accumulator_value=initial_accumulator_value, 61 foreach=foreach, 62 maximize=maximize, 63 differentiable=differentiable, 64 fused=fused, 65 ) 66 super().__init__(params, defaults) 67 68 if fused: 69 if differentiable: 70 raise RuntimeError("`fused` does not support `differentiable`") 71 if foreach: 72 raise RuntimeError("`fused` and `foreach` cannot be `True` together.") 73 self._need_device_dtype_check_for_fused = True 74 75 for group in self.param_groups: 76 for p in group["params"]: 77 state = self.state[p] 78 state["step"] = ( 79 torch.zeros( 80 (), 81 dtype=_get_scalar_dtype(is_fused=group["fused"]), 82 device=p.device, 83 ) 84 if group["fused"] 85 else torch.tensor(0.0, dtype=_get_scalar_dtype()) 86 ) 87 init_value = ( 88 complex(initial_accumulator_value, initial_accumulator_value) 89 if torch.is_complex(p) 90 else initial_accumulator_value 91 ) 92 state["sum"] = torch.full_like( 93 p, init_value, memory_format=torch.preserve_format 94 ) 95 96 def __setstate__(self, state): 97 super().__setstate__(state) 98 # define "fused" for 99 # MYPY error: Name "fused" may be undefined 100 fused = None 101 for group in self.param_groups: 102 group.setdefault("foreach", None) 103 group.setdefault("maximize", False) 104 group.setdefault("differentiable", False) 105 fused = group.setdefault("fused", None) 106 107 state_values = list(self.state.values()) 108 step_is_tensor = (len(state_values) != 0) and torch.is_tensor( 109 state_values[0]["step"] 110 ) 111 if not step_is_tensor: 112 for s in state_values: 113 s["step"] = torch.tensor( 114 float(s["step"]), dtype=_get_scalar_dtype(is_fused=fused) 115 ) 116 117 def share_memory(self): 118 for group in self.param_groups: 119 for p in group["params"]: 120 state = self.state[p] 121 state["sum"].share_memory_() 122 123 def _init_group(self, group, params_with_grad, grads, state_sums, state_steps): 124 has_sparse_grad, has_complex = False, False 125 for p in group["params"]: 126 if p.grad is not None: 127 if group["fused"] and getattr( 128 self, 129 "_need_device_dtype_check_for_fused", 130 True, 131 ): 132 _device_dtype_check_for_fused(p, cuda_unsupported=True) 133 self._need_device_dtype_check_for_fused = False 134 has_sparse_grad |= p.grad.is_sparse 135 has_complex |= torch.is_complex(p) 136 params_with_grad.append(p) 137 grads.append(p.grad) 138 state = self.state[p] 139 state_sums.append(state["sum"]) 140 state_steps.append(state["step"]) 141 142 return has_sparse_grad, has_complex 143 144 @_use_grad_for_differentiable 145 def step(self, closure=None): 146 """Perform a single optimization step. 147 148 Args: 149 closure (Callable, optional): A closure that reevaluates the model 150 and returns the loss. 151 """ 152 loss = None 153 154 if closure is not None: 155 with torch.enable_grad(): 156 loss = closure() 157 158 for group in self.param_groups: 159 params_with_grad: List[Tensor] = [] 160 grads: List[Tensor] = [] 161 state_sums: List[Tensor] = [] 162 state_steps: List[Tensor] = [] 163 164 has_sparse_grad, has_complex = self._init_group( 165 group, params_with_grad, grads, state_sums, state_steps 166 ) 167 168 adagrad( 169 params_with_grad, 170 grads, 171 state_sums, 172 state_steps, 173 lr=group["lr"], 174 weight_decay=group["weight_decay"], 175 lr_decay=group["lr_decay"], 176 eps=group["eps"], 177 has_sparse_grad=has_sparse_grad, 178 foreach=group["foreach"], 179 maximize=group["maximize"], 180 differentiable=group["differentiable"], 181 has_complex=has_complex, 182 fused=group["fused"], 183 grad_scale=getattr(self, "grad_scale", None), 184 found_inf=getattr(self, "found_inf", None), 185 ) 186 187 return loss 188 189 190Adagrad.__doc__ = ( 191 r"""Implements Adagrad algorithm. 192 193 .. math:: 194 \begin{aligned} 195 &\rule{110mm}{0.4pt} \\ 196 &\textbf{input} : \gamma \text{ (lr)}, \: \theta_0 \text{ (params)}, \: f(\theta) 197 \text{ (objective)}, \: \lambda \text{ (weight decay)}, \\ 198 &\hspace{12mm} \tau \text{ (initial accumulator value)}, \: \eta\text{ (lr decay)}\\ 199 &\textbf{initialize} : state\_sum_0 \leftarrow \tau \\[-1.ex] 200 &\rule{110mm}{0.4pt} \\ 201 &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\ 202 &\hspace{5mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\ 203 &\hspace{5mm} \tilde{\gamma} \leftarrow \gamma / (1 +(t-1) \eta) \\ 204 &\hspace{5mm} \textbf{if} \: \lambda \neq 0 \\ 205 &\hspace{10mm} g_t \leftarrow g_t + \lambda \theta_{t-1} \\ 206 &\hspace{5mm}state\_sum_t \leftarrow state\_sum_{t-1} + g^2_t \\ 207 &\hspace{5mm}\theta_t \leftarrow 208 \theta_{t-1}- \tilde{\gamma} \frac{g_t}{\sqrt{state\_sum_t}+\epsilon} \\ 209 &\rule{110mm}{0.4pt} \\[-1.ex] 210 &\bf{return} \: \theta_t \\[-1.ex] 211 &\rule{110mm}{0.4pt} \\[-1.ex] 212 \end{aligned} 213 214 For further details regarding the algorithm we refer to `Adaptive Subgradient Methods for Online Learning 215 and Stochastic Optimization`_. 216 """ 217 + rf""" 218 Args: 219 params (iterable): iterable of parameters to optimize or dicts defining 220 parameter groups 221 lr (float, Tensor, optional): learning rate (default: 1e-2) 222 lr_decay (float, optional): learning rate decay (default: 0) 223 weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 224 initial_accumulator_value (float, optional): initial value of the 225 sum of squares of gradients (default: 0) 226 eps (float, optional): term added to the denominator to improve 227 numerical stability (default: 1e-10) 228 {_foreach_doc} 229 {_maximize_doc} 230 {_differentiable_doc} 231 fused (bool, optional): whether the fused implementation (CPU only) is used. 232 Currently, `torch.float64`, `torch.float32`, `torch.float16`, and `torch.bfloat16` 233 are supported. (default: None). Please note that the fused implementations does not 234 support sparse or complex gradients. 235 .. _Adaptive Subgradient Methods for Online Learning and Stochastic 236 Optimization: http://jmlr.org/papers/v12/duchi11a.html 237 238 """ 239) 240 241 242def adagrad( 243 params: List[Tensor], 244 grads: List[Tensor], 245 state_sums: List[Tensor], 246 state_steps: List[Tensor], 247 fused: Optional[bool] = None, 248 grad_scale: Optional[Tensor] = None, 249 found_inf: Optional[Tensor] = None, 250 # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627 251 # setting these as kwargs for now as functional API is compiled by torch/distributed/optim 252 has_sparse_grad: bool = False, 253 foreach: Optional[bool] = None, 254 differentiable: bool = False, 255 has_complex: bool = False, 256 *, 257 lr: float, 258 weight_decay: float, 259 lr_decay: float, 260 eps: float, 261 maximize: bool, 262): 263 r"""Functional API that performs Adagrad algorithm computation. 264 265 See :class:`~torch.optim.Adagrad` for details. 266 """ 267 if not all(isinstance(t, torch.Tensor) for t in state_steps): 268 raise RuntimeError( 269 "API has changed, `state_steps` argument must contain a list of singleton tensors" 270 ) 271 272 # Respect when the user inputs False/True for foreach or fused. We only want to change 273 # the default when neither have been user-specified. Note that we default to foreach 274 # and pass False to use_fused. This is not a mistake--we want to give the fused impl 275 # bake-in time before making it the default, even if it is typically faster. 276 if fused is None and foreach is None: 277 _, foreach = _default_to_fused_or_foreach( 278 params, differentiable, use_fused=False 279 ) 280 281 if fused is None: 282 fused = False 283 if foreach is None: 284 foreach = False 285 286 if foreach and torch.jit.is_scripting(): 287 raise RuntimeError("torch.jit.script not supported with foreach optimizers") 288 if fused and torch.jit.is_scripting(): 289 raise RuntimeError("torch.jit.script not supported with fused optimizers") 290 291 if fused and not torch.jit.is_scripting(): 292 func = _fused_adagrad 293 elif foreach and not torch.jit.is_scripting(): 294 func = _multi_tensor_adagrad 295 else: 296 func = _single_tensor_adagrad 297 298 func( 299 params, 300 grads, 301 state_sums, 302 state_steps, 303 lr=lr, 304 weight_decay=weight_decay, 305 lr_decay=lr_decay, 306 eps=eps, 307 has_sparse_grad=has_sparse_grad, 308 maximize=maximize, 309 differentiable=differentiable, 310 has_complex=has_complex, 311 grad_scale=grad_scale, 312 found_inf=found_inf, 313 ) 314 315 316def _make_sparse(grad, grad_indices, values): 317 size = grad.size() 318 return torch.sparse_coo_tensor(grad_indices, values, size) 319 320 321def _single_tensor_adagrad( 322 params: List[Tensor], 323 grads: List[Tensor], 324 state_sums: List[Tensor], 325 state_steps: List[Tensor], 326 grad_scale: Optional[Tensor], 327 found_inf: Optional[Tensor], 328 *, 329 lr: float, 330 weight_decay: float, 331 lr_decay: float, 332 eps: float, 333 has_sparse_grad: bool, 334 maximize: bool, 335 differentiable: bool, 336 has_complex: bool, 337): 338 assert grad_scale is None and found_inf is None 339 for param, grad, state_sum, step_t in zip(params, grads, state_sums, state_steps): 340 # update step 341 step_t += 1 342 step = _get_value(step_t) 343 grad = grad if not maximize else -grad 344 345 if weight_decay != 0: 346 if grad.is_sparse: 347 raise RuntimeError( 348 "weight_decay option is not compatible with sparse gradients" 349 ) 350 grad = grad.add(param, alpha=weight_decay) 351 352 clr = lr / (1 + (step - 1) * lr_decay) 353 354 if grad.is_sparse: 355 grad = grad.coalesce() # the update is non-linear so indices must be unique 356 grad_indices = grad._indices() 357 grad_values = grad._values() 358 359 state_sum.add_(_make_sparse(grad, grad_indices, grad_values.pow(2))) 360 std = state_sum.sparse_mask(grad) 361 std_values = std._values().sqrt_().add_(eps) 362 param.add_( 363 _make_sparse(grad, grad_indices, grad_values / std_values), alpha=-clr 364 ) 365 else: 366 is_complex = torch.is_complex(param) 367 if is_complex: 368 grad = torch.view_as_real(grad) 369 state_sum = torch.view_as_real(state_sum) 370 param = torch.view_as_real(param) 371 state_sum.addcmul_(grad, grad, value=1) 372 if differentiable: 373 std = state_sum.sqrt() + eps 374 else: 375 std = state_sum.sqrt().add_(eps) 376 param.addcdiv_(grad, std, value=-clr) 377 if is_complex: 378 param = torch.view_as_complex(param) 379 state_sum = torch.view_as_complex(state_sum) 380 381 382def _multi_tensor_adagrad( 383 params: List[Tensor], 384 grads: List[Tensor], 385 state_sums: List[Tensor], 386 state_steps: List[Tensor], 387 grad_scale: Optional[Tensor], 388 found_inf: Optional[Tensor], 389 *, 390 lr: float, 391 weight_decay: float, 392 lr_decay: float, 393 eps: float, 394 has_sparse_grad: bool, 395 maximize: bool, 396 differentiable: bool, 397 has_complex: bool, 398): 399 assert not differentiable, "_foreach ops don't support autograd" 400 assert grad_scale is None and found_inf is None 401 402 # Foreach functions will throw errors if given empty lists 403 if len(params) == 0: 404 return 405 406 grouped_tensorlists = Optimizer._group_tensors_by_device_and_dtype( 407 [params, grads, state_sums, state_steps] # type: ignore[list-item] 408 ) 409 for ( 410 device_params_, 411 device_grads_, 412 device_state_sums_, 413 device_state_steps_, 414 ), _ in grouped_tensorlists.values(): 415 device_params = cast(List[Tensor], device_params_) 416 device_grads = cast(List[Tensor], device_grads_) 417 device_state_sums = cast(List[Tensor], device_state_sums_) 418 device_state_steps = cast(List[Tensor], device_state_steps_) 419 420 device_has_sparse_grad = has_sparse_grad and any( 421 grad.is_sparse for grad in device_grads 422 ) 423 424 if device_has_sparse_grad: 425 _single_tensor_adagrad( 426 device_params, 427 device_grads, 428 device_state_sums, 429 device_state_steps, 430 lr=lr, 431 weight_decay=weight_decay, 432 lr_decay=lr_decay, 433 eps=eps, 434 has_sparse_grad=True, 435 maximize=maximize, 436 differentiable=differentiable, 437 has_complex=has_complex, 438 grad_scale=grad_scale, 439 found_inf=found_inf, 440 ) 441 continue 442 443 # Handle complex parameters 444 if has_complex: 445 _view_as_real(device_params, device_grads, device_state_sums) 446 447 if maximize: 448 device_grads = torch._foreach_neg(device_grads) # type: ignore[assignment] 449 450 # Update steps 451 # If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over 452 # and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just 453 # wrapped it once now. The alpha is required to assure we go to the right overload. 454 if not torch._utils.is_compiling() and device_state_steps[0].is_cpu: 455 torch._foreach_add_( 456 device_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0 457 ) 458 else: 459 torch._foreach_add_(device_state_steps, 1) 460 461 if weight_decay != 0: 462 # Re-use the intermediate memory (device_grads) already allocated for maximize 463 if maximize: 464 torch._foreach_add_(device_grads, device_params, alpha=weight_decay) 465 else: 466 device_grads = torch._foreach_add( # type: ignore[assignment] 467 device_grads, device_params, alpha=weight_decay 468 ) 469 470 minus_clr = [ 471 -lr / (1 + (_get_value(step) - 1) * lr_decay) for step in device_state_steps 472 ] 473 474 torch._foreach_addcmul_(device_state_sums, device_grads, device_grads, value=1) 475 476 std = torch._foreach_sqrt(device_state_sums) 477 torch._foreach_add_(std, eps) 478 479 if weight_decay != 0 or maximize: 480 # Again, re-use the intermediate memory (device_grads) already allocated 481 torch._foreach_mul_(device_grads, minus_clr) 482 numerator = device_grads 483 else: 484 numerator = torch._foreach_mul(device_grads, minus_clr) # type: ignore[assignment] 485 486 torch._foreach_addcdiv_(device_params, numerator, std) 487 488 489def _fused_adagrad( 490 params: List[Tensor], 491 grads: List[Tensor], 492 state_sums: List[Tensor], 493 state_steps: List[Tensor], 494 grad_scale: Optional[Tensor], 495 found_inf: Optional[Tensor], 496 *, 497 lr: float, 498 weight_decay: float, 499 lr_decay: float, 500 eps: float, 501 has_sparse_grad: bool, 502 maximize: bool, 503 differentiable: bool, 504 has_complex: bool, 505) -> None: 506 if not params: 507 return 508 if has_sparse_grad or has_complex: 509 raise RuntimeError("`fused` does not support sparse grad or complex param") 510 511 if differentiable: 512 raise RuntimeError( 513 "adagrad with fused=True does not support differentiable=True" 514 ) 515 516 grad_scale_dict = ( 517 {grad_scale.device: grad_scale} if grad_scale is not None else None 518 ) 519 found_inf_dict = {found_inf.device: found_inf} if found_inf is not None else None 520 521 grouped_tensors = Optimizer._group_tensors_by_device_and_dtype( 522 [params, grads, state_sums, state_steps] # type: ignore[list-item] 523 ) 524 for (device, _), ( 525 ( 526 device_params_, 527 device_grads_, 528 device_state_sums_, 529 device_state_steps_, 530 ), 531 _, 532 ) in grouped_tensors.items(): 533 device_params = cast(List[Tensor], device_params_) 534 device_grads = cast(List[Tensor], device_grads_) 535 device_state_sums = cast(List[Tensor], device_state_sums_) 536 device_state_steps = cast(List[Tensor], device_state_steps_) 537 538 device_grad_scale, device_found_inf = None, None 539 if grad_scale is not None and grad_scale_dict is not None: 540 if device not in grad_scale_dict: 541 grad_scale_dict[device] = grad_scale.to(device, non_blocking=True) # type: ignore[index] 542 device_grad_scale = grad_scale_dict[device] # type: ignore[index] 543 if found_inf is not None and found_inf_dict is not None: 544 if found_inf not in found_inf_dict: 545 found_inf_dict[device] = found_inf.to(device, non_blocking=True) # type: ignore[index] 546 device_found_inf = found_inf_dict[device] # type: ignore[index] 547 torch._foreach_add_(device_state_steps, 1) 548 torch._fused_adagrad_( 549 device_params, 550 device_grads, 551 device_state_sums, 552 device_state_steps, 553 lr=lr, 554 lr_decay=lr_decay, 555 weight_decay=weight_decay, 556 eps=eps, 557 maximize=maximize, 558 grad_scale=device_grad_scale, 559 found_inf=device_found_inf, 560 ) 561 if device_found_inf is not None: 562 torch._foreach_sub_( 563 device_state_steps, [device_found_inf] * len(device_state_steps) 564 ) 565