xref: /aosp_15_r20/external/pytorch/torch/nn/common_types.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1from typing import Optional, Tuple, TypeVar, Union
2
3from torch import Tensor
4
5
6# Create some useful type aliases
7
8# Template for arguments which can be supplied as a tuple, or which can be a scalar which PyTorch will internally
9# broadcast to a tuple.
10# Comes in several variants: A tuple of unknown size, and a fixed-size tuple for 1d, 2d, or 3d operations.
11T = TypeVar("T")
12_scalar_or_tuple_any_t = Union[T, Tuple[T, ...]]
13_scalar_or_tuple_1_t = Union[T, Tuple[T]]
14_scalar_or_tuple_2_t = Union[T, Tuple[T, T]]
15_scalar_or_tuple_3_t = Union[T, Tuple[T, T, T]]
16_scalar_or_tuple_4_t = Union[T, Tuple[T, T, T, T]]
17_scalar_or_tuple_5_t = Union[T, Tuple[T, T, T, T, T]]
18_scalar_or_tuple_6_t = Union[T, Tuple[T, T, T, T, T, T]]
19
20# For arguments which represent size parameters (eg, kernel size, padding)
21_size_any_t = _scalar_or_tuple_any_t[int]
22_size_1_t = _scalar_or_tuple_1_t[int]
23_size_2_t = _scalar_or_tuple_2_t[int]
24_size_3_t = _scalar_or_tuple_3_t[int]
25_size_4_t = _scalar_or_tuple_4_t[int]
26_size_5_t = _scalar_or_tuple_5_t[int]
27_size_6_t = _scalar_or_tuple_6_t[int]
28
29# For arguments which represent optional size parameters (eg, adaptive pool parameters)
30_size_any_opt_t = _scalar_or_tuple_any_t[Optional[int]]
31_size_2_opt_t = _scalar_or_tuple_2_t[Optional[int]]
32_size_3_opt_t = _scalar_or_tuple_3_t[Optional[int]]
33
34# For arguments that represent a ratio to adjust each dimension of an input with (eg, upsampling parameters)
35_ratio_2_t = _scalar_or_tuple_2_t[float]
36_ratio_3_t = _scalar_or_tuple_3_t[float]
37_ratio_any_t = _scalar_or_tuple_any_t[float]
38
39_tensor_list_t = _scalar_or_tuple_any_t[Tensor]
40
41# For the return value of max pooling operations that may or may not return indices.
42# With the proposed 'Literal' feature to Python typing, it might be possible to
43# eventually eliminate this.
44_maybe_indices_t = _scalar_or_tuple_2_t[Tensor]
45