1# mypy: allow-untyped-defs 2from typing import Dict, List, Optional, Tuple 3 4import torch 5import torch.optim._functional as F 6from torch import Tensor 7 8 9__all__: List[str] = [] 10 11 12# Define a TorchScript compatible Functional Adamax Optimizer 13# where we use these optimizer in a functional way. 14# Instead of using the `param.grad` when updating parameters, 15# we explicitly allow the distributed optimizer pass gradients to 16# the `step` function. In this way, we could separate the gradients 17# and parameters and allow multithreaded trainer to update the 18# parameters without data traces on accumulating to the same .grad. 19# NOTE: This should be only used by distributed optimizer internals 20# and not meant to expose to the user. 21@torch.jit.script 22class _FunctionalAdamax: 23 def __init__( 24 self, 25 params: List[Tensor], 26 lr: float = 1e-3, 27 betas: Tuple[float, float] = (0.9, 0.999), 28 eps: float = 1e-8, 29 weight_decay: float = 0.0, 30 foreach: bool = False, 31 maximize: bool = False, 32 _allow_empty_param_list: bool = False, 33 ): 34 if not 0.0 <= lr: 35 raise ValueError(f"Invalid learning rate: {lr}") 36 if not 0.0 <= eps: 37 raise ValueError(f"Invalid epsilon value: {eps}") 38 if not 0.0 <= betas[0] < 1.0: 39 raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}") 40 if not 0.0 <= betas[1] < 1.0: 41 raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}") 42 if not 0.0 <= weight_decay: 43 raise ValueError(f"Invalid weight_decay value: {weight_decay}") 44 45 self.defaults = { 46 "lr": lr, 47 "eps": eps, 48 "beta1": betas[0], 49 "beta2": betas[1], 50 "weight_decay": weight_decay, 51 } 52 self.foreach = foreach 53 self.maximize = maximize 54 self.state = torch.jit.annotate(Dict[torch.Tensor, Dict[str, torch.Tensor]], {}) 55 56 if len(params) == 0 and not _allow_empty_param_list: 57 raise ValueError("optimizer got an empty parameter list") 58 59 # NOTE: we only have one param_group and don't allow user to add additional 60 # param group as it's not a common use case. 61 self.param_group = {"params": params} 62 63 def step(self, gradients: List[Optional[Tensor]]): 64 params = self.param_group["params"] 65 params_with_grad = [] 66 grads = [] 67 exp_avgs = [] 68 exp_infs = [] 69 state_steps: List[Tensor] = [] 70 71 if len(params) != len(gradients): 72 raise ValueError( 73 "the gradients passed in does not equal to the size of the parameters!" 74 + f"Params length: {len(params)}. " 75 + f"Gradients length: {len(gradients)}" 76 ) 77 78 has_complex = False 79 for param, gradient in zip(self.param_group["params"], gradients): 80 if gradient is not None: 81 has_complex |= torch.is_complex(param) 82 params_with_grad.append(param) 83 grads.append(gradient) 84 # Lazy state initialization 85 if param not in self.state: 86 self.state[param] = {} 87 state = self.state[param] 88 state["step"] = torch.tensor(0.0) 89 # Exponential moving average of gradient values 90 state["exp_avg"] = torch.zeros_like( 91 param, memory_format=torch.preserve_format 92 ) 93 # Exponential moving average of squared gradient values 94 state["exp_inf"] = torch.zeros_like( 95 param, memory_format=torch.preserve_format 96 ) 97 98 state = self.state[param] 99 100 exp_avgs.append(state["exp_avg"]) 101 exp_infs.append(state["exp_inf"]) 102 state_steps.append(state["step"]) 103 104 with torch.no_grad(): 105 F.adamax( 106 params_with_grad, 107 grads, 108 exp_avgs, 109 exp_infs, 110 state_steps, 111 eps=self.defaults["eps"], 112 beta1=self.defaults["beta1"], 113 beta2=self.defaults["beta2"], 114 lr=self.defaults["lr"], 115 weight_decay=self.defaults["weight_decay"], 116 foreach=self.foreach, 117 maximize=self.maximize, 118 has_complex=has_complex, 119 ) 120