1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# pyre-unsafe 8 9from __future__ import annotations 10 11from typing import Any, Dict, List, Optional, Sequence, Tuple 12 13from executorch.exir._warnings import experimental 14from torch import Tensor 15 16@experimental("This API is experimental and subject to change without notice.") 17class ExecuTorchSGD: 18 """SGD Optimizer. 19 20 .. warning:: 21 22 This API is experimental and subject to change without notice. 23 """ 24 25 def step(self, named_gradients: Dict[str, Tensor]) -> None: 26 """Take a step in the direction of the gradients.""" 27 ... 28 29@experimental("This API is experimental and subject to change without notice.") 30def get_sgd_optimizer( 31 named_parameters: Dict[str, Tensor], 32 lr: float, 33 momentum: float = 0, 34 dampening: float = 0, 35 weight_decay: float = 0, 36 nesterov: bool = False, 37) -> ExecuTorchSGD: 38 """Creates an sgd optimizer that operates on the passed in named_parameters according to the specified hyper parameters. 39 40 .. warning:: 41 42 This API is experimental and subject to change without notice. 43 ... 44 """ 45 ... 46