1from typing import Optional, Tuple, TypeVar, Union 2 3from torch import Tensor 4 5 6# Create some useful type aliases 7 8# Template for arguments which can be supplied as a tuple, or which can be a scalar which PyTorch will internally 9# broadcast to a tuple. 10# Comes in several variants: A tuple of unknown size, and a fixed-size tuple for 1d, 2d, or 3d operations. 11T = TypeVar("T") 12_scalar_or_tuple_any_t = Union[T, Tuple[T, ...]] 13_scalar_or_tuple_1_t = Union[T, Tuple[T]] 14_scalar_or_tuple_2_t = Union[T, Tuple[T, T]] 15_scalar_or_tuple_3_t = Union[T, Tuple[T, T, T]] 16_scalar_or_tuple_4_t = Union[T, Tuple[T, T, T, T]] 17_scalar_or_tuple_5_t = Union[T, Tuple[T, T, T, T, T]] 18_scalar_or_tuple_6_t = Union[T, Tuple[T, T, T, T, T, T]] 19 20# For arguments which represent size parameters (eg, kernel size, padding) 21_size_any_t = _scalar_or_tuple_any_t[int] 22_size_1_t = _scalar_or_tuple_1_t[int] 23_size_2_t = _scalar_or_tuple_2_t[int] 24_size_3_t = _scalar_or_tuple_3_t[int] 25_size_4_t = _scalar_or_tuple_4_t[int] 26_size_5_t = _scalar_or_tuple_5_t[int] 27_size_6_t = _scalar_or_tuple_6_t[int] 28 29# For arguments which represent optional size parameters (eg, adaptive pool parameters) 30_size_any_opt_t = _scalar_or_tuple_any_t[Optional[int]] 31_size_2_opt_t = _scalar_or_tuple_2_t[Optional[int]] 32_size_3_opt_t = _scalar_or_tuple_3_t[Optional[int]] 33 34# For arguments that represent a ratio to adjust each dimension of an input with (eg, upsampling parameters) 35_ratio_2_t = _scalar_or_tuple_2_t[float] 36_ratio_3_t = _scalar_or_tuple_3_t[float] 37_ratio_any_t = _scalar_or_tuple_any_t[float] 38 39_tensor_list_t = _scalar_or_tuple_any_t[Tensor] 40 41# For the return value of max pooling operations that may or may not return indices. 42# With the proposed 'Literal' feature to Python typing, it might be possible to 43# eventually eliminate this. 44_maybe_indices_t = _scalar_or_tuple_2_t[Tensor] 45