xref: /aosp_15_r20/external/executorch/extension/training/pybindings/_training_lib.pyi (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# pyre-unsafe
8
9from __future__ import annotations
10
11from typing import Any, Dict, List, Optional, Sequence, Tuple
12
13from executorch.exir._warnings import experimental
14from torch import Tensor
15
16@experimental("This API is experimental and subject to change without notice.")
17class ExecuTorchSGD:
18    """SGD Optimizer.
19
20    .. warning::
21
22        This API is experimental and subject to change without notice.
23    """
24
25    def step(self, named_gradients: Dict[str, Tensor]) -> None:
26        """Take a step in the direction of the gradients."""
27        ...
28
29@experimental("This API is experimental and subject to change without notice.")
30def get_sgd_optimizer(
31    named_parameters: Dict[str, Tensor],
32    lr: float,
33    momentum: float = 0,
34    dampening: float = 0,
35    weight_decay: float = 0,
36    nesterov: bool = False,
37) -> ExecuTorchSGD:
38    """Creates an sgd optimizer that operates on the passed in named_parameters according to the specified hyper parameters.
39
40    .. warning::
41
42        This API is experimental and subject to change without notice.
43    ...
44    """
45    ...
46