xref: /aosp_15_r20/external/pytorch/torch/optim/asgd.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-decorators
2# mypy: allow-untyped-defs
3from typing import cast, List, Optional, Tuple, Union
4
5import torch
6from torch import Tensor
7
8from .optimizer import (
9    _capturable_doc,
10    _default_to_fused_or_foreach,
11    _differentiable_doc,
12    _disable_dynamo_if_unsupported,
13    _foreach_doc,
14    _get_capturable_supported_devices,
15    _get_scalar_dtype,
16    _get_value,
17    _maximize_doc,
18    _use_grad_for_differentiable,
19    _view_as_real,
20    Optimizer,
21    ParamsT,
22)
23
24
25__all__ = ["ASGD", "asgd"]
26
27
28class ASGD(Optimizer):
29    def __init__(
30        self,
31        params: ParamsT,
32        lr: Union[float, Tensor] = 1e-2,
33        lambd: float = 1e-4,
34        alpha: float = 0.75,
35        t0: float = 1e6,
36        weight_decay: float = 0,
37        foreach: Optional[bool] = None,
38        maximize: bool = False,
39        differentiable: bool = False,
40        capturable: bool = False,
41    ):
42        if isinstance(lr, Tensor) and lr.numel() != 1:
43            raise ValueError("Tensor lr must be 1-element")
44        if not 0.0 <= lr:
45            raise ValueError(f"Invalid learning rate: {lr}")
46        if not 0.0 <= weight_decay:
47            raise ValueError(f"Invalid weight_decay value: {weight_decay}")
48
49        defaults = dict(
50            lr=lr,
51            lambd=lambd,
52            alpha=alpha,
53            t0=t0,
54            weight_decay=weight_decay,
55            foreach=foreach,
56            maximize=maximize,
57            differentiable=differentiable,
58            capturable=capturable,
59        )
60        super().__init__(params, defaults)
61
62    def __setstate__(self, state):
63        super().__setstate__(state)
64        for group in self.param_groups:
65            group.setdefault("foreach", None)
66            group.setdefault("maximize", False)
67            group.setdefault("differentiable", False)
68            group.setdefault("capturable", False)
69            for p in group["params"]:
70                p_state = self.state.get(p, [])
71                if len(p_state) != 0:
72                    if not torch.is_tensor(p_state["step"]):
73                        step_val = float(p_state["step"])
74                        p_state["step"] = torch.tensor(
75                            step_val, dtype=_get_scalar_dtype(), device=p.device
76                        )
77                    if not torch.is_tensor(p_state["eta"]):
78                        p_state["eta"] = torch.tensor(
79                            p_state["eta"], dtype=_get_scalar_dtype(), device=p.device
80                        )
81                    if not torch.is_tensor(p_state["mu"]):
82                        p_state["mu"] = torch.tensor(
83                            p_state["mu"], dtype=_get_scalar_dtype(), device=p.device
84                        )
85
86    def _init_group(self, group, params_with_grad, grads, mus, axs, etas, state_steps):
87        has_complex = False
88        for p in group["params"]:
89            if p.grad is not None:
90                has_complex |= torch.is_complex(p)
91                params_with_grad.append(p)
92                if p.grad.is_sparse:
93                    raise RuntimeError("ASGD does not support sparse gradients")
94                grads.append(p.grad)
95
96                state = self.state[p]
97                # State initialization
98                if len(state) == 0:
99                    state["step"] = torch.zeros(
100                        (), device=p.device, dtype=_get_scalar_dtype()
101                    )
102                    state["eta"] = (
103                        torch.as_tensor(
104                            group["lr"], device=p.device, dtype=_get_scalar_dtype()
105                        )
106                        .clone()
107                        .detach()
108                    )
109                    state["mu"] = torch.ones(
110                        (), device=p.device, dtype=_get_scalar_dtype()
111                    )
112                    state["ax"] = torch.zeros_like(
113                        p, memory_format=torch.preserve_format
114                    )
115
116                mus.append(state["mu"])
117                axs.append(state["ax"])
118                etas.append(state["eta"])
119                state_steps.append(state["step"])
120        return has_complex
121
122    @_use_grad_for_differentiable
123    def step(self, closure=None):
124        """Perform a single optimization step.
125
126        Args:
127            closure (Callable, optional): A closure that reevaluates the model
128                and returns the loss.
129        """
130        self._cuda_graph_capture_health_check()
131
132        loss = None
133        if closure is not None:
134            with torch.enable_grad():
135                loss = closure()
136
137        for group in self.param_groups:
138            params_with_grad: List[Tensor] = []
139            grads: List[Tensor] = []
140            mus: List[Tensor] = []
141            axs: List[Tensor] = []
142            etas: List[Tensor] = []
143            state_steps: List[Tensor] = []
144
145            has_complex = self._init_group(
146                group, params_with_grad, grads, mus, axs, etas, state_steps
147            )
148
149            asgd(
150                params_with_grad,
151                grads,
152                axs,
153                mus,
154                etas,
155                state_steps,
156                lambd=group["lambd"],
157                lr=group["lr"],
158                t0=group["t0"],
159                alpha=group["alpha"],
160                weight_decay=group["weight_decay"],
161                foreach=group["foreach"],
162                maximize=group["maximize"],
163                differentiable=group["differentiable"],
164                capturable=group["capturable"],
165                has_complex=has_complex,
166            )
167
168        return loss
169
170
171ASGD.__doc__ = rf"""Implements Averaged Stochastic Gradient Descent.
172
173    It has been proposed in `Acceleration of stochastic approximation by
174    averaging`_.
175
176    Args:
177        params (iterable): iterable of parameters to optimize or dicts defining
178            parameter groups
179        lr (float, Tensor, optional): learning rate (default: 1e-2)
180        lambd (float, optional): decay term (default: 1e-4)
181        alpha (float, optional): power for eta update (default: 0.75)
182        t0 (float, optional): point at which to start averaging (default: 1e6)
183        weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
184        {_foreach_doc}
185        {_maximize_doc}
186        {_differentiable_doc}
187        {_capturable_doc}
188
189    .. _Acceleration of stochastic approximation by averaging:
190        https://dl.acm.org/citation.cfm?id=131098
191
192    """
193
194
195def _single_tensor_asgd(
196    params: List[Tensor],
197    grads: List[Tensor],
198    axs: List[Tensor],
199    mus: List[Tensor],
200    etas: List[Tensor],
201    state_steps: List[Tensor],
202    *,
203    lambd: float,
204    lr: float,
205    t0: float,
206    alpha: float,
207    weight_decay: float,
208    maximize: bool,
209    differentiable: bool,
210    capturable: bool,
211    has_complex: bool,
212):
213    for i, param in enumerate(params):
214        grad = grads[i]
215        grad = grad if not maximize else -grad
216        mu = mus[i]
217        ax = axs[i]
218        eta = etas[i]
219        step_t = state_steps[i]
220
221        # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
222        if not torch._utils.is_compiling() and capturable:
223            capturable_supported_devices = _get_capturable_supported_devices()
224            assert (
225                param.device.type
226                == mu.device.type
227                == eta.device.type
228                == step_t.device.type
229                and param.device.type in capturable_supported_devices
230            ), (
231                f"If capturable=True, params, mus, etas, and state_steps must be "
232                f"on supported devices: {capturable_supported_devices}."
233            )
234
235        if torch.is_complex(param):
236            grad = torch.view_as_real(grad)
237            param = torch.view_as_real(param)
238            ax = torch.view_as_real(ax)
239
240        # update step
241        step_t += 1
242
243        if weight_decay != 0:
244            grad = grad.add(param, alpha=weight_decay)
245
246        if capturable:
247            param.mul_(1 - lambd * eta)
248            param.addcmul_(grad, eta, value=-1)  # update parameter
249        else:
250            eta_value = _get_value(eta)
251            param.mul_(1 - lambd * eta_value)  # decay term
252            param.add_(grad, alpha=-eta_value)  # update parameter
253
254        # averaging
255        if capturable or mu.item() != 1:
256            ax.add_(param.sub(ax).mul_(mu))
257        else:
258            ax.copy_(param)
259
260        if capturable:
261            eta.copy_(lr / ((1 + lambd * lr * step_t) ** alpha))
262            mu.copy_(1 / torch.maximum(step_t - t0, torch.ones_like(step_t)))
263        else:
264            step = _get_value(step_t)
265            new_eta = torch.as_tensor(lr / ((1 + lambd * lr * step) ** alpha))
266            eta.copy_(new_eta)
267            new_mu = torch.as_tensor(1 / max(1, step - t0))
268            mu.copy_(new_mu)
269
270
271def _multi_tensor_asgd(
272    params: List[Tensor],
273    grads: List[Tensor],
274    axs: List[Tensor],
275    mus: List[Tensor],
276    etas: List[Tensor],
277    state_steps: List[Tensor],
278    *,
279    lambd: float,
280    lr: float,
281    t0: float,
282    alpha: float,
283    weight_decay: float,
284    maximize: bool,
285    differentiable: bool,
286    capturable: bool,
287    has_complex: bool,
288):
289    if len(params) == 0:
290        return
291
292    assert not differentiable, "_foreach ops don't support autograd"
293
294    # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
295    if not torch._utils.is_compiling() and capturable:
296        capturable_supported_devices = _get_capturable_supported_devices(
297            supports_xla=False
298        )
299        assert all(
300            p.device.type == mu.device.type == eta.device.type == step.device.type
301            and p.device.type in capturable_supported_devices
302            for p, mu, eta, step in zip(params, mus, etas, state_steps)
303        ), f"If capturable=True, params, mus, etas, and state_steps must be on supported devices: {capturable_supported_devices}."
304
305    grouped_tensors = Optimizer._group_tensors_by_device_and_dtype(
306        [params, grads, axs, mus, etas, state_steps]  # type: ignore[list-item]
307    )
308    for (device, _), (
309        (
310            grouped_params_,
311            grouped_grads_,
312            grouped_axs_,
313            grouped_mus_,
314            grouped_etas_,
315            grouped_state_steps_,
316        ),
317        _,
318    ) in grouped_tensors.items():
319        grouped_params = cast(List[Tensor], grouped_params_)
320        grouped_grads = cast(List[Tensor], grouped_grads_)
321        grouped_axs = cast(List[Tensor], grouped_axs_)
322        grouped_mus = cast(List[Tensor], grouped_mus_)
323        grouped_etas = cast(List[Tensor], grouped_etas_)
324        grouped_state_steps = cast(List[Tensor], grouped_state_steps_)
325
326        if has_complex:
327            _view_as_real(grouped_params, grouped_grads, grouped_axs)
328
329        if maximize:
330            grouped_grads = torch._foreach_neg(grouped_grads)  # type: ignore[assignment]
331
332        # Update steps
333        # If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over
334        # and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just
335        # wrapped it once now. The alpha is required to assure we go to the right overload.
336        if not torch._utils.is_compiling() and grouped_state_steps[0].is_cpu:
337            torch._foreach_add_(
338                grouped_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0
339            )
340        else:
341            torch._foreach_add_(grouped_state_steps, 1)
342
343        # intermediate = grad + param * lambd
344        intermediate: Union[Tuple[Tensor, ...], List[Tensor]]
345        if weight_decay != 0:
346            if maximize:
347                torch._foreach_add_(grouped_grads, grouped_params, alpha=weight_decay)
348                intermediate = grouped_grads
349            else:
350                intermediate = torch._foreach_add(
351                    grouped_grads, grouped_params, alpha=weight_decay
352                )
353
354            torch._foreach_add_(intermediate, grouped_params, alpha=lambd)
355        else:
356            intermediate = torch._foreach_add(
357                grouped_grads, grouped_params, alpha=lambd
358            )
359
360        # update param
361        # param * (1 - lambd * eta) - eta * grad
362        # => param - param * lambd * eta - eta * grad
363        # => param - eta * intermediate
364        torch._foreach_addcmul_(grouped_params, intermediate, grouped_etas, value=-1)
365        del intermediate
366
367        # update grouped_axs
368        # averaging: ax = ax + mu * (param - ax)
369        # Note (mlazos): We can't use lerp here since it requires weight to be float64
370        # and our grouping code requires dtypes to match for all tensors in a group (and it should, since
371        # we use the mus in other places)
372        # all dtypes need to match, so we could introduce a cast in a loop
373        # but since this only adds one additional kernel launch, this looks like the cleaner
374        # and faster solution
375        intermediate = torch._foreach_sub(grouped_params, grouped_axs)
376        torch._foreach_addcmul_(grouped_axs, intermediate, grouped_mus)
377        del intermediate
378
379        new_etas: Union[Tuple[Tensor, ...], List[Tensor]]
380        new_mus: Union[Tuple[Tensor, ...], List[Tensor]]
381        if capturable:
382            # update grouped_mus
383            new_mus = torch._foreach_sub(grouped_state_steps, t0)
384            torch._foreach_maximum_(new_mus, 1.0)
385            torch._foreach_reciprocal_(new_mus)
386            torch._foreach_copy_(grouped_mus, new_mus)
387            del new_mus
388
389            # update eta = lr / ((1 + lambd * lr * step)^alpha)
390            new_etas = torch._foreach_mul(grouped_state_steps, lambd)
391            torch._foreach_mul_(new_etas, lr)
392            torch._foreach_add_(new_etas, 1)
393            torch._foreach_pow_(new_etas, alpha)
394            torch._foreach_reciprocal_(new_etas)
395            torch._foreach_mul_(new_etas, lr)
396            torch._foreach_copy_(grouped_etas, new_etas)
397        else:
398            new_etas = [
399                torch.as_tensor(lr / ((1 + lambd * lr * step) ** alpha), device=device)
400                for step in grouped_state_steps
401            ]
402            new_mus = [
403                torch.as_tensor(1 / max(1, _get_value(step) - t0), device=device)
404                for step in grouped_state_steps
405            ]
406            torch._foreach_copy_(grouped_etas, new_etas)
407            torch._foreach_copy_(grouped_mus, new_mus)
408
409
410@_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_asgd)
411def asgd(
412    params: List[Tensor],
413    grads: List[Tensor],
414    axs: List[Tensor],
415    mus: List[Tensor],
416    etas: List[Tensor],
417    state_steps: List[Tensor],
418    # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
419    # setting this as kwarg for now as functional API is compiled by torch/distributed/optim
420    foreach: Optional[bool] = None,
421    maximize: bool = False,
422    differentiable: bool = False,
423    capturable: bool = False,
424    has_complex: bool = False,
425    *,
426    lambd: float,
427    lr: float,
428    t0: float,
429    alpha: float,
430    weight_decay: float,
431):
432    r"""Functional API that performs asgd algorithm computation.
433
434    See :class:`~torch.optim.ASGD` for details.
435    """
436    if foreach is None:
437        _, foreach = _default_to_fused_or_foreach(
438            params, differentiable, use_fused=False
439        )
440
441    if foreach and torch.jit.is_scripting():
442        raise RuntimeError("torch.jit.script not supported with foreach optimizers")
443
444    if foreach and not torch.jit.is_scripting():
445        func = _multi_tensor_asgd
446    else:
447        func = _single_tensor_asgd
448
449    func(
450        params,
451        grads,
452        axs,
453        mus,
454        etas,
455        state_steps,
456        lambd=lambd,
457        lr=lr,
458        t0=t0,
459        alpha=alpha,
460        weight_decay=weight_decay,
461        maximize=maximize,
462        differentiable=differentiable,
463        capturable=capturable,
464        has_complex=has_complex,
465    )
466