1# mypy: allow-untyped-defs 2"""This file exports ONNX ops for opset 18. 3 4Note [ONNX Operators that are added/updated in opset 18] 5 6~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 7https://github.com/onnx/onnx/blob/main/docs/Changelog.md#version-18-of-the-default-onnx-operator-set 8New operators: 9 BitwiseAnd 10 CenterCropPad 11 Col2Im 12 Mish 13 OptionalGetElement 14 OptionalHasElement 15 Pad 16 Resize 17 ScatterElements 18 ScatterND 19 Split 20""" 21 22import functools 23from typing import List, Optional, Sequence, Tuple 24 25import torch 26from torch import _C 27from torch.onnx import _type_utils, symbolic_helper, symbolic_opset9 as opset9 28from torch.onnx._internal import jit_utils, registration 29 30 31# EDITING THIS FILE? READ THIS FIRST! 32# see Note [Edit Symbolic Files] in symbolic_helper.py 33 34__all__ = [ 35 "col2im", 36] 37 38_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=18) 39 40 41@_onnx_symbolic("aten::__and_") 42@_onnx_symbolic("aten::bitwise_and") 43def __and_(g: jit_utils.GraphContext, self, other): 44 # do type promotion (scalars don't seem to apply) 45 args = [self, other] 46 # type promotion doesn't happen with torch.bitwise_and(tensor, scalar) 47 prom_args = [arg for arg in args if symbolic_helper._get_tensor_rank(arg)] 48 if len(prom_args) == 0: 49 prom_args = args 50 promotion_jit_type = symbolic_helper._type_promote_from_values(*prom_args) 51 self = symbolic_helper._maybe_cast_to_type(g, self, promotion_jit_type) 52 other = symbolic_helper._maybe_cast_to_type(g, other, promotion_jit_type) 53 if promotion_jit_type == _type_utils.JitScalarType.BOOL: 54 return g.op("And", self, other) 55 return g.op("BitwiseAnd", self, other) 56 57 58@_onnx_symbolic("aten::col2im") 59@symbolic_helper.parse_args("v", "v", "v", "is", "is", "is") 60def col2im( 61 g, 62 input: _C.Value, 63 output_size: _C.Value, 64 kernel_size: _C.Value, 65 dilation: Sequence[int], 66 padding: Sequence[int], 67 stride: Sequence[int], 68): 69 # convert [i0, i1, ..., in] into [i0, i0, i1, i1, ..., in, in] 70 adjusted_padding = [] 71 for pad in padding: 72 for _ in range(2): 73 adjusted_padding.append(pad) 74 75 num_dimensional_axis = symbolic_helper._get_tensor_sizes(output_size)[0] 76 if not adjusted_padding: 77 adjusted_padding = [0, 0] * num_dimensional_axis 78 79 if not dilation: 80 dilation = [1] * num_dimensional_axis 81 82 if not stride: 83 stride = [1] * num_dimensional_axis 84 85 return g.op( 86 "Col2Im", 87 input, 88 output_size, 89 kernel_size, 90 dilations_i=dilation, 91 pads_i=adjusted_padding, 92 strides_i=stride, 93 ) 94 95 96@_onnx_symbolic( 97 "aten::mean", decorate=[symbolic_helper._apply_params("ReduceMean", "mean")] 98) 99@_onnx_symbolic( 100 "aten::prod", 101 decorate=[ 102 symbolic_helper._apply_params( 103 "ReduceProd", "prod", allow_multi_dim_support=False 104 ) 105 ], 106) 107def _reduce_with_dtype(onnx_op: str, name: str, allow_multi_dim_support: bool = True): 108 return symbolic_helper._reduce_with_dtype_helper( 109 onnx_op, name, allow_multi_dim_support 110 ) 111 112 113@_onnx_symbolic("aten::native_layer_norm") 114@symbolic_helper.quantized_args(True, False, False, False) 115@symbolic_helper.parse_args("v", "is", "v", "v", "f") 116def _native_layer_norm( 117 g: jit_utils.GraphContext, 118 input: _C.Value, 119 normalized_shape: Sequence[int], 120 weight: _C.Value, 121 bias: _C.Value, 122 eps: float, 123) -> Tuple[_C.Value, _C.Value, _C.Value]: 124 return opset9.native_layer_norm(g, input, normalized_shape, weight, bias, eps) 125 126 127@_onnx_symbolic("aten::glu") 128@symbolic_helper.parse_args("v", "i") 129def _glu(g: jit_utils.GraphContext, input, dim): 130 dim_size = symbolic_helper._get_tensor_dim_size(input, dim) 131 if dim_size is not None: 132 assert dim_size % 2 == 0 133 134 first, second = g.op("Split", input, axis_i=dim, num_outputs_i=2, outputs=2) 135 return g.op("Mul", first, g.op("Sigmoid", second)) 136 137 138@_onnx_symbolic("aten::max") 139# torch.max (same for torch.min) actually has two interfaces smashed together: 140# torch.max(x, dim, keepdim) and torch.max(x, y) 141# TODO(justinchuby): Support multiple quantized args in output 142def max(g: jit_utils.GraphContext, self, dim_or_y=None, keepdim=None): 143 return symbolic_helper._max_helper(g, self, dim_or_y, keepdim) 144 145 146@_onnx_symbolic("aten::maximum") 147@symbolic_helper.quantized_args(True, True) 148def maximum(g: jit_utils.GraphContext, input, other): 149 return max(g, input, dim_or_y=other) 150 151 152@_onnx_symbolic("aten::min") 153# TODO(justinchuby): Support multiple quantized args in output 154def min(g: jit_utils.GraphContext, self, dim_or_y=None, keepdim=None): 155 return symbolic_helper._min_helper(g, self, dim_or_y, keepdim) 156 157 158@_onnx_symbolic("aten::minimum") 159@symbolic_helper.quantized_args(True, True) 160def minimum(g: jit_utils.GraphContext, input, other): 161 return min(g, input, dim_or_y=other) 162 163 164@_onnx_symbolic("aten::amax") 165@symbolic_helper.quantized_args(True) 166@symbolic_helper.parse_args("v", "is", "i") 167def amax(g: jit_utils.GraphContext, self, dim, keepdim): 168 axes = g.op("Constant", value_t=torch.tensor(dim, dtype=torch.long)) 169 return g.op("ReduceMax", self, axes, keepdims_i=keepdim) 170 171 172@_onnx_symbolic("aten::amin") 173@symbolic_helper.quantized_args(True) 174@symbolic_helper.parse_args("v", "is", "i") 175def amin(g: jit_utils.GraphContext, self, dim, keepdim): 176 axes = g.op("Constant", value_t=torch.tensor(dim, dtype=torch.long)) 177 return g.op("ReduceMin", self, axes, keepdims_i=keepdim) 178 179 180@_onnx_symbolic("aten::aminmax") 181@symbolic_helper.quantized_args(True) 182@symbolic_helper.parse_args("v", "v", "i") 183def aminmax(g: jit_utils.GraphContext, self, dim, keepdim): 184 if not symbolic_helper._is_none(dim): 185 dim = symbolic_helper._get_const(dim, "i", "dim") 186 axes = g.op("Constant", value_t=torch.tensor([dim], dtype=torch.long)) 187 return g.op("ReduceMin", self, axes, keepdims_i=keepdim), g.op( 188 "ReduceMax", self, axes, keepdims_i=keepdim 189 ) 190 else: 191 return g.op("ReduceMin", self, keepdims_i=keepdim), g.op( 192 "ReduceMax", self, keepdims_i=keepdim 193 ) 194 195 196@_onnx_symbolic("aten::var_mean") 197def _var_mean(g: jit_utils.GraphContext, input, *args): 198 if len(args) == 1: 199 return symbolic_helper._var_mean_helper(g, input, None, args[0], None) 200 else: 201 return symbolic_helper._var_mean_helper(g, input, *args) 202 203 204@_onnx_symbolic("aten::logsumexp") 205@symbolic_helper.parse_args("v", "is", "i") 206def _logsumexp(g: jit_utils.GraphContext, input, dim, keepdim): 207 if dim is None: 208 return g.op("ReduceLogSumExp", input, keepdims_i=0) 209 else: 210 axes = g.op("Constant", value_t=torch.tensor(dim, dtype=torch.long)) 211 return g.op("ReduceLogSumExp", input, axes, keepdims_i=keepdim) 212 213 214@_onnx_symbolic("aten::linalg_matrix_norm") 215@symbolic_helper.parse_args("v", "v", "is", "b", "v") 216def _linalg_matrix_norm( 217 g: jit_utils.GraphContext, 218 self: torch._C.Value, 219 ord: torch._C.Value, 220 dim: List[int], 221 keepdim: bool, 222 dtype: torch._C.Value, 223): 224 return opset9.linalg_matrix_norm(g, self, ord, dim, keepdim, dtype) 225 226 227@_onnx_symbolic("aten::embedding_bag") 228@symbolic_helper.parse_args("v", "v", "v", "i", "i", "i", "v", "i", "i") 229def embedding_bag( 230 g: jit_utils.GraphContext, 231 embedding_matrix, 232 indices, 233 offsets, 234 scale_grad_by_freq, 235 mode, 236 sparse, 237 per_sample_weights, 238 include_last_offset, 239 padding_idx, 240): 241 return symbolic_helper._embedding_bag_helper( 242 g, 243 embedding_matrix, 244 indices, 245 offsets, 246 scale_grad_by_freq, 247 mode, 248 sparse, 249 per_sample_weights, 250 include_last_offset, 251 padding_idx, 252 ) 253 254 255@_onnx_symbolic("aten::linalg_vector_norm") 256@symbolic_helper.parse_args("v", "f", "is", "b", "v") 257def linalg_vector_norm( 258 g: jit_utils.GraphContext, 259 self: torch._C.Value, 260 ord: float, 261 dim: Optional[Sequence[int]], 262 keepdim: bool, 263 dtype: torch._C.Value, 264): 265 return symbolic_helper._linalg_vector_norm_helper(g, self, ord, dim, keepdim, dtype) 266