1# mypy: allow-untyped-defs 2from __future__ import annotations 3 4import functools 5import inspect 6import math 7import sys 8import typing 9import warnings 10from typing import Any, Callable, Literal, NoReturn, Sequence, TypeVar as _TypeVar 11from typing_extensions import Concatenate as _Concatenate, ParamSpec as _ParamSpec 12 13import torch 14import torch._C._onnx as _C_onnx 15from torch import _C 16 17# Monkey-patch graph manipulation methods on Graph, used for the ONNX symbolics 18from torch.onnx import _constants, _type_utils, errors, utils 19from torch.onnx._globals import GLOBALS 20from torch.onnx._internal import jit_utils 21 22 23if typing.TYPE_CHECKING: 24 from torch.types import Number 25 26_T = _TypeVar("_T") 27_U = _TypeVar("_U") 28_P = _ParamSpec("_P") 29 30# --------------------------------------------------------------------------------- 31# Helper functions 32# --------------------------------------------------------------------------------- 33 34_ValueDescriptor = Literal[ 35 "v", 36 "i", 37 "is", 38 "f", 39 "fs", 40 "b", 41 "s", 42 "t", 43 "none", 44] 45 46 47def _parse_arg( 48 value, 49 desc: _ValueDescriptor, 50 arg_name: str | None = None, 51 node_name: str | None = None, 52): 53 if desc == "none": 54 return value 55 if desc == "v" or not _is_value(value): 56 return value 57 58 node = value.node() 59 if node.mustBeNone(): 60 return None 61 if node.kind() == "onnx::Constant": 62 node_val = _node_get(node, "value") 63 if desc == "i": 64 return int(node_val) 65 elif desc == "f": 66 return float(node_val) 67 elif desc == "b": 68 return bool(node_val) 69 elif desc == "s": 70 return str(node_val) 71 elif desc == "t": 72 return node_val 73 elif desc == "is": 74 return [int(v) for v in node_val] 75 elif desc == "fs": 76 return [float(v) for v in node_val] 77 else: 78 raise errors.SymbolicValueError( 79 f"ONNX symbolic does not understand the Constant node '{node}' " 80 f"specified with descriptor '{desc}'.", 81 value, 82 ) 83 elif node.kind() == "prim::ListConstruct": 84 if desc == "is": 85 for v in node.inputs(): 86 element_node = v.node() 87 if element_node.kind() != "onnx::Constant": 88 raise errors.SymbolicValueError( 89 f"Failed to export a node '{element_node}' " 90 f"(in list node {node}) " 91 f"because it is not constant. " 92 f"Please try to make things (e.g. kernel sizes) static if possible.", 93 value, 94 ) 95 return [int(_node_get(v.node(), "value")) for v in value.node().inputs()] 96 else: 97 raise errors.SymbolicValueError( 98 f"ONNX symbolic does not know how to unpack the ListConstruct node that " 99 f"is not a list of integers: '{node}'", 100 value, 101 ) 102 103 if arg_name is None or node_name is None: 104 raise errors.SymbolicValueError( 105 f"Expected node type 'onnx::Constant', got '{node.kind()}'.", 106 value, 107 ) 108 109 raise errors.SymbolicValueError( 110 "Expected node type 'onnx::Constant' " 111 f"for argument '{arg_name}' of node '{node_name}', got '{node.kind()}'.", 112 value, 113 ) 114 115 116def _node_get(node: _C.Node, key: str): 117 """Gets attributes of a node which is polymorphic over return type.""" 118 assert isinstance(node, _C.Node) 119 sel = node.kindOf(key) 120 return getattr(node, sel)(key) 121 122 123def _is_onnx_constant(value: _C.Value): 124 """Whether a Value is an ONNX constant.""" 125 return value.node().kind() == "onnx::Constant" 126 127 128def _maybe_get_const( 129 value: _C.Value | torch.Tensor | Number | Sequence | None, 130 descriptor: _ValueDescriptor, 131): 132 # NOTE: prim::Constant at this stage usually means something not compatible in ONNX, 133 # otherwise it'd be converted to onnx::Constant 134 # TODO(justinchuby): Replace insinstance with _is_value once we figure out mypy 135 if isinstance(value, _C.Value) and _is_onnx_constant(value): 136 return _parse_arg(value, descriptor) 137 return value 138 139 140def _maybe_get_scalar(value): 141 value_t = _maybe_get_const(value, "t") 142 if isinstance(value_t, torch.Tensor) and value_t.shape == (): 143 return value_t 144 return value 145 146 147def _get_const(value, desc, arg_name): 148 if not _is_constant(value): 149 raise errors.SymbolicValueError( 150 f"ONNX symbolic expected a constant value of the '{arg_name}' argument, " 151 f"got '{value}'", 152 value, 153 ) 154 return _parse_arg(value, desc) 155 156 157def _unpack_list(list_value: _C.Value) -> list[_C.Value]: 158 list_node = list_value.node() 159 if list_node.kind() != "prim::ListConstruct": 160 raise errors.SymbolicValueError( 161 f"ONNX symbolic expected node type prim::ListConstruct, " 162 f"got '{list_node}'.", 163 list_value, 164 ) 165 return list(list_node.inputs()) 166 167 168def _unpack_tuple(tuple_value: _C.Value) -> tuple[_C.Value, ...]: 169 tuple_node = tuple_value.node() 170 if not _is_tuple_construct(tuple_value): 171 raise errors.SymbolicValueError( 172 f"ONNX symbolic expected node type 'prim::TupleConstruct', " 173 f"got '{tuple_node.kind()}'.", 174 tuple_value, 175 ) 176 return tuple(tuple_node.inputs()) 177 178 179def _unpack_quantized_tensor(tuple_value: _C.Value) -> tuple[_C.Value, ...]: 180 """Unpacks a quantized tensor into a tuple of tensor and scale/zero_point. 181 Args: 182 tuple_value: A tuple of tensor, scale, zero_point, and optionally axis. 183 Returns: 184 A tuple of tensor, scale, zero_point, and optionally axis. 185 """ 186 tuple_node = tuple_value.node() 187 # A quantized tensor is represented as tuple of the form (tensor, scale, zero_point, <axis>) 188 if not _is_tuple_construct(tuple_value): 189 raise errors.SymbolicValueError( 190 f"ONNX symbolic expected the output of `{tuple_node}` to be a quantized " 191 f"tensor. Is this likely due to missing support for quantized " 192 f"`{tuple_node.kind()}`. Please create an issue on {_constants.PYTORCH_GITHUB_ISSUES_URL}", 193 tuple_value, 194 ) 195 unpacked = tuple(tuple_node.inputs()) 196 assert len(unpacked) == 3 or len(unpacked) == 4 197 return unpacked 198 199 200# Check if list_value is output from prim::ListConstruct 201# This is usually called before _unpack_list to ensure the list can be unpacked. 202def _is_packed_list(list_value: Any) -> bool: 203 return _is_value(list_value) and list_value.node().kind() == "prim::ListConstruct" 204 205 206def parse_args( 207 *arg_descriptors: _ValueDescriptor, 208) -> Callable[[Callable[_Concatenate[_U, _P], _T]], Callable[_Concatenate[_U, _P], _T]]: 209 """A decorator which converts args from torch._C.Value to built-in types. 210 211 For example: 212 213 ``` 214 @parse_args('v', 'i', 'fs') 215 foo(g, a, b, c): 216 assert isinstance(a, torch._C.Value) 217 assert isinstance(b, int) 218 assert isinstance(c, list) 219 assert isinstance(c[0], float) 220 ``` 221 222 Args: 223 arg_descriptors: list of str, where each element is 224 a string that specifies the type to convert to. Valid descriptors: 225 "v": no conversion, keep torch._C.Value. 226 "i": int 227 "is": list of int 228 "f": float 229 "fs": list of float 230 "b": bool 231 "s": str 232 "t": torch.Tensor 233 "none": the variable is unused 234 """ 235 236 def decorator( 237 fn: Callable[_Concatenate[_U, _P], _T], 238 ) -> Callable[_Concatenate[_U, _P], _T]: 239 fn._arg_descriptors = arg_descriptors # type: ignore[attr-defined] 240 241 @functools.wraps(fn) 242 def wrapper(g: _U, *args: _P.args, **kwargs: _P.kwargs) -> _T: 243 # some args may be optional, so the length may be smaller 244 FILE_BUG_MSG = ( 245 "If you believe this is not due to custom symbolic implementation within your code or " 246 "an external library, please file an issue at " 247 "https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml to report this bug." 248 ) 249 assert len(arg_descriptors) >= len(args), ( 250 f"A mismatch between the number of arguments ({len(args)}) and " 251 f"their descriptors ({len(arg_descriptors)}) was found at symbolic function '{fn.__name__}'. " 252 f"{FILE_BUG_MSG}" 253 ) 254 255 try: 256 sig = inspect.signature(fn) 257 arg_names = list(sig.parameters.keys())[1:] 258 fn_name = fn.__name__ 259 except Exception: 260 # FIXME(justinchuby): Avoid catching Exception. 261 # Catch a more specific exception instead. 262 arg_names = [None] * len(args) # type: ignore[list-item] 263 fn_name = None 264 args = [ 265 _parse_arg(arg, arg_desc, arg_name, fn_name) # type: ignore[method-assign] 266 for arg, arg_desc, arg_name in zip(args, arg_descriptors, arg_names) 267 ] 268 # only support _outputs in kwargs 269 assert len(kwargs) <= 1, ( 270 f"Symbolic function {fn.__name__}'s '**kwargs' can contain a single " 271 f"key/value entry. " 272 f"{FILE_BUG_MSG}" 273 ) 274 275 if len(kwargs) == 1: 276 assert "_outputs" in kwargs, ( 277 f"Symbolic function {fn.__name__}'s '**kwargs' can only contain " 278 f"'_outputs' key at '**kwargs'. " 279 f"{FILE_BUG_MSG}" 280 ) 281 return fn(g, *args, **kwargs) 282 283 return wrapper 284 285 return decorator 286 287 288def quantized_args( 289 *arg_q_descriptors: bool, 290 scale: float | None = None, 291 zero_point: int | None = None, 292 quantize_output: bool = True, 293) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]: 294 """A decorator which extends support for quantized version of the base operator. 295 296 Quantization is detected by examining the arguments that are annotated by 297 `arg_q_descriptors`. 298 299 If quantization is detected, the base operator symbolic function will be wrapped with 300 argument de-quantization and output quantization. 301 302 Otherwise, only the base symbolic function will be invoked. 303 304 For example: 305 306 ``` 307 @quantized_args(True, False) 308 def foo(g, x, y): 309 return x + y 310 ``` 311 312 is equivalent to 313 314 ``` 315 def q_foo(g, x, y): 316 if is_quantized_tensor(x): 317 x = dequantize(x) 318 out = foo(g, x, y) 319 return quantize(out) 320 else: 321 return foo(g, x, y) 322 ``` 323 324 Args: 325 arg_q_descriptors: A sequence of bool, where each element represents if the 326 argument is QTensor for quantized version of this operator. It defaults 327 to False for unspecified (variable length) arguments. 328 scale: Quantized output scale. If None, derive from 329 the first quantized input scale. 330 zero_point: Quantized output zero point. If None, 331 derive from the first quantized input zero point. 332 quantize_output: If True, quantize the output of the base operator. Default is True 333 """ 334 335 def decorator(fn): 336 @functools.wraps(fn) 337 def wrapper(g, *args, **kwargs): 338 nonlocal scale 339 nonlocal zero_point 340 if scale is not None: 341 _scale = g.op("Constant", value_t=torch.tensor(scale)) 342 else: 343 _scale = None 344 if zero_point is not None: 345 _zero_point = g.op("Constant", value_t=torch.tensor(zero_point)) 346 else: 347 _zero_point = None 348 349 # Support variable length arguments by marking unspecified ones as non-quantized 350 arg_q_descriptors_extended = arg_q_descriptors + (False,) * ( 351 len(args) - len(arg_q_descriptors) 352 ) 353 descriptor_args = tuple(zip(arg_q_descriptors_extended, args)) 354 355 def _is_arg_quantized(descriptor, arg): 356 return descriptor and _is_value(arg) and _is_tuple_construct(arg) 357 358 # Run regular symbolic function if none of the argument is QTensor. 359 is_quantized = [] 360 for descriptor, arg in descriptor_args: 361 # ListConstruct 362 if _is_packed_list(arg): 363 for arg_input in arg.node().inputs(): 364 is_quantized.append(_is_arg_quantized(descriptor, arg_input)) 365 else: 366 is_quantized.append(_is_arg_quantized(descriptor, arg)) 367 368 if not any(is_quantized): 369 return fn(g, *args, **kwargs) 370 371 # Dequantize arguments that are quantized 372 non_quantized_args = [] 373 for descriptor, arg in descriptor_args: 374 if _is_arg_quantized(descriptor, arg): 375 # Quantized arg is a tuple of (value, scale, zero_point) 376 dequantized_arg, arg_scale, arg_zero_point, _ = dequantize_helper( 377 g, arg 378 ) 379 non_quantized_args.append(dequantized_arg) 380 # Set scale and zero_point to the first quantized input if not already set 381 if _scale is None: 382 _scale = arg_scale 383 if _zero_point is None: 384 _zero_point = arg_zero_point 385 # ListConstruct 386 elif _is_packed_list(arg): 387 for arg_input in arg.node().inputs(): 388 if _is_arg_quantized(descriptor, arg_input): 389 # Quantized arg is a tuple of (value, scale, zero_point) 390 ( 391 dequantized_arg, 392 arg_scale, 393 arg_zero_point, 394 _, 395 ) = dequantize_helper(g, arg_input) 396 # Set scale and zero_point to the first quantized input if not already set 397 if _scale is None: 398 _scale = arg_scale 399 if _zero_point is None: 400 _zero_point = arg_zero_point 401 arg_input.replaceAllUsesWith(dequantized_arg) 402 non_quantized_args.append(arg) 403 else: 404 # Non-quantized arg 405 non_quantized_args.append(arg) 406 # TODO(justinchuby): Only single output is supported for now. We may want to 407 # support multiple outputs in the future. 408 output = fn(g, *non_quantized_args, **kwargs) 409 410 assert _scale is not None, "Bug: Scale must be set for quantized operator" 411 assert ( 412 _zero_point is not None 413 ), "Bug: Zero point must be set for quantized operator" 414 415 if quantize_output: 416 return quantize_helper(g, output, _scale, _zero_point) 417 return output 418 419 return wrapper 420 421 return decorator 422 423 424def _scalar(x: Any) -> Number | None: 425 """Convert a scalar tensor into a Python value.""" 426 if isinstance(x, torch.Tensor) and x.shape == (): 427 return x.item() 428 return None 429 430 431def _if_scalar_type_as(self, tensor): 432 """ 433 Convert self into the same type of tensor, as necessary. 434 We only support implicit casting for scalars, so we never 435 actually need to insert an ONNX cast operator here; just 436 fix up the scalar. 437 """ 438 if isinstance(self, _C.Value): 439 return self 440 441 scalar_type = _type_utils.JitScalarType.from_value( 442 tensor, _type_utils.JitScalarType.UNDEFINED 443 ) 444 if scalar_type != _type_utils.JitScalarType.UNDEFINED: 445 ty = scalar_type.scalar_name().lower() 446 return getattr(self, ty)() 447 return self 448 449 450def _is_none(x: Any) -> bool: 451 return x is None or (x.node().mustBeNone() if isinstance(x, _C.Value) else False) 452 453 454def _is_value(x: Any) -> bool: 455 return isinstance(x, _C.Value) 456 457 458def _is_constant(value: Any) -> bool: 459 return not _is_value(value) or value.node().kind() in { 460 "onnx::Constant", 461 "prim::Constant", 462 } 463 464 465def _is_tensor(x: _C.Value) -> bool: 466 return x.type().isSubtypeOf(_C.TensorType.get()) 467 468 469# Note: _C.JitType is not exposed to Python and cannot be checked in runtime. 470def _as_list_type(jit_type: _C.JitType) -> _C.ListType | None: 471 if isinstance(jit_type, _C.ListType): 472 return jit_type 473 return None 474 475 476def _is_list(x: _C.Value) -> bool: 477 return _as_list_type(x.type()) is not None 478 479 480def _is_tensor_list(x: _C.Value) -> bool: 481 x_type = _as_list_type(x.type()) 482 if x_type is None: 483 return False 484 return isinstance(x_type.getElementType(), _C.TensorType) 485 486 487def _is_scalar_list(x: _C.Value) -> bool: 488 """Checks if x is a scalar list, for example: List[float], List[int]. 489 490 Besides checking the type is ListType, we also check if the data type is 491 a valid ONNX data type. 492 """ 493 x_type = _as_list_type(x.type()) 494 if x_type is None: 495 return False 496 scalar_type = _type_utils.JitScalarType.from_value(x) 497 return scalar_type.onnx_compatible() 498 499 500def _is_tuple_construct(x: _C.Value) -> bool: 501 return x.node().kind() == "prim::TupleConstruct" 502 503 504def is_complex_value(x: _C.Value) -> bool: 505 assert _is_value(x) 506 return _type_utils.JitScalarType.from_value( 507 x, _type_utils.JitScalarType.UNDEFINED 508 ) in { 509 _type_utils.JitScalarType.COMPLEX32, 510 _type_utils.JitScalarType.COMPLEX64, 511 _type_utils.JitScalarType.COMPLEX128, 512 } 513 514 515def _get_tensor_rank(x: _C.Value) -> int | None: 516 if not _is_tensor(x) or x.type() is None: 517 return None 518 x_type = x.type() 519 x_type = typing.cast(_C.TensorType, x_type) 520 return x_type.dim() 521 522 523def _get_tensor_sizes(x: _C.Value, allow_nonstatic: bool = True): 524 if not _is_tensor(x) or x.type() is None: 525 return None 526 x_type = x.type() 527 x_type = typing.cast(_C.TensorType, x_type) 528 if allow_nonstatic: 529 # Each individual symbol is returned as None. 530 # e.g. [1, "a", "b"] -> [1, None, None] 531 return x_type.varyingSizes() 532 # returns None, if exists any symbol in sizes. 533 # e.g. [1, "a", "b"] -> None 534 return x_type.sizes() 535 536 537def _get_tensor_dim_size(x: _C.Value, dim: int) -> int | None: 538 sizes = _get_tensor_sizes(x) 539 return sizes[dim] if sizes else None 540 541 542def _get_dim_for_cross(x: _C.Value, dim: int | None): 543 if dim == -1: 544 tensor_rank = _get_tensor_rank(x) 545 assert tensor_rank is not None 546 return dim + tensor_rank 547 # If dim is not given, it defaults to the first dimension found with the size 3 548 if dim is None: 549 sizes = _get_tensor_sizes(x) 550 assert sizes is not None 551 for index, size in enumerate(sizes): 552 if size is not None and size == 3: 553 return index 554 return dim 555 556 557def _unimplemented(op: str, msg: str, value: _C.Value | None = None) -> None: 558 # For BC reasons, the behavior for Caffe2 does not raise exception for unimplemented operators 559 if GLOBALS.operator_export_type == _C_onnx.OperatorExportTypes.ONNX: 560 _onnx_unsupported(f"{op}, {msg}", value) 561 562 563def _onnx_unsupported(op_name: str, value: _C.Value | None = None) -> NoReturn: 564 message = ( 565 f"Unsupported: ONNX export of operator {op_name}. " 566 f"Please feel free to request support or submit a pull request " 567 f"on PyTorch GitHub: {_constants.PYTORCH_GITHUB_ISSUES_URL}" 568 ) 569 if isinstance(value, _C.Value): 570 raise errors.SymbolicValueError( 571 message, 572 value, 573 ) 574 raise errors.OnnxExporterError(message) 575 576 577def _onnx_opset_unsupported( 578 op_name: str, 579 current_opset: int, 580 supported_opset: int, 581 value: _C.Value | None = None, 582) -> NoReturn: 583 message = ( 584 f"Unsupported: ONNX export of {op_name} in opset {current_opset}. " 585 f"Please try opset version {supported_opset}." 586 ) 587 if isinstance(value, _C.Value): 588 raise errors.SymbolicValueError( 589 message, 590 value, 591 ) 592 raise errors.OnnxExporterError(message) 593 594 595def _onnx_opset_unsupported_detailed( 596 op_name: str, 597 current_opset: int, 598 supported_opset: int, 599 reason: str, 600 value: _C.Value | None = None, 601) -> NoReturn: 602 message = ( 603 f"Unsupported: ONNX export of {op_name} in " 604 f"opset {current_opset}. {reason}. Please try opset version {supported_opset}." 605 ) 606 if isinstance(value, _C.Value): 607 raise errors.SymbolicValueError( 608 message, 609 value, 610 ) 611 raise errors.OnnxExporterError(message) 612 613 614def _block_list_in_opset(name: str): 615 def symbolic_fn(*args, **kwargs): 616 raise errors.OnnxExporterError( 617 f"ONNX export failed on {name}, which is not implemented for opset " 618 f"{GLOBALS.export_onnx_opset_version}. " 619 "Try exporting with other opset versions." 620 ) 621 622 return symbolic_fn 623 624 625def _try_get_scalar_type(*args) -> _type_utils.JitScalarType | None: 626 for arg in args: 627 scalar_type = _type_utils.JitScalarType.from_value( 628 arg, _type_utils.JitScalarType.UNDEFINED 629 ) 630 if scalar_type != _type_utils.JitScalarType.UNDEFINED: 631 return scalar_type 632 return None 633 634 635def _type_promote_from_values(*args) -> _type_utils.JitScalarType: 636 undef = _type_utils.JitScalarType.UNDEFINED 637 jit_types = [_try_get_scalar_type(arg) for arg in args] 638 if len(jit_types) == 0: 639 return undef 640 if len(jit_types) == 1: 641 return jit_types[0] # type: ignore[return-value] 642 new_dtype = jit_types[0].dtype() # type: ignore[union-attr] 643 for t in jit_types: 644 new_dtype = torch.promote_types(new_dtype, t.dtype()) # type: ignore[union-attr] 645 return _type_utils.JitScalarType.from_dtype(new_dtype) 646 647 648def _maybe_cast_to_type( 649 g: jit_utils.GraphContext, value, jit_type: _type_utils.JitScalarType 650): 651 if ( 652 _type_utils.JitScalarType.from_value(value, _type_utils.JitScalarType.UNDEFINED) 653 != jit_type 654 ): 655 return g.op( 656 "Cast", 657 value, 658 to_i=jit_type.onnx_type(), 659 ) 660 return value 661 662 663def _select_helper(g: jit_utils.GraphContext, self, dim, index, apply_reshape=True): 664 index_const = _maybe_get_scalar(index) 665 index_dim = _get_tensor_rank(index) 666 if not _is_value(index_const): 667 # Index is a constant scalar. Make it a size 1 constant tensor. 668 index = g.op("Constant", value_t=torch.LongTensor([index_const])) 669 elif index_dim is not None and apply_reshape: 670 if index_dim == 0: 671 # Index is a scalar. Reshape it to a size 1 tensor. 672 index = _reshape_helper( 673 g, index, g.op("Constant", value_t=torch.LongTensor([1])) 674 ) 675 676 index_scalar_type = _type_utils.JitScalarType.from_value( 677 index, _type_utils.JitScalarType.UNDEFINED 678 ) 679 if index_scalar_type not in { 680 _type_utils.JitScalarType.INT64, 681 _type_utils.JitScalarType.INT, 682 }: 683 index = g.op("Cast", index, to_i=_C_onnx.TensorProtoDataType.INT64) 684 return g.op("Gather", self, index, axis_i=dim) 685 686 687def _slice_helper( 688 g: jit_utils.GraphContext, 689 input, 690 axes, 691 starts, 692 ends, 693 steps=None, 694): 695 if g.opset <= 9: 696 from torch.onnx.symbolic_opset9 import _slice as _slice9 697 698 return _slice9(g, input, axes, starts, ends) 699 else: 700 from torch.onnx.symbolic_opset10 import _slice as _slice10 701 702 return _slice10(g, input, axes, starts, ends, steps) 703 704 705def _is_fp(value) -> bool: 706 return _type_utils.JitScalarType.from_value( 707 value, _type_utils.JitScalarType.UNDEFINED 708 ) in { 709 _type_utils.JitScalarType.FLOAT, 710 _type_utils.JitScalarType.DOUBLE, 711 _type_utils.JitScalarType.HALF, 712 _type_utils.JitScalarType.BFLOAT16, 713 } 714 715 716def _is_bool(value) -> bool: 717 return _type_utils.JitScalarType.from_value( 718 value, _type_utils.JitScalarType.UNDEFINED 719 ) in {_type_utils.JitScalarType.BOOL} 720 721 722def _generate_wrapped_number(g: jit_utils.GraphContext, scalar): 723 """Creates a wrapped number based on https://github.com/pytorch/pytorch/issues/9515. 724 725 A Tensor is a considered a "wrapped number" if it is 726 auto-wrapped from a C++ or Python number type. Integer types are 727 wrapped as 0-dim int64 tensors and floating-point types are 728 wrapped as 0-dim double tensors. 729 730 The input to this function is constant value. If the data type 731 is a floating point type, it is converted to a 0-dim double 732 tensor, else it is converted to a 0-dim tensor of its original type 733 """ 734 assert not isinstance(scalar, torch.Tensor) 735 if isinstance(scalar, float): 736 return g.op("Constant", value_t=torch.tensor(scalar, dtype=torch.double)) 737 return g.op("Constant", value_t=torch.tensor(scalar)) 738 739 740def _sort_helper(g: jit_utils.GraphContext, input, dim, decending=True, out=None): 741 if out is not None: 742 _unimplemented("Sort", "Out parameter is not supported") 743 shape_ = g.op("Shape", input) 744 dim_size_ = g.op( 745 "Gather", 746 shape_, 747 g.op("Constant", value_t=torch.tensor([dim], dtype=torch.int64)), 748 ) 749 if g.opset <= 10: 750 if not decending: 751 _unimplemented("Sort", "Ascending is not supported") 752 return g.op("TopK", input, dim_size_, axis_i=dim, outputs=2) 753 else: 754 return g.op( 755 "TopK", input, dim_size_, axis_i=dim, largest_i=decending, outputs=2 756 ) 757 758 759def _topk_helper( 760 g: jit_utils.GraphContext, input, k, dim, largest=True, sorted=False, out=None 761): 762 if out is not None: 763 _unimplemented("TopK", "Out parameter is not supported") 764 if not _is_value(k): 765 k = g.op("Constant", value_t=torch.tensor([k], dtype=torch.int64)) 766 else: 767 k = _reshape_helper(g, k, g.op("Constant", value_t=torch.tensor([1]))) 768 if _try_get_scalar_type(k) != _type_utils.JitScalarType.INT64: 769 k = g.op("Cast", k, to_i=_C_onnx.TensorProtoDataType.INT64) 770 if g.opset <= 10: 771 if not largest: 772 _unimplemented("TopK", "Ascending is not supported") 773 return g.op("TopK", input, k, axis_i=dim, outputs=2) 774 else: 775 return g.op( 776 "TopK", input, k, axis_i=dim, largest_i=largest, sorted_i=sorted, outputs=2 777 ) 778 779 780def _lt_helper(g: jit_utils.GraphContext, input, other): 781 if g.opset <= 8: 782 from torch.onnx.symbolic_opset8 import lt as _lt8 783 784 return _lt8(g, input, other) 785 else: 786 from torch.onnx.symbolic_opset9 import lt as _lt9 787 788 return _lt9(g, input, other) 789 790 791def _interpolate_warning(interpolate_mode): 792 onnx_op = ( 793 "onnx:Resize" if GLOBALS.export_onnx_opset_version >= 10 else "onnx:Upsample" 794 ) 795 warnings.warn( 796 "You are trying to export the model with " 797 + onnx_op 798 + " for ONNX opset version " 799 "" + str(GLOBALS.export_onnx_opset_version) + ". " 800 "This operator might cause results to not match the expected results by PyTorch.\n" 801 "ONNX's Upsample/Resize operator did not match Pytorch's Interpolation until opset 11. " 802 "Attributes to determine how to transform the input were added in onnx:Resize in opset 11 " 803 "to support Pytorch's behavior (like coordinate_transformation_mode and nearest_mode).\n" 804 "We recommend using opset 11 and above for models using this operator." 805 ) 806 807 808def _unsqueeze_helper(g: jit_utils.GraphContext, input, axes_i): 809 if _is_constant(axes_i[0]): 810 if g.opset >= 13: 811 axes = g.op("Constant", value_t=torch.tensor(axes_i, dtype=torch.long)) 812 return g.op("Unsqueeze", input, axes) 813 return g.op("Unsqueeze", input, axes_i=axes_i) 814 # Tensor type 815 if g.opset < 13: 816 raise errors.SymbolicValueError( 817 "Opset version must be >= 13 for Unsqueeze with dynamic axes.", input 818 ) 819 return g.op("Unsqueeze", input, axes_i[0]) 820 821 822def _squeeze_helper(g: jit_utils.GraphContext, input, axes_i): 823 if _is_constant(axes_i[0]): 824 if g.opset >= 13: 825 axes = g.op("Constant", value_t=torch.tensor(axes_i, dtype=torch.long)) 826 return g.op("Squeeze", input, axes) 827 return g.op("Squeeze", input, axes_i=axes_i) 828 # Tensor type 829 if g.opset < 13: 830 raise errors.SymbolicValueError( 831 "Opset version must be >= 13 for Squeeze with dynamic axes.", input 832 ) 833 axes_t = axes_i[0] 834 axes_rank = _get_tensor_rank(axes_t) 835 assert axes_rank is not None 836 if axes_rank > 1: 837 raise errors.SymbolicValueError( 838 "For Squeeze axses as input, the axes rank must be one in ONNX spec.", input 839 ) 840 elif axes_rank == 0: 841 # The axes is a scalar. Unsqueeze it to a rank 1 tensor. 842 axes_t = _unsqueeze_helper(g, axes_t, [0]) 843 return g.op("Squeeze", input, axes_t) 844 return g.op("Squeeze", input, axes_t) 845 846 847def _reducesum_helper( 848 g: jit_utils.GraphContext, 849 input, 850 axes_i=None, 851 keepdims_i=1, 852 noop_with_empty_axes_i=0, 853): 854 keepdims_i = _maybe_get_const(keepdims_i, "i") 855 if g.opset >= 13: 856 if axes_i: 857 if not _is_value(axes_i): 858 axes_i = g.op( 859 "Constant", value_t=torch.tensor(axes_i, dtype=torch.long) 860 ) 861 return g.op( 862 "ReduceSum", 863 input, 864 axes_i, 865 keepdims_i=keepdims_i, 866 noop_with_empty_axes_i=noop_with_empty_axes_i, 867 ) 868 return g.op( 869 "ReduceSum", 870 input, 871 keepdims_i=keepdims_i, 872 noop_with_empty_axes_i=noop_with_empty_axes_i, 873 ) 874 else: 875 return g.op("ReduceSum", input, axes_i=axes_i, keepdims_i=keepdims_i) 876 877 878def _interpolate_size_to_scales(g: jit_utils.GraphContext, input, output_size, dim): 879 output_size = _maybe_get_const(output_size, "is") 880 if _is_value(output_size): 881 offset = 2 882 offsets = g.op("Constant", value_t=torch.ones(offset, dtype=torch.float32)) 883 dividend = g.op("Cast", output_size, to_i=_C_onnx.TensorProtoDataType.FLOAT) 884 divisor = _slice_helper( 885 g, g.op("Shape", input), axes=[0], ends=[sys.maxsize], starts=[offset] 886 ) 887 divisor = g.op("Cast", divisor, to_i=_C_onnx.TensorProtoDataType.FLOAT) 888 scale_dims = g.op("Div", dividend, divisor) 889 scales = g.op("Concat", offsets, scale_dims, axis_i=0) 890 else: 891 scales_constant = [ 892 1.0 893 if i < 2 894 else float(output_size[-(dim - i)]) 895 / float(input.type().sizes()[-(dim - i)]) 896 for i in range(0, dim) 897 ] 898 scales = g.op( 899 "Constant", value_t=torch.tensor(scales_constant, dtype=torch.float32) 900 ) 901 return scales 902 903 904def _interpolate_get_scales_if_available(g: jit_utils.GraphContext, scales): 905 available_scales = _maybe_get_const(scales[0], "fs") != -1 and not _is_none( 906 scales[0] 907 ) 908 909 if not available_scales: 910 return None 911 912 offsets = g.op("Constant", value_t=torch.ones(2, dtype=torch.float32)) 913 scales_list = g.op( 914 "Constant", value_t=torch.tensor(_maybe_get_const(scales[0], "fs")) 915 ) 916 scales = g.op("Concat", offsets, scales_list, axis_i=0) 917 return scales 918 919 920def _get_interpolate_attributes(g: jit_utils.GraphContext, mode, args): 921 if mode == "nearest": 922 align_corners = None 923 scales = args[0:] 924 else: 925 align_corners = args[0] 926 scales = args[1:] 927 scales = _interpolate_get_scales_if_available(g, scales) 928 return scales, align_corners 929 930 931def _interpolate_get_scales(g: jit_utils.GraphContext, scale_factor, dim): 932 offsets = g.op("Constant", value_t=torch.ones(2, dtype=torch.float32)) 933 scale_factor_rank = _get_tensor_rank(scale_factor) 934 if isinstance(scale_factor.type(), _C.ListType) or ( 935 scale_factor_rank is not None and scale_factor_rank > 0 936 ): 937 return g.op("Concat", offsets, scale_factor, axis_i=0) 938 else: 939 scale_factor = _unsqueeze_helper(g, scale_factor, [0]) 940 scale_factor = g.op( 941 "Cast", scale_factor, to_i=_C_onnx.TensorProtoDataType.FLOAT 942 ) 943 scales = [scale_factor for i in range(dim - 2)] 944 scale_factor = g.op("Concat", offsets, *scales, axis_i=0) 945 return scale_factor 946 947 948def _interpolate_get_scales_and_mode( 949 g: jit_utils.GraphContext, input, size, scale_factor, mode, align_corners 950): 951 mode = _maybe_get_const(mode, "s") 952 if "linear" in mode: 953 mode = "linear" 954 if "cubic" in mode: 955 mode = "cubic" 956 _interpolate_warning(mode) 957 958 align_corners = _maybe_get_const(align_corners, "b") 959 if isinstance(align_corners, bool) and align_corners: 960 return _unimplemented("interpolate", "align_corners == True") 961 962 if not input.type().dim(): 963 return _unimplemented("interpolate", "missing input shape") 964 dim = input.type().dim() 965 966 if not _is_none(scale_factor): 967 scale_factor = _interpolate_get_scales(g, scale_factor, dim) 968 elif not _is_none(size): 969 if not _is_packed_list(size): 970 is_scalar = _maybe_get_const(size, "t").dim() == 0 971 if is_scalar: 972 size = _unsqueeze_helper(g, size, [0]) 973 size = [size for i in range(dim - 2)] 974 size = g.op("Concat", *size, axis_i=0) 975 scale_factor = _interpolate_size_to_scales(g, input, size, dim) 976 else: 977 return _unimplemented( 978 "interpolate", "Both size and scales are None in __interpolate" 979 ) 980 return scale_factor, mode 981 982 983def _argmin_argmax_helper( 984 g: jit_utils.GraphContext, 985 input: torch._C.Value, 986 dim: torch._C.Value, 987 keepdim: bool, 988 op_name: str, 989): 990 def op_wrapper(input, axis_i, keepdims_i): 991 if g.opset >= 12: 992 return g.op( 993 op_name, 994 input, 995 axis_i=axis_i, 996 keepdims_i=keepdims_i, 997 select_last_index_i=False, 998 ) 999 return g.op(op_name, input, axis_i=axis_i, keepdims_i=keepdims_i) 1000 1001 if _is_none(dim): 1002 flattened = _reshape_helper( 1003 g, input, g.op("Constant", value_t=torch.tensor([-1])) 1004 ) 1005 output = op_wrapper(flattened, axis_i=0, keepdims_i=False) 1006 if keepdim: 1007 input_shape = g.op("Shape", input) 1008 input_shape_shape = g.op("Shape", input_shape) 1009 new_shape = g.op( 1010 "ConstantOfShape", 1011 input_shape_shape, 1012 value_t=torch.tensor([1], dtype=torch.int64), 1013 ) 1014 output = g.op("Reshape", output, new_shape) 1015 return output 1016 1017 dim = _parse_arg(dim, "i") 1018 return op_wrapper(input, axis_i=dim, keepdims_i=keepdim) 1019 1020 1021def _interpolate_helper(name, dim, interpolate_mode): 1022 @quantized_args(True, False, False) 1023 def symbolic_fn(g, input, output_size, *args): 1024 scales, align_corners = _get_interpolate_attributes(g, interpolate_mode, args) 1025 align_corners = _maybe_get_scalar(align_corners) 1026 coordinate_transformation_mode = ( 1027 "asymmetric" 1028 if interpolate_mode == "nearest" 1029 else "align_corners" 1030 if align_corners 1031 else "half_pixel" 1032 ) 1033 1034 if scales is None: 1035 input_size = g.op("Shape", input) 1036 input_size_beg = _slice_helper( 1037 g, input_size, axes=[0], ends=[2], starts=[0] 1038 ) 1039 output_size = g.op( 1040 "Cast", output_size, to_i=_C_onnx.TensorProtoDataType.INT64 1041 ) 1042 output_size = g.op("Concat", input_size_beg, output_size, axis_i=0) 1043 1044 if g.opset >= 13: 1045 empty_roi = _optional_input_placeholder_tensor(g) 1046 empty_scales = _optional_input_placeholder_tensor(g) 1047 else: 1048 empty_roi = g.op( 1049 "Constant", value_t=torch.tensor([], dtype=torch.float32) 1050 ) 1051 empty_scales = g.op( 1052 "Constant", value_t=torch.tensor([], dtype=torch.float32) 1053 ) 1054 1055 return g.op( 1056 "Resize", 1057 input, 1058 empty_roi, 1059 empty_scales, 1060 output_size, 1061 coordinate_transformation_mode_s=coordinate_transformation_mode, 1062 cubic_coeff_a_f=-0.75, # only valid when mode="cubic" 1063 mode_s=interpolate_mode, # nearest, linear, or cubic 1064 nearest_mode_s="floor", 1065 ) # only valid when mode="nearest" 1066 else: 1067 if g.opset >= 13: 1068 empty_roi = _optional_input_placeholder_tensor(g) 1069 else: 1070 empty_roi = g.op( 1071 "Constant", value_t=torch.tensor([], dtype=torch.float32) 1072 ) 1073 1074 return g.op( 1075 "Resize", 1076 input, 1077 empty_roi, 1078 scales, 1079 coordinate_transformation_mode_s=coordinate_transformation_mode, 1080 cubic_coeff_a_f=-0.75, # only valid when mode="cubic" 1081 mode_s=interpolate_mode, # nearest, linear, or cubic 1082 nearest_mode_s="floor", 1083 ) # only valid when mode="nearest" 1084 1085 return symbolic_fn 1086 1087 1088def __interpolate_helper( 1089 g: jit_utils.GraphContext, 1090 input, 1091 size, 1092 scale_factor, 1093 mode, 1094 align_corners, 1095 recompute_scale_factor, 1096): 1097 mode = _maybe_get_const(mode, "s") 1098 if "linear" in mode: 1099 mode = "linear" 1100 if "cubic" in mode: 1101 mode = "cubic" 1102 align_corners = _maybe_get_const(align_corners, "b") 1103 align_corners = False if not isinstance(align_corners, bool) else align_corners 1104 coordinate_transformation_mode = ( 1105 "asymmetric" 1106 if mode == "nearest" 1107 else "align_corners" 1108 if align_corners 1109 else "half_pixel" 1110 ) 1111 1112 if not _is_none(size): 1113 input_size = g.op("Shape", input) 1114 input_size = _slice_helper(g, input_size, axes=[0], ends=[2], starts=[0]) 1115 # in some cases size is not a packed list but size is a scalar 1116 # We need to also verify that (_maybe_get_const(size, "t").dim() == 0) 1117 # but this information is not always available. Try to get the dim, 1118 # and if not assume that it is not a scalar. 1119 try: 1120 is_scalar = not _is_packed_list(size) and ( 1121 _maybe_get_const(size, "t").dim() == 0 1122 ) 1123 except AttributeError: 1124 is_scalar = not _is_packed_list(size) 1125 if not is_scalar: 1126 warnings.warn( 1127 "Cannot verify if the output_size is a scalar " 1128 "while exporting interpolate. Assuming that it is not a scalar." 1129 ) 1130 1131 if is_scalar: 1132 rank = _get_tensor_rank(input) 1133 if rank is None: 1134 return _unimplemented( 1135 "interpolate (with a scalar output_size)", 1136 "missing input shape (try giving an array of output_size values)", 1137 ) 1138 size = _unsqueeze_helper(g, size, [0]) 1139 size = [size for i in range(rank - 2)] 1140 size = g.op("Concat", *size, axis_i=0) 1141 size = g.op("Cast", size, to_i=_C_onnx.TensorProtoDataType.INT64) 1142 size = g.op("Concat", input_size, size, axis_i=0) 1143 1144 if g.opset >= 13: 1145 empty_roi = _optional_input_placeholder_tensor(g) 1146 empty_scales = _optional_input_placeholder_tensor(g) 1147 else: 1148 empty_roi = g.op("Constant", value_t=torch.tensor([], dtype=torch.float32)) 1149 empty_scales = g.op( 1150 "Constant", value_t=torch.tensor([], dtype=torch.float32) 1151 ) 1152 1153 return g.op( 1154 "Resize", 1155 input, 1156 empty_roi, 1157 empty_scales, 1158 size, 1159 coordinate_transformation_mode_s=coordinate_transformation_mode, 1160 cubic_coeff_a_f=-0.75, # only valid when mode="cubic" 1161 mode_s=mode, # nearest, linear, or cubic 1162 nearest_mode_s="floor", 1163 ) 1164 else: # if not _is_none(scales) 1165 rank = _get_tensor_rank(input) 1166 if rank is None: 1167 return _unimplemented("interpolate (with scales)", "missing input shape") 1168 1169 if g.opset >= 13: 1170 empty_roi = _optional_input_placeholder_tensor(g) 1171 else: 1172 empty_roi = g.op("Constant", value_t=torch.tensor([], dtype=torch.float32)) 1173 1174 scales = _interpolate_get_scales(g, scale_factor, rank) 1175 return g.op( 1176 "Resize", 1177 input, 1178 empty_roi, 1179 scales, 1180 coordinate_transformation_mode_s=coordinate_transformation_mode, 1181 cubic_coeff_a_f=-0.75, # only valid when mode="cubic" 1182 mode_s=mode, # nearest, linear, or cubic 1183 nearest_mode_s="floor", 1184 ) # only valid when mode="nearest" 1185 1186 1187def _unbind_helper(g: jit_utils.GraphContext, self, dim, _outputs): 1188 if g.opset < 11: 1189 from torch.onnx.symbolic_opset9 import unbind 1190 elif g.opset <= 12: 1191 from torch.onnx.symbolic_opset11 import unbind # type: ignore[no-redef] 1192 else: 1193 from torch.onnx.symbolic_opset13 import unbind # type: ignore[no-redef] 1194 return unbind(g, self, dim, _outputs) 1195 1196 1197def _scatter_helper(g: jit_utils.GraphContext, self, dim, index, src): 1198 if g.opset <= 10: 1199 from torch.onnx.symbolic_opset9 import scatter 1200 else: 1201 # for mypy, scatter was imported two lines above 1202 from torch.onnx.symbolic_opset11 import scatter # type: ignore[no-redef] 1203 return scatter(g, self, dim, index, src) 1204 1205 1206def _repeat_interleave_split_helper(g: jit_utils.GraphContext, self, reps, dim): 1207 if g.opset <= 12: 1208 split_out = g.op("Split", self, split_i=[1] * reps, axis_i=dim, outputs=reps) 1209 else: 1210 from torch.onnx.symbolic_opset13 import split 1211 1212 repeats = g.op("Constant", value_t=torch.tensor([1] * reps)) 1213 split_out = split(g, self, repeats, dim, _outputs=reps) 1214 return split_out if reps > 1 else [split_out] 1215 1216 1217def _repeat_interleave_single_value_repeat_helper( 1218 g: jit_utils.GraphContext, self, repeats, dim 1219): 1220 from torch.onnx.symbolic_opset9 import flatten, unsqueeze 1221 1222 if not _is_tensor(repeats): 1223 repeats = g.op("Constant", value_t=torch.LongTensor(repeats)) 1224 1225 const_repeats: bool = _is_constant(repeats) 1226 reps = _maybe_get_const(repeats, "t") 1227 1228 # Convert 'repeats' to 1-d if it is 0-d. 1229 if _get_tensor_rank(repeats) == 0: 1230 repeats = g.op("Reshape", repeats, g.op("Constant", value_t=torch.tensor([1]))) 1231 1232 # Create a new dim of size 1, then expand it to be 'repeats' long, and finally collapse it. 1233 unsqueezed = unsqueeze(g, self, dim + 1) 1234 1235 # repeats_per_dim is 1 for all dims except for the new unsqueezed dim, where it has value 'repeats'. 1236 if const_repeats: 1237 # 'Repeats' is a constant, 'repeats_per_dim' can be a constant. 1238 onehot = torch.ones(_get_tensor_rank(unsqueezed), dtype=torch.int64) # type: ignore[arg-type] 1239 onehot[dim + 1] = reps 1240 repeats_per_dim = g.op("Constant", value_t=onehot) 1241 else: 1242 # 'Repeats' is a variable, 'repeats_per_dim' cannot be a constant. 1243 onehot = g.op( 1244 "OneHot", 1245 unsqueeze(g, dim + 1, 0), # indices, must be >= 1-dimensional 1246 g.op( 1247 "Constant", value_t=torch.tensor(_get_tensor_rank(unsqueezed)) 1248 ), # depth 1249 g.op( 1250 "Concat", g.op("Constant", value_t=torch.tensor([1])), repeats, axis_i=0 1251 ), # on/off values 1252 ) 1253 repeats_per_dim = flatten(g, onehot, 0, 1) 1254 1255 tiled = g.op("Tile", unsqueezed, repeats_per_dim) 1256 return flatten(g, tiled, dim, dim + 1) 1257 1258 1259def _arange_cast_helper( 1260 g: jit_utils.GraphContext, end, start=None, step=None, dtype=None 1261) -> tuple[ 1262 _type_utils.JitScalarType, 1263 _C.Value | None, 1264 _C.Value | None, 1265 _C.Value | None, 1266]: 1267 def _is_all_integral(scalars): 1268 for scalar in scalars: 1269 scalar_type = _type_utils.JitScalarType.from_value( 1270 scalar, _type_utils.JitScalarType.UNDEFINED 1271 ) 1272 if ( 1273 scalar_type != _type_utils.JitScalarType.INT64 1274 and scalar_type != _type_utils.JitScalarType.UNDEFINED 1275 ): 1276 return False 1277 return True 1278 1279 # This logic is based on torch.arange docs. If "dtype" is provided, 1280 # infer input types from dtype. If not, then check if any of start, stop, 1281 # or step are floating point, and infer the type from get_default. 1282 # Otherwise, the dtype is inferred to be torch.int64. 1283 if dtype is None or (_is_value(dtype) and _is_none(dtype)): 1284 if _is_all_integral([start, end, step]): 1285 scalar_type = _type_utils.JitScalarType.INT64 1286 else: 1287 scalar_type = _type_utils.JitScalarType.from_dtype( 1288 torch.get_default_dtype() 1289 ) 1290 else: 1291 assert isinstance(dtype, int) 1292 # TODO(justinchuby): Check if dtype is indeed a int. 1293 scalar_type = _type_utils.JitScalarType(dtype) 1294 1295 start = g.op("Cast", start, to_i=scalar_type.onnx_type()) if start else None 1296 end = g.op("Cast", end, to_i=scalar_type.onnx_type()) if end else None 1297 step = g.op("Cast", step, to_i=scalar_type.onnx_type()) if step else None 1298 return scalar_type, end, start, step 1299 1300 1301def _arange_helper(g: jit_utils.GraphContext, *args): 1302 if g.opset <= 10: 1303 from torch.onnx.symbolic_opset9 import arange 1304 else: 1305 from torch.onnx.symbolic_opset11 import arange # type: ignore[no-redef] 1306 return arange(g, *args) 1307 1308 1309def _size_helper(g: jit_utils.GraphContext, self, dim): 1310 full_shape = g.op("Shape", self) 1311 from torch.onnx.symbolic_opset9 import select 1312 1313 return select(g, full_shape, g.op("Constant", value_t=torch.tensor([0])), dim) 1314 1315 1316def _index_fill_reshape_helper(g: jit_utils.GraphContext, self, dim, index): 1317 # 1. reshape index => [1, ..., 1, dim, 1, ..., 1] 1318 # 2. expand index => [..., dim, ...], same shape as self except for dim. 1319 # 3. expand value as well. 1320 # 4. apply onnx::scatter. 1321 1322 from torch.onnx.symbolic_opset9 import expand 1323 1324 if g.opset <= 10: 1325 from torch.onnx.symbolic_opset9 import scatter 1326 else: 1327 # for mypy, scatter was imported two lines above 1328 from torch.onnx.symbolic_opset11 import scatter # type: ignore[no-redef] 1329 1330 if self.type().dim() is None: 1331 return _unimplemented("index_fill", "input rank not accessible") 1332 self_dim = self.type().dim() 1333 dim_value = _parse_arg(dim, "i") 1334 if dim_value < 0: 1335 dim_value += self_dim 1336 unsqueezed_index = _unsqueeze_helper( 1337 g, index, [i for i in range(self_dim) if i != dim_value] 1338 ) 1339 expanded_index_shape = scatter( 1340 g, g.op("Shape", self), 0, _unsqueeze_helper(g, dim, [0]), g.op("Shape", index) 1341 ) 1342 expanded_index = expand(g, unsqueezed_index, expanded_index_shape, None) 1343 return expanded_index_shape, expanded_index 1344 1345 1346# By default, when any value in the 'shape' input is equal to zero 1347# the corresponding dimension value is copied from the input tensor dynamically. 1348# allowzero=1 indicates that if any value in the 'shape' input is set to zero, 1349# the zero value is honored, similar to NumPy. 1350# allowzero=1 is only supported for opset version >= 14. 1351def _reshape_helper(g: jit_utils.GraphContext, input, shape, allowzero=0): 1352 shape = _maybe_get_const(shape, "is") 1353 if not _is_value(shape): 1354 shape = g.op("Constant", value_t=torch.LongTensor(shape)) 1355 if g.opset <= 13: 1356 if allowzero == 1: 1357 _onnx_opset_unsupported( 1358 "Reshape with allowzero=1", GLOBALS.export_onnx_opset_version, 14, input 1359 ) 1360 return g.op("Reshape", input, shape) 1361 else: 1362 return g.op("Reshape", input, shape, allowzero_i=allowzero) 1363 1364 1365def _batchnorm_helper( 1366 g: jit_utils.GraphContext, input, weight, bias, running_mean, running_var 1367): 1368 from torch.onnx.symbolic_opset9 import _var_mean 1369 1370 batch_size = _get_tensor_dim_size(input, 0) 1371 channel_size = _get_tensor_dim_size(input, 1) 1372 1373 if weight is None or _is_none(weight): 1374 if channel_size is None: 1375 raise errors.SymbolicValueError( 1376 "Unsupported: ONNX export of batch_norm for unknown channel size.", 1377 input, 1378 ) 1379 weight_value = torch.tensor( 1380 [1.0] * channel_size, 1381 dtype=_type_utils.JitScalarType.from_value(input).dtype(), 1382 ) 1383 weight = g.op("Constant", value_t=weight_value) 1384 if bias is None or _is_none(bias): 1385 if channel_size is None: 1386 raise errors.SymbolicValueError( 1387 "Unsupported: ONNX export of batch_norm for unknown channel size.", 1388 input, 1389 ) 1390 bias_value = torch.tensor( 1391 [0.0] * channel_size, 1392 dtype=_type_utils.JitScalarType.from_value(input).dtype(), 1393 ) 1394 bias = g.op("Constant", value_t=bias_value) 1395 # If track_running_stats is set to False batch statistics are instead used during evaluation time 1396 if ( 1397 running_mean is None 1398 or _is_none(running_mean) 1399 or running_var is None 1400 or _is_none(running_var) 1401 ): 1402 assert batch_size is not None and channel_size is not None 1403 reshape_in = _reshape_helper( 1404 g, 1405 input, 1406 g.op( 1407 "Constant", 1408 value_t=torch.tensor([batch_size, channel_size, -1], dtype=torch.int64), 1409 ), 1410 ) 1411 trans_in = g.op("Transpose", reshape_in, perm_i=[0, 2, 1]) 1412 running_var, running_mean = _var_mean( 1413 g, 1414 trans_in, 1415 g.op("Constant", value_t=torch.tensor([0, 1], dtype=torch.int64)), 1416 False, 1417 False, 1418 ) 1419 return weight, bias, running_mean, running_var 1420 1421 1422def _avgpool_helper( 1423 tuple_fn: Callable[[Any], Sequence[int]], 1424 padding: int | Sequence[int], 1425 kernel_size, 1426 stride, 1427 divisor_override, 1428 name, 1429) -> tuple[int, ...]: 1430 if divisor_override and divisor_override.node().kind() != "prim::Constant": 1431 _unimplemented(name, "divisor_override") 1432 return tuple(tuple_fn(padding)) 1433 1434 1435def check_training_mode(op_train_mode: int, op_name: str) -> None: 1436 """Warns the user if the model's training mode and the export mode do not agree.""" 1437 if GLOBALS.training_mode == _C_onnx.TrainingMode.PRESERVE: 1438 return 1439 1440 if op_train_mode: 1441 op_mode_enum = _C_onnx.TrainingMode.TRAINING 1442 else: 1443 op_mode_enum = _C_onnx.TrainingMode.EVAL 1444 if op_mode_enum == GLOBALS.training_mode: 1445 # The modes agree. Do nothing 1446 return 1447 1448 op_mode_text = f"train={bool(op_train_mode)}" 1449 # Setting the model mode could result in op_mode != GLOBALS.training_mode 1450 # if the model is a FuncModule. In this case we warn the user of 1451 # the state and export depending on op_mode 1452 # This is to support use-cases of fixing certain layer weights 1453 # in training. 1454 warnings.warn( 1455 f"ONNX export mode is set to {GLOBALS.training_mode}, but operator '{op_name}' " 1456 f"is set to {op_mode_text}. Exporting with {op_mode_text}." 1457 ) 1458 1459 1460def _flatten_helper(g: jit_utils.GraphContext, input, start_dim, end_dim, dim): 1461 input_size = g.op("Shape", input) 1462 slice1 = _slice_helper(g, input_size, axes=[0], starts=[0], ends=[start_dim]) 1463 slices = [slice1, g.op("Constant", value_t=torch.tensor([-1], dtype=torch.long))] 1464 if end_dim < dim - 1: 1465 slice3 = _slice_helper( 1466 g, input_size, axes=[0], starts=[end_dim + 1], ends=[dim] 1467 ) 1468 slices = [ 1469 slice1, 1470 g.op("Constant", value_t=torch.tensor([-1], dtype=torch.long)), 1471 slice3, 1472 ] 1473 1474 final_shape = g.op("Concat", *slices, axis_i=0) 1475 from torch.onnx.symbolic_opset9 import _reshape_from_tensor 1476 1477 return _reshape_from_tensor(g, input, final_shape) 1478 1479 1480def _is_split_static(split_size_or_sizes, _outputs): 1481 if _outputs is None: 1482 return False 1483 if ( 1484 _is_value(split_size_or_sizes) 1485 and split_size_or_sizes.node().kind() != "onnx::Constant" 1486 ): 1487 return False 1488 return True 1489 1490 1491def _optional_input_placeholder_tensor(g): 1492 n = g.op("prim::Constant") 1493 n.setType(_C.OptionalType.ofTensor()) 1494 return n 1495 1496 1497def _handle_reduce_dim_none(g: jit_utils.GraphContext, self, op_name): 1498 rank = _get_tensor_rank(self) 1499 if rank is not None and any( 1500 _get_tensor_dim_size(self, i) == 0 for i in range(rank) 1501 ): 1502 # If input tensor is empty, according to ONNX ReduceSum definition, 1503 # set keepdims=1 so that the resulted tensor has the same rank as the input. 1504 return g.op(op_name, self, keepdims_i=1) 1505 return g.op(op_name, self, keepdims_i=0) 1506 1507 1508def dequantize_helper( 1509 g: jit_utils.GraphContext, 1510 qtensor: _C.Value, 1511 qdtype: _C_onnx.TensorProtoDataType | None = None, 1512) -> tuple[_C.Value, _C.Value, _C.Value, _C.Value | None]: 1513 """Appends to graph `g` ONNX nodes that dequantizes `qtensor` into `tensor`. 1514 1515 Args: 1516 g: Graph, the ONNX IR graph that is under construction. 1517 qtensor: torch._C.Value, either a tuple of (quantized_tensor, scale, zero_point) 1518 for per tensor quantization, or 1519 (quantized_tensor, scale, zero_point, axis) for per channel quantization, 1520 representing the quantized tensor. 1521 qdtype: torch.onnx.TensorProtoDataType default None, if not None, represents the 1522 data type of quantized tensor. It must be either 1523 torch.onnx.TensorProtoDataType.UINT8 or torch.onnx.TensorProtoDataType.INT8. 1524 """ 1525 unpacked_qtensors = _unpack_quantized_tensor(qtensor) 1526 tensor, scale, zero_point = unpacked_qtensors[:3] 1527 axis = unpacked_qtensors[3] if len(unpacked_qtensors) >= 4 else None 1528 axis_i = _get_const(axis, "i", "axis") 1529 input_qdtype = _type_utils.JitScalarType.from_value(tensor) 1530 if qdtype is None: 1531 if input_qdtype is not None: 1532 qdtype = input_qdtype.onnx_type() 1533 else: 1534 qdtype = _C_onnx.TensorProtoDataType.UINT8 1535 value = g.op("Cast", tensor, to_i=qdtype) 1536 scale = g.op("Cast", scale, to_i=_C_onnx.TensorProtoDataType.FLOAT) 1537 zero_point = g.op("Cast", zero_point, to_i=qdtype) 1538 1539 if axis_i is not None and GLOBALS.export_onnx_opset_version < 13: 1540 _onnx_opset_unsupported_detailed( 1541 "DequantizeLinear", 1542 GLOBALS.export_onnx_opset_version, 1543 13, 1544 "Attribute axis is not supported.", 1545 qtensor, 1546 ) 1547 1548 return ( 1549 g.op("DequantizeLinear", value, scale, zero_point, axis_i=axis_i), 1550 scale, 1551 zero_point, 1552 axis, 1553 ) 1554 1555 1556def quantize_helper( 1557 g: jit_utils.GraphContext, 1558 tensor: _C.Value, 1559 scale: _C.Value, 1560 zero_point: _C.Value, 1561 axis: _C.Value | None = None, 1562) -> _C.Value: 1563 """Appends to graph `g` ONNX nodes that quantizes `tensor` based on `scale`, `zero_point` and `axis`. 1564 1565 Args: 1566 g: Graph, the ONNX IR graph that is under construction. 1567 tensor: torch._C.Value, representing the tensor to be quantized. 1568 scale: torch._C.Value, quantized scale. 1569 zero_point: torch._C.Value, quantized zero point. 1570 axis: Optional[torch._C.Value] default None, if None, represents per tensor quantization. 1571 Otherwise, represents per channel quantization, along given axis. 1572 1573 Returns: 1574 A TupleConstruct storing information of the quantized tensor. 1575 """ 1576 if ( 1577 axis is not None 1578 and not _is_none(axis) 1579 and GLOBALS.export_onnx_opset_version < 13 1580 ): 1581 _onnx_opset_unsupported_detailed( 1582 "QuantizeLinear", 1583 GLOBALS.export_onnx_opset_version, 1584 13, 1585 "Attribute axis is not supported.", 1586 tensor, 1587 ) 1588 1589 assert scale is not None 1590 if ( 1591 _type_utils.JitScalarType.from_value(scale, _type_utils.JitScalarType.UNDEFINED) 1592 != _type_utils.JitScalarType.FLOAT 1593 ): 1594 scale = g.op("Cast", scale, to_i=_C_onnx.TensorProtoDataType.FLOAT) 1595 1596 assert zero_point is not None 1597 if _type_utils.JitScalarType.from_value( 1598 zero_point, _type_utils.JitScalarType.UNDEFINED 1599 ) not in { 1600 _type_utils.JitScalarType.UINT8, 1601 _type_utils.JitScalarType.INT8, 1602 }: 1603 zero_point = g.op("Cast", zero_point, to_i=_C_onnx.TensorProtoDataType.UINT8) 1604 output = g.op( 1605 "QuantizeLinear", 1606 tensor, 1607 scale, 1608 zero_point, 1609 axis_i=_get_const(axis, "i", "axis"), 1610 ) 1611 args = [output, scale, zero_point] 1612 if axis is not None and not _is_none(axis): 1613 args.append(axis) 1614 return g.op("prim::TupleConstruct", *args) 1615 1616 1617def requantize_bias_helper( 1618 g: jit_utils.GraphContext, bias, input_scale, weight_scale, axis=None 1619): 1620 """In PyTorch, bias is float and is quantized to int32 implicitly inside the quantized ATen op kernel. 1621 In ONNX we need to make the quantization explicit because operators expect all of their inputs to be quantized. 1622 Since int32 is not a supported output type by ONNX operator `QuantizeLinear`, quantization is exported using 1623 regular operators. 1624 """ 1625 bias_scale = g.op("Mul", weight_scale, input_scale) 1626 bias_scale_shape = g.op("Shape", bias_scale) 1627 bias_zero_point = g.op( 1628 "ConstantOfShape", bias_scale_shape, value_t=torch.tensor([0], dtype=torch.int) 1629 ) 1630 q_bias = g.op( 1631 "Cast", g.op("Div", bias, bias_scale), to_i=_C_onnx.TensorProtoDataType.INT32 1632 ) 1633 axis_args = [] 1634 if axis is not None and not _is_none(axis): 1635 axis_args.append(axis) 1636 return g.op("prim::TupleConstruct", q_bias, bias_scale, bias_zero_point, *axis_args) 1637 1638 1639def args_have_same_dtype(args): 1640 assert args 1641 base_dtype = _type_utils.JitScalarType.from_value(args[0]) 1642 has_same_dtype = all( 1643 _type_utils.JitScalarType.from_value(elem) == base_dtype for elem in args 1644 ) 1645 return has_same_dtype 1646 1647 1648def _op_with_optional_float_cast(g: jit_utils.GraphContext, op_name, *args, **kwargs): 1649 """Some PyTorch operators (e.g., Clip/Min/ReLU/Pad) are super set of ONNX in terms of data types. 1650 This function maximizes the exportability of PyTorch-ONNX by allowing ONNX-unsupported PyTorch 1651 operator data type. For example, `Cast<int>(Clip<float>(Cast<float>(INPUT)))` can be used to mimic 1652 `Clip<int>(INPUT)` (opset version < 12). 1653 1654 Args: 1655 g (torch._C.Graph): graph to write the ONNX representation into. 1656 op_name (str): operator name in ONNX. 1657 *args (tuple): operands to the operator. 1658 **kwargs (dict): attributes to the operator along with "opset_before" (optional, None by default) 1659 indicating the smallest opset version to trigger such casting behavior and "target_float_t" 1660 (optional, torch.onnx.JitScalarType.FLOAT by default) indicating the data type of internal operator. 1661 1662 Returns: 1663 Optional[torch._C.Value, Tuple[torch._C.Value, ...]]: output(s) of the operator. 1664 """ 1665 opset_before = kwargs.pop("opset_before", None) 1666 target_float_t = kwargs.pop("target_float_t", _type_utils.JitScalarType.FLOAT) 1667 1668 inputs = list(args) 1669 dtype_0 = _type_utils.JitScalarType.from_value(inputs[0]) 1670 1671 require_cast = not _is_fp(inputs[0]) and ( 1672 opset_before is None or GLOBALS.export_onnx_opset_version < opset_before 1673 ) 1674 1675 if require_cast: 1676 for input in inputs: 1677 if input.isCompleteTensor(): 1678 input_scalar_type = _type_utils.JitScalarType.from_value(input) 1679 if input_scalar_type != dtype_0: 1680 raise errors.SymbolicValueError( 1681 f"Inputs of {op_name} must have same dtype." 1682 f"Got {dtype_0.scalar_name()} and {input_scalar_type.scalar_name()}", 1683 input, 1684 ) 1685 for i, input in enumerate(inputs): 1686 if input.isCompleteTensor() and not _is_fp(input): 1687 inputs[i] = g.op( 1688 "Cast", 1689 input, 1690 to_i=target_float_t.onnx_type(), 1691 ) 1692 1693 self = g.op(op_name, *inputs, **kwargs) 1694 1695 if require_cast: 1696 self = g.op("Cast", self, to_i=dtype_0.onnx_type()) 1697 1698 return self 1699 1700 1701def _maybe_cast_reduce_op_input(g: jit_utils.GraphContext, self): 1702 scalar_type = _type_utils.JitScalarType.from_value( 1703 self, _type_utils.JitScalarType.UNDEFINED 1704 ) 1705 if scalar_type != _type_utils.JitScalarType.UNDEFINED: 1706 # This check only covers traced modules where dtype is present 1707 # pytorch reduce-ops cast all other integral types to int64 1708 if not _is_fp(self) and scalar_type != _type_utils.JitScalarType.INT64: 1709 self = g.op("Cast", self, to_i=_C_onnx.TensorProtoDataType.INT64) 1710 return self 1711 1712 1713def _apply_params(*args, **kwargs): 1714 """Returns a decorator that calls the decorated (higher-order) function with the given parameters.""" 1715 1716 def _apply(fn): 1717 return fn(*args, **kwargs) 1718 1719 return _apply 1720 1721 1722def _reduce_op_symbolic_helper(onnx_op_name, allow_multi_dim_support=True): 1723 def symbolic(g, self, dim=None, keepdim=None): 1724 self = _maybe_cast_reduce_op_input(g, self) 1725 if dim is None or dim == (): 1726 # Dim can be 0, which will cause (not dim) == True. So we don't want to do 1727 # (not dim) 1728 # all-reduce path 1729 return _handle_reduce_dim_none(g, self, onnx_op_name) 1730 else: 1731 # dim-reduce path 1732 keepdim = _get_const(keepdim, "i", "keepdim") 1733 if g.opset < 18: 1734 desc = "is" if allow_multi_dim_support else "i" 1735 dim = _get_const(dim, desc, "dim") 1736 dim_list = dim if allow_multi_dim_support else [dim] 1737 return g.op(onnx_op_name, self, axes_i=dim_list, keepdims_i=keepdim) 1738 else: 1739 if _is_value(dim): 1740 axes = dim 1741 else: 1742 if allow_multi_dim_support: 1743 axes = g.op( 1744 "Constant", value_t=torch.tensor(dim, dtype=torch.long) 1745 ) 1746 else: 1747 axes = g.op( 1748 "Constant", value_t=torch.tensor([dim], dtype=torch.long) 1749 ) 1750 return g.op(onnx_op_name, self, axes, keepdims_i=keepdim) 1751 1752 return symbolic 1753 1754 1755def _overload_by_arg_count(fn): 1756 @functools.wraps(fn) 1757 def wrapper(g, *args): 1758 overloads = fn(g, *args) 1759 for overload in overloads: 1760 arg_descriptors = overload._arg_descriptors 1761 if len(arg_descriptors) == len(args): 1762 return overload(g, *args) 1763 return _unimplemented(f"aten::{fn.__name__}", f"with {len(args)} arguments") 1764 1765 return wrapper 1766 1767 1768def _reduce_with_dtype_helper( 1769 onnx_op: str, name: str, allow_multi_dim_support: bool = True 1770): 1771 symbolic = _reduce_op_symbolic_helper( 1772 onnx_op, allow_multi_dim_support=allow_multi_dim_support 1773 ) 1774 1775 @_overload_by_arg_count 1776 def reduce(g, *args, **kwargs): 1777 @quantized_args(True) 1778 @parse_args("v", "none") 1779 def reduce_nodim(g, self, dtype): 1780 dtype_onnx = None 1781 if dtype.node().kind() == "onnx::Constant": 1782 dtype = _get_const(dtype, "i", "dtype") 1783 dtype_onnx = _type_utils.JitScalarType(dtype).onnx_type() 1784 self = g.op("Cast", self, to_i=dtype_onnx) 1785 elif dtype.node().kind() != "prim::Constant": 1786 return _unimplemented(name, "dtype", dtype) 1787 result = symbolic(g, self) 1788 if dtype_onnx is not None: 1789 result_dtype_onnx = _type_utils.JitScalarType.from_value( 1790 result 1791 ).onnx_type() 1792 if result_dtype_onnx != dtype_onnx: 1793 result = g.op("Cast", result, to_i=dtype_onnx) 1794 return result 1795 1796 dim_desc = "is" if allow_multi_dim_support else "i" 1797 1798 @quantized_args(True) 1799 @parse_args("v", dim_desc, "i", "none") # type: ignore[arg-type] 1800 def reduce_dim(g, self, dim, keepdim, dtype): 1801 dtype_onnx = None 1802 if dtype.node().kind() == "onnx::Constant": 1803 dtype = _get_const(dtype, "i", "dtype") 1804 dtype_onnx = _type_utils.JitScalarType(dtype).onnx_type() 1805 self = g.op("Cast", self, to_i=dtype_onnx) 1806 elif dtype.node().kind() != "prim::Constant": 1807 return _unimplemented(name, "dtype", dtype) 1808 result = symbolic(g, self, dim, keepdim) 1809 if dtype_onnx is not None: 1810 result_dtype_onnx = _type_utils.JitScalarType.from_value( 1811 result 1812 ).onnx_type() 1813 if result_dtype_onnx != dtype_onnx: 1814 result = g.op("Cast", result, to_i=dtype_onnx) 1815 return result 1816 1817 return reduce_nodim, reduce_dim 1818 1819 return reduce 1820 1821 1822def _max_helper(g: jit_utils.GraphContext, self, dim_or_y=None, keepdim=None): 1823 # torch.max(input) 1824 if dim_or_y is None and keepdim is None: 1825 return g.op("ReduceMax", self, keepdims_i=0) 1826 # torch.max(input, other) 1827 if keepdim is None: 1828 return _op_with_optional_float_cast(g, "Max", self, dim_or_y, opset_before=12) 1829 # torch.max(input, dim, keepdim) 1830 else: 1831 keepdim = _get_const(keepdim, "i", "keepdim") 1832 dim = _get_const(dim_or_y, "i", "dim") 1833 if g.opset < 18: 1834 max = g.op("ReduceMax", self, axes_i=[dim], keepdims_i=keepdim) 1835 else: 1836 axes = g.op("Constant", value_t=torch.tensor([dim], dtype=torch.long)) 1837 max = g.op("ReduceMax", self, axes, keepdims_i=keepdim) 1838 indices = g.op("ArgMax", self, axis_i=dim, keepdims_i=keepdim) 1839 return max, indices 1840 1841 1842def _min_helper(g: jit_utils.GraphContext, self, dim_or_y=None, keepdim=None): 1843 # torch.min(input) 1844 if dim_or_y is None and keepdim is None: 1845 return g.op("ReduceMin", self, keepdims_i=0) 1846 # torch.min(input, other) 1847 if keepdim is None: 1848 return _op_with_optional_float_cast(g, "Min", self, dim_or_y, opset_before=12) 1849 # torch.min(input, dim, keepdim) 1850 else: 1851 keepdim = _get_const(keepdim, "i", "keepdim") 1852 dim = _get_const(dim_or_y, "i", "dim") 1853 if g.opset < 18: 1854 min = g.op("ReduceMin", self, axes_i=[dim], keepdims_i=keepdim) 1855 else: 1856 axes = g.op("Constant", value_t=torch.tensor([dim], dtype=torch.long)) 1857 min = g.op("ReduceMin", self, axes, keepdims_i=keepdim) 1858 indices = g.op("ArgMin", self, axis_i=dim, keepdims_i=keepdim) 1859 return min, indices 1860 1861 1862def _numel_helper(g: jit_utils.GraphContext, self): 1863 shape = g.op("Shape", self) 1864 return g.op("ReduceProd", shape, keepdims_i=0) 1865 1866 1867@parse_args("v", "is", "i", "i") 1868def _var_mean_helper(g: jit_utils.GraphContext, input, dim, correction, keepdim): 1869 if g.opset < 18: 1870 if dim is None: 1871 mean = g.op("ReduceMean", input, keepdims_i=0) 1872 t_mean = mean 1873 num_elements = _numel_helper(g, input) 1874 else: 1875 mean = g.op("ReduceMean", input, axes_i=dim, keepdims_i=keepdim) 1876 t_mean = g.op("ReduceMean", input, axes_i=dim, keepdims_i=1) 1877 redudced_dims = g.op("Shape", input) 1878 # dim could contain one or multiple dimensions 1879 redudced_dims = g.op( 1880 "Gather", 1881 redudced_dims, 1882 g.op("Constant", value_t=torch.tensor(dim)), 1883 axis_i=0, 1884 ) 1885 num_elements = g.op("ReduceProd", redudced_dims, keepdims_i=0) 1886 sub_v = g.op("Sub", input, t_mean) 1887 sqr_sub = g.op("Mul", sub_v, sub_v) 1888 keepdim_mean = 0 if dim is None else keepdim 1889 var = g.op("ReduceMean", sqr_sub, axes_i=dim, keepdims_i=keepdim_mean) 1890 # Correct bias in calculating variance, by dividing it over (N - correction) instead on N 1891 if correction is None: 1892 correction = 1 1893 if correction != 0: 1894 num_elements = g.op( 1895 "Cast", num_elements, to_i=_C_onnx.TensorProtoDataType.FLOAT 1896 ) 1897 one = g.op("Constant", value_t=torch.tensor(correction, dtype=torch.float)) 1898 mul = g.op("Mul", var, num_elements) 1899 var = g.op("Div", mul, g.op("Sub", num_elements, one)) 1900 return var, mean 1901 else: 1902 axes = None 1903 if dim is None: 1904 mean = g.op("ReduceMean", input, keepdims_i=0) 1905 t_mean = mean 1906 num_elements = _numel_helper(g, input) 1907 else: 1908 axes = g.op("Constant", value_t=torch.tensor(dim, dtype=torch.long)) 1909 mean = g.op("ReduceMean", input, axes, keepdims_i=keepdim) 1910 t_mean = g.op("ReduceMean", input, axes, keepdims_i=1) 1911 redudced_dims = g.op("Shape", input) 1912 # dim could contain one or multiple dimensions 1913 redudced_dims = g.op( 1914 "Gather", 1915 redudced_dims, 1916 g.op("Constant", value_t=torch.tensor(dim)), 1917 axis_i=0, 1918 ) 1919 num_elements = g.op("ReduceProd", redudced_dims, keepdims_i=0) 1920 sub_v = g.op("Sub", input, t_mean) 1921 sqr_sub = g.op("Mul", sub_v, sub_v) 1922 keepdim_mean = 0 if dim is None else keepdim 1923 if axes is None: 1924 var = g.op("ReduceMean", sqr_sub, keepdims_i=keepdim_mean) 1925 else: 1926 var = g.op("ReduceMean", sqr_sub, axes, keepdims_i=keepdim_mean) 1927 # Correct bias in calculating variance, by dividing it over (N - correction) instead on N 1928 if correction is None: 1929 correction = 1 1930 if correction != 0: 1931 num_elements = g.op( 1932 "Cast", num_elements, to_i=_C_onnx.TensorProtoDataType.FLOAT 1933 ) 1934 one = g.op("Constant", value_t=torch.tensor(correction, dtype=torch.float)) 1935 mul = g.op("Mul", var, num_elements) 1936 var = g.op("Div", mul, g.op("Sub", num_elements, one)) 1937 return var, mean 1938 1939 1940def _embedding_bag_helper( 1941 g: jit_utils.GraphContext, 1942 embedding_matrix, 1943 indices, 1944 offsets, 1945 scale_grad_by_freq, 1946 mode, 1947 sparse, 1948 per_sample_weights, 1949 include_last_offset, 1950 padding_idx, 1951): 1952 if scale_grad_by_freq and GLOBALS.export_training: 1953 return _onnx_unsupported( 1954 "embedding_bag with scale_grad_by_freq for training mode" 1955 ) 1956 if padding_idx is not None and padding_idx >= 0: 1957 raise RuntimeError("embedding_bag with padding_idx") 1958 1959 loop_condition = g.op("Constant", value_t=torch.tensor(1)) 1960 loop_condition = g.op("Cast", loop_condition, to_i=_C_onnx.TensorProtoDataType.BOOL) 1961 zero = g.op("Constant", value_t=torch.tensor([0])) 1962 1963 indices_len = _unsqueeze_helper( 1964 g, 1965 _size_helper(g, indices, g.op("Constant", value_t=torch.tensor(0))), 1966 [0], 1967 ) 1968 if not include_last_offset: 1969 offsets = [offsets, indices_len] 1970 offsets = g.op("Concat", *offsets, axis_i=0) 1971 1972 # Offsets holds the starting index position of each bag. So we create a list of the indices slices (determined by 1973 # offsets) and gather those indices in indices_row. Then we use this subset of indices to gather from embeddings. 1974 # The embeddings output is a loop scan output, so we can avoid creating a sequence and inserting elements in. 1975 offsets_starts = _slice_helper( 1976 g, offsets, axes=[0], starts=[0], ends=[sys.maxsize], steps=[1] 1977 ) 1978 offsets_ends = _slice_helper( 1979 g, offsets, axes=[0], starts=[1], ends=[sys.maxsize], steps=[1] 1980 ) 1981 1982 loop_len = _size_helper(g, offsets_ends, g.op("Constant", value_t=torch.tensor(0))) 1983 1984 loop, (loop_context,), _ = jit_utils.add_op_with_blocks( 1985 g, "Loop", loop_len, loop_condition, n_blocks=1 1986 ) 1987 loop_block = loop_context.block 1988 1989 # FIXME(justinchuby): We need to handle what happens when we call b.op on a node return 1990 block_input_iter = utils._add_input_to_block(loop_block) 1991 cond = utils._add_input_to_block(loop_block) 1992 1993 indices_start = loop_context.op( 1994 "Gather", offsets_starts, block_input_iter, axis_i=0 1995 ) 1996 indices_end = loop_context.op("Gather", offsets_ends, block_input_iter, axis_i=0) 1997 indices_start = _unsqueeze_helper(loop_context, indices_start, [0]) 1998 indices_end = _unsqueeze_helper(loop_context, indices_end, [0]) 1999 2000 indices_row = loop_context.op("Slice", indices, indices_start, indices_end, zero) 2001 embeddings = loop_context.op("Gather", embedding_matrix, indices_row, axis_i=0) 2002 if not _is_none(per_sample_weights): 2003 per_sample_weights_row = loop_context.op( 2004 "Slice", per_sample_weights, indices_start, indices_end, zero 2005 ) 2006 per_sample_weights_row = _unsqueeze_helper( 2007 loop_context, per_sample_weights_row, [1] 2008 ) 2009 embeddings = loop_context.op("Mul", embeddings, per_sample_weights_row) 2010 if mode == 0: 2011 embeddings = _reducesum_helper( 2012 loop_context, embeddings, axes_i=[0], keepdims_i=0 2013 ) 2014 elif mode == 1: 2015 if loop_context.opset < 18: 2016 embeddings = loop_context.op( 2017 "ReduceMean", embeddings, axes_i=[0], keepdims_i=0 2018 ) 2019 else: 2020 axes = loop_context.op( 2021 "Constant", value_t=torch.tensor([0], dtype=torch.long) 2022 ) 2023 embeddings = loop_context.op("ReduceMean", embeddings, axes, keepdims_i=0) 2024 else: 2025 if loop_context.opset < 18: 2026 embeddings = loop_context.op( 2027 "ReduceMax", embeddings, axes_i=[0], keepdims_i=0 2028 ) 2029 else: 2030 axes = loop_context.op( 2031 "Constant", value_t=torch.tensor([0], dtype=torch.long) 2032 ) 2033 embeddings = loop_context.op("ReduceMax", embeddings, axes, keepdims_i=0) 2034 2035 cond_out = loop_context.op( 2036 "Cast", loop_condition, to_i=_C_onnx.TensorProtoDataType.BOOL 2037 ) 2038 utils._add_output_to_block(loop_block, cond_out) 2039 utils._add_output_to_block(loop_block, embeddings) 2040 2041 # aten::embedding_bag returns a tuple of 4 elements: output, offset2bag, bag_size, max_indices. 2042 # But the last three outputs are not used in torch.nn.EmbeddingBag or torch.nn.functional.embedding_bag. 2043 return loop.node().output(), None, None, None 2044 2045 2046def _linalg_vector_norm_helper( 2047 g: jit_utils.GraphContext, 2048 self: torch._C.Value, 2049 ord: float, 2050 dim: Sequence[int] | None, 2051 keepdim: bool, 2052 dtype: torch._C.Value, 2053): 2054 axes = None 2055 # Conditions based on https://pytorch.org/docs/stable/generated/torch.linalg.vector_norm.html 2056 if _is_none(dim): 2057 self = _reshape_helper(g, self, [-1]) 2058 keepdim = False 2059 elif g.opset >= 18: 2060 axes = g.op("Constant", value_t=torch.tensor(dim, dtype=torch.long)) 2061 2062 if ord == math.inf: 2063 if g.opset < 18: 2064 result = g.op( 2065 "ReduceMax", g.op("Abs", self), axes_i=dim, keepdims_i=keepdim 2066 ) 2067 else: 2068 if axes is None: 2069 result = g.op("ReduceMax", g.op("Abs", self), keepdims_i=keepdim) 2070 else: 2071 result = g.op("ReduceMax", g.op("Abs", self), axes, keepdims_i=keepdim) 2072 elif ord == -math.inf: 2073 if g.opset < 18: 2074 result = g.op( 2075 "ReduceMin", g.op("Abs", self), axes_i=dim, keepdims_i=keepdim 2076 ) 2077 else: 2078 if axes is None: 2079 result = g.op("ReduceMin", g.op("Abs", self), keepdims_i=keepdim) 2080 else: 2081 result = g.op("ReduceMin", g.op("Abs", self), axes, keepdims_i=keepdim) 2082 elif ord == 0: 2083 if g.opset < 11: 2084 return _onnx_opset_unsupported_detailed( 2085 "linalg_vector_norm", 9, 11, "ord=0 not supported", self 2086 ) 2087 else: 2088 if dim is None: 2089 self = _reshape_helper( 2090 g, 2091 self, 2092 g.op("Constant", value_t=torch.tensor([-1], dtype=torch.int64)), 2093 ) 2094 keepdim = False 2095 2096 cond_op = g.op( 2097 "Not", 2098 g.op("Equal", self, g.op("Constant", value_t=torch.LongTensor([0]))), 2099 ) 2100 cond_op = g.op( 2101 "Cast", 2102 cond_op, 2103 to_i=_type_utils.JitScalarType.from_value(self).onnx_type(), 2104 ) 2105 return _reducesum_helper(g, cond_op, axes_i=dim, keepdims_i=keepdim) 2106 elif ord == 1: 2107 if g.opset < 18: 2108 result = _reduce_op_symbolic_helper("ReduceL1")( 2109 g, self, dim=dim, keepdim=keepdim 2110 ) 2111 else: 2112 if axes is None: 2113 result = _reduce_op_symbolic_helper("ReduceL1")( 2114 g, self, keepdim=keepdim 2115 ) 2116 else: 2117 result = _reduce_op_symbolic_helper("ReduceL1")( 2118 g, self, axes, keepdim=keepdim 2119 ) 2120 elif ord == 2: 2121 if g.opset < 18: 2122 result = _reduce_op_symbolic_helper("ReduceL2")( 2123 g, self, dim=dim, keepdim=keepdim 2124 ) 2125 else: 2126 if axes is None: 2127 result = _reduce_op_symbolic_helper("ReduceL2")( 2128 g, self, keepdim=keepdim 2129 ) 2130 else: 2131 result = _reduce_op_symbolic_helper("ReduceL2")( 2132 g, self, axes, keepdim=keepdim 2133 ) 2134 else: 2135 ord_op = g.op("Constant", value_t=torch.tensor(ord, dtype=torch.float32)) 2136 result = _reducesum_helper( 2137 g, g.op("Pow", g.op("Abs", self), ord_op), axes_i=dim, keepdims_i=keepdim 2138 ) 2139 result = g.op( 2140 "Pow", 2141 result, 2142 g.op( 2143 "Div", 2144 g.op("Constant", value_t=torch.tensor(1, dtype=torch.float32)), 2145 ord_op, 2146 ), 2147 ) 2148 2149 if not _is_none(dtype): 2150 dtype = _get_const(dtype, "i", "dtype") 2151 result = g.op("Cast", result, to_i=_type_utils.JitScalarType(dtype).onnx_type()) # type: ignore[arg-type] 2152 return result 2153 2154 2155# Deprecated. Internally use _type_utils.ScalarType 2156# TODO: remove these once we support Type's in the JIT IR and we can once again 2157# use the unified toType operator 2158cast_pytorch_to_onnx = { 2159 "Byte": _C_onnx.TensorProtoDataType.UINT8, 2160 "Char": _C_onnx.TensorProtoDataType.INT8, 2161 "Double": _C_onnx.TensorProtoDataType.DOUBLE, 2162 "Float": _C_onnx.TensorProtoDataType.FLOAT, 2163 "Half": _C_onnx.TensorProtoDataType.FLOAT16, 2164 "Int": _C_onnx.TensorProtoDataType.INT32, 2165 "Long": _C_onnx.TensorProtoDataType.INT64, 2166 "Short": _C_onnx.TensorProtoDataType.INT16, 2167 "Bool": _C_onnx.TensorProtoDataType.BOOL, 2168 "ComplexFloat": _C_onnx.TensorProtoDataType.COMPLEX64, 2169 "ComplexDouble": _C_onnx.TensorProtoDataType.COMPLEX128, 2170 "BFloat16": _C_onnx.TensorProtoDataType.BFLOAT16, 2171 "Undefined": _C_onnx.TensorProtoDataType.UNDEFINED, 2172} 2173 2174# Deprecated. Internally use _type_utils.ScalarType 2175scalar_name_to_pytorch = { 2176 "uint8_t": "Byte", 2177 "int8_t": "Char", 2178 "double": "Double", 2179 "float": "Float", 2180 "half": "Half", 2181 "int": "Int", 2182 "int64_t": "Long", 2183 "int16_t": "Short", 2184 "bool": "Bool", 2185 "complex64": "ComplexFloat", 2186 "complex128": "ComplexDouble", 2187 "qint8": "QInt8", 2188 "quint8": "QUInt8", 2189 "qint32": "QInt32", 2190 "bfloat16": "BFloat16", 2191} 2192 2193 2194# Deprecated. Internally use _type_utils.ScalarType 2195# This indicates each scalar type's corresponding 2196# torch type. Related source: 2197# https://github.com/pytorch/pytorch/blob/344defc9733a45fee8d0c4d3f5530f631e823196/c10/core/ScalarType.h 2198scalar_type_to_pytorch_type = [ 2199 torch.uint8, # 0 2200 torch.int8, # 1 2201 torch.short, # 2 2202 torch.int, # 3 2203 torch.int64, # 4 2204 torch.half, # 5 2205 torch.float, # 6 2206 torch.double, # 7 2207 torch.complex32, # 8 2208 torch.complex64, # 9 2209 torch.complex128, # 10 2210 torch.bool, # 11 2211 torch.qint8, # 12 2212 torch.quint8, # 13 2213 torch.qint32, # 14 2214 torch.bfloat16, # 15 2215] 2216 2217# Deprecated. Internally use _type_utils.ScalarType 2218# source of truth is 2219# https://github.com/pytorch/pytorch/blob/master/torch/csrc/utils/tensor_dtypes.cpp 2220pytorch_name_to_type = { 2221 "Byte": torch.uint8, 2222 "Char": torch.int8, 2223 "Double": torch.double, 2224 "Float": torch.float, 2225 "Half": torch.half, 2226 "Int": torch.int, 2227 "Long": torch.int64, 2228 "Short": torch.short, 2229 "Bool": torch.bool, 2230 "ComplexFloat": torch.complex64, 2231 "ComplexDouble": torch.complex128, 2232 "QInt8": torch.qint8, 2233 "QUInt8": torch.quint8, 2234 "QInt32": torch.qint32, 2235 "BFloat16": torch.bfloat16, 2236} 2237 2238 2239# Deprecated. Internally use _type_utils.ScalarType 2240scalar_type_to_onnx = [ 2241 cast_pytorch_to_onnx["Byte"], # 0 2242 cast_pytorch_to_onnx["Char"], # 1 2243 cast_pytorch_to_onnx["Short"], # 2 2244 cast_pytorch_to_onnx["Int"], # 3 2245 cast_pytorch_to_onnx["Long"], # 4 2246 cast_pytorch_to_onnx["Half"], # 5 2247 cast_pytorch_to_onnx["Float"], # 6 2248 cast_pytorch_to_onnx["Double"], # 7 2249 cast_pytorch_to_onnx["Undefined"], # 8 2250 cast_pytorch_to_onnx["ComplexFloat"], # 9 2251 cast_pytorch_to_onnx["ComplexDouble"], # 10 2252 cast_pytorch_to_onnx["Bool"], # 11 2253 cast_pytorch_to_onnx["Char"], # 12 2254 cast_pytorch_to_onnx["Byte"], # 13 2255 cast_pytorch_to_onnx["Int"], # 14 2256 cast_pytorch_to_onnx["BFloat16"], # 15 2257] 2258 2259# Global set to store the list of quantized operators in the network. 2260# This is currently only used in the conversion of quantized ops from PT -> C2 via ONNX. 2261_quantized_ops: set[int] = set() 2262