1# mypy: allow-untyped-defs 2from typing import Dict, List, Optional 3 4import torch 5import torch.optim._functional as F 6from torch import Tensor 7 8 9__all__: List[str] = [] 10 11 12# Define a TorchScript compatible Functional RMSprop Optimizer 13# where we use these optimizer in a functional way. 14# Instead of using the `param.grad` when updating parameters, 15# we explicitly allow the distributed optimizer pass gradients to 16# the `step` function. In this way, we could separate the gradients 17# and parameters and allow multithreaded trainer to update the 18# parameters without data traces on accumulating to the same .grad. 19# NOTE: This should be only used by distributed optimizer internals 20# and not meant to expose to the user. 21@torch.jit.script 22class _FunctionalRMSprop: 23 def __init__( 24 self, 25 params: List[Tensor], 26 lr: float = 1e-2, 27 alpha: float = 0.99, 28 eps: float = 1e-8, 29 weight_decay: float = 0.0, 30 momentum: float = 0.0, 31 centered: bool = False, 32 foreach: bool = False, 33 maximize: bool = False, 34 _allow_empty_param_list: bool = False, 35 ): 36 self.defaults = { 37 "lr": lr, 38 "alpha": alpha, 39 "eps": eps, 40 "weight_decay": weight_decay, 41 "momentum": momentum, 42 } 43 self.centered = centered 44 self.foreach = foreach 45 self.maximize = maximize 46 47 if len(params) == 0 and not _allow_empty_param_list: 48 raise ValueError("optimizer got an empty parameter list") 49 50 # NOTE: we only have one param_group and don't allow user to add additional 51 # param group as it's not a common use case. 52 self.param_group = {"params": params} 53 54 self.state = torch.jit.annotate(Dict[torch.Tensor, Dict[str, torch.Tensor]], {}) 55 56 def step(self, gradients: List[Optional[Tensor]]): 57 params = self.param_group["params"] 58 params_with_grad = [] 59 grads = [] 60 square_avgs = [] 61 grad_avgs = [] 62 momentum_buffer_list = [] 63 state_steps = [] 64 lr = self.defaults["lr"] 65 alpha = self.defaults["alpha"] 66 eps = self.defaults["eps"] 67 momentum = self.defaults["momentum"] 68 weight_decay = self.defaults["weight_decay"] 69 70 if len(params) != len(gradients): 71 raise ValueError( 72 "the gradients passed in does not equal to the size of the parameters!" 73 + f"Params length: {len(params)}. " 74 + f"Gradients length: {len(gradients)}" 75 ) 76 77 has_complex = False 78 for param, gradient in zip(params, gradients): 79 if gradient is not None: 80 has_complex |= torch.is_complex(param) 81 params_with_grad.append(param) 82 grads.append(gradient) 83 # Lazy state initialization 84 if param not in self.state: 85 self.state[param] = {} 86 state = self.state[param] 87 state["step"] = torch.tensor(0.0) 88 state["square_avg"] = torch.zeros_like( 89 param, memory_format=torch.preserve_format 90 ) 91 if momentum > 0: 92 state["momentum_buffer"] = torch.zeros_like( 93 param, memory_format=torch.preserve_format 94 ) 95 if self.centered: 96 state["grad_avg"] = torch.zeros_like( 97 param, memory_format=torch.preserve_format 98 ) 99 100 state = self.state[param] 101 square_avgs.append(state["square_avg"]) 102 if momentum > 0: 103 momentum_buffer_list.append(state["momentum_buffer"]) 104 if self.centered: 105 grad_avgs.append(state["grad_avg"]) 106 107 state_steps.append(state["step"]) 108 109 with torch.no_grad(): 110 F.rmsprop( 111 params_with_grad, 112 grads, 113 square_avgs, 114 grad_avgs, 115 momentum_buffer_list, 116 state_steps, 117 lr=lr, 118 alpha=alpha, 119 eps=eps, 120 weight_decay=weight_decay, 121 momentum=momentum, 122 centered=self.centered, 123 foreach=self.foreach, 124 maximize=self.maximize, 125 has_complex=has_complex, 126 ) 127