xref: /aosp_15_r20/external/pytorch/torch/distributed/optim/functional_adam.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from typing import Dict, List, Optional, Tuple
3
4import torch
5import torch.optim._functional as F
6from torch import Tensor
7
8
9__all__: List[str] = []
10
11
12# Define a TorchScript compatible Functional Adam Optimizer
13# where we use these optimizer in a functional way.
14# Instead of using the `param.grad` when updating parameters,
15# we explicitly allow the distributed optimizer pass gradients to
16# the `step` function. In this way, we could separate the gradients
17# and parameters and allow multithreaded trainer to update the
18# parameters without data traces on accumulating to the same .grad.
19# NOTE: This should be only used by distributed optimizer internals
20# and not meant to expose to the user.
21@torch.jit.script
22class _FunctionalAdam:
23    def __init__(
24        self,
25        params: List[Tensor],
26        lr: float = 1e-3,
27        betas: Tuple[float, float] = (0.9, 0.999),
28        eps: float = 1e-8,
29        weight_decay: float = 0.0,
30        amsgrad: bool = False,
31        maximize: bool = False,
32        foreach: bool = False,
33        fused: bool = False,
34        _allow_empty_param_list: bool = False,
35    ):
36        if not 0.0 <= lr:
37            raise ValueError(f"Invalid learning rate: {lr}")
38        if not 0.0 <= eps:
39            raise ValueError(f"Invalid epsilon value: {eps}")
40        if not 0.0 <= betas[0] < 1.0:
41            raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
42        if not 0.0 <= betas[1] < 1.0:
43            raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
44        if not 0.0 <= weight_decay:
45            raise ValueError(f"Invalid weight_decay value: {weight_decay}")
46
47        self.defaults = {
48            "lr": lr,
49            "eps": eps,
50            "beta1": betas[0],
51            "beta2": betas[1],
52            "weight_decay": weight_decay,
53        }
54        self.amsgrad = amsgrad
55        self.maximize = maximize
56        self.foreach = foreach
57        self.fused = fused
58        self.state = torch.jit.annotate(Dict[torch.Tensor, Dict[str, torch.Tensor]], {})
59
60        if len(params) == 0 and not _allow_empty_param_list:
61            raise ValueError("optimizer got an empty parameter list")
62
63        # NOTE: we only have one param_group and don't allow user to add additional
64        # param group as it's not a common use case.
65        self.param_group = {"params": params}
66
67    def step_param(self, param: Tensor, grad: Optional[Tensor]):
68        """
69        Similar to step, but operates on a single parameter and optionally a
70        gradient tensor.
71        """
72        params_with_grad = []
73        grads = []
74        exp_avgs = []
75        exp_avg_sqs = []
76        max_exp_avg_sqs = []
77        state_steps: List[Tensor] = []
78        has_complex = torch.is_complex(param)
79        if grad is not None:
80            params_with_grad.append(param)
81            grads.append(grad)
82        if param not in self.state:
83            self.state[param] = {}
84            state = self.state[param]
85            state["step"] = torch.tensor(0.0)
86            state["exp_avg"] = torch.zeros_like(
87                param, memory_format=torch.preserve_format
88            )
89            state["exp_avg_sq"] = torch.zeros_like(
90                param, memory_format=torch.preserve_format
91            )
92            if self.amsgrad:
93                state["max_exp_avg_sq"] = torch.zeros_like(
94                    param, memory_format=torch.preserve_format
95                )
96
97        state = self.state[param]
98        exp_avgs.append(state["exp_avg"])
99        exp_avg_sqs.append(state["exp_avg_sq"])
100
101        if self.amsgrad:
102            max_exp_avg_sqs.append(state["max_exp_avg_sq"])
103
104        state_steps.append(state["step"])
105        with torch.no_grad():
106            F.adam(
107                params_with_grad,
108                grads,
109                exp_avgs,
110                exp_avg_sqs,
111                max_exp_avg_sqs,
112                state_steps,
113                amsgrad=self.amsgrad,
114                has_complex=has_complex,
115                maximize=self.maximize,
116                beta1=self.defaults["beta1"],
117                beta2=self.defaults["beta2"],
118                lr=self.defaults["lr"],
119                weight_decay=self.defaults["weight_decay"],
120                eps=self.defaults["eps"],
121                foreach=self.foreach,
122                fused=self.fused,
123                grad_scale=None,
124                found_inf=None,
125            )
126
127    def step(self, gradients: List[Optional[Tensor]]):
128        params = self.param_group["params"]
129        params_with_grad = []
130        grads = []
131        exp_avgs = []
132        exp_avg_sqs = []
133        max_exp_avg_sqs = []
134        state_steps: List[Tensor] = []
135        has_complex = False
136
137        if len(params) != len(gradients):
138            raise ValueError(
139                "the gradients passed in does not equal to the size of the parameters!"
140                + f"Params length: {len(params)}. "
141                + f"Gradients length: {len(gradients)}"
142            )
143
144        for param, gradient in zip(self.param_group["params"], gradients):
145            if gradient is not None:
146                has_complex |= torch.is_complex(param)
147                params_with_grad.append(param)
148                grads.append(gradient)
149                # Lazy state initialization
150                if param not in self.state:
151                    self.state[param] = {}
152                    state = self.state[param]
153                    state["step"] = torch.tensor(0.0)
154                    # Exponential moving average of gradient values
155                    state["exp_avg"] = torch.zeros_like(
156                        param, memory_format=torch.preserve_format
157                    )
158                    # Exponential moving average of squared gradient values
159                    state["exp_avg_sq"] = torch.zeros_like(
160                        param, memory_format=torch.preserve_format
161                    )
162                    if self.amsgrad:
163                        # Maintains max of all exp. moving avg. of sq. grad. values
164                        state["max_exp_avg_sq"] = torch.zeros_like(
165                            param, memory_format=torch.preserve_format
166                        )
167
168                state = self.state[param]
169
170                exp_avgs.append(state["exp_avg"])
171                exp_avg_sqs.append(state["exp_avg_sq"])
172
173                if self.amsgrad:
174                    max_exp_avg_sqs.append(state["max_exp_avg_sq"])
175
176                state_steps.append(state["step"])
177
178        with torch.no_grad():
179            F.adam(
180                params_with_grad,
181                grads,
182                exp_avgs,
183                exp_avg_sqs,
184                max_exp_avg_sqs,
185                state_steps,
186                amsgrad=self.amsgrad,
187                has_complex=has_complex,
188                maximize=self.maximize,
189                beta1=self.defaults["beta1"],
190                beta2=self.defaults["beta2"],
191                lr=self.defaults["lr"],
192                weight_decay=self.defaults["weight_decay"],
193                eps=self.defaults["eps"],
194                foreach=self.foreach,
195                fused=self.fused,
196                grad_scale=None,
197                found_inf=None,
198            )
199