1# mypy: allow-untyped-decorators 2# mypy: allow-untyped-defs 3from typing import cast, List, Optional, Tuple, Union 4 5import torch 6from torch import Tensor 7 8from .optimizer import ( 9 _capturable_doc, 10 _default_to_fused_or_foreach, 11 _differentiable_doc, 12 _disable_dynamo_if_unsupported, 13 _foreach_doc, 14 _get_capturable_supported_devices, 15 _get_scalar_dtype, 16 _get_value, 17 _maximize_doc, 18 _use_grad_for_differentiable, 19 _view_as_real, 20 Optimizer, 21 ParamsT, 22) 23 24 25__all__ = ["Adamax", "adamax"] 26 27 28class Adamax(Optimizer): 29 def __init__( 30 self, 31 params: ParamsT, 32 lr: Union[float, Tensor] = 2e-3, 33 betas: Tuple[float, float] = (0.9, 0.999), 34 eps: float = 1e-8, 35 weight_decay: float = 0, 36 foreach: Optional[bool] = None, 37 *, 38 maximize: bool = False, 39 differentiable: bool = False, 40 capturable: bool = False, 41 ): 42 if isinstance(lr, Tensor) and lr.numel() != 1: 43 raise ValueError("Tensor lr must be 1-element") 44 if not 0.0 <= lr: 45 raise ValueError(f"Invalid learning rate: {lr}") 46 if not 0.0 <= eps: 47 raise ValueError(f"Invalid epsilon value: {eps}") 48 if not 0.0 <= betas[0] < 1.0: 49 raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}") 50 if not 0.0 <= betas[1] < 1.0: 51 raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}") 52 if not 0.0 <= weight_decay: 53 raise ValueError(f"Invalid weight_decay value: {weight_decay}") 54 55 defaults = dict( 56 lr=lr, 57 betas=betas, 58 eps=eps, 59 weight_decay=weight_decay, 60 foreach=foreach, 61 maximize=maximize, 62 differentiable=differentiable, 63 capturable=capturable, 64 ) 65 super().__init__(params, defaults) 66 67 def __setstate__(self, state): 68 super().__setstate__(state) 69 for group in self.param_groups: 70 group.setdefault("foreach", None) 71 group.setdefault("maximize", False) 72 group.setdefault("differentiable", False) 73 group.setdefault("capturable", False) 74 for p in group["params"]: 75 p_state = self.state.get(p, []) 76 if len(p_state) != 0 and not torch.is_tensor(p_state["step"]): 77 step_val = float(p_state["step"]) 78 p_state["step"] = ( 79 torch.tensor( 80 step_val, dtype=_get_scalar_dtype(), device=p.device 81 ) 82 if group["capturable"] 83 else torch.tensor(step_val, dtype=_get_scalar_dtype()) 84 ) 85 86 def _init_group( 87 self, group, params_with_grad, grads, exp_avgs, exp_infs, state_steps 88 ): 89 has_complex = False 90 for p in group["params"]: 91 if p.grad is None: 92 continue 93 has_complex |= torch.is_complex(p) 94 params_with_grad.append(p) 95 if p.grad.is_sparse: 96 raise RuntimeError("Adamax does not support sparse gradients") 97 grads.append(p.grad) 98 99 state = self.state[p] 100 101 # State initialization 102 if len(state) == 0: 103 state["step"] = ( 104 torch.zeros((), dtype=_get_scalar_dtype(), device=p.device) 105 if group["capturable"] 106 else torch.tensor(0.0, dtype=_get_scalar_dtype()) 107 ) 108 state["exp_avg"] = torch.zeros_like( 109 p, memory_format=torch.preserve_format 110 ) 111 state["exp_inf"] = torch.zeros_like( 112 p, memory_format=torch.preserve_format 113 ) 114 115 exp_avgs.append(state["exp_avg"]) 116 exp_infs.append(state["exp_inf"]) 117 state_steps.append(state["step"]) 118 119 return has_complex 120 121 @_use_grad_for_differentiable 122 def step(self, closure=None): 123 """Performs a single optimization step. 124 125 Args: 126 closure (Callable, optional): A closure that reevaluates the model 127 and returns the loss. 128 """ 129 self._cuda_graph_capture_health_check() 130 131 loss = None 132 if closure is not None: 133 with torch.enable_grad(): 134 loss = closure() 135 136 for group in self.param_groups: 137 params_with_grad: List[Tensor] = [] 138 grads: List[Tensor] = [] 139 exp_avgs: List[Tensor] = [] 140 exp_infs: List[Tensor] = [] 141 state_steps: List[Tensor] = [] 142 143 beta1, beta2 = group["betas"] 144 eps = group["eps"] 145 lr = group["lr"] 146 weight_decay = group["weight_decay"] 147 foreach = group["foreach"] 148 maximize = group["maximize"] 149 differentiable = group["differentiable"] 150 capturable = group["capturable"] 151 152 has_complex = self._init_group( 153 group, params_with_grad, grads, exp_avgs, exp_infs, state_steps 154 ) 155 156 adamax( 157 params_with_grad, 158 grads, 159 exp_avgs, 160 exp_infs, 161 state_steps, 162 eps=eps, 163 beta1=beta1, 164 beta2=beta2, 165 lr=lr, 166 weight_decay=weight_decay, 167 foreach=foreach, 168 maximize=maximize, 169 differentiable=differentiable, 170 capturable=capturable, 171 has_complex=has_complex, 172 ) 173 174 return loss 175 176 177Adamax.__doc__ = ( 178 r"""Implements Adamax algorithm (a variant of Adam based on infinity norm). 179 180 .. math:: 181 \begin{aligned} 182 &\rule{110mm}{0.4pt} \\ 183 &\textbf{input} : \gamma \text{ (lr)}, \beta_1, \beta_2 184 \text{ (betas)},\theta_0 \text{ (params)},f(\theta) \text{ (objective)}, 185 \: \lambda \text{ (weight decay)}, \\ 186 &\hspace{13mm} \epsilon \text{ (epsilon)} \\ 187 &\textbf{initialize} : m_0 \leftarrow 0 \text{ ( first moment)}, 188 u_0 \leftarrow 0 \text{ ( infinity norm)} \\[-1.ex] 189 &\rule{110mm}{0.4pt} \\ 190 &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\ 191 &\hspace{5mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\ 192 &\hspace{5mm}if \: \lambda \neq 0 \\ 193 &\hspace{10mm} g_t \leftarrow g_t + \lambda \theta_{t-1} \\ 194 &\hspace{5mm}m_t \leftarrow \beta_1 m_{t-1} + (1 - \beta_1) g_t \\ 195 &\hspace{5mm}u_t \leftarrow \mathrm{max}(\beta_2 u_{t-1}, |g_{t}|+\epsilon) \\ 196 &\hspace{5mm}\theta_t \leftarrow \theta_{t-1} - \frac{\gamma m_t}{(1-\beta^t_1) u_t} \\ 197 &\rule{110mm}{0.4pt} \\[-1.ex] 198 &\bf{return} \: \theta_t \\[-1.ex] 199 &\rule{110mm}{0.4pt} \\[-1.ex] 200 \end{aligned} 201 202 For further details regarding the algorithm we refer to `Adam: A Method for Stochastic Optimization`_. 203 """ 204 + rf""" 205 Args: 206 params (iterable): iterable of parameters to optimize or dicts defining 207 parameter groups 208 lr (float, Tensor, optional): learning rate (default: 2e-3) 209 betas (Tuple[float, float], optional): coefficients used for computing 210 running averages of gradient and its square 211 eps (float, optional): term added to the denominator to improve 212 numerical stability (default: 1e-8) 213 weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 214 {_foreach_doc} 215 {_maximize_doc} 216 {_differentiable_doc} 217 {_capturable_doc} 218 219 .. _Adam\: A Method for Stochastic Optimization: 220 https://arxiv.org/abs/1412.6980 221 222 """ 223) 224 225 226def _single_tensor_adamax( 227 params: List[Tensor], 228 grads: List[Tensor], 229 exp_avgs: List[Tensor], 230 exp_infs: List[Tensor], 231 state_steps: List[Tensor], 232 *, 233 eps: float, 234 beta1: float, 235 beta2: float, 236 lr: float, 237 weight_decay: float, 238 maximize: bool, 239 differentiable: bool, 240 capturable: bool, 241 has_complex: bool, 242): 243 for i, param in enumerate(params): 244 grad = grads[i] 245 grad = grad if not maximize else -grad 246 exp_avg = exp_avgs[i] 247 exp_inf = exp_infs[i] 248 step_t = state_steps[i] 249 250 # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] 251 if not torch._utils.is_compiling() and capturable: 252 capturable_supported_devices = _get_capturable_supported_devices() 253 assert ( 254 param.device.type == step_t.device.type 255 and param.device.type in capturable_supported_devices 256 ), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}." 257 258 # update step 259 step_t += 1 260 261 if weight_decay != 0: 262 grad = grad.add(param, alpha=weight_decay) 263 264 if torch.is_complex(param): 265 param = torch.view_as_real(param) 266 grad = torch.view_as_real(grad) 267 exp_avg = torch.view_as_real(exp_avg) 268 exp_inf = torch.view_as_real(exp_inf) 269 270 # Update biased first moment estimate. 271 exp_avg.lerp_(grad, 1 - beta1) 272 # Update the exponentially weighted infinity norm. 273 if not differentiable: 274 torch.maximum( 275 exp_inf.mul_(beta2), 276 grad.abs().add_(eps), 277 out=exp_inf, 278 ) 279 else: 280 norm_buf = torch.cat( 281 [exp_inf.mul_(beta2).unsqueeze(0), grad.abs().add_(eps).unsqueeze_(0)], 282 0, 283 ) 284 exp_inf.copy_(torch.amax(norm_buf, 0, keepdim=False)) 285 286 if capturable: 287 # why jump through extra hoops and negate bias_correction? check out #121238 288 # once fixed, we should use bias_correction with addcdiv value=-1 for readability 289 neg_bias_correction = beta1**step_t - 1 290 neg_bias_correction.div_(lr) 291 denom = exp_inf * neg_bias_correction 292 param.addcdiv_(exp_avg, denom) 293 else: 294 bias_correction = 1 - beta1 ** _get_value(step_t) 295 clr = lr / bias_correction 296 297 param.addcdiv_(exp_avg, exp_inf, value=-clr) 298 299 300def _multi_tensor_adamax( 301 params: List[Tensor], 302 grads: List[Tensor], 303 exp_avgs: List[Tensor], 304 exp_infs: List[Tensor], 305 state_steps: List[Tensor], 306 *, 307 eps: float, 308 beta1: float, 309 beta2: float, 310 lr: float, 311 weight_decay: float, 312 maximize: bool, 313 differentiable: bool, 314 capturable: bool, 315 has_complex: bool, 316): 317 assert not differentiable, "_foreach ops don't support autograd" 318 319 if len(params) == 0: 320 return 321 322 # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] 323 if not torch._utils.is_compiling() and capturable: 324 capturable_supported_devices = _get_capturable_supported_devices( 325 supports_xla=False 326 ) 327 assert all( 328 p.device.type == step.device.type 329 and p.device.type in capturable_supported_devices 330 for p, step in zip(params, state_steps) 331 ), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}." 332 333 grouped_tensors = Optimizer._group_tensors_by_device_and_dtype( 334 [params, grads, exp_avgs, exp_infs, state_steps] # type: ignore[list-item] 335 ) 336 for ( 337 grouped_params_, 338 grouped_grads_, 339 grouped_exp_avgs_, 340 grouped_exp_infs_, 341 grouped_state_steps_, 342 ), _ in grouped_tensors.values(): 343 grouped_params = cast(List[Tensor], grouped_params_) 344 grouped_grads = cast(List[Tensor], grouped_grads_) 345 grouped_exp_avgs = cast(List[Tensor], grouped_exp_avgs_) 346 grouped_exp_infs = cast(List[Tensor], grouped_exp_infs_) 347 grouped_state_steps = cast(List[Tensor], grouped_state_steps_) 348 349 if has_complex: 350 _view_as_real( 351 grouped_params, grouped_grads, grouped_exp_avgs, grouped_exp_infs 352 ) 353 354 if maximize: 355 grouped_grads = torch._foreach_neg(grouped_grads) # type: ignore[assignment] 356 357 # Update steps 358 # If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over 359 # and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just 360 # wrapped it once now. The alpha is required to assure we go to the right overload. 361 if not torch._utils.is_compiling() and grouped_state_steps[0].is_cpu: 362 torch._foreach_add_( 363 grouped_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0 364 ) 365 else: 366 torch._foreach_add_(grouped_state_steps, 1) 367 368 if weight_decay != 0: 369 if maximize: 370 # Re-use the intermediate memory (grouped_grads) already allocated for maximize 371 torch._foreach_add_(grouped_grads, grouped_params, alpha=weight_decay) 372 else: 373 grouped_grads = torch._foreach_add( # type: ignore[assignment] 374 grouped_grads, grouped_params, alpha=weight_decay 375 ) 376 377 # Update biased first moment estimate. 378 torch._foreach_lerp_(grouped_exp_avgs, grouped_grads, 1 - beta1) 379 380 # Update the exponentially weighted infinity norm. 381 torch._foreach_mul_(grouped_exp_infs, beta2) 382 383 # in this case, we need to introduce a copy of the grads 384 # since one has not been introduced previously 385 if not maximize and weight_decay == 0: 386 grouped_grads = torch._foreach_abs(grouped_grads) # type: ignore[assignment] 387 else: 388 torch._foreach_abs_(grouped_grads) 389 390 torch._foreach_add_(grouped_grads, eps) 391 torch._foreach_maximum_(grouped_exp_infs, grouped_grads) 392 393 bias_corrections: Union[Tuple[Tensor, ...], List[Tensor]] 394 if capturable: 395 bias_corrections = torch._foreach_pow(beta1, grouped_state_steps) 396 # foreach_sub doesn't allow a scalar as the first arg 397 torch._foreach_sub_(bias_corrections, 1) 398 torch._foreach_div_(bias_corrections, lr) 399 400 denom = torch._foreach_mul(grouped_exp_infs, bias_corrections) 401 torch._foreach_addcdiv_(grouped_params, grouped_exp_avgs, denom) 402 else: 403 bias_corrections = [ 404 1 - beta1 ** _get_value(step) for step in grouped_state_steps 405 ] 406 step_size = [(_get_value(lr) / bc) * -1 for bc in bias_corrections] 407 torch._foreach_addcdiv_( 408 grouped_params, grouped_exp_avgs, grouped_exp_infs, step_size 409 ) 410 411 412@_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_adamax) 413def adamax( 414 params: List[Tensor], 415 grads: List[Tensor], 416 exp_avgs: List[Tensor], 417 exp_infs: List[Tensor], 418 state_steps: List[Tensor], 419 # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627 420 # setting this as kwarg for now as functional API is compiled by torch/distributed/optim 421 foreach: Optional[bool] = None, 422 maximize: bool = False, 423 differentiable: bool = False, 424 capturable: bool = False, 425 has_complex: bool = False, 426 *, 427 eps: float, 428 beta1: float, 429 beta2: float, 430 lr: float, 431 weight_decay: float, 432): 433 r"""Functional API that performs adamax algorithm computation. 434 435 See :class:`~torch.optim.Adamax` for details. 436 """ 437 438 if not torch._utils.is_compiling() and not all( 439 isinstance(t, torch.Tensor) for t in state_steps 440 ): 441 raise RuntimeError( 442 "API has changed, `state_steps` argument must contain a list of singleton tensors" 443 ) 444 445 if foreach is None: 446 _, foreach = _default_to_fused_or_foreach( 447 params, differentiable, use_fused=False 448 ) 449 450 if foreach and torch.jit.is_scripting(): 451 raise RuntimeError("torch.jit.script not supported with foreach optimizers") 452 453 if foreach and not torch.jit.is_scripting(): 454 func = _multi_tensor_adamax 455 else: 456 func = _single_tensor_adamax 457 458 func( 459 params, 460 grads, 461 exp_avgs, 462 exp_infs, 463 state_steps, 464 eps=eps, 465 beta1=beta1, 466 beta2=beta2, 467 lr=lr, 468 weight_decay=weight_decay, 469 maximize=maximize, 470 differentiable=differentiable, 471 has_complex=has_complex, 472 capturable=capturable, 473 ) 474