xref: /aosp_15_r20/external/pytorch/torch/onnx/symbolic_opset18.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2"""This file exports ONNX ops for opset 18.
3
4Note [ONNX Operators that are added/updated in opset 18]
5
6~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
7https://github.com/onnx/onnx/blob/main/docs/Changelog.md#version-18-of-the-default-onnx-operator-set
8New operators:
9    BitwiseAnd
10    CenterCropPad
11    Col2Im
12    Mish
13    OptionalGetElement
14    OptionalHasElement
15    Pad
16    Resize
17    ScatterElements
18    ScatterND
19    Split
20"""
21
22import functools
23from typing import List, Optional, Sequence, Tuple
24
25import torch
26from torch import _C
27from torch.onnx import _type_utils, symbolic_helper, symbolic_opset9 as opset9
28from torch.onnx._internal import jit_utils, registration
29
30
31# EDITING THIS FILE? READ THIS FIRST!
32# see Note [Edit Symbolic Files] in symbolic_helper.py
33
34__all__ = [
35    "col2im",
36]
37
38_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=18)
39
40
41@_onnx_symbolic("aten::__and_")
42@_onnx_symbolic("aten::bitwise_and")
43def __and_(g: jit_utils.GraphContext, self, other):
44    # do type promotion (scalars don't seem to apply)
45    args = [self, other]
46    # type promotion doesn't happen with torch.bitwise_and(tensor, scalar)
47    prom_args = [arg for arg in args if symbolic_helper._get_tensor_rank(arg)]
48    if len(prom_args) == 0:
49        prom_args = args
50    promotion_jit_type = symbolic_helper._type_promote_from_values(*prom_args)
51    self = symbolic_helper._maybe_cast_to_type(g, self, promotion_jit_type)
52    other = symbolic_helper._maybe_cast_to_type(g, other, promotion_jit_type)
53    if promotion_jit_type == _type_utils.JitScalarType.BOOL:
54        return g.op("And", self, other)
55    return g.op("BitwiseAnd", self, other)
56
57
58@_onnx_symbolic("aten::col2im")
59@symbolic_helper.parse_args("v", "v", "v", "is", "is", "is")
60def col2im(
61    g,
62    input: _C.Value,
63    output_size: _C.Value,
64    kernel_size: _C.Value,
65    dilation: Sequence[int],
66    padding: Sequence[int],
67    stride: Sequence[int],
68):
69    # convert [i0, i1, ..., in] into [i0, i0, i1, i1, ..., in, in]
70    adjusted_padding = []
71    for pad in padding:
72        for _ in range(2):
73            adjusted_padding.append(pad)
74
75    num_dimensional_axis = symbolic_helper._get_tensor_sizes(output_size)[0]
76    if not adjusted_padding:
77        adjusted_padding = [0, 0] * num_dimensional_axis
78
79    if not dilation:
80        dilation = [1] * num_dimensional_axis
81
82    if not stride:
83        stride = [1] * num_dimensional_axis
84
85    return g.op(
86        "Col2Im",
87        input,
88        output_size,
89        kernel_size,
90        dilations_i=dilation,
91        pads_i=adjusted_padding,
92        strides_i=stride,
93    )
94
95
96@_onnx_symbolic(
97    "aten::mean", decorate=[symbolic_helper._apply_params("ReduceMean", "mean")]
98)
99@_onnx_symbolic(
100    "aten::prod",
101    decorate=[
102        symbolic_helper._apply_params(
103            "ReduceProd", "prod", allow_multi_dim_support=False
104        )
105    ],
106)
107def _reduce_with_dtype(onnx_op: str, name: str, allow_multi_dim_support: bool = True):
108    return symbolic_helper._reduce_with_dtype_helper(
109        onnx_op, name, allow_multi_dim_support
110    )
111
112
113@_onnx_symbolic("aten::native_layer_norm")
114@symbolic_helper.quantized_args(True, False, False, False)
115@symbolic_helper.parse_args("v", "is", "v", "v", "f")
116def _native_layer_norm(
117    g: jit_utils.GraphContext,
118    input: _C.Value,
119    normalized_shape: Sequence[int],
120    weight: _C.Value,
121    bias: _C.Value,
122    eps: float,
123) -> Tuple[_C.Value, _C.Value, _C.Value]:
124    return opset9.native_layer_norm(g, input, normalized_shape, weight, bias, eps)
125
126
127@_onnx_symbolic("aten::glu")
128@symbolic_helper.parse_args("v", "i")
129def _glu(g: jit_utils.GraphContext, input, dim):
130    dim_size = symbolic_helper._get_tensor_dim_size(input, dim)
131    if dim_size is not None:
132        assert dim_size % 2 == 0
133
134    first, second = g.op("Split", input, axis_i=dim, num_outputs_i=2, outputs=2)
135    return g.op("Mul", first, g.op("Sigmoid", second))
136
137
138@_onnx_symbolic("aten::max")
139# torch.max (same for torch.min) actually has two interfaces smashed together:
140# torch.max(x, dim, keepdim) and torch.max(x, y)
141# TODO(justinchuby): Support multiple quantized args in output
142def max(g: jit_utils.GraphContext, self, dim_or_y=None, keepdim=None):
143    return symbolic_helper._max_helper(g, self, dim_or_y, keepdim)
144
145
146@_onnx_symbolic("aten::maximum")
147@symbolic_helper.quantized_args(True, True)
148def maximum(g: jit_utils.GraphContext, input, other):
149    return max(g, input, dim_or_y=other)
150
151
152@_onnx_symbolic("aten::min")
153# TODO(justinchuby): Support multiple quantized args in output
154def min(g: jit_utils.GraphContext, self, dim_or_y=None, keepdim=None):
155    return symbolic_helper._min_helper(g, self, dim_or_y, keepdim)
156
157
158@_onnx_symbolic("aten::minimum")
159@symbolic_helper.quantized_args(True, True)
160def minimum(g: jit_utils.GraphContext, input, other):
161    return min(g, input, dim_or_y=other)
162
163
164@_onnx_symbolic("aten::amax")
165@symbolic_helper.quantized_args(True)
166@symbolic_helper.parse_args("v", "is", "i")
167def amax(g: jit_utils.GraphContext, self, dim, keepdim):
168    axes = g.op("Constant", value_t=torch.tensor(dim, dtype=torch.long))
169    return g.op("ReduceMax", self, axes, keepdims_i=keepdim)
170
171
172@_onnx_symbolic("aten::amin")
173@symbolic_helper.quantized_args(True)
174@symbolic_helper.parse_args("v", "is", "i")
175def amin(g: jit_utils.GraphContext, self, dim, keepdim):
176    axes = g.op("Constant", value_t=torch.tensor(dim, dtype=torch.long))
177    return g.op("ReduceMin", self, axes, keepdims_i=keepdim)
178
179
180@_onnx_symbolic("aten::aminmax")
181@symbolic_helper.quantized_args(True)
182@symbolic_helper.parse_args("v", "v", "i")
183def aminmax(g: jit_utils.GraphContext, self, dim, keepdim):
184    if not symbolic_helper._is_none(dim):
185        dim = symbolic_helper._get_const(dim, "i", "dim")
186        axes = g.op("Constant", value_t=torch.tensor([dim], dtype=torch.long))
187        return g.op("ReduceMin", self, axes, keepdims_i=keepdim), g.op(
188            "ReduceMax", self, axes, keepdims_i=keepdim
189        )
190    else:
191        return g.op("ReduceMin", self, keepdims_i=keepdim), g.op(
192            "ReduceMax", self, keepdims_i=keepdim
193        )
194
195
196@_onnx_symbolic("aten::var_mean")
197def _var_mean(g: jit_utils.GraphContext, input, *args):
198    if len(args) == 1:
199        return symbolic_helper._var_mean_helper(g, input, None, args[0], None)
200    else:
201        return symbolic_helper._var_mean_helper(g, input, *args)
202
203
204@_onnx_symbolic("aten::logsumexp")
205@symbolic_helper.parse_args("v", "is", "i")
206def _logsumexp(g: jit_utils.GraphContext, input, dim, keepdim):
207    if dim is None:
208        return g.op("ReduceLogSumExp", input, keepdims_i=0)
209    else:
210        axes = g.op("Constant", value_t=torch.tensor(dim, dtype=torch.long))
211        return g.op("ReduceLogSumExp", input, axes, keepdims_i=keepdim)
212
213
214@_onnx_symbolic("aten::linalg_matrix_norm")
215@symbolic_helper.parse_args("v", "v", "is", "b", "v")
216def _linalg_matrix_norm(
217    g: jit_utils.GraphContext,
218    self: torch._C.Value,
219    ord: torch._C.Value,
220    dim: List[int],
221    keepdim: bool,
222    dtype: torch._C.Value,
223):
224    return opset9.linalg_matrix_norm(g, self, ord, dim, keepdim, dtype)
225
226
227@_onnx_symbolic("aten::embedding_bag")
228@symbolic_helper.parse_args("v", "v", "v", "i", "i", "i", "v", "i", "i")
229def embedding_bag(
230    g: jit_utils.GraphContext,
231    embedding_matrix,
232    indices,
233    offsets,
234    scale_grad_by_freq,
235    mode,
236    sparse,
237    per_sample_weights,
238    include_last_offset,
239    padding_idx,
240):
241    return symbolic_helper._embedding_bag_helper(
242        g,
243        embedding_matrix,
244        indices,
245        offsets,
246        scale_grad_by_freq,
247        mode,
248        sparse,
249        per_sample_weights,
250        include_last_offset,
251        padding_idx,
252    )
253
254
255@_onnx_symbolic("aten::linalg_vector_norm")
256@symbolic_helper.parse_args("v", "f", "is", "b", "v")
257def linalg_vector_norm(
258    g: jit_utils.GraphContext,
259    self: torch._C.Value,
260    ord: float,
261    dim: Optional[Sequence[int]],
262    keepdim: bool,
263    dtype: torch._C.Value,
264):
265    return symbolic_helper._linalg_vector_norm_helper(g, self, ord, dim, keepdim, dtype)
266