xref: /aosp_15_r20/external/pytorch/torch/nn/functional.pyi.in (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# ${generated_comment}
2*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-defs
3*da0073e9SAndroid Build Coastguard Worker
4*da0073e9SAndroid Build Coastguard Workerfrom typing import (
5*da0073e9SAndroid Build Coastguard Worker    Any,
6*da0073e9SAndroid Build Coastguard Worker    Callable,
7*da0073e9SAndroid Build Coastguard Worker    Dict,
8*da0073e9SAndroid Build Coastguard Worker    List,
9*da0073e9SAndroid Build Coastguard Worker    Literal,
10*da0073e9SAndroid Build Coastguard Worker    Optional,
11*da0073e9SAndroid Build Coastguard Worker    overload,
12*da0073e9SAndroid Build Coastguard Worker    Sequence,
13*da0073e9SAndroid Build Coastguard Worker    Tuple,
14*da0073e9SAndroid Build Coastguard Worker    Union,
15*da0073e9SAndroid Build Coastguard Worker)
16*da0073e9SAndroid Build Coastguard Worker
17*da0073e9SAndroid Build Coastguard Workerfrom torch import Tensor
18*da0073e9SAndroid Build Coastguard Workerfrom torch.types import _dtype, _int, _size
19*da0073e9SAndroid Build Coastguard Worker
20*da0073e9SAndroid Build Coastguard Workerfrom .common_types import (
21*da0073e9SAndroid Build Coastguard Worker    _ratio_any_t,
22*da0073e9SAndroid Build Coastguard Worker    _size_1_t,
23*da0073e9SAndroid Build Coastguard Worker    _size_2_opt_t,
24*da0073e9SAndroid Build Coastguard Worker    _size_2_t,
25*da0073e9SAndroid Build Coastguard Worker    _size_3_opt_t,
26*da0073e9SAndroid Build Coastguard Worker    _size_3_t,
27*da0073e9SAndroid Build Coastguard Worker    _size_any_t,
28*da0073e9SAndroid Build Coastguard Worker)
29*da0073e9SAndroid Build Coastguard Worker
30*da0073e9SAndroid Build Coastguard Worker# 'TypedDict' is a new accepted type that represents a dictionary with a fixed set of allowed keys.
31*da0073e9SAndroid Build Coastguard Worker# It is standards-track but not in `typing` yet. We leave this hear to be uncommented once the feature
32*da0073e9SAndroid Build Coastguard Worker# is wide-spread.
33*da0073e9SAndroid Build Coastguard Worker
34*da0073e9SAndroid Build Coastguard Worker# from mypy_extensions import TypedDict
35*da0073e9SAndroid Build Coastguard Worker
36*da0073e9SAndroid Build Coastguard Worker# GRID_SAMPLE_INTERPOLATION_MODES = TypedDict('GRID_SAMPLE_INTERPOLATION_MODES', {'bilinear': int, 'nearest': int})
37*da0073e9SAndroid Build Coastguard Worker# GRID_SAMPLE_PADDING_MODES = TypedDict('GRID_SAMPLE_PADDING_MODES', {'zeros': int, 'border': int, 'reflection': int})
38*da0073e9SAndroid Build Coastguard Worker
39*da0073e9SAndroid Build Coastguard WorkerGRID_SAMPLE_INTERPOLATION_MODES = Dict[str, int]
40*da0073e9SAndroid Build Coastguard WorkerGRID_SAMPLE_PADDING_MODES = Dict[str, int]
41*da0073e9SAndroid Build Coastguard Worker
42*da0073e9SAndroid Build Coastguard Worker# These stubs were generated by running stubgen (`stubgen --parse-only functional.py`), followed by manual cleaning.
43*da0073e9SAndroid Build Coastguard Worker#
44*da0073e9SAndroid Build Coastguard Worker# The 'BroadcastingList{1,2,3}' types were replaced by `_size` or _output_ratio, as appropriate.
45*da0073e9SAndroid Build Coastguard Worker# This was necessary since the JIT uses BroadcastingList* types but static checking with mypy etc requires a `Sequence`
46*da0073e9SAndroid Build Coastguard Worker# type. There is no way to express the expected lengths of these lists in the current Python typing system.
47*da0073e9SAndroid Build Coastguard Worker#
48*da0073e9SAndroid Build Coastguard Worker# Functions created via `_add_docstr` in `functional.py` where merely typed as `Any` by `stubgen`, so those were
49*da0073e9SAndroid Build Coastguard Worker# deleted from the stub and replaced by generated declarations. See `gen_pyi` for the implementation of the code
50*da0073e9SAndroid Build Coastguard Worker# generation logic for those functions. In the future, it might be worth looking into using the mypy plugin system
51*da0073e9SAndroid Build Coastguard Worker# to encode the type semantics of `_add_docstr`, should that system ever become widespread.
52*da0073e9SAndroid Build Coastguard Workerdef fractional_max_pool2d_with_indices(
53*da0073e9SAndroid Build Coastguard Worker    input: Tensor,
54*da0073e9SAndroid Build Coastguard Worker    kernel_size: _size,
55*da0073e9SAndroid Build Coastguard Worker    output_size: Optional[_size] = ...,
56*da0073e9SAndroid Build Coastguard Worker    output_ratio: Optional[_ratio_any_t] = ...,
57*da0073e9SAndroid Build Coastguard Worker    return_indices: bool = ...,
58*da0073e9SAndroid Build Coastguard Worker    _random_samples: Optional[Tensor] = ...,
59*da0073e9SAndroid Build Coastguard Worker) -> Tuple[Tensor, Tensor]: ...
60*da0073e9SAndroid Build Coastguard Workerdef fractional_max_pool3d_with_indices(
61*da0073e9SAndroid Build Coastguard Worker    input: Tensor,
62*da0073e9SAndroid Build Coastguard Worker    kernel_size: _size,
63*da0073e9SAndroid Build Coastguard Worker    output_size: Optional[_size] = ...,
64*da0073e9SAndroid Build Coastguard Worker    output_ratio: Optional[_ratio_any_t] = ...,
65*da0073e9SAndroid Build Coastguard Worker    return_indices: bool = ...,
66*da0073e9SAndroid Build Coastguard Worker    _random_samples: Optional[Tensor] = ...,
67*da0073e9SAndroid Build Coastguard Worker) -> Tuple[Tensor, Tensor]: ...
68*da0073e9SAndroid Build Coastguard Workerdef max_pool1d_with_indices(
69*da0073e9SAndroid Build Coastguard Worker    input: Tensor,
70*da0073e9SAndroid Build Coastguard Worker    kernel_size: _size,
71*da0073e9SAndroid Build Coastguard Worker    stride: Optional[_size] = ...,
72*da0073e9SAndroid Build Coastguard Worker    padding: _size = ...,
73*da0073e9SAndroid Build Coastguard Worker    dilation: _size = ...,
74*da0073e9SAndroid Build Coastguard Worker    ceil_mode: bool = ...,
75*da0073e9SAndroid Build Coastguard Worker    return_indices: bool = ...,
76*da0073e9SAndroid Build Coastguard Worker) -> Tuple[Tensor, Tensor]: ...
77*da0073e9SAndroid Build Coastguard Workerdef max_pool2d_with_indices(
78*da0073e9SAndroid Build Coastguard Worker    input: Tensor,
79*da0073e9SAndroid Build Coastguard Worker    kernel_size: _size,
80*da0073e9SAndroid Build Coastguard Worker    stride: Optional[_size] = ...,
81*da0073e9SAndroid Build Coastguard Worker    padding: _size = ...,
82*da0073e9SAndroid Build Coastguard Worker    dilation: _size = ...,
83*da0073e9SAndroid Build Coastguard Worker    ceil_mode: bool = ...,
84*da0073e9SAndroid Build Coastguard Worker    return_indices: bool = ...,
85*da0073e9SAndroid Build Coastguard Worker) -> Tuple[Tensor, Tensor]: ...
86*da0073e9SAndroid Build Coastguard Workerdef max_pool3d_with_indices(
87*da0073e9SAndroid Build Coastguard Worker    input: Tensor,
88*da0073e9SAndroid Build Coastguard Worker    kernel_size: _size,
89*da0073e9SAndroid Build Coastguard Worker    stride: Optional[_size] = ...,
90*da0073e9SAndroid Build Coastguard Worker    padding: _size = ...,
91*da0073e9SAndroid Build Coastguard Worker    dilation: _size = ...,
92*da0073e9SAndroid Build Coastguard Worker    ceil_mode: bool = ...,
93*da0073e9SAndroid Build Coastguard Worker    return_indices: bool = ...,
94*da0073e9SAndroid Build Coastguard Worker) -> Tuple[Tensor, Tensor]: ...
95*da0073e9SAndroid Build Coastguard Workerdef max_unpool1d(
96*da0073e9SAndroid Build Coastguard Worker    input: Tensor,
97*da0073e9SAndroid Build Coastguard Worker    indices: Tensor,
98*da0073e9SAndroid Build Coastguard Worker    kernel_size: _size,
99*da0073e9SAndroid Build Coastguard Worker    stride: Optional[_size] = ...,
100*da0073e9SAndroid Build Coastguard Worker    padding: _size = ...,
101*da0073e9SAndroid Build Coastguard Worker    output_size: Optional[_size] = ...,
102*da0073e9SAndroid Build Coastguard Worker) -> Tensor: ...
103*da0073e9SAndroid Build Coastguard Workerdef max_unpool2d(
104*da0073e9SAndroid Build Coastguard Worker    input: Tensor,
105*da0073e9SAndroid Build Coastguard Worker    indices: Tensor,
106*da0073e9SAndroid Build Coastguard Worker    kernel_size: _size,
107*da0073e9SAndroid Build Coastguard Worker    stride: Optional[_size] = ...,
108*da0073e9SAndroid Build Coastguard Worker    padding: _size = ...,
109*da0073e9SAndroid Build Coastguard Worker    output_size: Optional[_size] = ...,
110*da0073e9SAndroid Build Coastguard Worker) -> Tensor: ...
111*da0073e9SAndroid Build Coastguard Workerdef max_unpool3d(
112*da0073e9SAndroid Build Coastguard Worker    input: Tensor,
113*da0073e9SAndroid Build Coastguard Worker    indices: Tensor,
114*da0073e9SAndroid Build Coastguard Worker    kernel_size: _size,
115*da0073e9SAndroid Build Coastguard Worker    stride: Optional[_size] = ...,
116*da0073e9SAndroid Build Coastguard Worker    padding: _size = ...,
117*da0073e9SAndroid Build Coastguard Worker    output_size: Optional[_size] = ...,
118*da0073e9SAndroid Build Coastguard Worker) -> Tensor: ...
119*da0073e9SAndroid Build Coastguard Workerdef lp_pool1d(
120*da0073e9SAndroid Build Coastguard Worker    input: Tensor,
121*da0073e9SAndroid Build Coastguard Worker    norm_type: float,
122*da0073e9SAndroid Build Coastguard Worker    kernel_size: _size_1_t,
123*da0073e9SAndroid Build Coastguard Worker    stride: Union[Optional[_size], Optional[int]] = ...,
124*da0073e9SAndroid Build Coastguard Worker    ceil_mode: bool = ...,
125*da0073e9SAndroid Build Coastguard Worker) -> Tensor: ...
126*da0073e9SAndroid Build Coastguard Workerdef lp_pool2d(
127*da0073e9SAndroid Build Coastguard Worker    input: Tensor,
128*da0073e9SAndroid Build Coastguard Worker    norm_type: float,
129*da0073e9SAndroid Build Coastguard Worker    kernel_size: _size_2_t,
130*da0073e9SAndroid Build Coastguard Worker    stride: Union[Optional[_size], Optional[int]] = ...,
131*da0073e9SAndroid Build Coastguard Worker    ceil_mode: bool = ...,
132*da0073e9SAndroid Build Coastguard Worker) -> Tensor: ...
133*da0073e9SAndroid Build Coastguard Workerdef lp_pool3d(
134*da0073e9SAndroid Build Coastguard Worker    input: Tensor,
135*da0073e9SAndroid Build Coastguard Worker    norm_type: float,
136*da0073e9SAndroid Build Coastguard Worker    kernel_size: _size_3_t,
137*da0073e9SAndroid Build Coastguard Worker    stride: Union[Optional[_size], Optional[int]] = ...,
138*da0073e9SAndroid Build Coastguard Worker    ceil_mode: bool = ...,
139*da0073e9SAndroid Build Coastguard Worker) -> Tensor: ...
140*da0073e9SAndroid Build Coastguard Workerdef adaptive_max_pool1d_with_indices(
141*da0073e9SAndroid Build Coastguard Worker    input: Tensor,
142*da0073e9SAndroid Build Coastguard Worker    output_size: _size,
143*da0073e9SAndroid Build Coastguard Worker    return_indices: bool = ...,
144*da0073e9SAndroid Build Coastguard Worker) -> Tuple[Tensor, Tensor]: ...
145*da0073e9SAndroid Build Coastguard Workerdef adaptive_max_pool2d_with_indices(
146*da0073e9SAndroid Build Coastguard Worker    input: Tensor,
147*da0073e9SAndroid Build Coastguard Worker    output_size: _size_2_opt_t,
148*da0073e9SAndroid Build Coastguard Worker    return_indices: bool = ...,
149*da0073e9SAndroid Build Coastguard Worker) -> Tuple[Tensor, Tensor]: ...
150*da0073e9SAndroid Build Coastguard Workerdef adaptive_max_pool3d_with_indices(
151*da0073e9SAndroid Build Coastguard Worker    input: Tensor,
152*da0073e9SAndroid Build Coastguard Worker    output_size: _size_3_opt_t,
153*da0073e9SAndroid Build Coastguard Worker    return_indices: bool = ...,
154*da0073e9SAndroid Build Coastguard Worker) -> Tuple[Tensor, Tensor]: ...
155*da0073e9SAndroid Build Coastguard Workerdef adaptive_avg_pool2d(input: Tensor, output_size: _size_2_opt_t) -> Tensor: ...
156*da0073e9SAndroid Build Coastguard Workerdef adaptive_avg_pool3d(input: Tensor, output_size: _size_3_opt_t) -> Tensor: ...
157*da0073e9SAndroid Build Coastguard Workerdef dropout(
158*da0073e9SAndroid Build Coastguard Worker    input: Tensor,
159*da0073e9SAndroid Build Coastguard Worker    p: float = ...,
160*da0073e9SAndroid Build Coastguard Worker    training: bool = ...,
161*da0073e9SAndroid Build Coastguard Worker    inplace: bool = ...,
162*da0073e9SAndroid Build Coastguard Worker) -> Tensor: ...
163*da0073e9SAndroid Build Coastguard Workerdef alpha_dropout(
164*da0073e9SAndroid Build Coastguard Worker    input: Tensor,
165*da0073e9SAndroid Build Coastguard Worker    p: float = ...,
166*da0073e9SAndroid Build Coastguard Worker    training: bool = ...,
167*da0073e9SAndroid Build Coastguard Worker    inplace: bool = ...,
168*da0073e9SAndroid Build Coastguard Worker) -> Tensor: ...
169*da0073e9SAndroid Build Coastguard Workerdef dropout1d(
170*da0073e9SAndroid Build Coastguard Worker    input: Tensor,
171*da0073e9SAndroid Build Coastguard Worker    p: float = ...,
172*da0073e9SAndroid Build Coastguard Worker    training: bool = ...,
173*da0073e9SAndroid Build Coastguard Worker    inplace: bool = ...,
174*da0073e9SAndroid Build Coastguard Worker) -> Tensor: ...
175*da0073e9SAndroid Build Coastguard Workerdef dropout2d(
176*da0073e9SAndroid Build Coastguard Worker    input: Tensor,
177*da0073e9SAndroid Build Coastguard Worker    p: float = ...,
178*da0073e9SAndroid Build Coastguard Worker    training: bool = ...,
179*da0073e9SAndroid Build Coastguard Worker    inplace: bool = ...,
180*da0073e9SAndroid Build Coastguard Worker) -> Tensor: ...
181*da0073e9SAndroid Build Coastguard Workerdef dropout3d(
182*da0073e9SAndroid Build Coastguard Worker    input: Tensor,
183*da0073e9SAndroid Build Coastguard Worker    p: float = ...,
184*da0073e9SAndroid Build Coastguard Worker    training: bool = ...,
185*da0073e9SAndroid Build Coastguard Worker    inplace: bool = ...,
186*da0073e9SAndroid Build Coastguard Worker) -> Tensor: ...
187*da0073e9SAndroid Build Coastguard Workerdef feature_alpha_dropout(
188*da0073e9SAndroid Build Coastguard Worker    input: Tensor,
189*da0073e9SAndroid Build Coastguard Worker    p: float = ...,
190*da0073e9SAndroid Build Coastguard Worker    training: bool = ...,
191*da0073e9SAndroid Build Coastguard Worker    inplace: bool = ...,
192*da0073e9SAndroid Build Coastguard Worker) -> Tensor: ...
193*da0073e9SAndroid Build Coastguard Workerdef threshold(
194*da0073e9SAndroid Build Coastguard Worker    input: Tensor,
195*da0073e9SAndroid Build Coastguard Worker    threshold: float,
196*da0073e9SAndroid Build Coastguard Worker    value: float,
197*da0073e9SAndroid Build Coastguard Worker    inplace: bool = ...,
198*da0073e9SAndroid Build Coastguard Worker) -> Tensor: ...
199*da0073e9SAndroid Build Coastguard Workerdef relu(input: Tensor, inplace: bool = ...) -> Tensor: ...
200*da0073e9SAndroid Build Coastguard Workerdef glu(input: Tensor, dim: int = ...) -> Tensor: ...
201*da0073e9SAndroid Build Coastguard Workerdef hardtanh(
202*da0073e9SAndroid Build Coastguard Worker    input: Tensor,
203*da0073e9SAndroid Build Coastguard Worker    min_val: float = ...,
204*da0073e9SAndroid Build Coastguard Worker    max_val: float = ...,
205*da0073e9SAndroid Build Coastguard Worker    inplace: bool = ...,
206*da0073e9SAndroid Build Coastguard Worker) -> Tensor: ...
207*da0073e9SAndroid Build Coastguard Workerdef relu6(input: Tensor, inplace: bool = ...) -> Tensor: ...
208*da0073e9SAndroid Build Coastguard Workerdef elu(input: Tensor, alpha: float = ..., inplace: bool = ...) -> Tensor: ...
209*da0073e9SAndroid Build Coastguard Workerdef selu(input: Tensor, inplace: bool = ...) -> Tensor: ...
210*da0073e9SAndroid Build Coastguard Workerdef celu(input: Tensor, alpha: float = ..., inplace: bool = ...) -> Tensor: ...
211*da0073e9SAndroid Build Coastguard Workerdef leaky_relu(
212*da0073e9SAndroid Build Coastguard Worker    input: Tensor,
213*da0073e9SAndroid Build Coastguard Worker    negative_slope: float = ...,
214*da0073e9SAndroid Build Coastguard Worker    inplace: bool = ...,
215*da0073e9SAndroid Build Coastguard Worker) -> Tensor: ...
216*da0073e9SAndroid Build Coastguard Workerdef rrelu(
217*da0073e9SAndroid Build Coastguard Worker    input: Tensor,
218*da0073e9SAndroid Build Coastguard Worker    lower: float = ...,
219*da0073e9SAndroid Build Coastguard Worker    upper: float = ...,
220*da0073e9SAndroid Build Coastguard Worker    training: bool = ...,
221*da0073e9SAndroid Build Coastguard Worker    inplace: bool = ...,
222*da0073e9SAndroid Build Coastguard Worker) -> Tensor: ...
223*da0073e9SAndroid Build Coastguard Workerdef tanhshrink(input: Any): ...
224*da0073e9SAndroid Build Coastguard Workerdef softsign(input: Any): ...
225*da0073e9SAndroid Build Coastguard Workerdef softmin(
226*da0073e9SAndroid Build Coastguard Worker    input: Tensor,
227*da0073e9SAndroid Build Coastguard Worker    dim: Optional[int] = ...,
228*da0073e9SAndroid Build Coastguard Worker    _stacklevel: int = ...,
229*da0073e9SAndroid Build Coastguard Worker    dtype: Optional[_dtype] = ...,
230*da0073e9SAndroid Build Coastguard Worker) -> Tensor: ...
231*da0073e9SAndroid Build Coastguard Workerdef softmax(
232*da0073e9SAndroid Build Coastguard Worker    input: Tensor,
233*da0073e9SAndroid Build Coastguard Worker    dim: Optional[int] = ...,
234*da0073e9SAndroid Build Coastguard Worker    _stacklevel: int = ...,
235*da0073e9SAndroid Build Coastguard Worker    dtype: Optional[_dtype] = ...,
236*da0073e9SAndroid Build Coastguard Worker) -> Tensor: ...
237*da0073e9SAndroid Build Coastguard Workerdef gumbel_softmax(
238*da0073e9SAndroid Build Coastguard Worker    logits: Tensor,
239*da0073e9SAndroid Build Coastguard Worker    tau: float = ...,
240*da0073e9SAndroid Build Coastguard Worker    hard: bool = ...,
241*da0073e9SAndroid Build Coastguard Worker    eps: float = ...,
242*da0073e9SAndroid Build Coastguard Worker    dim: int = ...,
243*da0073e9SAndroid Build Coastguard Worker) -> Tensor: ...
244*da0073e9SAndroid Build Coastguard Workerdef log_softmax(
245*da0073e9SAndroid Build Coastguard Worker    input: Tensor,
246*da0073e9SAndroid Build Coastguard Worker    dim: Optional[int] = ...,
247*da0073e9SAndroid Build Coastguard Worker    _stacklevel: int = ...,
248*da0073e9SAndroid Build Coastguard Worker    dtype: Optional[_dtype] = ...,
249*da0073e9SAndroid Build Coastguard Worker) -> Tensor: ...
250*da0073e9SAndroid Build Coastguard Workerdef tanh(input: Any): ...
251*da0073e9SAndroid Build Coastguard Workerdef sigmoid(input: Any) -> Tensor: ...
252*da0073e9SAndroid Build Coastguard Workerdef hardsigmoid(input: Tensor, inplace: bool = False) -> Tensor: ...
253*da0073e9SAndroid Build Coastguard Workerdef silu(input: Tensor, inplace: bool = False) -> Tensor: ...
254*da0073e9SAndroid Build Coastguard Workerdef mish(input: Tensor, inplace: bool = False) -> Tensor: ...
255*da0073e9SAndroid Build Coastguard Workerdef hardswish(input: Tensor, inplace: bool = False) -> Tensor: ...
256*da0073e9SAndroid Build Coastguard Workerdef embedding(
257*da0073e9SAndroid Build Coastguard Worker    input: Tensor,
258*da0073e9SAndroid Build Coastguard Worker    weight: Tensor,
259*da0073e9SAndroid Build Coastguard Worker    padding_idx: Optional[int] = ...,
260*da0073e9SAndroid Build Coastguard Worker    max_norm: Optional[float] = ...,
261*da0073e9SAndroid Build Coastguard Worker    norm_type: float = ...,
262*da0073e9SAndroid Build Coastguard Worker    scale_grad_by_freq: bool = ...,
263*da0073e9SAndroid Build Coastguard Worker    sparse: bool = ...,
264*da0073e9SAndroid Build Coastguard Worker) -> Tensor: ...
265*da0073e9SAndroid Build Coastguard Workerdef embedding_bag(
266*da0073e9SAndroid Build Coastguard Worker    input: Tensor,
267*da0073e9SAndroid Build Coastguard Worker    weight: Tensor,
268*da0073e9SAndroid Build Coastguard Worker    offsets: Optional[Tensor] = ...,
269*da0073e9SAndroid Build Coastguard Worker    max_norm: Optional[float] = ...,
270*da0073e9SAndroid Build Coastguard Worker    norm_type: float = ...,
271*da0073e9SAndroid Build Coastguard Worker    scale_grad_by_freq: bool = ...,
272*da0073e9SAndroid Build Coastguard Worker    mode: str = ...,
273*da0073e9SAndroid Build Coastguard Worker    sparse: bool = ...,
274*da0073e9SAndroid Build Coastguard Worker    per_sample_weights: Optional[Tensor] = ...,
275*da0073e9SAndroid Build Coastguard Worker    include_last_offset: bool = ...,
276*da0073e9SAndroid Build Coastguard Worker    padding_idx: Optional[int] = ...,
277*da0073e9SAndroid Build Coastguard Worker) -> Tensor: ...
278*da0073e9SAndroid Build Coastguard Workerdef batch_norm(
279*da0073e9SAndroid Build Coastguard Worker    input: Tensor,
280*da0073e9SAndroid Build Coastguard Worker    running_mean: Optional[Tensor],
281*da0073e9SAndroid Build Coastguard Worker    running_var: Optional[Tensor],
282*da0073e9SAndroid Build Coastguard Worker    weight: Optional[Tensor] = ...,
283*da0073e9SAndroid Build Coastguard Worker    bias: Optional[Tensor] = ...,
284*da0073e9SAndroid Build Coastguard Worker    training: bool = ...,
285*da0073e9SAndroid Build Coastguard Worker    momentum: float = ...,
286*da0073e9SAndroid Build Coastguard Worker    eps: float = ...,
287*da0073e9SAndroid Build Coastguard Worker) -> Tensor: ...
288*da0073e9SAndroid Build Coastguard Workerdef instance_norm(
289*da0073e9SAndroid Build Coastguard Worker    input: Tensor,
290*da0073e9SAndroid Build Coastguard Worker    running_mean: Optional[Tensor] = ...,
291*da0073e9SAndroid Build Coastguard Worker    running_var: Optional[Tensor] = ...,
292*da0073e9SAndroid Build Coastguard Worker    weight: Optional[Tensor] = ...,
293*da0073e9SAndroid Build Coastguard Worker    bias: Optional[Tensor] = ...,
294*da0073e9SAndroid Build Coastguard Worker    use_input_stats: bool = ...,
295*da0073e9SAndroid Build Coastguard Worker    momentum: float = ...,
296*da0073e9SAndroid Build Coastguard Worker    eps: float = ...,
297*da0073e9SAndroid Build Coastguard Worker) -> Tensor: ...
298*da0073e9SAndroid Build Coastguard Workerdef layer_norm(
299*da0073e9SAndroid Build Coastguard Worker    input: Tensor,
300*da0073e9SAndroid Build Coastguard Worker    normalized_shape: Sequence[int],
301*da0073e9SAndroid Build Coastguard Worker    weight: Optional[Tensor] = ...,
302*da0073e9SAndroid Build Coastguard Worker    bias: Optional[Tensor] = ...,
303*da0073e9SAndroid Build Coastguard Worker    eps: float = ...,
304*da0073e9SAndroid Build Coastguard Worker) -> Tensor: ...
305*da0073e9SAndroid Build Coastguard Workerdef rms_norm(
306*da0073e9SAndroid Build Coastguard Worker    input: Tensor,
307*da0073e9SAndroid Build Coastguard Worker    normalized_shape: Sequence[int],
308*da0073e9SAndroid Build Coastguard Worker    weight: Optional[Tensor] = ...,
309*da0073e9SAndroid Build Coastguard Worker    eps: Optional[float] = ...,
310*da0073e9SAndroid Build Coastguard Worker) -> Tensor: ...
311*da0073e9SAndroid Build Coastguard Workerdef group_norm(
312*da0073e9SAndroid Build Coastguard Worker    input: Tensor,
313*da0073e9SAndroid Build Coastguard Worker    num_groups: int,
314*da0073e9SAndroid Build Coastguard Worker    weight: Optional[Tensor] = ...,
315*da0073e9SAndroid Build Coastguard Worker    bias: Optional[Tensor] = ...,
316*da0073e9SAndroid Build Coastguard Worker    eps: float = ...,
317*da0073e9SAndroid Build Coastguard Worker) -> Tensor: ...
318*da0073e9SAndroid Build Coastguard Workerdef local_response_norm(
319*da0073e9SAndroid Build Coastguard Worker    input: Tensor,
320*da0073e9SAndroid Build Coastguard Worker    size: int,
321*da0073e9SAndroid Build Coastguard Worker    alpha: float = ...,
322*da0073e9SAndroid Build Coastguard Worker    beta: float = ...,
323*da0073e9SAndroid Build Coastguard Worker    k: float = ...,
324*da0073e9SAndroid Build Coastguard Worker) -> Tensor: ...
325*da0073e9SAndroid Build Coastguard Workerdef ctc_loss(
326*da0073e9SAndroid Build Coastguard Worker    log_probs: Tensor,
327*da0073e9SAndroid Build Coastguard Worker    targets: Tensor,
328*da0073e9SAndroid Build Coastguard Worker    input_lengths: Tensor,
329*da0073e9SAndroid Build Coastguard Worker    target_lengths: Tensor,
330*da0073e9SAndroid Build Coastguard Worker    blank: int = ...,
331*da0073e9SAndroid Build Coastguard Worker    reduction: str = ...,
332*da0073e9SAndroid Build Coastguard Worker    zero_infinity: bool = ...,
333*da0073e9SAndroid Build Coastguard Worker) -> Tensor: ...
334*da0073e9SAndroid Build Coastguard Workerdef nll_loss(
335*da0073e9SAndroid Build Coastguard Worker    input: Tensor,
336*da0073e9SAndroid Build Coastguard Worker    target: Tensor,
337*da0073e9SAndroid Build Coastguard Worker    weight: Optional[Tensor] = ...,
338*da0073e9SAndroid Build Coastguard Worker    size_average: Optional[bool] = ...,
339*da0073e9SAndroid Build Coastguard Worker    ignore_index: int = ...,
340*da0073e9SAndroid Build Coastguard Worker    reduce: Optional[bool] = ...,
341*da0073e9SAndroid Build Coastguard Worker    reduction: str = ...,
342*da0073e9SAndroid Build Coastguard Worker) -> Tensor: ...
343*da0073e9SAndroid Build Coastguard Workerdef poisson_nll_loss(
344*da0073e9SAndroid Build Coastguard Worker    input: Tensor,
345*da0073e9SAndroid Build Coastguard Worker    target: Tensor,
346*da0073e9SAndroid Build Coastguard Worker    log_input: bool = ...,
347*da0073e9SAndroid Build Coastguard Worker    full: bool = ...,
348*da0073e9SAndroid Build Coastguard Worker    size_average: Optional[bool] = ...,
349*da0073e9SAndroid Build Coastguard Worker    eps: float = ...,
350*da0073e9SAndroid Build Coastguard Worker    reduce: Optional[bool] = ...,
351*da0073e9SAndroid Build Coastguard Worker    reduction: str = ...,
352*da0073e9SAndroid Build Coastguard Worker) -> Tensor: ...
353*da0073e9SAndroid Build Coastguard Workerdef gaussian_nll_loss(
354*da0073e9SAndroid Build Coastguard Worker    input: Tensor,
355*da0073e9SAndroid Build Coastguard Worker    target: Tensor,
356*da0073e9SAndroid Build Coastguard Worker    var: Tensor,
357*da0073e9SAndroid Build Coastguard Worker    full: Optional[bool] = ...,
358*da0073e9SAndroid Build Coastguard Worker    eps: Optional[float] = ...,
359*da0073e9SAndroid Build Coastguard Worker    reduction: Optional[str] = ...,
360*da0073e9SAndroid Build Coastguard Worker) -> Tensor: ...
361*da0073e9SAndroid Build Coastguard Workerdef kl_div(
362*da0073e9SAndroid Build Coastguard Worker    input: Tensor,
363*da0073e9SAndroid Build Coastguard Worker    target: Tensor,
364*da0073e9SAndroid Build Coastguard Worker    size_average: Optional[bool] = ...,
365*da0073e9SAndroid Build Coastguard Worker    reduce: Optional[bool] = ...,
366*da0073e9SAndroid Build Coastguard Worker    reduction: str = ...,
367*da0073e9SAndroid Build Coastguard Worker    log_target: bool = ...,
368*da0073e9SAndroid Build Coastguard Worker) -> Tensor: ...
369*da0073e9SAndroid Build Coastguard Workerdef cross_entropy(
370*da0073e9SAndroid Build Coastguard Worker    input: Tensor,
371*da0073e9SAndroid Build Coastguard Worker    target: Tensor,
372*da0073e9SAndroid Build Coastguard Worker    weight: Optional[Tensor] = ...,
373*da0073e9SAndroid Build Coastguard Worker    size_average: Optional[bool] = ...,
374*da0073e9SAndroid Build Coastguard Worker    ignore_index: int = ...,
375*da0073e9SAndroid Build Coastguard Worker    reduce: Optional[bool] = ...,
376*da0073e9SAndroid Build Coastguard Worker    reduction: str = ...,
377*da0073e9SAndroid Build Coastguard Worker    label_smoothing: float = ...,
378*da0073e9SAndroid Build Coastguard Worker) -> Tensor: ...
379*da0073e9SAndroid Build Coastguard Workerdef binary_cross_entropy(
380*da0073e9SAndroid Build Coastguard Worker    input: Tensor,
381*da0073e9SAndroid Build Coastguard Worker    target: Tensor,
382*da0073e9SAndroid Build Coastguard Worker    weight: Optional[Tensor] = ...,
383*da0073e9SAndroid Build Coastguard Worker    size_average: Optional[bool] = ...,
384*da0073e9SAndroid Build Coastguard Worker    reduce: Optional[bool] = ...,
385*da0073e9SAndroid Build Coastguard Worker    reduction: str = ...,
386*da0073e9SAndroid Build Coastguard Worker) -> Tensor: ...
387*da0073e9SAndroid Build Coastguard Workerdef binary_cross_entropy_with_logits(
388*da0073e9SAndroid Build Coastguard Worker    input: Tensor,
389*da0073e9SAndroid Build Coastguard Worker    target: Tensor,
390*da0073e9SAndroid Build Coastguard Worker    weight: Optional[Tensor] = ...,
391*da0073e9SAndroid Build Coastguard Worker    size_average: Optional[bool] = ...,
392*da0073e9SAndroid Build Coastguard Worker    reduce: Optional[bool] = ...,
393*da0073e9SAndroid Build Coastguard Worker    reduction: str = ...,
394*da0073e9SAndroid Build Coastguard Worker    pos_weight: Optional[Tensor] = ...,
395*da0073e9SAndroid Build Coastguard Worker) -> Tensor: ...
396*da0073e9SAndroid Build Coastguard Workerdef smooth_l1_loss(
397*da0073e9SAndroid Build Coastguard Worker    input: Tensor,
398*da0073e9SAndroid Build Coastguard Worker    target: Tensor,
399*da0073e9SAndroid Build Coastguard Worker    size_average: Optional[bool] = ...,
400*da0073e9SAndroid Build Coastguard Worker    reduce: Optional[bool] = ...,
401*da0073e9SAndroid Build Coastguard Worker    reduction: str = ...,
402*da0073e9SAndroid Build Coastguard Worker    beta: float = ...,
403*da0073e9SAndroid Build Coastguard Worker) -> Tensor: ...
404*da0073e9SAndroid Build Coastguard Workerdef huber_loss(
405*da0073e9SAndroid Build Coastguard Worker    input: Tensor,
406*da0073e9SAndroid Build Coastguard Worker    target: Tensor,
407*da0073e9SAndroid Build Coastguard Worker    reduction: str = ...,
408*da0073e9SAndroid Build Coastguard Worker    delta: float = ...,
409*da0073e9SAndroid Build Coastguard Worker) -> Tensor: ...
410*da0073e9SAndroid Build Coastguard Workerdef l1_loss(
411*da0073e9SAndroid Build Coastguard Worker    input: Tensor,
412*da0073e9SAndroid Build Coastguard Worker    target: Tensor,
413*da0073e9SAndroid Build Coastguard Worker    size_average: Optional[bool] = ...,
414*da0073e9SAndroid Build Coastguard Worker    reduce: Optional[bool] = ...,
415*da0073e9SAndroid Build Coastguard Worker    reduction: str = ...,
416*da0073e9SAndroid Build Coastguard Worker) -> Tensor: ...
417*da0073e9SAndroid Build Coastguard Workerdef mse_loss(
418*da0073e9SAndroid Build Coastguard Worker    input: Tensor,
419*da0073e9SAndroid Build Coastguard Worker    target: Tensor,
420*da0073e9SAndroid Build Coastguard Worker    size_average: Optional[bool] = ...,
421*da0073e9SAndroid Build Coastguard Worker    reduce: Optional[bool] = ...,
422*da0073e9SAndroid Build Coastguard Worker    reduction: str = ...,
423*da0073e9SAndroid Build Coastguard Worker) -> Tensor: ...
424*da0073e9SAndroid Build Coastguard Workerdef margin_ranking_loss(
425*da0073e9SAndroid Build Coastguard Worker    input1: Tensor,
426*da0073e9SAndroid Build Coastguard Worker    input2: Tensor,
427*da0073e9SAndroid Build Coastguard Worker    target: Tensor,
428*da0073e9SAndroid Build Coastguard Worker    margin: float = ...,
429*da0073e9SAndroid Build Coastguard Worker    size_average: Optional[bool] = ...,
430*da0073e9SAndroid Build Coastguard Worker    reduce: Optional[bool] = ...,
431*da0073e9SAndroid Build Coastguard Worker    reduction: str = ...,
432*da0073e9SAndroid Build Coastguard Worker) -> Tensor: ...
433*da0073e9SAndroid Build Coastguard Workerdef hinge_embedding_loss(
434*da0073e9SAndroid Build Coastguard Worker    input: Tensor,
435*da0073e9SAndroid Build Coastguard Worker    target: Tensor,
436*da0073e9SAndroid Build Coastguard Worker    margin: float = ...,
437*da0073e9SAndroid Build Coastguard Worker    size_average: Optional[bool] = ...,
438*da0073e9SAndroid Build Coastguard Worker    reduce: Optional[bool] = ...,
439*da0073e9SAndroid Build Coastguard Worker    reduction: str = ...,
440*da0073e9SAndroid Build Coastguard Worker) -> Tensor: ...
441*da0073e9SAndroid Build Coastguard Workerdef multilabel_margin_loss(
442*da0073e9SAndroid Build Coastguard Worker    input: Tensor,
443*da0073e9SAndroid Build Coastguard Worker    target: Tensor,
444*da0073e9SAndroid Build Coastguard Worker    size_average: Optional[bool] = ...,
445*da0073e9SAndroid Build Coastguard Worker    reduce: Optional[bool] = ...,
446*da0073e9SAndroid Build Coastguard Worker    reduction: str = ...,
447*da0073e9SAndroid Build Coastguard Worker) -> Tensor: ...
448*da0073e9SAndroid Build Coastguard Workerdef soft_margin_loss(
449*da0073e9SAndroid Build Coastguard Worker    input: Tensor,
450*da0073e9SAndroid Build Coastguard Worker    target: Tensor,
451*da0073e9SAndroid Build Coastguard Worker    size_average: Optional[bool] = ...,
452*da0073e9SAndroid Build Coastguard Worker    reduce: Optional[bool] = ...,
453*da0073e9SAndroid Build Coastguard Worker    reduction: str = ...,
454*da0073e9SAndroid Build Coastguard Worker) -> Tensor: ...
455*da0073e9SAndroid Build Coastguard Workerdef multilabel_soft_margin_loss(
456*da0073e9SAndroid Build Coastguard Worker    input: Tensor,
457*da0073e9SAndroid Build Coastguard Worker    target: Tensor,
458*da0073e9SAndroid Build Coastguard Worker    weight: Optional[Tensor] = ...,
459*da0073e9SAndroid Build Coastguard Worker    size_average: Optional[bool] = ...,
460*da0073e9SAndroid Build Coastguard Worker    reduce: Optional[bool] = ...,
461*da0073e9SAndroid Build Coastguard Worker    reduction: str = ...,
462*da0073e9SAndroid Build Coastguard Worker) -> Tensor: ...
463*da0073e9SAndroid Build Coastguard Workerdef cosine_embedding_loss(
464*da0073e9SAndroid Build Coastguard Worker    input1: Tensor,
465*da0073e9SAndroid Build Coastguard Worker    input2: Tensor,
466*da0073e9SAndroid Build Coastguard Worker    target: Tensor,
467*da0073e9SAndroid Build Coastguard Worker    margin: float = ...,
468*da0073e9SAndroid Build Coastguard Worker    size_average: Optional[bool] = ...,
469*da0073e9SAndroid Build Coastguard Worker    reduce: Optional[bool] = ...,
470*da0073e9SAndroid Build Coastguard Worker    reduction: str = ...,
471*da0073e9SAndroid Build Coastguard Worker) -> Tensor: ...
472*da0073e9SAndroid Build Coastguard Workerdef multi_margin_loss(
473*da0073e9SAndroid Build Coastguard Worker    input: Tensor,
474*da0073e9SAndroid Build Coastguard Worker    target: Tensor,
475*da0073e9SAndroid Build Coastguard Worker    p: int = ...,
476*da0073e9SAndroid Build Coastguard Worker    margin: float = ...,
477*da0073e9SAndroid Build Coastguard Worker    weight: Optional[Tensor] = ...,
478*da0073e9SAndroid Build Coastguard Worker    size_average: Optional[bool] = ...,
479*da0073e9SAndroid Build Coastguard Worker    reduce: Optional[bool] = ...,
480*da0073e9SAndroid Build Coastguard Worker    reduction: str = ...,
481*da0073e9SAndroid Build Coastguard Worker) -> Tensor: ...
482*da0073e9SAndroid Build Coastguard Workerdef upsample(
483*da0073e9SAndroid Build Coastguard Worker    input: Any,
484*da0073e9SAndroid Build Coastguard Worker    size: Optional[Any] = ...,
485*da0073e9SAndroid Build Coastguard Worker    scale_factor: Optional[Any] = ...,
486*da0073e9SAndroid Build Coastguard Worker    mode: str = ...,
487*da0073e9SAndroid Build Coastguard Worker    align_corners: Optional[Any] = ...,
488*da0073e9SAndroid Build Coastguard Worker): ...
489*da0073e9SAndroid Build Coastguard Workerdef interpolate(
490*da0073e9SAndroid Build Coastguard Worker    input: Any,
491*da0073e9SAndroid Build Coastguard Worker    size: Optional[Any] = ...,
492*da0073e9SAndroid Build Coastguard Worker    scale_factor: Optional[Any] = ...,
493*da0073e9SAndroid Build Coastguard Worker    mode: str = ...,
494*da0073e9SAndroid Build Coastguard Worker    align_corners: Optional[Any] = ...,
495*da0073e9SAndroid Build Coastguard Worker    recompute_scale_factor: Optional[Any] = ...,
496*da0073e9SAndroid Build Coastguard Worker    antialias: bool = ...,
497*da0073e9SAndroid Build Coastguard Worker): ...
498*da0073e9SAndroid Build Coastguard Workerdef upsample_nearest(
499*da0073e9SAndroid Build Coastguard Worker    input: Any,
500*da0073e9SAndroid Build Coastguard Worker    size: Optional[Any] = ...,
501*da0073e9SAndroid Build Coastguard Worker    scale_factor: Optional[Any] = ...,
502*da0073e9SAndroid Build Coastguard Worker): ...
503*da0073e9SAndroid Build Coastguard Workerdef upsample_bilinear(
504*da0073e9SAndroid Build Coastguard Worker    input: Any,
505*da0073e9SAndroid Build Coastguard Worker    size: Optional[Any] = ...,
506*da0073e9SAndroid Build Coastguard Worker    scale_factor: Optional[Any] = ...,
507*da0073e9SAndroid Build Coastguard Worker): ...
508*da0073e9SAndroid Build Coastguard Workerdef grid_sample(
509*da0073e9SAndroid Build Coastguard Worker    input: Tensor,
510*da0073e9SAndroid Build Coastguard Worker    grid: Tensor,
511*da0073e9SAndroid Build Coastguard Worker    mode: str = ...,
512*da0073e9SAndroid Build Coastguard Worker    padding_mode: str = ...,
513*da0073e9SAndroid Build Coastguard Worker    align_corners: Optional[Any] = ...,
514*da0073e9SAndroid Build Coastguard Worker) -> Tensor: ...
515*da0073e9SAndroid Build Coastguard Workerdef affine_grid(
516*da0073e9SAndroid Build Coastguard Worker    theta: Tensor,
517*da0073e9SAndroid Build Coastguard Worker    size: List[int],
518*da0073e9SAndroid Build Coastguard Worker    align_corners: Optional[Any] = ...,
519*da0073e9SAndroid Build Coastguard Worker) -> Tensor: ...
520*da0073e9SAndroid Build Coastguard Workerdef triplet_margin_loss(
521*da0073e9SAndroid Build Coastguard Worker    anchor: Tensor,
522*da0073e9SAndroid Build Coastguard Worker    positive: Tensor,
523*da0073e9SAndroid Build Coastguard Worker    negative: Tensor,
524*da0073e9SAndroid Build Coastguard Worker    margin: float = ...,
525*da0073e9SAndroid Build Coastguard Worker    p: float = ...,
526*da0073e9SAndroid Build Coastguard Worker    eps: float = ...,
527*da0073e9SAndroid Build Coastguard Worker    swap: bool = ...,
528*da0073e9SAndroid Build Coastguard Worker    size_average: Optional[bool] = ...,
529*da0073e9SAndroid Build Coastguard Worker    reduce: Optional[bool] = ...,
530*da0073e9SAndroid Build Coastguard Worker    reduction: str = ...,
531*da0073e9SAndroid Build Coastguard Worker) -> Tensor: ...
532*da0073e9SAndroid Build Coastguard Workerdef triplet_margin_with_distance_loss(
533*da0073e9SAndroid Build Coastguard Worker    anchor: Tensor,
534*da0073e9SAndroid Build Coastguard Worker    positive: Tensor,
535*da0073e9SAndroid Build Coastguard Worker    negative: Tensor,
536*da0073e9SAndroid Build Coastguard Worker    *,
537*da0073e9SAndroid Build Coastguard Worker    distance_function: Optional[Callable[[Tensor, Tensor], Tensor]] = ...,
538*da0073e9SAndroid Build Coastguard Worker    margin: float = ...,
539*da0073e9SAndroid Build Coastguard Worker    swap: bool = ...,
540*da0073e9SAndroid Build Coastguard Worker    reduction: str = ...,
541*da0073e9SAndroid Build Coastguard Worker) -> Tensor: ...
542*da0073e9SAndroid Build Coastguard Workerdef normalize(
543*da0073e9SAndroid Build Coastguard Worker    input: Tensor,
544*da0073e9SAndroid Build Coastguard Worker    p: float = ...,
545*da0073e9SAndroid Build Coastguard Worker    dim: int = ...,
546*da0073e9SAndroid Build Coastguard Worker    eps: float = ...,
547*da0073e9SAndroid Build Coastguard Worker    out: Optional[Tensor] = ...,
548*da0073e9SAndroid Build Coastguard Worker) -> Tensor: ...
549*da0073e9SAndroid Build Coastguard Workerdef assert_int_or_pair(
550*da0073e9SAndroid Build Coastguard Worker    arg: Any,
551*da0073e9SAndroid Build Coastguard Worker    arg_name: Any,
552*da0073e9SAndroid Build Coastguard Worker    message: Any,
553*da0073e9SAndroid Build Coastguard Worker) -> None: ...
554*da0073e9SAndroid Build Coastguard Workerdef unfold(
555*da0073e9SAndroid Build Coastguard Worker    input: Tensor,
556*da0073e9SAndroid Build Coastguard Worker    kernel_size: _size_any_t,
557*da0073e9SAndroid Build Coastguard Worker    dilation: _size_any_t = ...,
558*da0073e9SAndroid Build Coastguard Worker    padding: _size_any_t = ...,
559*da0073e9SAndroid Build Coastguard Worker    stride: _size_any_t = ...,
560*da0073e9SAndroid Build Coastguard Worker) -> Tensor: ...
561*da0073e9SAndroid Build Coastguard Workerdef fold(
562*da0073e9SAndroid Build Coastguard Worker    input: Tensor,
563*da0073e9SAndroid Build Coastguard Worker    output_size: _size_any_t,
564*da0073e9SAndroid Build Coastguard Worker    kernel_size: _size_any_t,
565*da0073e9SAndroid Build Coastguard Worker    dilation: _size_any_t = ...,
566*da0073e9SAndroid Build Coastguard Worker    padding: _size_any_t = ...,
567*da0073e9SAndroid Build Coastguard Worker    stride: _size_any_t = ...,
568*da0073e9SAndroid Build Coastguard Worker) -> Tensor: ...
569*da0073e9SAndroid Build Coastguard Workerdef _canonical_mask(
570*da0073e9SAndroid Build Coastguard Worker    mask: Optional[Tensor],
571*da0073e9SAndroid Build Coastguard Worker    mask_name: str,
572*da0073e9SAndroid Build Coastguard Worker    other_type: Optional[_dtype],
573*da0073e9SAndroid Build Coastguard Worker    other_name: str,
574*da0073e9SAndroid Build Coastguard Worker    target_type: _dtype,
575*da0073e9SAndroid Build Coastguard Worker    check_other: bool = True,
576*da0073e9SAndroid Build Coastguard Worker) -> Optional[Tensor]: ...
577*da0073e9SAndroid Build Coastguard Workerdef _none_or_dtype(input: Optional[Tensor]) -> Optional[_dtype]: ...
578*da0073e9SAndroid Build Coastguard Workerdef multi_head_attention_forward(
579*da0073e9SAndroid Build Coastguard Worker    query: Tensor,
580*da0073e9SAndroid Build Coastguard Worker    key: Tensor,
581*da0073e9SAndroid Build Coastguard Worker    value: Tensor,
582*da0073e9SAndroid Build Coastguard Worker    embed_dim_to_check: int,
583*da0073e9SAndroid Build Coastguard Worker    num_heads: int,
584*da0073e9SAndroid Build Coastguard Worker    in_proj_weight: Optional[Tensor],
585*da0073e9SAndroid Build Coastguard Worker    in_proj_bias: Optional[Tensor],
586*da0073e9SAndroid Build Coastguard Worker    bias_k: Optional[Tensor],
587*da0073e9SAndroid Build Coastguard Worker    bias_v: Optional[Tensor],
588*da0073e9SAndroid Build Coastguard Worker    add_zero_attn: bool,
589*da0073e9SAndroid Build Coastguard Worker    dropout_p: float,
590*da0073e9SAndroid Build Coastguard Worker    out_proj_weight: Tensor,
591*da0073e9SAndroid Build Coastguard Worker    out_proj_bias: Optional[Tensor],
592*da0073e9SAndroid Build Coastguard Worker    training: bool = True,
593*da0073e9SAndroid Build Coastguard Worker    key_padding_mask: Optional[Tensor] = None,
594*da0073e9SAndroid Build Coastguard Worker    need_weights: bool = True,
595*da0073e9SAndroid Build Coastguard Worker    attn_mask: Optional[Tensor] = None,
596*da0073e9SAndroid Build Coastguard Worker    use_separate_proj_weight: bool = False,
597*da0073e9SAndroid Build Coastguard Worker    q_proj_weight: Optional[Tensor] = None,
598*da0073e9SAndroid Build Coastguard Worker    k_proj_weight: Optional[Tensor] = None,
599*da0073e9SAndroid Build Coastguard Worker    v_proj_weight: Optional[Tensor] = None,
600*da0073e9SAndroid Build Coastguard Worker    static_k: Optional[Tensor] = None,
601*da0073e9SAndroid Build Coastguard Worker    static_v: Optional[Tensor] = None,
602*da0073e9SAndroid Build Coastguard Worker    average_attn_weights: bool = True,
603*da0073e9SAndroid Build Coastguard Worker    is_causal: bool = False,
604*da0073e9SAndroid Build Coastguard Worker) -> Tuple[Tensor, Optional[Tensor]]: ...
605*da0073e9SAndroid Build Coastguard Worker
606*da0073e9SAndroid Build Coastguard Worker${imported_hints}
607*da0073e9SAndroid Build Coastguard Worker
608*da0073e9SAndroid Build Coastguard Worker${dispatched_hints}
609