1*da0073e9SAndroid Build Coastguard Worker# ${generated_comment} 2*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-defs 3*da0073e9SAndroid Build Coastguard Worker 4*da0073e9SAndroid Build Coastguard Workerfrom typing import ( 5*da0073e9SAndroid Build Coastguard Worker Any, 6*da0073e9SAndroid Build Coastguard Worker Callable, 7*da0073e9SAndroid Build Coastguard Worker Dict, 8*da0073e9SAndroid Build Coastguard Worker List, 9*da0073e9SAndroid Build Coastguard Worker Literal, 10*da0073e9SAndroid Build Coastguard Worker Optional, 11*da0073e9SAndroid Build Coastguard Worker overload, 12*da0073e9SAndroid Build Coastguard Worker Sequence, 13*da0073e9SAndroid Build Coastguard Worker Tuple, 14*da0073e9SAndroid Build Coastguard Worker Union, 15*da0073e9SAndroid Build Coastguard Worker) 16*da0073e9SAndroid Build Coastguard Worker 17*da0073e9SAndroid Build Coastguard Workerfrom torch import Tensor 18*da0073e9SAndroid Build Coastguard Workerfrom torch.types import _dtype, _int, _size 19*da0073e9SAndroid Build Coastguard Worker 20*da0073e9SAndroid Build Coastguard Workerfrom .common_types import ( 21*da0073e9SAndroid Build Coastguard Worker _ratio_any_t, 22*da0073e9SAndroid Build Coastguard Worker _size_1_t, 23*da0073e9SAndroid Build Coastguard Worker _size_2_opt_t, 24*da0073e9SAndroid Build Coastguard Worker _size_2_t, 25*da0073e9SAndroid Build Coastguard Worker _size_3_opt_t, 26*da0073e9SAndroid Build Coastguard Worker _size_3_t, 27*da0073e9SAndroid Build Coastguard Worker _size_any_t, 28*da0073e9SAndroid Build Coastguard Worker) 29*da0073e9SAndroid Build Coastguard Worker 30*da0073e9SAndroid Build Coastguard Worker# 'TypedDict' is a new accepted type that represents a dictionary with a fixed set of allowed keys. 31*da0073e9SAndroid Build Coastguard Worker# It is standards-track but not in `typing` yet. We leave this hear to be uncommented once the feature 32*da0073e9SAndroid Build Coastguard Worker# is wide-spread. 33*da0073e9SAndroid Build Coastguard Worker 34*da0073e9SAndroid Build Coastguard Worker# from mypy_extensions import TypedDict 35*da0073e9SAndroid Build Coastguard Worker 36*da0073e9SAndroid Build Coastguard Worker# GRID_SAMPLE_INTERPOLATION_MODES = TypedDict('GRID_SAMPLE_INTERPOLATION_MODES', {'bilinear': int, 'nearest': int}) 37*da0073e9SAndroid Build Coastguard Worker# GRID_SAMPLE_PADDING_MODES = TypedDict('GRID_SAMPLE_PADDING_MODES', {'zeros': int, 'border': int, 'reflection': int}) 38*da0073e9SAndroid Build Coastguard Worker 39*da0073e9SAndroid Build Coastguard WorkerGRID_SAMPLE_INTERPOLATION_MODES = Dict[str, int] 40*da0073e9SAndroid Build Coastguard WorkerGRID_SAMPLE_PADDING_MODES = Dict[str, int] 41*da0073e9SAndroid Build Coastguard Worker 42*da0073e9SAndroid Build Coastguard Worker# These stubs were generated by running stubgen (`stubgen --parse-only functional.py`), followed by manual cleaning. 43*da0073e9SAndroid Build Coastguard Worker# 44*da0073e9SAndroid Build Coastguard Worker# The 'BroadcastingList{1,2,3}' types were replaced by `_size` or _output_ratio, as appropriate. 45*da0073e9SAndroid Build Coastguard Worker# This was necessary since the JIT uses BroadcastingList* types but static checking with mypy etc requires a `Sequence` 46*da0073e9SAndroid Build Coastguard Worker# type. There is no way to express the expected lengths of these lists in the current Python typing system. 47*da0073e9SAndroid Build Coastguard Worker# 48*da0073e9SAndroid Build Coastguard Worker# Functions created via `_add_docstr` in `functional.py` where merely typed as `Any` by `stubgen`, so those were 49*da0073e9SAndroid Build Coastguard Worker# deleted from the stub and replaced by generated declarations. See `gen_pyi` for the implementation of the code 50*da0073e9SAndroid Build Coastguard Worker# generation logic for those functions. In the future, it might be worth looking into using the mypy plugin system 51*da0073e9SAndroid Build Coastguard Worker# to encode the type semantics of `_add_docstr`, should that system ever become widespread. 52*da0073e9SAndroid Build Coastguard Workerdef fractional_max_pool2d_with_indices( 53*da0073e9SAndroid Build Coastguard Worker input: Tensor, 54*da0073e9SAndroid Build Coastguard Worker kernel_size: _size, 55*da0073e9SAndroid Build Coastguard Worker output_size: Optional[_size] = ..., 56*da0073e9SAndroid Build Coastguard Worker output_ratio: Optional[_ratio_any_t] = ..., 57*da0073e9SAndroid Build Coastguard Worker return_indices: bool = ..., 58*da0073e9SAndroid Build Coastguard Worker _random_samples: Optional[Tensor] = ..., 59*da0073e9SAndroid Build Coastguard Worker) -> Tuple[Tensor, Tensor]: ... 60*da0073e9SAndroid Build Coastguard Workerdef fractional_max_pool3d_with_indices( 61*da0073e9SAndroid Build Coastguard Worker input: Tensor, 62*da0073e9SAndroid Build Coastguard Worker kernel_size: _size, 63*da0073e9SAndroid Build Coastguard Worker output_size: Optional[_size] = ..., 64*da0073e9SAndroid Build Coastguard Worker output_ratio: Optional[_ratio_any_t] = ..., 65*da0073e9SAndroid Build Coastguard Worker return_indices: bool = ..., 66*da0073e9SAndroid Build Coastguard Worker _random_samples: Optional[Tensor] = ..., 67*da0073e9SAndroid Build Coastguard Worker) -> Tuple[Tensor, Tensor]: ... 68*da0073e9SAndroid Build Coastguard Workerdef max_pool1d_with_indices( 69*da0073e9SAndroid Build Coastguard Worker input: Tensor, 70*da0073e9SAndroid Build Coastguard Worker kernel_size: _size, 71*da0073e9SAndroid Build Coastguard Worker stride: Optional[_size] = ..., 72*da0073e9SAndroid Build Coastguard Worker padding: _size = ..., 73*da0073e9SAndroid Build Coastguard Worker dilation: _size = ..., 74*da0073e9SAndroid Build Coastguard Worker ceil_mode: bool = ..., 75*da0073e9SAndroid Build Coastguard Worker return_indices: bool = ..., 76*da0073e9SAndroid Build Coastguard Worker) -> Tuple[Tensor, Tensor]: ... 77*da0073e9SAndroid Build Coastguard Workerdef max_pool2d_with_indices( 78*da0073e9SAndroid Build Coastguard Worker input: Tensor, 79*da0073e9SAndroid Build Coastguard Worker kernel_size: _size, 80*da0073e9SAndroid Build Coastguard Worker stride: Optional[_size] = ..., 81*da0073e9SAndroid Build Coastguard Worker padding: _size = ..., 82*da0073e9SAndroid Build Coastguard Worker dilation: _size = ..., 83*da0073e9SAndroid Build Coastguard Worker ceil_mode: bool = ..., 84*da0073e9SAndroid Build Coastguard Worker return_indices: bool = ..., 85*da0073e9SAndroid Build Coastguard Worker) -> Tuple[Tensor, Tensor]: ... 86*da0073e9SAndroid Build Coastguard Workerdef max_pool3d_with_indices( 87*da0073e9SAndroid Build Coastguard Worker input: Tensor, 88*da0073e9SAndroid Build Coastguard Worker kernel_size: _size, 89*da0073e9SAndroid Build Coastguard Worker stride: Optional[_size] = ..., 90*da0073e9SAndroid Build Coastguard Worker padding: _size = ..., 91*da0073e9SAndroid Build Coastguard Worker dilation: _size = ..., 92*da0073e9SAndroid Build Coastguard Worker ceil_mode: bool = ..., 93*da0073e9SAndroid Build Coastguard Worker return_indices: bool = ..., 94*da0073e9SAndroid Build Coastguard Worker) -> Tuple[Tensor, Tensor]: ... 95*da0073e9SAndroid Build Coastguard Workerdef max_unpool1d( 96*da0073e9SAndroid Build Coastguard Worker input: Tensor, 97*da0073e9SAndroid Build Coastguard Worker indices: Tensor, 98*da0073e9SAndroid Build Coastguard Worker kernel_size: _size, 99*da0073e9SAndroid Build Coastguard Worker stride: Optional[_size] = ..., 100*da0073e9SAndroid Build Coastguard Worker padding: _size = ..., 101*da0073e9SAndroid Build Coastguard Worker output_size: Optional[_size] = ..., 102*da0073e9SAndroid Build Coastguard Worker) -> Tensor: ... 103*da0073e9SAndroid Build Coastguard Workerdef max_unpool2d( 104*da0073e9SAndroid Build Coastguard Worker input: Tensor, 105*da0073e9SAndroid Build Coastguard Worker indices: Tensor, 106*da0073e9SAndroid Build Coastguard Worker kernel_size: _size, 107*da0073e9SAndroid Build Coastguard Worker stride: Optional[_size] = ..., 108*da0073e9SAndroid Build Coastguard Worker padding: _size = ..., 109*da0073e9SAndroid Build Coastguard Worker output_size: Optional[_size] = ..., 110*da0073e9SAndroid Build Coastguard Worker) -> Tensor: ... 111*da0073e9SAndroid Build Coastguard Workerdef max_unpool3d( 112*da0073e9SAndroid Build Coastguard Worker input: Tensor, 113*da0073e9SAndroid Build Coastguard Worker indices: Tensor, 114*da0073e9SAndroid Build Coastguard Worker kernel_size: _size, 115*da0073e9SAndroid Build Coastguard Worker stride: Optional[_size] = ..., 116*da0073e9SAndroid Build Coastguard Worker padding: _size = ..., 117*da0073e9SAndroid Build Coastguard Worker output_size: Optional[_size] = ..., 118*da0073e9SAndroid Build Coastguard Worker) -> Tensor: ... 119*da0073e9SAndroid Build Coastguard Workerdef lp_pool1d( 120*da0073e9SAndroid Build Coastguard Worker input: Tensor, 121*da0073e9SAndroid Build Coastguard Worker norm_type: float, 122*da0073e9SAndroid Build Coastguard Worker kernel_size: _size_1_t, 123*da0073e9SAndroid Build Coastguard Worker stride: Union[Optional[_size], Optional[int]] = ..., 124*da0073e9SAndroid Build Coastguard Worker ceil_mode: bool = ..., 125*da0073e9SAndroid Build Coastguard Worker) -> Tensor: ... 126*da0073e9SAndroid Build Coastguard Workerdef lp_pool2d( 127*da0073e9SAndroid Build Coastguard Worker input: Tensor, 128*da0073e9SAndroid Build Coastguard Worker norm_type: float, 129*da0073e9SAndroid Build Coastguard Worker kernel_size: _size_2_t, 130*da0073e9SAndroid Build Coastguard Worker stride: Union[Optional[_size], Optional[int]] = ..., 131*da0073e9SAndroid Build Coastguard Worker ceil_mode: bool = ..., 132*da0073e9SAndroid Build Coastguard Worker) -> Tensor: ... 133*da0073e9SAndroid Build Coastguard Workerdef lp_pool3d( 134*da0073e9SAndroid Build Coastguard Worker input: Tensor, 135*da0073e9SAndroid Build Coastguard Worker norm_type: float, 136*da0073e9SAndroid Build Coastguard Worker kernel_size: _size_3_t, 137*da0073e9SAndroid Build Coastguard Worker stride: Union[Optional[_size], Optional[int]] = ..., 138*da0073e9SAndroid Build Coastguard Worker ceil_mode: bool = ..., 139*da0073e9SAndroid Build Coastguard Worker) -> Tensor: ... 140*da0073e9SAndroid Build Coastguard Workerdef adaptive_max_pool1d_with_indices( 141*da0073e9SAndroid Build Coastguard Worker input: Tensor, 142*da0073e9SAndroid Build Coastguard Worker output_size: _size, 143*da0073e9SAndroid Build Coastguard Worker return_indices: bool = ..., 144*da0073e9SAndroid Build Coastguard Worker) -> Tuple[Tensor, Tensor]: ... 145*da0073e9SAndroid Build Coastguard Workerdef adaptive_max_pool2d_with_indices( 146*da0073e9SAndroid Build Coastguard Worker input: Tensor, 147*da0073e9SAndroid Build Coastguard Worker output_size: _size_2_opt_t, 148*da0073e9SAndroid Build Coastguard Worker return_indices: bool = ..., 149*da0073e9SAndroid Build Coastguard Worker) -> Tuple[Tensor, Tensor]: ... 150*da0073e9SAndroid Build Coastguard Workerdef adaptive_max_pool3d_with_indices( 151*da0073e9SAndroid Build Coastguard Worker input: Tensor, 152*da0073e9SAndroid Build Coastguard Worker output_size: _size_3_opt_t, 153*da0073e9SAndroid Build Coastguard Worker return_indices: bool = ..., 154*da0073e9SAndroid Build Coastguard Worker) -> Tuple[Tensor, Tensor]: ... 155*da0073e9SAndroid Build Coastguard Workerdef adaptive_avg_pool2d(input: Tensor, output_size: _size_2_opt_t) -> Tensor: ... 156*da0073e9SAndroid Build Coastguard Workerdef adaptive_avg_pool3d(input: Tensor, output_size: _size_3_opt_t) -> Tensor: ... 157*da0073e9SAndroid Build Coastguard Workerdef dropout( 158*da0073e9SAndroid Build Coastguard Worker input: Tensor, 159*da0073e9SAndroid Build Coastguard Worker p: float = ..., 160*da0073e9SAndroid Build Coastguard Worker training: bool = ..., 161*da0073e9SAndroid Build Coastguard Worker inplace: bool = ..., 162*da0073e9SAndroid Build Coastguard Worker) -> Tensor: ... 163*da0073e9SAndroid Build Coastguard Workerdef alpha_dropout( 164*da0073e9SAndroid Build Coastguard Worker input: Tensor, 165*da0073e9SAndroid Build Coastguard Worker p: float = ..., 166*da0073e9SAndroid Build Coastguard Worker training: bool = ..., 167*da0073e9SAndroid Build Coastguard Worker inplace: bool = ..., 168*da0073e9SAndroid Build Coastguard Worker) -> Tensor: ... 169*da0073e9SAndroid Build Coastguard Workerdef dropout1d( 170*da0073e9SAndroid Build Coastguard Worker input: Tensor, 171*da0073e9SAndroid Build Coastguard Worker p: float = ..., 172*da0073e9SAndroid Build Coastguard Worker training: bool = ..., 173*da0073e9SAndroid Build Coastguard Worker inplace: bool = ..., 174*da0073e9SAndroid Build Coastguard Worker) -> Tensor: ... 175*da0073e9SAndroid Build Coastguard Workerdef dropout2d( 176*da0073e9SAndroid Build Coastguard Worker input: Tensor, 177*da0073e9SAndroid Build Coastguard Worker p: float = ..., 178*da0073e9SAndroid Build Coastguard Worker training: bool = ..., 179*da0073e9SAndroid Build Coastguard Worker inplace: bool = ..., 180*da0073e9SAndroid Build Coastguard Worker) -> Tensor: ... 181*da0073e9SAndroid Build Coastguard Workerdef dropout3d( 182*da0073e9SAndroid Build Coastguard Worker input: Tensor, 183*da0073e9SAndroid Build Coastguard Worker p: float = ..., 184*da0073e9SAndroid Build Coastguard Worker training: bool = ..., 185*da0073e9SAndroid Build Coastguard Worker inplace: bool = ..., 186*da0073e9SAndroid Build Coastguard Worker) -> Tensor: ... 187*da0073e9SAndroid Build Coastguard Workerdef feature_alpha_dropout( 188*da0073e9SAndroid Build Coastguard Worker input: Tensor, 189*da0073e9SAndroid Build Coastguard Worker p: float = ..., 190*da0073e9SAndroid Build Coastguard Worker training: bool = ..., 191*da0073e9SAndroid Build Coastguard Worker inplace: bool = ..., 192*da0073e9SAndroid Build Coastguard Worker) -> Tensor: ... 193*da0073e9SAndroid Build Coastguard Workerdef threshold( 194*da0073e9SAndroid Build Coastguard Worker input: Tensor, 195*da0073e9SAndroid Build Coastguard Worker threshold: float, 196*da0073e9SAndroid Build Coastguard Worker value: float, 197*da0073e9SAndroid Build Coastguard Worker inplace: bool = ..., 198*da0073e9SAndroid Build Coastguard Worker) -> Tensor: ... 199*da0073e9SAndroid Build Coastguard Workerdef relu(input: Tensor, inplace: bool = ...) -> Tensor: ... 200*da0073e9SAndroid Build Coastguard Workerdef glu(input: Tensor, dim: int = ...) -> Tensor: ... 201*da0073e9SAndroid Build Coastguard Workerdef hardtanh( 202*da0073e9SAndroid Build Coastguard Worker input: Tensor, 203*da0073e9SAndroid Build Coastguard Worker min_val: float = ..., 204*da0073e9SAndroid Build Coastguard Worker max_val: float = ..., 205*da0073e9SAndroid Build Coastguard Worker inplace: bool = ..., 206*da0073e9SAndroid Build Coastguard Worker) -> Tensor: ... 207*da0073e9SAndroid Build Coastguard Workerdef relu6(input: Tensor, inplace: bool = ...) -> Tensor: ... 208*da0073e9SAndroid Build Coastguard Workerdef elu(input: Tensor, alpha: float = ..., inplace: bool = ...) -> Tensor: ... 209*da0073e9SAndroid Build Coastguard Workerdef selu(input: Tensor, inplace: bool = ...) -> Tensor: ... 210*da0073e9SAndroid Build Coastguard Workerdef celu(input: Tensor, alpha: float = ..., inplace: bool = ...) -> Tensor: ... 211*da0073e9SAndroid Build Coastguard Workerdef leaky_relu( 212*da0073e9SAndroid Build Coastguard Worker input: Tensor, 213*da0073e9SAndroid Build Coastguard Worker negative_slope: float = ..., 214*da0073e9SAndroid Build Coastguard Worker inplace: bool = ..., 215*da0073e9SAndroid Build Coastguard Worker) -> Tensor: ... 216*da0073e9SAndroid Build Coastguard Workerdef rrelu( 217*da0073e9SAndroid Build Coastguard Worker input: Tensor, 218*da0073e9SAndroid Build Coastguard Worker lower: float = ..., 219*da0073e9SAndroid Build Coastguard Worker upper: float = ..., 220*da0073e9SAndroid Build Coastguard Worker training: bool = ..., 221*da0073e9SAndroid Build Coastguard Worker inplace: bool = ..., 222*da0073e9SAndroid Build Coastguard Worker) -> Tensor: ... 223*da0073e9SAndroid Build Coastguard Workerdef tanhshrink(input: Any): ... 224*da0073e9SAndroid Build Coastguard Workerdef softsign(input: Any): ... 225*da0073e9SAndroid Build Coastguard Workerdef softmin( 226*da0073e9SAndroid Build Coastguard Worker input: Tensor, 227*da0073e9SAndroid Build Coastguard Worker dim: Optional[int] = ..., 228*da0073e9SAndroid Build Coastguard Worker _stacklevel: int = ..., 229*da0073e9SAndroid Build Coastguard Worker dtype: Optional[_dtype] = ..., 230*da0073e9SAndroid Build Coastguard Worker) -> Tensor: ... 231*da0073e9SAndroid Build Coastguard Workerdef softmax( 232*da0073e9SAndroid Build Coastguard Worker input: Tensor, 233*da0073e9SAndroid Build Coastguard Worker dim: Optional[int] = ..., 234*da0073e9SAndroid Build Coastguard Worker _stacklevel: int = ..., 235*da0073e9SAndroid Build Coastguard Worker dtype: Optional[_dtype] = ..., 236*da0073e9SAndroid Build Coastguard Worker) -> Tensor: ... 237*da0073e9SAndroid Build Coastguard Workerdef gumbel_softmax( 238*da0073e9SAndroid Build Coastguard Worker logits: Tensor, 239*da0073e9SAndroid Build Coastguard Worker tau: float = ..., 240*da0073e9SAndroid Build Coastguard Worker hard: bool = ..., 241*da0073e9SAndroid Build Coastguard Worker eps: float = ..., 242*da0073e9SAndroid Build Coastguard Worker dim: int = ..., 243*da0073e9SAndroid Build Coastguard Worker) -> Tensor: ... 244*da0073e9SAndroid Build Coastguard Workerdef log_softmax( 245*da0073e9SAndroid Build Coastguard Worker input: Tensor, 246*da0073e9SAndroid Build Coastguard Worker dim: Optional[int] = ..., 247*da0073e9SAndroid Build Coastguard Worker _stacklevel: int = ..., 248*da0073e9SAndroid Build Coastguard Worker dtype: Optional[_dtype] = ..., 249*da0073e9SAndroid Build Coastguard Worker) -> Tensor: ... 250*da0073e9SAndroid Build Coastguard Workerdef tanh(input: Any): ... 251*da0073e9SAndroid Build Coastguard Workerdef sigmoid(input: Any) -> Tensor: ... 252*da0073e9SAndroid Build Coastguard Workerdef hardsigmoid(input: Tensor, inplace: bool = False) -> Tensor: ... 253*da0073e9SAndroid Build Coastguard Workerdef silu(input: Tensor, inplace: bool = False) -> Tensor: ... 254*da0073e9SAndroid Build Coastguard Workerdef mish(input: Tensor, inplace: bool = False) -> Tensor: ... 255*da0073e9SAndroid Build Coastguard Workerdef hardswish(input: Tensor, inplace: bool = False) -> Tensor: ... 256*da0073e9SAndroid Build Coastguard Workerdef embedding( 257*da0073e9SAndroid Build Coastguard Worker input: Tensor, 258*da0073e9SAndroid Build Coastguard Worker weight: Tensor, 259*da0073e9SAndroid Build Coastguard Worker padding_idx: Optional[int] = ..., 260*da0073e9SAndroid Build Coastguard Worker max_norm: Optional[float] = ..., 261*da0073e9SAndroid Build Coastguard Worker norm_type: float = ..., 262*da0073e9SAndroid Build Coastguard Worker scale_grad_by_freq: bool = ..., 263*da0073e9SAndroid Build Coastguard Worker sparse: bool = ..., 264*da0073e9SAndroid Build Coastguard Worker) -> Tensor: ... 265*da0073e9SAndroid Build Coastguard Workerdef embedding_bag( 266*da0073e9SAndroid Build Coastguard Worker input: Tensor, 267*da0073e9SAndroid Build Coastguard Worker weight: Tensor, 268*da0073e9SAndroid Build Coastguard Worker offsets: Optional[Tensor] = ..., 269*da0073e9SAndroid Build Coastguard Worker max_norm: Optional[float] = ..., 270*da0073e9SAndroid Build Coastguard Worker norm_type: float = ..., 271*da0073e9SAndroid Build Coastguard Worker scale_grad_by_freq: bool = ..., 272*da0073e9SAndroid Build Coastguard Worker mode: str = ..., 273*da0073e9SAndroid Build Coastguard Worker sparse: bool = ..., 274*da0073e9SAndroid Build Coastguard Worker per_sample_weights: Optional[Tensor] = ..., 275*da0073e9SAndroid Build Coastguard Worker include_last_offset: bool = ..., 276*da0073e9SAndroid Build Coastguard Worker padding_idx: Optional[int] = ..., 277*da0073e9SAndroid Build Coastguard Worker) -> Tensor: ... 278*da0073e9SAndroid Build Coastguard Workerdef batch_norm( 279*da0073e9SAndroid Build Coastguard Worker input: Tensor, 280*da0073e9SAndroid Build Coastguard Worker running_mean: Optional[Tensor], 281*da0073e9SAndroid Build Coastguard Worker running_var: Optional[Tensor], 282*da0073e9SAndroid Build Coastguard Worker weight: Optional[Tensor] = ..., 283*da0073e9SAndroid Build Coastguard Worker bias: Optional[Tensor] = ..., 284*da0073e9SAndroid Build Coastguard Worker training: bool = ..., 285*da0073e9SAndroid Build Coastguard Worker momentum: float = ..., 286*da0073e9SAndroid Build Coastguard Worker eps: float = ..., 287*da0073e9SAndroid Build Coastguard Worker) -> Tensor: ... 288*da0073e9SAndroid Build Coastguard Workerdef instance_norm( 289*da0073e9SAndroid Build Coastguard Worker input: Tensor, 290*da0073e9SAndroid Build Coastguard Worker running_mean: Optional[Tensor] = ..., 291*da0073e9SAndroid Build Coastguard Worker running_var: Optional[Tensor] = ..., 292*da0073e9SAndroid Build Coastguard Worker weight: Optional[Tensor] = ..., 293*da0073e9SAndroid Build Coastguard Worker bias: Optional[Tensor] = ..., 294*da0073e9SAndroid Build Coastguard Worker use_input_stats: bool = ..., 295*da0073e9SAndroid Build Coastguard Worker momentum: float = ..., 296*da0073e9SAndroid Build Coastguard Worker eps: float = ..., 297*da0073e9SAndroid Build Coastguard Worker) -> Tensor: ... 298*da0073e9SAndroid Build Coastguard Workerdef layer_norm( 299*da0073e9SAndroid Build Coastguard Worker input: Tensor, 300*da0073e9SAndroid Build Coastguard Worker normalized_shape: Sequence[int], 301*da0073e9SAndroid Build Coastguard Worker weight: Optional[Tensor] = ..., 302*da0073e9SAndroid Build Coastguard Worker bias: Optional[Tensor] = ..., 303*da0073e9SAndroid Build Coastguard Worker eps: float = ..., 304*da0073e9SAndroid Build Coastguard Worker) -> Tensor: ... 305*da0073e9SAndroid Build Coastguard Workerdef rms_norm( 306*da0073e9SAndroid Build Coastguard Worker input: Tensor, 307*da0073e9SAndroid Build Coastguard Worker normalized_shape: Sequence[int], 308*da0073e9SAndroid Build Coastguard Worker weight: Optional[Tensor] = ..., 309*da0073e9SAndroid Build Coastguard Worker eps: Optional[float] = ..., 310*da0073e9SAndroid Build Coastguard Worker) -> Tensor: ... 311*da0073e9SAndroid Build Coastguard Workerdef group_norm( 312*da0073e9SAndroid Build Coastguard Worker input: Tensor, 313*da0073e9SAndroid Build Coastguard Worker num_groups: int, 314*da0073e9SAndroid Build Coastguard Worker weight: Optional[Tensor] = ..., 315*da0073e9SAndroid Build Coastguard Worker bias: Optional[Tensor] = ..., 316*da0073e9SAndroid Build Coastguard Worker eps: float = ..., 317*da0073e9SAndroid Build Coastguard Worker) -> Tensor: ... 318*da0073e9SAndroid Build Coastguard Workerdef local_response_norm( 319*da0073e9SAndroid Build Coastguard Worker input: Tensor, 320*da0073e9SAndroid Build Coastguard Worker size: int, 321*da0073e9SAndroid Build Coastguard Worker alpha: float = ..., 322*da0073e9SAndroid Build Coastguard Worker beta: float = ..., 323*da0073e9SAndroid Build Coastguard Worker k: float = ..., 324*da0073e9SAndroid Build Coastguard Worker) -> Tensor: ... 325*da0073e9SAndroid Build Coastguard Workerdef ctc_loss( 326*da0073e9SAndroid Build Coastguard Worker log_probs: Tensor, 327*da0073e9SAndroid Build Coastguard Worker targets: Tensor, 328*da0073e9SAndroid Build Coastguard Worker input_lengths: Tensor, 329*da0073e9SAndroid Build Coastguard Worker target_lengths: Tensor, 330*da0073e9SAndroid Build Coastguard Worker blank: int = ..., 331*da0073e9SAndroid Build Coastguard Worker reduction: str = ..., 332*da0073e9SAndroid Build Coastguard Worker zero_infinity: bool = ..., 333*da0073e9SAndroid Build Coastguard Worker) -> Tensor: ... 334*da0073e9SAndroid Build Coastguard Workerdef nll_loss( 335*da0073e9SAndroid Build Coastguard Worker input: Tensor, 336*da0073e9SAndroid Build Coastguard Worker target: Tensor, 337*da0073e9SAndroid Build Coastguard Worker weight: Optional[Tensor] = ..., 338*da0073e9SAndroid Build Coastguard Worker size_average: Optional[bool] = ..., 339*da0073e9SAndroid Build Coastguard Worker ignore_index: int = ..., 340*da0073e9SAndroid Build Coastguard Worker reduce: Optional[bool] = ..., 341*da0073e9SAndroid Build Coastguard Worker reduction: str = ..., 342*da0073e9SAndroid Build Coastguard Worker) -> Tensor: ... 343*da0073e9SAndroid Build Coastguard Workerdef poisson_nll_loss( 344*da0073e9SAndroid Build Coastguard Worker input: Tensor, 345*da0073e9SAndroid Build Coastguard Worker target: Tensor, 346*da0073e9SAndroid Build Coastguard Worker log_input: bool = ..., 347*da0073e9SAndroid Build Coastguard Worker full: bool = ..., 348*da0073e9SAndroid Build Coastguard Worker size_average: Optional[bool] = ..., 349*da0073e9SAndroid Build Coastguard Worker eps: float = ..., 350*da0073e9SAndroid Build Coastguard Worker reduce: Optional[bool] = ..., 351*da0073e9SAndroid Build Coastguard Worker reduction: str = ..., 352*da0073e9SAndroid Build Coastguard Worker) -> Tensor: ... 353*da0073e9SAndroid Build Coastguard Workerdef gaussian_nll_loss( 354*da0073e9SAndroid Build Coastguard Worker input: Tensor, 355*da0073e9SAndroid Build Coastguard Worker target: Tensor, 356*da0073e9SAndroid Build Coastguard Worker var: Tensor, 357*da0073e9SAndroid Build Coastguard Worker full: Optional[bool] = ..., 358*da0073e9SAndroid Build Coastguard Worker eps: Optional[float] = ..., 359*da0073e9SAndroid Build Coastguard Worker reduction: Optional[str] = ..., 360*da0073e9SAndroid Build Coastguard Worker) -> Tensor: ... 361*da0073e9SAndroid Build Coastguard Workerdef kl_div( 362*da0073e9SAndroid Build Coastguard Worker input: Tensor, 363*da0073e9SAndroid Build Coastguard Worker target: Tensor, 364*da0073e9SAndroid Build Coastguard Worker size_average: Optional[bool] = ..., 365*da0073e9SAndroid Build Coastguard Worker reduce: Optional[bool] = ..., 366*da0073e9SAndroid Build Coastguard Worker reduction: str = ..., 367*da0073e9SAndroid Build Coastguard Worker log_target: bool = ..., 368*da0073e9SAndroid Build Coastguard Worker) -> Tensor: ... 369*da0073e9SAndroid Build Coastguard Workerdef cross_entropy( 370*da0073e9SAndroid Build Coastguard Worker input: Tensor, 371*da0073e9SAndroid Build Coastguard Worker target: Tensor, 372*da0073e9SAndroid Build Coastguard Worker weight: Optional[Tensor] = ..., 373*da0073e9SAndroid Build Coastguard Worker size_average: Optional[bool] = ..., 374*da0073e9SAndroid Build Coastguard Worker ignore_index: int = ..., 375*da0073e9SAndroid Build Coastguard Worker reduce: Optional[bool] = ..., 376*da0073e9SAndroid Build Coastguard Worker reduction: str = ..., 377*da0073e9SAndroid Build Coastguard Worker label_smoothing: float = ..., 378*da0073e9SAndroid Build Coastguard Worker) -> Tensor: ... 379*da0073e9SAndroid Build Coastguard Workerdef binary_cross_entropy( 380*da0073e9SAndroid Build Coastguard Worker input: Tensor, 381*da0073e9SAndroid Build Coastguard Worker target: Tensor, 382*da0073e9SAndroid Build Coastguard Worker weight: Optional[Tensor] = ..., 383*da0073e9SAndroid Build Coastguard Worker size_average: Optional[bool] = ..., 384*da0073e9SAndroid Build Coastguard Worker reduce: Optional[bool] = ..., 385*da0073e9SAndroid Build Coastguard Worker reduction: str = ..., 386*da0073e9SAndroid Build Coastguard Worker) -> Tensor: ... 387*da0073e9SAndroid Build Coastguard Workerdef binary_cross_entropy_with_logits( 388*da0073e9SAndroid Build Coastguard Worker input: Tensor, 389*da0073e9SAndroid Build Coastguard Worker target: Tensor, 390*da0073e9SAndroid Build Coastguard Worker weight: Optional[Tensor] = ..., 391*da0073e9SAndroid Build Coastguard Worker size_average: Optional[bool] = ..., 392*da0073e9SAndroid Build Coastguard Worker reduce: Optional[bool] = ..., 393*da0073e9SAndroid Build Coastguard Worker reduction: str = ..., 394*da0073e9SAndroid Build Coastguard Worker pos_weight: Optional[Tensor] = ..., 395*da0073e9SAndroid Build Coastguard Worker) -> Tensor: ... 396*da0073e9SAndroid Build Coastguard Workerdef smooth_l1_loss( 397*da0073e9SAndroid Build Coastguard Worker input: Tensor, 398*da0073e9SAndroid Build Coastguard Worker target: Tensor, 399*da0073e9SAndroid Build Coastguard Worker size_average: Optional[bool] = ..., 400*da0073e9SAndroid Build Coastguard Worker reduce: Optional[bool] = ..., 401*da0073e9SAndroid Build Coastguard Worker reduction: str = ..., 402*da0073e9SAndroid Build Coastguard Worker beta: float = ..., 403*da0073e9SAndroid Build Coastguard Worker) -> Tensor: ... 404*da0073e9SAndroid Build Coastguard Workerdef huber_loss( 405*da0073e9SAndroid Build Coastguard Worker input: Tensor, 406*da0073e9SAndroid Build Coastguard Worker target: Tensor, 407*da0073e9SAndroid Build Coastguard Worker reduction: str = ..., 408*da0073e9SAndroid Build Coastguard Worker delta: float = ..., 409*da0073e9SAndroid Build Coastguard Worker) -> Tensor: ... 410*da0073e9SAndroid Build Coastguard Workerdef l1_loss( 411*da0073e9SAndroid Build Coastguard Worker input: Tensor, 412*da0073e9SAndroid Build Coastguard Worker target: Tensor, 413*da0073e9SAndroid Build Coastguard Worker size_average: Optional[bool] = ..., 414*da0073e9SAndroid Build Coastguard Worker reduce: Optional[bool] = ..., 415*da0073e9SAndroid Build Coastguard Worker reduction: str = ..., 416*da0073e9SAndroid Build Coastguard Worker) -> Tensor: ... 417*da0073e9SAndroid Build Coastguard Workerdef mse_loss( 418*da0073e9SAndroid Build Coastguard Worker input: Tensor, 419*da0073e9SAndroid Build Coastguard Worker target: Tensor, 420*da0073e9SAndroid Build Coastguard Worker size_average: Optional[bool] = ..., 421*da0073e9SAndroid Build Coastguard Worker reduce: Optional[bool] = ..., 422*da0073e9SAndroid Build Coastguard Worker reduction: str = ..., 423*da0073e9SAndroid Build Coastguard Worker) -> Tensor: ... 424*da0073e9SAndroid Build Coastguard Workerdef margin_ranking_loss( 425*da0073e9SAndroid Build Coastguard Worker input1: Tensor, 426*da0073e9SAndroid Build Coastguard Worker input2: Tensor, 427*da0073e9SAndroid Build Coastguard Worker target: Tensor, 428*da0073e9SAndroid Build Coastguard Worker margin: float = ..., 429*da0073e9SAndroid Build Coastguard Worker size_average: Optional[bool] = ..., 430*da0073e9SAndroid Build Coastguard Worker reduce: Optional[bool] = ..., 431*da0073e9SAndroid Build Coastguard Worker reduction: str = ..., 432*da0073e9SAndroid Build Coastguard Worker) -> Tensor: ... 433*da0073e9SAndroid Build Coastguard Workerdef hinge_embedding_loss( 434*da0073e9SAndroid Build Coastguard Worker input: Tensor, 435*da0073e9SAndroid Build Coastguard Worker target: Tensor, 436*da0073e9SAndroid Build Coastguard Worker margin: float = ..., 437*da0073e9SAndroid Build Coastguard Worker size_average: Optional[bool] = ..., 438*da0073e9SAndroid Build Coastguard Worker reduce: Optional[bool] = ..., 439*da0073e9SAndroid Build Coastguard Worker reduction: str = ..., 440*da0073e9SAndroid Build Coastguard Worker) -> Tensor: ... 441*da0073e9SAndroid Build Coastguard Workerdef multilabel_margin_loss( 442*da0073e9SAndroid Build Coastguard Worker input: Tensor, 443*da0073e9SAndroid Build Coastguard Worker target: Tensor, 444*da0073e9SAndroid Build Coastguard Worker size_average: Optional[bool] = ..., 445*da0073e9SAndroid Build Coastguard Worker reduce: Optional[bool] = ..., 446*da0073e9SAndroid Build Coastguard Worker reduction: str = ..., 447*da0073e9SAndroid Build Coastguard Worker) -> Tensor: ... 448*da0073e9SAndroid Build Coastguard Workerdef soft_margin_loss( 449*da0073e9SAndroid Build Coastguard Worker input: Tensor, 450*da0073e9SAndroid Build Coastguard Worker target: Tensor, 451*da0073e9SAndroid Build Coastguard Worker size_average: Optional[bool] = ..., 452*da0073e9SAndroid Build Coastguard Worker reduce: Optional[bool] = ..., 453*da0073e9SAndroid Build Coastguard Worker reduction: str = ..., 454*da0073e9SAndroid Build Coastguard Worker) -> Tensor: ... 455*da0073e9SAndroid Build Coastguard Workerdef multilabel_soft_margin_loss( 456*da0073e9SAndroid Build Coastguard Worker input: Tensor, 457*da0073e9SAndroid Build Coastguard Worker target: Tensor, 458*da0073e9SAndroid Build Coastguard Worker weight: Optional[Tensor] = ..., 459*da0073e9SAndroid Build Coastguard Worker size_average: Optional[bool] = ..., 460*da0073e9SAndroid Build Coastguard Worker reduce: Optional[bool] = ..., 461*da0073e9SAndroid Build Coastguard Worker reduction: str = ..., 462*da0073e9SAndroid Build Coastguard Worker) -> Tensor: ... 463*da0073e9SAndroid Build Coastguard Workerdef cosine_embedding_loss( 464*da0073e9SAndroid Build Coastguard Worker input1: Tensor, 465*da0073e9SAndroid Build Coastguard Worker input2: Tensor, 466*da0073e9SAndroid Build Coastguard Worker target: Tensor, 467*da0073e9SAndroid Build Coastguard Worker margin: float = ..., 468*da0073e9SAndroid Build Coastguard Worker size_average: Optional[bool] = ..., 469*da0073e9SAndroid Build Coastguard Worker reduce: Optional[bool] = ..., 470*da0073e9SAndroid Build Coastguard Worker reduction: str = ..., 471*da0073e9SAndroid Build Coastguard Worker) -> Tensor: ... 472*da0073e9SAndroid Build Coastguard Workerdef multi_margin_loss( 473*da0073e9SAndroid Build Coastguard Worker input: Tensor, 474*da0073e9SAndroid Build Coastguard Worker target: Tensor, 475*da0073e9SAndroid Build Coastguard Worker p: int = ..., 476*da0073e9SAndroid Build Coastguard Worker margin: float = ..., 477*da0073e9SAndroid Build Coastguard Worker weight: Optional[Tensor] = ..., 478*da0073e9SAndroid Build Coastguard Worker size_average: Optional[bool] = ..., 479*da0073e9SAndroid Build Coastguard Worker reduce: Optional[bool] = ..., 480*da0073e9SAndroid Build Coastguard Worker reduction: str = ..., 481*da0073e9SAndroid Build Coastguard Worker) -> Tensor: ... 482*da0073e9SAndroid Build Coastguard Workerdef upsample( 483*da0073e9SAndroid Build Coastguard Worker input: Any, 484*da0073e9SAndroid Build Coastguard Worker size: Optional[Any] = ..., 485*da0073e9SAndroid Build Coastguard Worker scale_factor: Optional[Any] = ..., 486*da0073e9SAndroid Build Coastguard Worker mode: str = ..., 487*da0073e9SAndroid Build Coastguard Worker align_corners: Optional[Any] = ..., 488*da0073e9SAndroid Build Coastguard Worker): ... 489*da0073e9SAndroid Build Coastguard Workerdef interpolate( 490*da0073e9SAndroid Build Coastguard Worker input: Any, 491*da0073e9SAndroid Build Coastguard Worker size: Optional[Any] = ..., 492*da0073e9SAndroid Build Coastguard Worker scale_factor: Optional[Any] = ..., 493*da0073e9SAndroid Build Coastguard Worker mode: str = ..., 494*da0073e9SAndroid Build Coastguard Worker align_corners: Optional[Any] = ..., 495*da0073e9SAndroid Build Coastguard Worker recompute_scale_factor: Optional[Any] = ..., 496*da0073e9SAndroid Build Coastguard Worker antialias: bool = ..., 497*da0073e9SAndroid Build Coastguard Worker): ... 498*da0073e9SAndroid Build Coastguard Workerdef upsample_nearest( 499*da0073e9SAndroid Build Coastguard Worker input: Any, 500*da0073e9SAndroid Build Coastguard Worker size: Optional[Any] = ..., 501*da0073e9SAndroid Build Coastguard Worker scale_factor: Optional[Any] = ..., 502*da0073e9SAndroid Build Coastguard Worker): ... 503*da0073e9SAndroid Build Coastguard Workerdef upsample_bilinear( 504*da0073e9SAndroid Build Coastguard Worker input: Any, 505*da0073e9SAndroid Build Coastguard Worker size: Optional[Any] = ..., 506*da0073e9SAndroid Build Coastguard Worker scale_factor: Optional[Any] = ..., 507*da0073e9SAndroid Build Coastguard Worker): ... 508*da0073e9SAndroid Build Coastguard Workerdef grid_sample( 509*da0073e9SAndroid Build Coastguard Worker input: Tensor, 510*da0073e9SAndroid Build Coastguard Worker grid: Tensor, 511*da0073e9SAndroid Build Coastguard Worker mode: str = ..., 512*da0073e9SAndroid Build Coastguard Worker padding_mode: str = ..., 513*da0073e9SAndroid Build Coastguard Worker align_corners: Optional[Any] = ..., 514*da0073e9SAndroid Build Coastguard Worker) -> Tensor: ... 515*da0073e9SAndroid Build Coastguard Workerdef affine_grid( 516*da0073e9SAndroid Build Coastguard Worker theta: Tensor, 517*da0073e9SAndroid Build Coastguard Worker size: List[int], 518*da0073e9SAndroid Build Coastguard Worker align_corners: Optional[Any] = ..., 519*da0073e9SAndroid Build Coastguard Worker) -> Tensor: ... 520*da0073e9SAndroid Build Coastguard Workerdef triplet_margin_loss( 521*da0073e9SAndroid Build Coastguard Worker anchor: Tensor, 522*da0073e9SAndroid Build Coastguard Worker positive: Tensor, 523*da0073e9SAndroid Build Coastguard Worker negative: Tensor, 524*da0073e9SAndroid Build Coastguard Worker margin: float = ..., 525*da0073e9SAndroid Build Coastguard Worker p: float = ..., 526*da0073e9SAndroid Build Coastguard Worker eps: float = ..., 527*da0073e9SAndroid Build Coastguard Worker swap: bool = ..., 528*da0073e9SAndroid Build Coastguard Worker size_average: Optional[bool] = ..., 529*da0073e9SAndroid Build Coastguard Worker reduce: Optional[bool] = ..., 530*da0073e9SAndroid Build Coastguard Worker reduction: str = ..., 531*da0073e9SAndroid Build Coastguard Worker) -> Tensor: ... 532*da0073e9SAndroid Build Coastguard Workerdef triplet_margin_with_distance_loss( 533*da0073e9SAndroid Build Coastguard Worker anchor: Tensor, 534*da0073e9SAndroid Build Coastguard Worker positive: Tensor, 535*da0073e9SAndroid Build Coastguard Worker negative: Tensor, 536*da0073e9SAndroid Build Coastguard Worker *, 537*da0073e9SAndroid Build Coastguard Worker distance_function: Optional[Callable[[Tensor, Tensor], Tensor]] = ..., 538*da0073e9SAndroid Build Coastguard Worker margin: float = ..., 539*da0073e9SAndroid Build Coastguard Worker swap: bool = ..., 540*da0073e9SAndroid Build Coastguard Worker reduction: str = ..., 541*da0073e9SAndroid Build Coastguard Worker) -> Tensor: ... 542*da0073e9SAndroid Build Coastguard Workerdef normalize( 543*da0073e9SAndroid Build Coastguard Worker input: Tensor, 544*da0073e9SAndroid Build Coastguard Worker p: float = ..., 545*da0073e9SAndroid Build Coastguard Worker dim: int = ..., 546*da0073e9SAndroid Build Coastguard Worker eps: float = ..., 547*da0073e9SAndroid Build Coastguard Worker out: Optional[Tensor] = ..., 548*da0073e9SAndroid Build Coastguard Worker) -> Tensor: ... 549*da0073e9SAndroid Build Coastguard Workerdef assert_int_or_pair( 550*da0073e9SAndroid Build Coastguard Worker arg: Any, 551*da0073e9SAndroid Build Coastguard Worker arg_name: Any, 552*da0073e9SAndroid Build Coastguard Worker message: Any, 553*da0073e9SAndroid Build Coastguard Worker) -> None: ... 554*da0073e9SAndroid Build Coastguard Workerdef unfold( 555*da0073e9SAndroid Build Coastguard Worker input: Tensor, 556*da0073e9SAndroid Build Coastguard Worker kernel_size: _size_any_t, 557*da0073e9SAndroid Build Coastguard Worker dilation: _size_any_t = ..., 558*da0073e9SAndroid Build Coastguard Worker padding: _size_any_t = ..., 559*da0073e9SAndroid Build Coastguard Worker stride: _size_any_t = ..., 560*da0073e9SAndroid Build Coastguard Worker) -> Tensor: ... 561*da0073e9SAndroid Build Coastguard Workerdef fold( 562*da0073e9SAndroid Build Coastguard Worker input: Tensor, 563*da0073e9SAndroid Build Coastguard Worker output_size: _size_any_t, 564*da0073e9SAndroid Build Coastguard Worker kernel_size: _size_any_t, 565*da0073e9SAndroid Build Coastguard Worker dilation: _size_any_t = ..., 566*da0073e9SAndroid Build Coastguard Worker padding: _size_any_t = ..., 567*da0073e9SAndroid Build Coastguard Worker stride: _size_any_t = ..., 568*da0073e9SAndroid Build Coastguard Worker) -> Tensor: ... 569*da0073e9SAndroid Build Coastguard Workerdef _canonical_mask( 570*da0073e9SAndroid Build Coastguard Worker mask: Optional[Tensor], 571*da0073e9SAndroid Build Coastguard Worker mask_name: str, 572*da0073e9SAndroid Build Coastguard Worker other_type: Optional[_dtype], 573*da0073e9SAndroid Build Coastguard Worker other_name: str, 574*da0073e9SAndroid Build Coastguard Worker target_type: _dtype, 575*da0073e9SAndroid Build Coastguard Worker check_other: bool = True, 576*da0073e9SAndroid Build Coastguard Worker) -> Optional[Tensor]: ... 577*da0073e9SAndroid Build Coastguard Workerdef _none_or_dtype(input: Optional[Tensor]) -> Optional[_dtype]: ... 578*da0073e9SAndroid Build Coastguard Workerdef multi_head_attention_forward( 579*da0073e9SAndroid Build Coastguard Worker query: Tensor, 580*da0073e9SAndroid Build Coastguard Worker key: Tensor, 581*da0073e9SAndroid Build Coastguard Worker value: Tensor, 582*da0073e9SAndroid Build Coastguard Worker embed_dim_to_check: int, 583*da0073e9SAndroid Build Coastguard Worker num_heads: int, 584*da0073e9SAndroid Build Coastguard Worker in_proj_weight: Optional[Tensor], 585*da0073e9SAndroid Build Coastguard Worker in_proj_bias: Optional[Tensor], 586*da0073e9SAndroid Build Coastguard Worker bias_k: Optional[Tensor], 587*da0073e9SAndroid Build Coastguard Worker bias_v: Optional[Tensor], 588*da0073e9SAndroid Build Coastguard Worker add_zero_attn: bool, 589*da0073e9SAndroid Build Coastguard Worker dropout_p: float, 590*da0073e9SAndroid Build Coastguard Worker out_proj_weight: Tensor, 591*da0073e9SAndroid Build Coastguard Worker out_proj_bias: Optional[Tensor], 592*da0073e9SAndroid Build Coastguard Worker training: bool = True, 593*da0073e9SAndroid Build Coastguard Worker key_padding_mask: Optional[Tensor] = None, 594*da0073e9SAndroid Build Coastguard Worker need_weights: bool = True, 595*da0073e9SAndroid Build Coastguard Worker attn_mask: Optional[Tensor] = None, 596*da0073e9SAndroid Build Coastguard Worker use_separate_proj_weight: bool = False, 597*da0073e9SAndroid Build Coastguard Worker q_proj_weight: Optional[Tensor] = None, 598*da0073e9SAndroid Build Coastguard Worker k_proj_weight: Optional[Tensor] = None, 599*da0073e9SAndroid Build Coastguard Worker v_proj_weight: Optional[Tensor] = None, 600*da0073e9SAndroid Build Coastguard Worker static_k: Optional[Tensor] = None, 601*da0073e9SAndroid Build Coastguard Worker static_v: Optional[Tensor] = None, 602*da0073e9SAndroid Build Coastguard Worker average_attn_weights: bool = True, 603*da0073e9SAndroid Build Coastguard Worker is_causal: bool = False, 604*da0073e9SAndroid Build Coastguard Worker) -> Tuple[Tensor, Optional[Tensor]]: ... 605*da0073e9SAndroid Build Coastguard Worker 606*da0073e9SAndroid Build Coastguard Worker${imported_hints} 607*da0073e9SAndroid Build Coastguard Worker 608*da0073e9SAndroid Build Coastguard Worker${dispatched_hints} 609