xref: /aosp_15_r20/external/pytorch/torch/optim/adagrad.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from typing import cast, List, Optional, Union
3
4import torch
5from torch import Tensor
6
7from .optimizer import (
8    _default_to_fused_or_foreach,
9    _device_dtype_check_for_fused,
10    _differentiable_doc,
11    _foreach_doc,
12    _get_scalar_dtype,
13    _get_value,
14    _maximize_doc,
15    _use_grad_for_differentiable,
16    _view_as_real,
17    Optimizer,
18    ParamsT,
19)
20
21
22__all__ = ["Adagrad", "adagrad"]
23
24
25class Adagrad(Optimizer):
26    def __init__(
27        self,
28        params: ParamsT,
29        lr: Union[float, Tensor] = 1e-2,
30        lr_decay: float = 0,
31        weight_decay: float = 0,
32        initial_accumulator_value: float = 0,
33        eps: float = 1e-10,
34        foreach: Optional[bool] = None,
35        *,
36        maximize: bool = False,
37        differentiable: bool = False,
38        fused: Optional[bool] = None,
39    ):
40        if isinstance(lr, Tensor) and lr.numel() != 1:
41            raise ValueError("Tensor lr must be 1-element")
42        if not 0.0 <= lr:
43            raise ValueError(f"Invalid learning rate: {lr}")
44        if not 0.0 <= lr_decay:
45            raise ValueError(f"Invalid lr_decay value: {lr_decay}")
46        if not 0.0 <= weight_decay:
47            raise ValueError(f"Invalid weight_decay value: {weight_decay}")
48        if not 0.0 <= initial_accumulator_value:
49            raise ValueError(
50                f"Invalid initial_accumulator_value value: {initial_accumulator_value}"
51            )
52        if not 0.0 <= eps:
53            raise ValueError(f"Invalid epsilon value: {eps}")
54
55        defaults = dict(
56            lr=lr,
57            lr_decay=lr_decay,
58            eps=eps,
59            weight_decay=weight_decay,
60            initial_accumulator_value=initial_accumulator_value,
61            foreach=foreach,
62            maximize=maximize,
63            differentiable=differentiable,
64            fused=fused,
65        )
66        super().__init__(params, defaults)
67
68        if fused:
69            if differentiable:
70                raise RuntimeError("`fused` does not support `differentiable`")
71            if foreach:
72                raise RuntimeError("`fused` and `foreach` cannot be `True` together.")
73            self._need_device_dtype_check_for_fused = True
74
75        for group in self.param_groups:
76            for p in group["params"]:
77                state = self.state[p]
78                state["step"] = (
79                    torch.zeros(
80                        (),
81                        dtype=_get_scalar_dtype(is_fused=group["fused"]),
82                        device=p.device,
83                    )
84                    if group["fused"]
85                    else torch.tensor(0.0, dtype=_get_scalar_dtype())
86                )
87                init_value = (
88                    complex(initial_accumulator_value, initial_accumulator_value)
89                    if torch.is_complex(p)
90                    else initial_accumulator_value
91                )
92                state["sum"] = torch.full_like(
93                    p, init_value, memory_format=torch.preserve_format
94                )
95
96    def __setstate__(self, state):
97        super().__setstate__(state)
98        #  define "fused" for
99        #  MYPY error: Name "fused" may be undefined
100        fused = None
101        for group in self.param_groups:
102            group.setdefault("foreach", None)
103            group.setdefault("maximize", False)
104            group.setdefault("differentiable", False)
105            fused = group.setdefault("fused", None)
106
107        state_values = list(self.state.values())
108        step_is_tensor = (len(state_values) != 0) and torch.is_tensor(
109            state_values[0]["step"]
110        )
111        if not step_is_tensor:
112            for s in state_values:
113                s["step"] = torch.tensor(
114                    float(s["step"]), dtype=_get_scalar_dtype(is_fused=fused)
115                )
116
117    def share_memory(self):
118        for group in self.param_groups:
119            for p in group["params"]:
120                state = self.state[p]
121                state["sum"].share_memory_()
122
123    def _init_group(self, group, params_with_grad, grads, state_sums, state_steps):
124        has_sparse_grad, has_complex = False, False
125        for p in group["params"]:
126            if p.grad is not None:
127                if group["fused"] and getattr(
128                    self,
129                    "_need_device_dtype_check_for_fused",
130                    True,
131                ):
132                    _device_dtype_check_for_fused(p, cuda_unsupported=True)
133                    self._need_device_dtype_check_for_fused = False
134                has_sparse_grad |= p.grad.is_sparse
135                has_complex |= torch.is_complex(p)
136                params_with_grad.append(p)
137                grads.append(p.grad)
138                state = self.state[p]
139                state_sums.append(state["sum"])
140                state_steps.append(state["step"])
141
142        return has_sparse_grad, has_complex
143
144    @_use_grad_for_differentiable
145    def step(self, closure=None):
146        """Perform a single optimization step.
147
148        Args:
149            closure (Callable, optional): A closure that reevaluates the model
150                and returns the loss.
151        """
152        loss = None
153
154        if closure is not None:
155            with torch.enable_grad():
156                loss = closure()
157
158        for group in self.param_groups:
159            params_with_grad: List[Tensor] = []
160            grads: List[Tensor] = []
161            state_sums: List[Tensor] = []
162            state_steps: List[Tensor] = []
163
164            has_sparse_grad, has_complex = self._init_group(
165                group, params_with_grad, grads, state_sums, state_steps
166            )
167
168            adagrad(
169                params_with_grad,
170                grads,
171                state_sums,
172                state_steps,
173                lr=group["lr"],
174                weight_decay=group["weight_decay"],
175                lr_decay=group["lr_decay"],
176                eps=group["eps"],
177                has_sparse_grad=has_sparse_grad,
178                foreach=group["foreach"],
179                maximize=group["maximize"],
180                differentiable=group["differentiable"],
181                has_complex=has_complex,
182                fused=group["fused"],
183                grad_scale=getattr(self, "grad_scale", None),
184                found_inf=getattr(self, "found_inf", None),
185            )
186
187        return loss
188
189
190Adagrad.__doc__ = (
191    r"""Implements Adagrad algorithm.
192
193    .. math::
194       \begin{aligned}
195            &\rule{110mm}{0.4pt}                                                                 \\
196            &\textbf{input}      : \gamma \text{ (lr)}, \: \theta_0 \text{ (params)}, \: f(\theta)
197                \text{ (objective)}, \: \lambda \text{ (weight decay)},                          \\
198            &\hspace{12mm}    \tau \text{ (initial accumulator value)}, \: \eta\text{ (lr decay)}\\
199            &\textbf{initialize} :  state\_sum_0 \leftarrow \tau                          \\[-1.ex]
200            &\rule{110mm}{0.4pt}                                                                 \\
201            &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do}                         \\
202            &\hspace{5mm}g_t           \leftarrow   \nabla_{\theta} f_t (\theta_{t-1})           \\
203            &\hspace{5mm} \tilde{\gamma}    \leftarrow \gamma / (1 +(t-1) \eta)                  \\
204            &\hspace{5mm} \textbf{if} \: \lambda \neq 0                                          \\
205            &\hspace{10mm} g_t \leftarrow g_t + \lambda \theta_{t-1}                             \\
206            &\hspace{5mm}state\_sum_t  \leftarrow  state\_sum_{t-1} + g^2_t                      \\
207            &\hspace{5mm}\theta_t \leftarrow
208                \theta_{t-1}- \tilde{\gamma} \frac{g_t}{\sqrt{state\_sum_t}+\epsilon}            \\
209            &\rule{110mm}{0.4pt}                                                          \\[-1.ex]
210            &\bf{return} \:  \theta_t                                                     \\[-1.ex]
211            &\rule{110mm}{0.4pt}                                                          \\[-1.ex]
212       \end{aligned}
213
214    For further details regarding the algorithm we refer to `Adaptive Subgradient Methods for Online Learning
215    and Stochastic Optimization`_.
216    """
217    + rf"""
218    Args:
219        params (iterable): iterable of parameters to optimize or dicts defining
220            parameter groups
221        lr (float, Tensor, optional): learning rate (default: 1e-2)
222        lr_decay (float, optional): learning rate decay (default: 0)
223        weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
224        initial_accumulator_value (float, optional): initial value of the
225            sum of squares of gradients (default: 0)
226        eps (float, optional): term added to the denominator to improve
227            numerical stability (default: 1e-10)
228        {_foreach_doc}
229        {_maximize_doc}
230        {_differentiable_doc}
231        fused (bool, optional): whether the fused implementation (CPU only) is used.
232            Currently, `torch.float64`, `torch.float32`, `torch.float16`, and `torch.bfloat16`
233            are supported. (default: None). Please note that the fused implementations does not
234            support sparse or complex gradients.
235    .. _Adaptive Subgradient Methods for Online Learning and Stochastic
236        Optimization: http://jmlr.org/papers/v12/duchi11a.html
237
238    """
239)
240
241
242def adagrad(
243    params: List[Tensor],
244    grads: List[Tensor],
245    state_sums: List[Tensor],
246    state_steps: List[Tensor],
247    fused: Optional[bool] = None,
248    grad_scale: Optional[Tensor] = None,
249    found_inf: Optional[Tensor] = None,
250    # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
251    # setting these as kwargs for now as functional API is compiled by torch/distributed/optim
252    has_sparse_grad: bool = False,
253    foreach: Optional[bool] = None,
254    differentiable: bool = False,
255    has_complex: bool = False,
256    *,
257    lr: float,
258    weight_decay: float,
259    lr_decay: float,
260    eps: float,
261    maximize: bool,
262):
263    r"""Functional API that performs Adagrad algorithm computation.
264
265    See :class:`~torch.optim.Adagrad` for details.
266    """
267    if not all(isinstance(t, torch.Tensor) for t in state_steps):
268        raise RuntimeError(
269            "API has changed, `state_steps` argument must contain a list of singleton tensors"
270        )
271
272    # Respect when the user inputs False/True for foreach or fused. We only want to change
273    # the default when neither have been user-specified. Note that we default to foreach
274    # and pass False to use_fused. This is not a mistake--we want to give the fused impl
275    # bake-in time before making it the default, even if it is typically faster.
276    if fused is None and foreach is None:
277        _, foreach = _default_to_fused_or_foreach(
278            params, differentiable, use_fused=False
279        )
280
281    if fused is None:
282        fused = False
283    if foreach is None:
284        foreach = False
285
286    if foreach and torch.jit.is_scripting():
287        raise RuntimeError("torch.jit.script not supported with foreach optimizers")
288    if fused and torch.jit.is_scripting():
289        raise RuntimeError("torch.jit.script not supported with fused optimizers")
290
291    if fused and not torch.jit.is_scripting():
292        func = _fused_adagrad
293    elif foreach and not torch.jit.is_scripting():
294        func = _multi_tensor_adagrad
295    else:
296        func = _single_tensor_adagrad
297
298    func(
299        params,
300        grads,
301        state_sums,
302        state_steps,
303        lr=lr,
304        weight_decay=weight_decay,
305        lr_decay=lr_decay,
306        eps=eps,
307        has_sparse_grad=has_sparse_grad,
308        maximize=maximize,
309        differentiable=differentiable,
310        has_complex=has_complex,
311        grad_scale=grad_scale,
312        found_inf=found_inf,
313    )
314
315
316def _make_sparse(grad, grad_indices, values):
317    size = grad.size()
318    return torch.sparse_coo_tensor(grad_indices, values, size)
319
320
321def _single_tensor_adagrad(
322    params: List[Tensor],
323    grads: List[Tensor],
324    state_sums: List[Tensor],
325    state_steps: List[Tensor],
326    grad_scale: Optional[Tensor],
327    found_inf: Optional[Tensor],
328    *,
329    lr: float,
330    weight_decay: float,
331    lr_decay: float,
332    eps: float,
333    has_sparse_grad: bool,
334    maximize: bool,
335    differentiable: bool,
336    has_complex: bool,
337):
338    assert grad_scale is None and found_inf is None
339    for param, grad, state_sum, step_t in zip(params, grads, state_sums, state_steps):
340        # update step
341        step_t += 1
342        step = _get_value(step_t)
343        grad = grad if not maximize else -grad
344
345        if weight_decay != 0:
346            if grad.is_sparse:
347                raise RuntimeError(
348                    "weight_decay option is not compatible with sparse gradients"
349                )
350            grad = grad.add(param, alpha=weight_decay)
351
352        clr = lr / (1 + (step - 1) * lr_decay)
353
354        if grad.is_sparse:
355            grad = grad.coalesce()  # the update is non-linear so indices must be unique
356            grad_indices = grad._indices()
357            grad_values = grad._values()
358
359            state_sum.add_(_make_sparse(grad, grad_indices, grad_values.pow(2)))
360            std = state_sum.sparse_mask(grad)
361            std_values = std._values().sqrt_().add_(eps)
362            param.add_(
363                _make_sparse(grad, grad_indices, grad_values / std_values), alpha=-clr
364            )
365        else:
366            is_complex = torch.is_complex(param)
367            if is_complex:
368                grad = torch.view_as_real(grad)
369                state_sum = torch.view_as_real(state_sum)
370                param = torch.view_as_real(param)
371            state_sum.addcmul_(grad, grad, value=1)
372            if differentiable:
373                std = state_sum.sqrt() + eps
374            else:
375                std = state_sum.sqrt().add_(eps)
376            param.addcdiv_(grad, std, value=-clr)
377            if is_complex:
378                param = torch.view_as_complex(param)
379                state_sum = torch.view_as_complex(state_sum)
380
381
382def _multi_tensor_adagrad(
383    params: List[Tensor],
384    grads: List[Tensor],
385    state_sums: List[Tensor],
386    state_steps: List[Tensor],
387    grad_scale: Optional[Tensor],
388    found_inf: Optional[Tensor],
389    *,
390    lr: float,
391    weight_decay: float,
392    lr_decay: float,
393    eps: float,
394    has_sparse_grad: bool,
395    maximize: bool,
396    differentiable: bool,
397    has_complex: bool,
398):
399    assert not differentiable, "_foreach ops don't support autograd"
400    assert grad_scale is None and found_inf is None
401
402    # Foreach functions will throw errors if given empty lists
403    if len(params) == 0:
404        return
405
406    grouped_tensorlists = Optimizer._group_tensors_by_device_and_dtype(
407        [params, grads, state_sums, state_steps]  # type: ignore[list-item]
408    )
409    for (
410        device_params_,
411        device_grads_,
412        device_state_sums_,
413        device_state_steps_,
414    ), _ in grouped_tensorlists.values():
415        device_params = cast(List[Tensor], device_params_)
416        device_grads = cast(List[Tensor], device_grads_)
417        device_state_sums = cast(List[Tensor], device_state_sums_)
418        device_state_steps = cast(List[Tensor], device_state_steps_)
419
420        device_has_sparse_grad = has_sparse_grad and any(
421            grad.is_sparse for grad in device_grads
422        )
423
424        if device_has_sparse_grad:
425            _single_tensor_adagrad(
426                device_params,
427                device_grads,
428                device_state_sums,
429                device_state_steps,
430                lr=lr,
431                weight_decay=weight_decay,
432                lr_decay=lr_decay,
433                eps=eps,
434                has_sparse_grad=True,
435                maximize=maximize,
436                differentiable=differentiable,
437                has_complex=has_complex,
438                grad_scale=grad_scale,
439                found_inf=found_inf,
440            )
441            continue
442
443        # Handle complex parameters
444        if has_complex:
445            _view_as_real(device_params, device_grads, device_state_sums)
446
447        if maximize:
448            device_grads = torch._foreach_neg(device_grads)  # type: ignore[assignment]
449
450        # Update steps
451        # If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over
452        # and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just
453        # wrapped it once now. The alpha is required to assure we go to the right overload.
454        if not torch._utils.is_compiling() and device_state_steps[0].is_cpu:
455            torch._foreach_add_(
456                device_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0
457            )
458        else:
459            torch._foreach_add_(device_state_steps, 1)
460
461        if weight_decay != 0:
462            # Re-use the intermediate memory (device_grads) already allocated for maximize
463            if maximize:
464                torch._foreach_add_(device_grads, device_params, alpha=weight_decay)
465            else:
466                device_grads = torch._foreach_add(  # type: ignore[assignment]
467                    device_grads, device_params, alpha=weight_decay
468                )
469
470        minus_clr = [
471            -lr / (1 + (_get_value(step) - 1) * lr_decay) for step in device_state_steps
472        ]
473
474        torch._foreach_addcmul_(device_state_sums, device_grads, device_grads, value=1)
475
476        std = torch._foreach_sqrt(device_state_sums)
477        torch._foreach_add_(std, eps)
478
479        if weight_decay != 0 or maximize:
480            # Again, re-use the intermediate memory (device_grads) already allocated
481            torch._foreach_mul_(device_grads, minus_clr)
482            numerator = device_grads
483        else:
484            numerator = torch._foreach_mul(device_grads, minus_clr)  # type: ignore[assignment]
485
486        torch._foreach_addcdiv_(device_params, numerator, std)
487
488
489def _fused_adagrad(
490    params: List[Tensor],
491    grads: List[Tensor],
492    state_sums: List[Tensor],
493    state_steps: List[Tensor],
494    grad_scale: Optional[Tensor],
495    found_inf: Optional[Tensor],
496    *,
497    lr: float,
498    weight_decay: float,
499    lr_decay: float,
500    eps: float,
501    has_sparse_grad: bool,
502    maximize: bool,
503    differentiable: bool,
504    has_complex: bool,
505) -> None:
506    if not params:
507        return
508    if has_sparse_grad or has_complex:
509        raise RuntimeError("`fused` does not support sparse grad or complex param")
510
511    if differentiable:
512        raise RuntimeError(
513            "adagrad with fused=True does not support differentiable=True"
514        )
515
516    grad_scale_dict = (
517        {grad_scale.device: grad_scale} if grad_scale is not None else None
518    )
519    found_inf_dict = {found_inf.device: found_inf} if found_inf is not None else None
520
521    grouped_tensors = Optimizer._group_tensors_by_device_and_dtype(
522        [params, grads, state_sums, state_steps]  # type: ignore[list-item]
523    )
524    for (device, _), (
525        (
526            device_params_,
527            device_grads_,
528            device_state_sums_,
529            device_state_steps_,
530        ),
531        _,
532    ) in grouped_tensors.items():
533        device_params = cast(List[Tensor], device_params_)
534        device_grads = cast(List[Tensor], device_grads_)
535        device_state_sums = cast(List[Tensor], device_state_sums_)
536        device_state_steps = cast(List[Tensor], device_state_steps_)
537
538        device_grad_scale, device_found_inf = None, None
539        if grad_scale is not None and grad_scale_dict is not None:
540            if device not in grad_scale_dict:
541                grad_scale_dict[device] = grad_scale.to(device, non_blocking=True)  # type: ignore[index]
542            device_grad_scale = grad_scale_dict[device]  # type: ignore[index]
543        if found_inf is not None and found_inf_dict is not None:
544            if found_inf not in found_inf_dict:
545                found_inf_dict[device] = found_inf.to(device, non_blocking=True)  # type: ignore[index]
546            device_found_inf = found_inf_dict[device]  # type: ignore[index]
547        torch._foreach_add_(device_state_steps, 1)
548        torch._fused_adagrad_(
549            device_params,
550            device_grads,
551            device_state_sums,
552            device_state_steps,
553            lr=lr,
554            lr_decay=lr_decay,
555            weight_decay=weight_decay,
556            eps=eps,
557            maximize=maximize,
558            grad_scale=device_grad_scale,
559            found_inf=device_found_inf,
560        )
561        if device_found_inf is not None:
562            torch._foreach_sub_(
563                device_state_steps, [device_found_inf] * len(device_state_steps)
564            )
565