1# mypy: allow-untyped-defs 2from typing import Dict, List, Optional, Tuple 3 4import torch 5import torch.optim._functional as F 6from torch import Tensor 7 8 9__all__: List[str] = [] 10 11 12# Define a TorchScript compatible Functional Rprop Optimizer 13# where we use these optimizer in a functional way. 14# Instead of using the `param.grad` when updating parameters, 15# we explicitly allow the distributed optimizer pass gradients to 16# the `step` function. In this way, we could separate the gradients 17# and parameters and allow multithreaded trainer to update the 18# parameters without data traces on accumulating to the same .grad. 19# NOTE: This should be only used by distributed optimizer internals 20# and not meant to expose to the user. 21@torch.jit.script 22class _FunctionalRprop: 23 def __init__( 24 self, 25 params: List[Tensor], 26 lr: float = 1e-2, 27 etas: Tuple[float, float] = (0.5, 1.2), 28 step_sizes: Tuple[float, float] = (1e-6, 50), 29 foreach: bool = False, 30 maximize: bool = False, 31 _allow_empty_param_list: bool = False, 32 ): 33 self.defaults = { 34 "lr": lr, 35 } 36 self.etas = etas 37 self.step_sizes = step_sizes 38 self.foreach = foreach 39 self.maximize = maximize 40 41 if len(params) == 0 and not _allow_empty_param_list: 42 raise ValueError("optimizer got an empty parameter list") 43 44 # NOTE: we only have one param_group and don't allow user to add additional 45 # param group as it's not a common use case. 46 self.param_group = {"params": params} 47 48 self.state = torch.jit.annotate(Dict[torch.Tensor, Dict[str, torch.Tensor]], {}) 49 50 def step(self, gradients: List[Optional[Tensor]]): 51 params = self.param_group["params"] 52 params_with_grad = [] 53 grads = [] 54 prevs = [] 55 step_sizes = [] 56 state_steps = [] 57 lr = self.defaults["lr"] 58 etaminus, etaplus = self.etas 59 step_size_min, step_size_max = self.step_sizes 60 61 if len(params) != len(gradients): 62 raise ValueError( 63 "the gradients passed in does not equal to the size of the parameters!" 64 + f"Params length: {len(params)}. " 65 + f"Gradients length: {len(gradients)}" 66 ) 67 68 has_complex = False 69 for param, gradient in zip(params, gradients): 70 if gradient is not None: 71 has_complex |= torch.is_complex(param) 72 params_with_grad.append(param) 73 grads.append(gradient) 74 # Lazy state initialization 75 if param not in self.state: 76 self.state[param] = {} 77 state = self.state[param] 78 state["step"] = torch.tensor(0.0) 79 state["prev"] = torch.zeros_like( 80 param, memory_format=torch.preserve_format 81 ) 82 state["step_size"] = torch.full_like(gradient, lr) 83 84 state = self.state[param] 85 prevs.append(state["prev"]) 86 step_sizes.append(state["step_size"]) 87 state_steps.append(state["step"]) 88 89 with torch.no_grad(): 90 F.rprop( 91 params_with_grad, 92 grads, 93 prevs, 94 step_sizes, 95 state_steps, 96 step_size_min=step_size_min, 97 step_size_max=step_size_max, 98 etaminus=etaminus, 99 etaplus=etaplus, 100 foreach=self.foreach, 101 maximize=self.maximize, 102 has_complex=has_complex, 103 ) 104