1# mypy: allow-untyped-defs 2# mypy: disable-error-code=arg-type 3from __future__ import annotations 4 5import functools 6import sys 7 8import torch 9from torch._C import _onnx as _C_onnx 10from torch.onnx import ( 11 _type_utils, 12 errors, 13 symbolic_helper, 14 symbolic_opset9 as opset9, 15 utils, 16) 17from torch.onnx._internal import jit_utils, registration 18 19 20# EDITING THIS FILE? READ THIS FIRST! 21# see Note [Edit Symbolic Files] in README.md 22 23# This file exports ONNX ops for opset 12 24 25__all__ = [ 26 "argmax", 27 "argmin", 28 "binary_cross_entropy_with_logits", 29 "celu", 30 "cross_entropy_loss", 31 "dropout", 32 "einsum", 33 "ge", 34 "le", 35 "native_dropout", 36 "nll_loss", 37 "nll_loss2d", 38 "nll_loss_nd", 39 "outer", 40 "pow", 41 "tensordot", 42 "unfold", 43] 44 45_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=12) 46 47 48def _einsum_helper(g: jit_utils.GraphContext, equation, tensors): 49 if not tensors: 50 raise RuntimeError("Einsum inputs are empty.") 51 # ONNX does not support bool for Einsum inputs. 52 if symbolic_helper._is_bool(tensors[0]): 53 tensors = [ 54 g.op("Cast", tensor, to_i=_C_onnx.TensorProtoDataType.INT64) 55 for tensor in tensors 56 ] 57 return g.op( 58 "Cast", 59 g.op("Einsum", *tensors, equation_s=equation), 60 to_i=_C_onnx.TensorProtoDataType.BOOL, 61 ) 62 else: 63 return g.op("Einsum", *tensors, equation_s=equation) 64 65 66@_onnx_symbolic("aten::einsum") 67@symbolic_helper.parse_args("s", "v", "is") 68def einsum(g: jit_utils.GraphContext, equation, tensor_list, path=None): 69 tensors = symbolic_helper._unpack_list(tensor_list) 70 return _einsum_helper(g, equation, tensors) 71 72 73@_onnx_symbolic("aten::outer") 74@symbolic_helper.parse_args("v", "v") 75def outer(g: jit_utils.GraphContext, input, other): 76 # make sure to cast other to self's type 77 if _type_utils.JitScalarType.from_value( 78 other, _type_utils.JitScalarType.UNDEFINED 79 ) != _type_utils.JitScalarType.from_value(input): 80 other = g.op( 81 "Cast", 82 other, 83 to_i=_type_utils.JitScalarType.from_value(input).onnx_type(), 84 ) 85 return _einsum_helper(g, "i,j->ij", [input, other]) 86 87 88def _dropout_returns_masked_input_and_mask( 89 g: jit_utils.GraphContext, input: torch._C.Value, p: float, train: bool 90) -> tuple[torch._C.Value, torch._C.Value | None]: 91 symbolic_helper.check_training_mode(train, "dropout") 92 # In eval mode, dropout is non-op. That is, if the node's 93 # train param is set to False, dropout just returns its inputs. 94 if not train: 95 return input, None 96 p = g.op("Constant", value_t=torch.tensor(p)) 97 t = g.op("Constant", value_t=torch.tensor(train, dtype=torch.bool)) 98 r, mask = g.op("Dropout", input, p, t, outputs=2) 99 return r, mask 100 101 102@_onnx_symbolic("aten::dropout") 103@symbolic_helper.parse_args("v", "f", "b") 104def dropout(g: jit_utils.GraphContext, input, p, train): 105 masked, _ = _dropout_returns_masked_input_and_mask(g, input, p, train) 106 return masked 107 108 109@_onnx_symbolic("aten::native_dropout") 110@symbolic_helper.parse_args("v", "f", "b") 111def native_dropout(g: jit_utils.GraphContext, input, p, train): 112 return _dropout_returns_masked_input_and_mask(g, input, p, train) 113 114 115@_onnx_symbolic("aten::nll_loss") 116def nll_loss(g: jit_utils.GraphContext, self, target, weight, reduction, ignore_index): 117 # none reduction : onnx::Constant[value={0}] 118 # mean reduction : onnx::Constant[value={1}] 119 # sum reduction : onnx::Constant[value={2}] 120 reduction = symbolic_helper._maybe_get_const(reduction, "i") 121 reduction_vals = ["none", "mean", "sum"] 122 reduction = reduction_vals[reduction] 123 124 # in onnx NegativeLogLikelihoodLoss specification, ignore_index is optional without default value. 125 # therefore we need to set ignore_index attribute even if it is not specified (e.g. ignore_index=-100). 126 ignore_index = symbolic_helper._maybe_get_const(ignore_index, "i") 127 if weight.node().mustBeNone(): 128 nllloss = g.op( 129 "NegativeLogLikelihoodLoss", 130 self, 131 target, 132 reduction_s=reduction, 133 ignore_index_i=ignore_index, 134 ) 135 else: 136 nllloss = g.op( 137 "NegativeLogLikelihoodLoss", 138 self, 139 target, 140 weight, 141 reduction_s=reduction, 142 ignore_index_i=ignore_index, 143 ) 144 145 return nllloss 146 147 148@_onnx_symbolic("aten::nll_loss2d") 149def nll_loss2d( 150 g: jit_utils.GraphContext, self, target, weight, reduction, ignore_index 151): 152 return nll_loss(g, self, target, weight, reduction, ignore_index) 153 154 155@_onnx_symbolic("aten::nll_loss_nd") 156def nll_loss_nd( 157 g: jit_utils.GraphContext, self, target, weight, reduction, ignore_index 158): 159 return nll_loss(g, self, target, weight, reduction, ignore_index) 160 161 162@_onnx_symbolic("aten::cross_entropy_loss") 163def cross_entropy_loss( 164 g: jit_utils.GraphContext, 165 self, 166 target, 167 weight, 168 reduction, 169 ignore_index, 170 label_smoothing, 171): 172 # none reduction : onnx::Constant[value={0}] 173 # mean reduction : onnx::Constant[value={1}] 174 # sum reduction : onnx::Constant[value={2}] 175 reduction = symbolic_helper._maybe_get_const(reduction, "i") 176 reduction_vals = ["none", "mean", "sum"] 177 reduction = reduction_vals[reduction] 178 179 label_smoothing = symbolic_helper._maybe_get_const(label_smoothing, "f") 180 if label_smoothing is not None and label_smoothing > 0.0: 181 raise errors.SymbolicValueError( 182 "Unsupported: ONNX does not support label_smoothing", self 183 ) 184 185 # in onnx SoftmaxCrossEntropyLoss specification, ignore_index is optional without default value. 186 # therefore we need to set ignore_index attribute even if it is not specified (e.g. ignore_index=-100). 187 ignore_index = symbolic_helper._maybe_get_const(ignore_index, "i") 188 if weight.node().mustBeNone(): 189 celoss = g.op( 190 "SoftmaxCrossEntropyLoss", 191 self, 192 target, 193 reduction_s=reduction, 194 ignore_index_i=ignore_index, 195 ) 196 else: 197 celoss = g.op( 198 "SoftmaxCrossEntropyLoss", 199 self, 200 target, 201 weight, 202 reduction_s=reduction, 203 ignore_index_i=ignore_index, 204 ) 205 206 return celoss 207 208 209@_onnx_symbolic("aten::binary_cross_entropy_with_logits") 210@symbolic_helper.parse_args("v", "v", "v", "v", "i") 211def binary_cross_entropy_with_logits( 212 g: jit_utils.GraphContext, input, target, weight, pos_weight, reduction 213): 214 p = g.op("Constant", value_t=torch.tensor([1])) 215 sig_x = opset9.sigmoid(g, input) 216 log_sig_x = opset9.log(g, sig_x) 217 sub_1_x = opset9.sub(g, p, sig_x) 218 sub_1_y = opset9.sub(g, p, target) 219 log_1_x = opset9.log(g, sub_1_x) 220 if pos_weight is None or symbolic_helper._is_none(pos_weight): 221 output = opset9.neg( 222 g, 223 opset9.add( 224 g, opset9.mul(g, target, log_sig_x), opset9.mul(g, sub_1_y, log_1_x) 225 ), 226 ) 227 else: 228 output = opset9.neg( 229 g, 230 opset9.add( 231 g, 232 opset9.mul(g, opset9.mul(g, target, log_sig_x), pos_weight), 233 opset9.mul(g, sub_1_y, log_1_x), 234 ), 235 ) 236 237 if weight is not None and not symbolic_helper._is_none(weight): 238 output = opset9.mul(g, weight, output) 239 240 reduction = symbolic_helper._maybe_get_const(reduction, "i") 241 if reduction == 0: 242 return output 243 elif reduction == 1: 244 return g.op("ReduceMean", output, keepdims_i=0) 245 elif reduction == 2: 246 return g.op("ReduceSum", output, keepdims_i=0) 247 else: 248 return symbolic_helper._onnx_unsupported( 249 "binary_cross_entropy_with_logits with reduction other than none, mean, or sum", 250 input, 251 ) 252 253 254@_onnx_symbolic("aten::celu") 255def celu(g: jit_utils.GraphContext, self, alpha): 256 alpha = symbolic_helper._maybe_get_const(alpha, "f") 257 # if the input is of type double cast it to float 258 if ( 259 _type_utils.JitScalarType.from_value(self, _type_utils.JitScalarType.UNDEFINED) 260 == _type_utils.JitScalarType.DOUBLE 261 ): 262 self = g.op("Cast", self, to_i=_C_onnx.TensorProtoDataType.FLOAT) 263 out = g.op("Celu", self, alpha_f=alpha) 264 return g.op("Cast", out, to_i=_C_onnx.TensorProtoDataType.DOUBLE) 265 266 return g.op("Celu", self, alpha_f=alpha) 267 268 269@_onnx_symbolic("aten::argmax") 270@symbolic_helper.parse_args("v", "v", "b") 271def argmax( 272 g: jit_utils.GraphContext, 273 input: torch._C.Value, 274 dim: torch._C.Value, 275 keepdim: bool, 276): 277 return symbolic_helper._argmin_argmax_helper(g, input, dim, keepdim, "ArgMax") 278 279 280@_onnx_symbolic("aten::argmin") 281@symbolic_helper.parse_args("v", "v", "b") 282def argmin( 283 g: jit_utils.GraphContext, 284 input: torch._C.Value, 285 dim: torch._C.Value, 286 keepdim: bool, 287): 288 return symbolic_helper._argmin_argmax_helper(g, input, dim, keepdim, "ArgMin") 289 290 291@_onnx_symbolic("aten::pow") 292def pow(g: jit_utils.GraphContext, self, exponent): 293 return g.op("Pow", self, exponent) 294 295 296@_onnx_symbolic("aten::ge") 297def ge(g: jit_utils.GraphContext, input, other): 298 return g.op("GreaterOrEqual", input, other) 299 300 301@_onnx_symbolic("aten::le") 302def le(g: jit_utils.GraphContext, input, other): 303 return g.op("LessOrEqual", input, other) 304 305 306@_onnx_symbolic("aten::unfold") 307@symbolic_helper.parse_args("v", "i", "v", "v") 308def unfold(g: jit_utils.GraphContext, input, dimension, size, step): 309 const_size = symbolic_helper._maybe_get_const(size, "i") 310 const_step = symbolic_helper._maybe_get_const(step, "i") 311 if not symbolic_helper._is_value(const_size) and not symbolic_helper._is_value( 312 const_step 313 ): 314 return opset9.unfold(g, input, dimension, const_size, const_step) 315 316 sizedim = symbolic_helper._get_tensor_dim_size(input, dimension) 317 if sizedim is not None: 318 low_start = g.op("Constant", value_t=torch.tensor(0)) 319 low_end = g.op("Constant", value_t=torch.tensor(sizedim)) 320 hi_end = g.op("Constant", value_t=torch.tensor(sizedim + 1)) 321 low_indices = g.op("Range", low_start, low_end, step) 322 hi_indices = g.op("Range", size, hi_end, step) 323 324 low_size = symbolic_helper._size_helper( 325 g, low_indices, g.op("Constant", value_t=torch.tensor(0)) 326 ) 327 hi_size = symbolic_helper._size_helper( 328 g, hi_indices, g.op("Constant", value_t=torch.tensor(0)) 329 ) 330 331 ndim = symbolic_helper._get_tensor_rank(input) 332 assert ndim is not None 333 perm = list(range(0, ndim)) 334 perm.append(perm.pop(dimension)) 335 336 unsqueeze_list = [] 337 loop_condition = g.op("Constant", value_t=torch.tensor(1)) 338 loop_condition = g.op( 339 "Cast", loop_condition, to_i=_C_onnx.TensorProtoDataType.BOOL 340 ) 341 loop_len = g.op("Min", low_size, hi_size) 342 343 loop, (loop_context,), _ = jit_utils.add_op_with_blocks( 344 g, "Loop", loop_len, loop_condition, n_blocks=1 345 ) 346 347 loop_block = loop_context.block 348 block_input_iter = utils._add_input_to_block(loop_block) 349 # FIXME(justinchuby): cond is unused? 350 cond = utils._add_input_to_block(loop_block) 351 352 starts = loop_context.op("Gather", low_indices, block_input_iter) 353 ends = loop_context.op("Gather", hi_indices, block_input_iter) 354 axes = loop_context.op("Constant", value_t=torch.tensor([2])) 355 starts = symbolic_helper._unsqueeze_helper(loop_context, starts, [0]) 356 ends = symbolic_helper._unsqueeze_helper(loop_context, ends, [0]) 357 stack = loop_context.op("Slice", input, starts, ends, axes) 358 359 unsqueeze = symbolic_helper._unsqueeze_helper( 360 loop_context, loop_context.op("Transpose", stack, perm_i=perm), [dimension] 361 ) 362 unsqueeze_list.append(unsqueeze) 363 concat = loop_context.op("Concat", *unsqueeze_list, axis_i=0) 364 365 cond_out = loop_context.op( 366 "Cast", loop_condition, _C_onnx.TensorProtoDataType.BOOL 367 ) 368 utils._add_output_to_block(loop_block, cond_out) 369 utils._add_output_to_block(loop_block, concat) 370 371 loop_output = loop.node().output() 372 perm = [0, 1, 2, 3, 4] 373 perm[0], perm[dimension + 1] = perm[dimension + 1], perm[0] 374 transpose = g.op("Transpose", loop_output, perm_i=perm) 375 squeeze = symbolic_helper._squeeze_helper(g, transpose, [0]) 376 377 return squeeze 378 379 return symbolic_helper._unimplemented("Unfold", "input size not accessible") 380 381 382@_onnx_symbolic("aten::tensordot") 383@symbolic_helper.parse_args("v", "v", "is", "is", "v") 384def tensordot(g: jit_utils.GraphContext, input_a, input_b, dims_a, dims_b, out=None): 385 if out is not None: 386 symbolic_helper._unimplemented( 387 "Tensordot", "Out parameter is not supported for tensordot." 388 ) 389 390 dim_count_a = symbolic_helper._get_tensor_rank(input_a) 391 if dim_count_a is None: 392 raise errors.SymbolicValueError( 393 "Unsupported: ONNX export of tensordot for tensor(input_a) of unknown rank.", 394 input_a, 395 ) 396 397 dim_count_b = symbolic_helper._get_tensor_rank(input_b) 398 if dim_count_b is None: 399 raise errors.SymbolicValueError( 400 "Unsupported: ONNX export of tensordot for tensor(input_b) of unknown rank.", 401 input_b, 402 ) 403 404 dims_a = [ 405 (dims_a[i] + dim_count_a) if (dims_a[i] < 0) else dims_a[i] 406 for i in range(len(dims_a)) 407 ] 408 dims_b = [ 409 (dims_b[i] + dim_count_b) if (dims_b[i] < 0) else dims_b[i] 410 for i in range(len(dims_b)) 411 ] 412 413 left_dims_a = [i for i in range(dim_count_a) if (i not in dims_a)] 414 left_dims_b = [i for i in range(dim_count_b) if (i not in dims_b)] 415 416 new_input_a = opset9.permute(g, input_a, left_dims_a + dims_a) 417 new_input_b = opset9.permute(g, input_b, dims_b + left_dims_b) 418 419 input_shape = g.op("Shape", new_input_a) 420 left_sizes_a = symbolic_helper._slice_helper( 421 g, input_shape, axes=[0], starts=[0], ends=[len(left_dims_a)] 422 ) 423 shape_sizes = [ 424 left_sizes_a, 425 g.op("Constant", value_t=torch.tensor([-1], dtype=torch.long)), 426 ] 427 output_a = opset9._reshape_from_tensor(g, new_input_a, shape_sizes) 428 429 input_shape = g.op("Shape", output_a) 430 slices = symbolic_helper._slice_helper( 431 g, input_shape, axes=[0], starts=[-1], ends=[sys.maxsize] 432 ) 433 shape_sizes = [ 434 g.op("Constant", value_t=torch.tensor([-1], dtype=torch.long)), 435 slices, 436 ] 437 output_a = opset9._reshape_from_tensor(g, new_input_a, shape_sizes) 438 439 input_shape = g.op("Shape", new_input_b) 440 left_sizes_b = symbolic_helper._slice_helper( 441 g, input_shape, axes=[0], starts=[len(dims_b)], ends=[sys.maxsize] 442 ) 443 slices = symbolic_helper._slice_helper( 444 g, input_shape, axes=[0], starts=[0], ends=[len(dims_b)] 445 ) 446 shape_sizes = [ 447 slices, 448 g.op("Constant", value_t=torch.tensor([-1], dtype=torch.long)), 449 ] 450 output_b = opset9._reshape_from_tensor(g, new_input_b, shape_sizes) 451 452 input_shape = g.op("Shape", output_b) 453 slices = symbolic_helper._slice_helper( 454 g, input_shape, axes=[0], starts=[-1], ends=[sys.maxsize] 455 ) 456 shape_sizes = [ 457 g.op("Constant", value_t=torch.tensor([-1], dtype=torch.long)), 458 slices, 459 ] 460 output_b = opset9._reshape_from_tensor(g, new_input_b, shape_sizes) 461 462 output = einsum(g, "ij,jk->ik", g.op("prim::ListConstruct", *[output_a, output_b])) 463 464 shape_sizes = [left_sizes_a, left_sizes_b] 465 return opset9._reshape_from_tensor(g, output, shape_sizes) 466