xref: /aosp_15_r20/external/pytorch/torch/nn/parameter.pyi (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from typing_extensions import TypeGuard
3
4from torch import device, dtype, Tensor
5
6class Parameter(Tensor):
7    def __init__(self, data: Tensor = ..., requires_grad: bool = ...) -> None: ...
8
9def is_lazy(
10    param: Tensor,
11) -> TypeGuard[UninitializedParameter | UninitializedBuffer]: ...
12
13class UninitializedParameter(Tensor):
14    def __init__(self, data: Tensor = ..., requires_grad: bool = ...) -> None: ...
15    def materialize(
16        self,
17        shape: tuple[int, ...],
18        device: device | None = None,
19        dtype: dtype | None = None,
20    ) -> None: ...
21
22class Buffer(Tensor):
23    persistent: bool
24    def __init__(
25        self,
26        data: Tensor = ...,
27        requires_grad: bool = ...,
28        persistent: bool = ...,
29    ): ...
30
31class UninitializedBuffer(Tensor):
32    persistent: bool
33    def __init__(
34        self,
35        data: Tensor = ...,
36        requires_grad: bool = ...,
37        persistent: bool = ...,
38    ): ...
39    def materialize(
40        self,
41        shape: tuple[int, ...],
42        device: device | None = None,
43        dtype: dtype | None = None,
44    ) -> None: ...
45