1# mypy: allow-untyped-defs 2from typing import Dict, List, Optional 3 4import torch 5import torch.optim._functional as F 6from torch import Tensor 7 8 9__all__: List[str] = [] 10 11 12# Define a TorchScript compatible Functional SGD Optimizer 13# where we use these optimizer in a functional way. 14# Instead of using the `param.grad` when updating parameters, 15# we explicitly allow the distributed optimizer pass gradients to 16# the `step` function. In this way, we could separate the gradients 17# and parameters and allow multithreaded trainer to update the 18# parameters without data traces on accumulating to the same .grad. 19# NOTE: This should be only used by distributed optimizer internals 20# and not meant to expose to the user. 21@torch.jit.script 22class _FunctionalSGD: 23 def __init__( 24 self, 25 params: List[Tensor], 26 lr: float = 1e-2, 27 momentum: float = 0.0, 28 dampening: float = 0.0, 29 weight_decay: float = 0.0, 30 nesterov: bool = False, 31 maximize: bool = False, 32 foreach: bool = False, 33 fused: bool = False, 34 _allow_empty_param_list: bool = False, 35 ): 36 self.defaults = { 37 "lr": lr, 38 "momentum": momentum, 39 "dampening": dampening, 40 "weight_decay": weight_decay, 41 } 42 self.nesterov = nesterov 43 self.maximize = maximize 44 self.foreach = foreach 45 self.fused = fused 46 self.state = torch.jit.annotate(Dict[torch.Tensor, Dict[str, torch.Tensor]], {}) 47 48 if len(params) == 0 and not _allow_empty_param_list: 49 raise ValueError("optimizer got an empty parameter list") 50 51 # NOTE: we only have one param_group and don't allow user to add additional 52 # param group as it's not a common use case. 53 self.param_group = {"params": params} 54 55 def step_param(self, param: Tensor, grad: Optional[Tensor]): 56 """Similar to self.step, but operates on a single parameter and 57 its gradient. 58 """ 59 # TODO: Once step_param interface is robust, refactor step to call 60 # step param on each param. 61 weight_decay = self.defaults["weight_decay"] 62 momentum = self.defaults["momentum"] 63 dampening = self.defaults["dampening"] 64 lr = self.defaults["lr"] 65 params = [param] 66 momentum_buffer_list: List[Optional[Tensor]] = [] 67 grads = [] 68 69 has_sparse_grad = False 70 if grad is not None: 71 grads.append(grad) 72 if grad.is_sparse: 73 has_sparse_grad = True 74 if param not in self.state: 75 self.state[param] = {} 76 state = self.state[param] 77 if "momentum_buffer" not in state: 78 momentum_buffer_list.append(None) 79 else: 80 momentum_buffer_list.append(state["momentum_buffer"]) 81 82 with torch.no_grad(): 83 F.sgd( 84 params, 85 grads, 86 momentum_buffer_list, 87 weight_decay=weight_decay, 88 momentum=momentum, 89 lr=lr, 90 dampening=dampening, 91 nesterov=self.nesterov, 92 maximize=self.maximize, 93 has_sparse_grad=has_sparse_grad, 94 foreach=self.foreach, 95 fused=self.fused, 96 grad_scale=None, 97 found_inf=None, 98 ) 99 # update momentum_buffer in state 100 state = self.state[param] 101 momentum_buffer = momentum_buffer_list[0] 102 if momentum_buffer is not None: 103 state["momentum_buffer"] = momentum_buffer 104 105 def step(self, gradients: List[Optional[Tensor]]): 106 params = self.param_group["params"] 107 params_with_grad = [] 108 grads = [] 109 momentum_buffer_list: List[Optional[Tensor]] = [] 110 lr = self.defaults["lr"] 111 weight_decay = self.defaults["weight_decay"] 112 momentum = self.defaults["momentum"] 113 dampening = self.defaults["dampening"] 114 115 if len(params) != len(gradients): 116 raise ValueError( 117 "the gradients passed in does not equal to the size of the parameters!" 118 + f"Params length: {len(params)}. " 119 + f"Gradients length: {len(gradients)}" 120 ) 121 122 has_sparse_grad = False 123 for param, gradient in zip(params, gradients): 124 if gradient is not None: 125 params_with_grad.append(param) 126 grads.append(gradient) 127 if gradient.is_sparse: 128 has_sparse_grad = True 129 130 if param not in self.state: 131 self.state[param] = {} 132 133 state = self.state[param] 134 if "momentum_buffer" not in state: 135 momentum_buffer_list.append(None) 136 else: 137 momentum_buffer_list.append(state["momentum_buffer"]) 138 139 with torch.no_grad(): 140 F.sgd( 141 params_with_grad, 142 grads, 143 momentum_buffer_list, 144 weight_decay=weight_decay, 145 momentum=momentum, 146 lr=lr, 147 dampening=dampening, 148 nesterov=self.nesterov, 149 maximize=self.maximize, 150 has_sparse_grad=has_sparse_grad, 151 foreach=self.foreach, 152 fused=self.fused, 153 grad_scale=None, 154 found_inf=None, 155 ) 156 157 # update momentum_buffers in state 158 for i, p in enumerate(params_with_grad): 159 state = self.state[p] 160 momentum_buffer = momentum_buffer_list[i] 161 if momentum_buffer is not None: 162 state["momentum_buffer"] = momentum_buffer 163