xref: /aosp_15_r20/external/pytorch/torch/distributed/optim/functional_adamw.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from typing import Dict, List, Optional, Tuple
3
4import torch
5import torch.optim._functional as F
6from torch import Tensor
7
8
9__all__: List[str] = []
10
11
12# Define a TorchScript compatible Functional AdamW Optimizer
13# where we use these optimizer in a functional way.
14# Instead of using the `param.grad` when updating parameters,
15# we explicitly allow the distributed optimizer pass gradients to
16# the `step` function. In this way, we could separate the gradients
17# and parameters and allow multithreaded trainer to update the
18# parameters without data traces on accumulating to the same .grad.
19# NOTE: This should be only used by distributed optimizer internals
20# and not meant to expose to the user.
21@torch.jit.script
22class _FunctionalAdamW:
23    def __init__(
24        self,
25        params: List[Tensor],
26        lr: float = 1e-3,
27        betas: Tuple[float, float] = (0.9, 0.999),
28        eps: float = 1e-8,
29        weight_decay: float = 1e-2,
30        amsgrad: bool = False,
31        maximize: bool = False,
32        foreach: bool = False,
33        fused: bool = False,
34        _allow_empty_param_list: bool = False,
35    ):
36        if not 0.0 <= lr:
37            raise ValueError(f"Invalid learning rate: {lr}")
38        if not 0.0 <= eps:
39            raise ValueError(f"Invalid epsilon value: {eps}")
40        if not 0.0 <= betas[0] < 1.0:
41            raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
42        if not 0.0 <= betas[1] < 1.0:
43            raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
44        if not 0.0 <= weight_decay:
45            raise ValueError(f"Invalid weight_decay value: {weight_decay}")
46
47        self.defaults = {
48            "lr": lr,
49            "eps": eps,
50            "beta1": betas[0],
51            "beta2": betas[1],
52            "weight_decay": weight_decay,
53        }
54        self.amsgrad = amsgrad
55        self.maximize = maximize
56        self.foreach = foreach
57        self.fused = fused
58        self.state = torch.jit.annotate(Dict[torch.Tensor, Dict[str, torch.Tensor]], {})
59
60        if len(params) == 0 and not _allow_empty_param_list:
61            raise ValueError("optimizer got an empty parameter list")
62
63        # NOTE: we only have one param_group and don't allow user to add additional
64        # param group as it's not a common use case.
65        self.param_group = {"params": params}
66
67    def step_param(self, param: Tensor, grad: Optional[Tensor]):
68        params_with_grad = []
69        grads = []
70        exp_avgs = []
71        exp_avg_sqs = []
72        max_exp_avg_sqs = []
73        state_steps: List[Tensor] = []
74        has_complex = torch.is_complex(param)
75        if grad is not None:
76            params_with_grad.append(param)
77            grads.append(grad)
78        # Lazy state initialization
79        if param not in self.state:
80            self.state[param] = {}
81            state = self.state[param]
82            state["step"] = torch.tensor(0.0)
83            # Exponential moving average of gradient values
84            state["exp_avg"] = torch.zeros_like(
85                param, memory_format=torch.preserve_format
86            )
87            # Exponential moving average of squared gradient values
88            state["exp_avg_sq"] = torch.zeros_like(
89                param, memory_format=torch.preserve_format
90            )
91            if self.amsgrad:
92                # Maintains max of all exp. moving avg. of sq. grad. values
93                state["max_exp_avg_sq"] = torch.zeros_like(
94                    param, memory_format=torch.preserve_format
95                )
96
97        state = self.state[param]
98
99        exp_avgs.append(state["exp_avg"])
100        exp_avg_sqs.append(state["exp_avg_sq"])
101
102        if self.amsgrad:
103            max_exp_avg_sqs.append(state["max_exp_avg_sq"])
104
105        state_steps.append(state["step"])
106        with torch.no_grad():
107            F.adamw(
108                params_with_grad,
109                grads,
110                exp_avgs,
111                exp_avg_sqs,
112                max_exp_avg_sqs,
113                state_steps,
114                amsgrad=self.amsgrad,
115                maximize=self.maximize,
116                beta1=self.defaults["beta1"],
117                beta2=self.defaults["beta2"],
118                lr=self.defaults["lr"],
119                weight_decay=self.defaults["weight_decay"],
120                eps=self.defaults["eps"],
121                foreach=self.foreach,
122                fused=self.fused,
123                grad_scale=None,
124                found_inf=None,
125                has_complex=has_complex,
126            )
127
128    def step(self, gradients: List[Optional[Tensor]]):
129        params = self.param_group["params"]
130        params_with_grad = []
131        grads = []
132        exp_avgs = []
133        exp_avg_sqs = []
134        max_exp_avg_sqs = []
135        state_steps: List[Tensor] = []
136
137        if len(params) != len(gradients):
138            raise ValueError(
139                "the gradients passed in does not equal to the size of the parameters!"
140                + f"Params length: {len(params)}. "
141                + f"Gradients length: {len(gradients)}"
142            )
143
144        has_complex = False
145        for param, gradient in zip(self.param_group["params"], gradients):
146            if gradient is not None:
147                has_complex |= torch.is_complex(param)
148                params_with_grad.append(param)
149                grads.append(gradient)
150                # Lazy state initialization
151                if param not in self.state:
152                    self.state[param] = {}
153                    state = self.state[param]
154                    state["step"] = torch.tensor(0.0)
155                    # Exponential moving average of gradient values
156                    state["exp_avg"] = torch.zeros_like(
157                        param, memory_format=torch.preserve_format
158                    )
159                    # Exponential moving average of squared gradient values
160                    state["exp_avg_sq"] = torch.zeros_like(
161                        param, memory_format=torch.preserve_format
162                    )
163                    if self.amsgrad:
164                        # Maintains max of all exp. moving avg. of sq. grad. values
165                        state["max_exp_avg_sq"] = torch.zeros_like(
166                            param, memory_format=torch.preserve_format
167                        )
168
169                state = self.state[param]
170
171                exp_avgs.append(state["exp_avg"])
172                exp_avg_sqs.append(state["exp_avg_sq"])
173
174                if self.amsgrad:
175                    max_exp_avg_sqs.append(state["max_exp_avg_sq"])
176
177                state_steps.append(state["step"])
178
179        with torch.no_grad():
180            F.adamw(
181                params_with_grad,
182                grads,
183                exp_avgs,
184                exp_avg_sqs,
185                max_exp_avg_sqs,
186                state_steps,
187                amsgrad=self.amsgrad,
188                maximize=self.maximize,
189                beta1=self.defaults["beta1"],
190                beta2=self.defaults["beta2"],
191                lr=self.defaults["lr"],
192                weight_decay=self.defaults["weight_decay"],
193                eps=self.defaults["eps"],
194                foreach=self.foreach,
195                fused=self.fused,
196                grad_scale=None,
197                found_inf=None,
198                has_complex=has_complex,
199            )
200