1# mypy: allow-untyped-defs 2from typing import Dict, List, Optional, Tuple 3 4import torch 5import torch.optim._functional as F 6from torch import Tensor 7 8 9__all__: List[str] = [] 10 11 12# Define a TorchScript compatible Functional AdamW Optimizer 13# where we use these optimizer in a functional way. 14# Instead of using the `param.grad` when updating parameters, 15# we explicitly allow the distributed optimizer pass gradients to 16# the `step` function. In this way, we could separate the gradients 17# and parameters and allow multithreaded trainer to update the 18# parameters without data traces on accumulating to the same .grad. 19# NOTE: This should be only used by distributed optimizer internals 20# and not meant to expose to the user. 21@torch.jit.script 22class _FunctionalAdamW: 23 def __init__( 24 self, 25 params: List[Tensor], 26 lr: float = 1e-3, 27 betas: Tuple[float, float] = (0.9, 0.999), 28 eps: float = 1e-8, 29 weight_decay: float = 1e-2, 30 amsgrad: bool = False, 31 maximize: bool = False, 32 foreach: bool = False, 33 fused: bool = False, 34 _allow_empty_param_list: bool = False, 35 ): 36 if not 0.0 <= lr: 37 raise ValueError(f"Invalid learning rate: {lr}") 38 if not 0.0 <= eps: 39 raise ValueError(f"Invalid epsilon value: {eps}") 40 if not 0.0 <= betas[0] < 1.0: 41 raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}") 42 if not 0.0 <= betas[1] < 1.0: 43 raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}") 44 if not 0.0 <= weight_decay: 45 raise ValueError(f"Invalid weight_decay value: {weight_decay}") 46 47 self.defaults = { 48 "lr": lr, 49 "eps": eps, 50 "beta1": betas[0], 51 "beta2": betas[1], 52 "weight_decay": weight_decay, 53 } 54 self.amsgrad = amsgrad 55 self.maximize = maximize 56 self.foreach = foreach 57 self.fused = fused 58 self.state = torch.jit.annotate(Dict[torch.Tensor, Dict[str, torch.Tensor]], {}) 59 60 if len(params) == 0 and not _allow_empty_param_list: 61 raise ValueError("optimizer got an empty parameter list") 62 63 # NOTE: we only have one param_group and don't allow user to add additional 64 # param group as it's not a common use case. 65 self.param_group = {"params": params} 66 67 def step_param(self, param: Tensor, grad: Optional[Tensor]): 68 params_with_grad = [] 69 grads = [] 70 exp_avgs = [] 71 exp_avg_sqs = [] 72 max_exp_avg_sqs = [] 73 state_steps: List[Tensor] = [] 74 has_complex = torch.is_complex(param) 75 if grad is not None: 76 params_with_grad.append(param) 77 grads.append(grad) 78 # Lazy state initialization 79 if param not in self.state: 80 self.state[param] = {} 81 state = self.state[param] 82 state["step"] = torch.tensor(0.0) 83 # Exponential moving average of gradient values 84 state["exp_avg"] = torch.zeros_like( 85 param, memory_format=torch.preserve_format 86 ) 87 # Exponential moving average of squared gradient values 88 state["exp_avg_sq"] = torch.zeros_like( 89 param, memory_format=torch.preserve_format 90 ) 91 if self.amsgrad: 92 # Maintains max of all exp. moving avg. of sq. grad. values 93 state["max_exp_avg_sq"] = torch.zeros_like( 94 param, memory_format=torch.preserve_format 95 ) 96 97 state = self.state[param] 98 99 exp_avgs.append(state["exp_avg"]) 100 exp_avg_sqs.append(state["exp_avg_sq"]) 101 102 if self.amsgrad: 103 max_exp_avg_sqs.append(state["max_exp_avg_sq"]) 104 105 state_steps.append(state["step"]) 106 with torch.no_grad(): 107 F.adamw( 108 params_with_grad, 109 grads, 110 exp_avgs, 111 exp_avg_sqs, 112 max_exp_avg_sqs, 113 state_steps, 114 amsgrad=self.amsgrad, 115 maximize=self.maximize, 116 beta1=self.defaults["beta1"], 117 beta2=self.defaults["beta2"], 118 lr=self.defaults["lr"], 119 weight_decay=self.defaults["weight_decay"], 120 eps=self.defaults["eps"], 121 foreach=self.foreach, 122 fused=self.fused, 123 grad_scale=None, 124 found_inf=None, 125 has_complex=has_complex, 126 ) 127 128 def step(self, gradients: List[Optional[Tensor]]): 129 params = self.param_group["params"] 130 params_with_grad = [] 131 grads = [] 132 exp_avgs = [] 133 exp_avg_sqs = [] 134 max_exp_avg_sqs = [] 135 state_steps: List[Tensor] = [] 136 137 if len(params) != len(gradients): 138 raise ValueError( 139 "the gradients passed in does not equal to the size of the parameters!" 140 + f"Params length: {len(params)}. " 141 + f"Gradients length: {len(gradients)}" 142 ) 143 144 has_complex = False 145 for param, gradient in zip(self.param_group["params"], gradients): 146 if gradient is not None: 147 has_complex |= torch.is_complex(param) 148 params_with_grad.append(param) 149 grads.append(gradient) 150 # Lazy state initialization 151 if param not in self.state: 152 self.state[param] = {} 153 state = self.state[param] 154 state["step"] = torch.tensor(0.0) 155 # Exponential moving average of gradient values 156 state["exp_avg"] = torch.zeros_like( 157 param, memory_format=torch.preserve_format 158 ) 159 # Exponential moving average of squared gradient values 160 state["exp_avg_sq"] = torch.zeros_like( 161 param, memory_format=torch.preserve_format 162 ) 163 if self.amsgrad: 164 # Maintains max of all exp. moving avg. of sq. grad. values 165 state["max_exp_avg_sq"] = torch.zeros_like( 166 param, memory_format=torch.preserve_format 167 ) 168 169 state = self.state[param] 170 171 exp_avgs.append(state["exp_avg"]) 172 exp_avg_sqs.append(state["exp_avg_sq"]) 173 174 if self.amsgrad: 175 max_exp_avg_sqs.append(state["max_exp_avg_sq"]) 176 177 state_steps.append(state["step"]) 178 179 with torch.no_grad(): 180 F.adamw( 181 params_with_grad, 182 grads, 183 exp_avgs, 184 exp_avg_sqs, 185 max_exp_avg_sqs, 186 state_steps, 187 amsgrad=self.amsgrad, 188 maximize=self.maximize, 189 beta1=self.defaults["beta1"], 190 beta2=self.defaults["beta2"], 191 lr=self.defaults["lr"], 192 weight_decay=self.defaults["weight_decay"], 193 eps=self.defaults["eps"], 194 foreach=self.foreach, 195 fused=self.fused, 196 grad_scale=None, 197 found_inf=None, 198 has_complex=has_complex, 199 ) 200