1# mypy: allow-untyped-decorators 2# mypy: allow-untyped-defs 3from typing import cast, List, Optional, Tuple, Union 4 5import torch 6from torch import Tensor 7 8from .optimizer import ( 9 _capturable_doc, 10 _default_to_fused_or_foreach, 11 _differentiable_doc, 12 _disable_dynamo_if_unsupported, 13 _foreach_doc, 14 _get_capturable_supported_devices, 15 _get_scalar_dtype, 16 _get_value, 17 _maximize_doc, 18 _use_grad_for_differentiable, 19 _view_as_real, 20 Optimizer, 21 ParamsT, 22) 23 24 25__all__ = ["ASGD", "asgd"] 26 27 28class ASGD(Optimizer): 29 def __init__( 30 self, 31 params: ParamsT, 32 lr: Union[float, Tensor] = 1e-2, 33 lambd: float = 1e-4, 34 alpha: float = 0.75, 35 t0: float = 1e6, 36 weight_decay: float = 0, 37 foreach: Optional[bool] = None, 38 maximize: bool = False, 39 differentiable: bool = False, 40 capturable: bool = False, 41 ): 42 if isinstance(lr, Tensor) and lr.numel() != 1: 43 raise ValueError("Tensor lr must be 1-element") 44 if not 0.0 <= lr: 45 raise ValueError(f"Invalid learning rate: {lr}") 46 if not 0.0 <= weight_decay: 47 raise ValueError(f"Invalid weight_decay value: {weight_decay}") 48 49 defaults = dict( 50 lr=lr, 51 lambd=lambd, 52 alpha=alpha, 53 t0=t0, 54 weight_decay=weight_decay, 55 foreach=foreach, 56 maximize=maximize, 57 differentiable=differentiable, 58 capturable=capturable, 59 ) 60 super().__init__(params, defaults) 61 62 def __setstate__(self, state): 63 super().__setstate__(state) 64 for group in self.param_groups: 65 group.setdefault("foreach", None) 66 group.setdefault("maximize", False) 67 group.setdefault("differentiable", False) 68 group.setdefault("capturable", False) 69 for p in group["params"]: 70 p_state = self.state.get(p, []) 71 if len(p_state) != 0: 72 if not torch.is_tensor(p_state["step"]): 73 step_val = float(p_state["step"]) 74 p_state["step"] = torch.tensor( 75 step_val, dtype=_get_scalar_dtype(), device=p.device 76 ) 77 if not torch.is_tensor(p_state["eta"]): 78 p_state["eta"] = torch.tensor( 79 p_state["eta"], dtype=_get_scalar_dtype(), device=p.device 80 ) 81 if not torch.is_tensor(p_state["mu"]): 82 p_state["mu"] = torch.tensor( 83 p_state["mu"], dtype=_get_scalar_dtype(), device=p.device 84 ) 85 86 def _init_group(self, group, params_with_grad, grads, mus, axs, etas, state_steps): 87 has_complex = False 88 for p in group["params"]: 89 if p.grad is not None: 90 has_complex |= torch.is_complex(p) 91 params_with_grad.append(p) 92 if p.grad.is_sparse: 93 raise RuntimeError("ASGD does not support sparse gradients") 94 grads.append(p.grad) 95 96 state = self.state[p] 97 # State initialization 98 if len(state) == 0: 99 state["step"] = torch.zeros( 100 (), device=p.device, dtype=_get_scalar_dtype() 101 ) 102 state["eta"] = ( 103 torch.as_tensor( 104 group["lr"], device=p.device, dtype=_get_scalar_dtype() 105 ) 106 .clone() 107 .detach() 108 ) 109 state["mu"] = torch.ones( 110 (), device=p.device, dtype=_get_scalar_dtype() 111 ) 112 state["ax"] = torch.zeros_like( 113 p, memory_format=torch.preserve_format 114 ) 115 116 mus.append(state["mu"]) 117 axs.append(state["ax"]) 118 etas.append(state["eta"]) 119 state_steps.append(state["step"]) 120 return has_complex 121 122 @_use_grad_for_differentiable 123 def step(self, closure=None): 124 """Perform a single optimization step. 125 126 Args: 127 closure (Callable, optional): A closure that reevaluates the model 128 and returns the loss. 129 """ 130 self._cuda_graph_capture_health_check() 131 132 loss = None 133 if closure is not None: 134 with torch.enable_grad(): 135 loss = closure() 136 137 for group in self.param_groups: 138 params_with_grad: List[Tensor] = [] 139 grads: List[Tensor] = [] 140 mus: List[Tensor] = [] 141 axs: List[Tensor] = [] 142 etas: List[Tensor] = [] 143 state_steps: List[Tensor] = [] 144 145 has_complex = self._init_group( 146 group, params_with_grad, grads, mus, axs, etas, state_steps 147 ) 148 149 asgd( 150 params_with_grad, 151 grads, 152 axs, 153 mus, 154 etas, 155 state_steps, 156 lambd=group["lambd"], 157 lr=group["lr"], 158 t0=group["t0"], 159 alpha=group["alpha"], 160 weight_decay=group["weight_decay"], 161 foreach=group["foreach"], 162 maximize=group["maximize"], 163 differentiable=group["differentiable"], 164 capturable=group["capturable"], 165 has_complex=has_complex, 166 ) 167 168 return loss 169 170 171ASGD.__doc__ = rf"""Implements Averaged Stochastic Gradient Descent. 172 173 It has been proposed in `Acceleration of stochastic approximation by 174 averaging`_. 175 176 Args: 177 params (iterable): iterable of parameters to optimize or dicts defining 178 parameter groups 179 lr (float, Tensor, optional): learning rate (default: 1e-2) 180 lambd (float, optional): decay term (default: 1e-4) 181 alpha (float, optional): power for eta update (default: 0.75) 182 t0 (float, optional): point at which to start averaging (default: 1e6) 183 weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 184 {_foreach_doc} 185 {_maximize_doc} 186 {_differentiable_doc} 187 {_capturable_doc} 188 189 .. _Acceleration of stochastic approximation by averaging: 190 https://dl.acm.org/citation.cfm?id=131098 191 192 """ 193 194 195def _single_tensor_asgd( 196 params: List[Tensor], 197 grads: List[Tensor], 198 axs: List[Tensor], 199 mus: List[Tensor], 200 etas: List[Tensor], 201 state_steps: List[Tensor], 202 *, 203 lambd: float, 204 lr: float, 205 t0: float, 206 alpha: float, 207 weight_decay: float, 208 maximize: bool, 209 differentiable: bool, 210 capturable: bool, 211 has_complex: bool, 212): 213 for i, param in enumerate(params): 214 grad = grads[i] 215 grad = grad if not maximize else -grad 216 mu = mus[i] 217 ax = axs[i] 218 eta = etas[i] 219 step_t = state_steps[i] 220 221 # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] 222 if not torch._utils.is_compiling() and capturable: 223 capturable_supported_devices = _get_capturable_supported_devices() 224 assert ( 225 param.device.type 226 == mu.device.type 227 == eta.device.type 228 == step_t.device.type 229 and param.device.type in capturable_supported_devices 230 ), ( 231 f"If capturable=True, params, mus, etas, and state_steps must be " 232 f"on supported devices: {capturable_supported_devices}." 233 ) 234 235 if torch.is_complex(param): 236 grad = torch.view_as_real(grad) 237 param = torch.view_as_real(param) 238 ax = torch.view_as_real(ax) 239 240 # update step 241 step_t += 1 242 243 if weight_decay != 0: 244 grad = grad.add(param, alpha=weight_decay) 245 246 if capturable: 247 param.mul_(1 - lambd * eta) 248 param.addcmul_(grad, eta, value=-1) # update parameter 249 else: 250 eta_value = _get_value(eta) 251 param.mul_(1 - lambd * eta_value) # decay term 252 param.add_(grad, alpha=-eta_value) # update parameter 253 254 # averaging 255 if capturable or mu.item() != 1: 256 ax.add_(param.sub(ax).mul_(mu)) 257 else: 258 ax.copy_(param) 259 260 if capturable: 261 eta.copy_(lr / ((1 + lambd * lr * step_t) ** alpha)) 262 mu.copy_(1 / torch.maximum(step_t - t0, torch.ones_like(step_t))) 263 else: 264 step = _get_value(step_t) 265 new_eta = torch.as_tensor(lr / ((1 + lambd * lr * step) ** alpha)) 266 eta.copy_(new_eta) 267 new_mu = torch.as_tensor(1 / max(1, step - t0)) 268 mu.copy_(new_mu) 269 270 271def _multi_tensor_asgd( 272 params: List[Tensor], 273 grads: List[Tensor], 274 axs: List[Tensor], 275 mus: List[Tensor], 276 etas: List[Tensor], 277 state_steps: List[Tensor], 278 *, 279 lambd: float, 280 lr: float, 281 t0: float, 282 alpha: float, 283 weight_decay: float, 284 maximize: bool, 285 differentiable: bool, 286 capturable: bool, 287 has_complex: bool, 288): 289 if len(params) == 0: 290 return 291 292 assert not differentiable, "_foreach ops don't support autograd" 293 294 # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] 295 if not torch._utils.is_compiling() and capturable: 296 capturable_supported_devices = _get_capturable_supported_devices( 297 supports_xla=False 298 ) 299 assert all( 300 p.device.type == mu.device.type == eta.device.type == step.device.type 301 and p.device.type in capturable_supported_devices 302 for p, mu, eta, step in zip(params, mus, etas, state_steps) 303 ), f"If capturable=True, params, mus, etas, and state_steps must be on supported devices: {capturable_supported_devices}." 304 305 grouped_tensors = Optimizer._group_tensors_by_device_and_dtype( 306 [params, grads, axs, mus, etas, state_steps] # type: ignore[list-item] 307 ) 308 for (device, _), ( 309 ( 310 grouped_params_, 311 grouped_grads_, 312 grouped_axs_, 313 grouped_mus_, 314 grouped_etas_, 315 grouped_state_steps_, 316 ), 317 _, 318 ) in grouped_tensors.items(): 319 grouped_params = cast(List[Tensor], grouped_params_) 320 grouped_grads = cast(List[Tensor], grouped_grads_) 321 grouped_axs = cast(List[Tensor], grouped_axs_) 322 grouped_mus = cast(List[Tensor], grouped_mus_) 323 grouped_etas = cast(List[Tensor], grouped_etas_) 324 grouped_state_steps = cast(List[Tensor], grouped_state_steps_) 325 326 if has_complex: 327 _view_as_real(grouped_params, grouped_grads, grouped_axs) 328 329 if maximize: 330 grouped_grads = torch._foreach_neg(grouped_grads) # type: ignore[assignment] 331 332 # Update steps 333 # If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over 334 # and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just 335 # wrapped it once now. The alpha is required to assure we go to the right overload. 336 if not torch._utils.is_compiling() and grouped_state_steps[0].is_cpu: 337 torch._foreach_add_( 338 grouped_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0 339 ) 340 else: 341 torch._foreach_add_(grouped_state_steps, 1) 342 343 # intermediate = grad + param * lambd 344 intermediate: Union[Tuple[Tensor, ...], List[Tensor]] 345 if weight_decay != 0: 346 if maximize: 347 torch._foreach_add_(grouped_grads, grouped_params, alpha=weight_decay) 348 intermediate = grouped_grads 349 else: 350 intermediate = torch._foreach_add( 351 grouped_grads, grouped_params, alpha=weight_decay 352 ) 353 354 torch._foreach_add_(intermediate, grouped_params, alpha=lambd) 355 else: 356 intermediate = torch._foreach_add( 357 grouped_grads, grouped_params, alpha=lambd 358 ) 359 360 # update param 361 # param * (1 - lambd * eta) - eta * grad 362 # => param - param * lambd * eta - eta * grad 363 # => param - eta * intermediate 364 torch._foreach_addcmul_(grouped_params, intermediate, grouped_etas, value=-1) 365 del intermediate 366 367 # update grouped_axs 368 # averaging: ax = ax + mu * (param - ax) 369 # Note (mlazos): We can't use lerp here since it requires weight to be float64 370 # and our grouping code requires dtypes to match for all tensors in a group (and it should, since 371 # we use the mus in other places) 372 # all dtypes need to match, so we could introduce a cast in a loop 373 # but since this only adds one additional kernel launch, this looks like the cleaner 374 # and faster solution 375 intermediate = torch._foreach_sub(grouped_params, grouped_axs) 376 torch._foreach_addcmul_(grouped_axs, intermediate, grouped_mus) 377 del intermediate 378 379 new_etas: Union[Tuple[Tensor, ...], List[Tensor]] 380 new_mus: Union[Tuple[Tensor, ...], List[Tensor]] 381 if capturable: 382 # update grouped_mus 383 new_mus = torch._foreach_sub(grouped_state_steps, t0) 384 torch._foreach_maximum_(new_mus, 1.0) 385 torch._foreach_reciprocal_(new_mus) 386 torch._foreach_copy_(grouped_mus, new_mus) 387 del new_mus 388 389 # update eta = lr / ((1 + lambd * lr * step)^alpha) 390 new_etas = torch._foreach_mul(grouped_state_steps, lambd) 391 torch._foreach_mul_(new_etas, lr) 392 torch._foreach_add_(new_etas, 1) 393 torch._foreach_pow_(new_etas, alpha) 394 torch._foreach_reciprocal_(new_etas) 395 torch._foreach_mul_(new_etas, lr) 396 torch._foreach_copy_(grouped_etas, new_etas) 397 else: 398 new_etas = [ 399 torch.as_tensor(lr / ((1 + lambd * lr * step) ** alpha), device=device) 400 for step in grouped_state_steps 401 ] 402 new_mus = [ 403 torch.as_tensor(1 / max(1, _get_value(step) - t0), device=device) 404 for step in grouped_state_steps 405 ] 406 torch._foreach_copy_(grouped_etas, new_etas) 407 torch._foreach_copy_(grouped_mus, new_mus) 408 409 410@_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_asgd) 411def asgd( 412 params: List[Tensor], 413 grads: List[Tensor], 414 axs: List[Tensor], 415 mus: List[Tensor], 416 etas: List[Tensor], 417 state_steps: List[Tensor], 418 # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627 419 # setting this as kwarg for now as functional API is compiled by torch/distributed/optim 420 foreach: Optional[bool] = None, 421 maximize: bool = False, 422 differentiable: bool = False, 423 capturable: bool = False, 424 has_complex: bool = False, 425 *, 426 lambd: float, 427 lr: float, 428 t0: float, 429 alpha: float, 430 weight_decay: float, 431): 432 r"""Functional API that performs asgd algorithm computation. 433 434 See :class:`~torch.optim.ASGD` for details. 435 """ 436 if foreach is None: 437 _, foreach = _default_to_fused_or_foreach( 438 params, differentiable, use_fused=False 439 ) 440 441 if foreach and torch.jit.is_scripting(): 442 raise RuntimeError("torch.jit.script not supported with foreach optimizers") 443 444 if foreach and not torch.jit.is_scripting(): 445 func = _multi_tensor_asgd 446 else: 447 func = _single_tensor_asgd 448 449 func( 450 params, 451 grads, 452 axs, 453 mus, 454 etas, 455 state_steps, 456 lambd=lambd, 457 lr=lr, 458 t0=t0, 459 alpha=alpha, 460 weight_decay=weight_decay, 461 maximize=maximize, 462 differentiable=differentiable, 463 capturable=capturable, 464 has_complex=has_complex, 465 ) 466