xref: /aosp_15_r20/external/pytorch/torch/onnx/symbolic_opset13.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2# EDITING THIS FILE? READ THIS FIRST!
3# see Note [Edit Symbolic Files] in README.md
4
5# This file exports ONNX ops for opset 13
6import functools
7
8import torch
9import torch._C._onnx as _C_onnx
10from torch.onnx import (
11    _constants,
12    _type_utils,
13    errors,
14    symbolic_helper,
15    symbolic_opset11 as opset11,
16    symbolic_opset9 as opset9,
17    utils,
18)
19from torch.onnx._internal import jit_utils, registration
20
21
22_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=13)
23
24
25@_onnx_symbolic("aten::softmax")
26@symbolic_helper.parse_args("v", "i", "none")
27def softmax(g: jit_utils.GraphContext, input, dim, dtype=None):
28    softmax = g.op("Softmax", input, axis_i=dim)
29    if dtype and dtype.node().kind() != "prim::Constant":
30        parsed_dtype = symbolic_helper._get_const(dtype, "i", "dtype")
31        softmax = g.op(
32            "Cast", softmax, to_i=_type_utils.JitScalarType(parsed_dtype).onnx_type()
33        )
34
35    return softmax
36
37
38@_onnx_symbolic("aten::log_softmax")
39@symbolic_helper.parse_args("v", "i", "none")
40def log_softmax(g: jit_utils.GraphContext, input, dim, dtype=None):
41    return_op = g.op("LogSoftmax", input, axis_i=dim)
42    if dtype and dtype.node().kind() != "prim::Constant":
43        parsed_dtype = symbolic_helper._get_const(dtype, "i", "dtype")
44        return_op = g.op(
45            "Cast", return_op, to_i=_type_utils.JitScalarType(parsed_dtype).onnx_type()
46        )
47    return return_op
48
49
50@_onnx_symbolic("aten::frobenius_norm")
51@symbolic_helper.parse_args("v", "v", "i")
52def frobenius_norm(g: jit_utils.GraphContext, self, dim=None, keepdim=False):
53    dim_val = symbolic_helper._maybe_get_const(dim, "is")
54    if not symbolic_helper._is_value(dim_val) and len(dim_val) == 0:
55        return g.op("ReduceL2", self, keepdims_i=0)
56    sqr = g.op("Mul", self, self)
57    sumsqr = symbolic_helper._reducesum_helper(g, sqr, dim, keepdims_i=keepdim)
58    return g.op("Sqrt", sumsqr)
59
60
61@_onnx_symbolic("aten::split")
62@symbolic_helper.parse_args("v", "v", "i", "i")
63def split(g: jit_utils.GraphContext, self, split_size_or_sizes, dim, _outputs=None):
64    if not symbolic_helper._is_split_static(split_size_or_sizes, _outputs):
65        split_out = g.op("SplitToSequence", self, split_size_or_sizes, axis_i=dim)
66        if _outputs is None:
67            return split_out
68        # Convert to multiple slice nodes iff number of splits and number of outputs are statically known.
69        if (
70            symbolic_helper._is_packed_list(split_size_or_sizes)
71            and len(symbolic_helper._unpack_list(split_size_or_sizes)) == _outputs
72        ):
73            split_sizes = [
74                symbolic_helper._unsqueeze_helper(g, v, [0])
75                for v in symbolic_helper._unpack_list(split_size_or_sizes)
76            ]
77
78            start = g.op("Constant", value_t=torch.tensor([0], dtype=torch.long))
79            axis = g.op("Constant", value_t=torch.tensor([dim], dtype=torch.long))
80            res = []
81            for i in range(_outputs):
82                end = g.op(
83                    "Add", start, split_sizes[i]
84                )  # split_sizes is a list of same length as _outputs
85                res.append(g.op("Slice", self, start, end, axis))
86                start = end
87            return res
88        return [
89            g.op(
90                "SequenceAt",
91                split_out,
92                g.op("Constant", value_t=torch.tensor([i], dtype=torch.long)),
93            )
94            for i in range(_outputs)
95        ]
96
97    split_val = symbolic_helper._node_get(split_size_or_sizes.node(), "value")
98    if split_val.dim() > 0:
99        return g.op("Split", self, split_size_or_sizes, axis_i=dim, outputs=_outputs)
100    split_size = symbolic_helper._get_const(split_size_or_sizes, "i", "split_size")
101
102    size = symbolic_helper._get_tensor_dim_size(self, dim)
103    if size is None:
104        if _outputs is not None:
105            size = split_size * _outputs
106        else:
107            raise errors.SymbolicValueError(
108                "Unknown dimension size not supported", self
109            )
110    splits = [split_size] * (size // split_size)
111    leftover = size % split_size
112    if leftover:
113        splits.append(leftover)
114    splits = g.op("Constant", value_t=torch.tensor(splits))
115    return g.op("Split", self, splits, axis_i=dim, outputs=_outputs)
116
117
118@_onnx_symbolic("aten::split_with_sizes")
119def split_with_sizes(g: jit_utils.GraphContext, self, split_sizes, dim, _outputs=None):
120    return split(g, self, split_sizes, dim, _outputs)
121
122
123@_onnx_symbolic("aten::unsafe_split")
124def unsafe_split(
125    g: jit_utils.GraphContext, self, split_size_or_sizes, dim, _outputs=None
126):
127    return split(g, self, split_size_or_sizes, dim, _outputs)
128
129
130@_onnx_symbolic("aten::unsafe_split_with_sizes")
131def unsafe_split_with_sizes(
132    g: jit_utils.GraphContext, self, split_sizes, dim, _outputs=None
133):
134    return split_with_sizes(g, self, split_sizes, dim, _outputs)
135
136
137@_onnx_symbolic("aten::tensor_split")
138@symbolic_helper.parse_args("v", "v", "i", "i")
139def tensor_split(
140    g: jit_utils.GraphContext, self, indices_or_sections, dim, _outputs=None
141):
142    axis = g.op("Constant", value_t=torch.tensor(dim, dtype=torch.long))
143    axis = opset11.unsqueeze(g, axis, 0)
144    const_1 = g.op("Constant", value_t=torch.tensor(1, dtype=torch.long))
145
146    if symbolic_helper._is_split_static(indices_or_sections, _outputs):
147        split_val = symbolic_helper._node_get(indices_or_sections.node(), "value")
148
149        if split_val.dim() > 0:
150            start = g.op("Constant", value_t=torch.tensor([0], dtype=torch.long))
151            res = []
152            assert _outputs is not None
153            for i in range(_outputs - 1):
154                end = g.op(
155                    "Gather",
156                    indices_or_sections,
157                    g.op("Constant", value_t=torch.tensor([i], dtype=torch.long)),
158                    axis_i=0,
159                )
160                res.append(g.op("Slice", self, start, end, axis))
161                start = end
162
163            end = symbolic_helper._size_helper(g, self, axis)
164            res.append(g.op("Slice", self, start, end, axis))
165            return res
166
167        split_size = symbolic_helper._get_const(
168            indices_or_sections, "i", "indices_or_sections"
169        )
170
171        size = symbolic_helper._get_tensor_dim_size(self, dim)
172        if size is None:
173            if _outputs is not None:
174                size = split_size * _outputs
175            else:
176                raise errors.SymbolicValueError(
177                    "Unknown dimension size not supported", self
178                )
179
180        min_split_size = size // split_size
181        num_splits_one_extra = size % split_size
182
183        splits = num_splits_one_extra * [min_split_size + 1]
184        leftover = (split_size - num_splits_one_extra) * [min_split_size]
185
186        splits = g.op(
187            "Constant", value_t=torch.tensor(splits + leftover, dtype=torch.long)
188        )
189        return g.op("Split", self, splits, axis_i=dim, outputs=_outputs)
190
191    if (
192        symbolic_helper._is_tensor(indices_or_sections)
193        and symbolic_helper._get_tensor_rank(indices_or_sections) == 1
194    ):
195        loop_len = symbolic_helper._size_helper(
196            g, indices_or_sections, g.op("Constant", value_t=torch.tensor(0))
197        )
198        loop_len = opset11.unsqueeze(g, loop_len, 0)
199        loop_condition = g.op("Cast", const_1, to_i=_C_onnx.TensorProtoDataType.BOOL)
200
201        # To make the first slice in the below loop work,
202        # we pad a zero to the first position so that it will be the initial start of slice.
203        padding_0 = g.op("Constant", value_t=torch.tensor([0], dtype=torch.long))
204        indices_or_sections = g.op("Concat", padding_0, indices_or_sections, axis_i=0)
205
206        final_splits = g.op("SequenceEmpty")
207        # Loop inputs
208        loop, (loop_context,), _ = jit_utils.add_op_with_blocks(
209            g, "Loop", loop_len, loop_condition, final_splits, outputs=1, n_blocks=1
210        )
211
212        loop_block = loop_context.block
213        block_input_iter = utils._add_input_to_block(loop_block)
214        cond = utils._add_input_to_block(loop_block)
215        final_splits = utils._add_input_to_block(loop_block)
216
217        start = loop_context.op(
218            "Gather", indices_or_sections, block_input_iter, axis_i=0
219        )
220        end = loop_context.op(
221            "Gather",
222            indices_or_sections,
223            loop_context.op("Add", block_input_iter, const_1),
224            axis_i=0,
225        )
226
227        slice = loop_context.op("Slice", self, start, end, axis)
228        final_splits = loop_context.op("SequenceInsert", final_splits, slice)
229
230        # Loop outputs
231        cond_out = loop_context.op("Identity", loop_condition)
232        utils._add_output_to_block(loop_block, cond_out)
233        utils._add_output_to_block(loop_block, final_splits)
234
235        loop_out = loop.node().output()
236        start = g.op(
237            "Gather",
238            indices_or_sections,
239            g.op("Constant", value_t=torch.tensor(-1, dtype=torch.long)),
240            axis_i=0,
241        )
242        start = opset11.unsqueeze(g, start, 0)
243        end = symbolic_helper._size_helper(g, self, axis)
244
245        last_slice = g.op("Slice", self, start, end, axis)
246
247        return g.op("SequenceInsert", loop_out, last_slice)
248
249    else:  # scalar tensor
250        dim_size = symbolic_helper._size_helper(g, self, axis)
251        min_split_size = g.op("Div", dim_size, indices_or_sections)
252        min_split_size_plus_1 = g.op(
253            "Add",
254            min_split_size,
255            const_1,
256        )
257        num_splits_one_extra = g.op("Mod", dim_size, indices_or_sections)
258        splits = g.op("Tile", min_split_size_plus_1, num_splits_one_extra)
259        leftover = g.op(
260            "Tile",
261            min_split_size,
262            g.op(
263                "Sub",
264                opset11.unsqueeze(g, indices_or_sections, 0),
265                num_splits_one_extra,
266            ),
267        )
268
269        splits = g.op("Concat", splits, leftover, axis_i=0)
270        if _outputs is None:
271            return g.op("SplitToSequence", self, splits, axis_i=dim)
272        return g.op("Split", self, splits, axis_i=dim, outputs=_outputs)
273
274
275@_onnx_symbolic("aten::unbind")
276@symbolic_helper.parse_args("v", "i", "i")
277def unbind(g: jit_utils.GraphContext, self, dim=0, _outputs=None):
278    if _outputs is None:
279        return g.op(
280            "SplitToSequence",
281            self,
282            g.op("Constant", value_t=torch.tensor(1, dtype=torch.long)),
283            axis_i=dim,
284            keepdims_i=0,
285        )
286
287    splits = g.op("Constant", value_t=torch.tensor([1] * _outputs))
288    outputs = g.op("Split", self, splits, axis_i=dim, outputs=_outputs)
289    outputs = [outputs] if _outputs == 1 else outputs
290    squeezed_outputs = [
291        g.op("Squeeze", out, g.op("Constant", value_t=torch.tensor([dim])))
292        for out in outputs
293    ]
294    return squeezed_outputs
295
296
297@_onnx_symbolic("aten::nonzero_numpy")
298# Emitted from `torch.nonzero(x, as_tuple=True)`
299def nonzero_numpy(g: jit_utils.GraphContext, input, _outputs=None):
300    return unbind(g, opset9.nonzero(g, input), 1, _outputs=_outputs)
301
302
303@_onnx_symbolic("aten::where")
304@symbolic_helper.parse_args("v", "v", "v", "i")
305def where(g: jit_utils.GraphContext, condition, self=None, other=None, _outputs=None):
306    # Assumes that torch.where's first argument takes only Bool and Byte tensors.
307    if not symbolic_helper._is_bool(condition):
308        condition = g.op("Cast", condition, to_i=_C_onnx.TensorProtoDataType.BOOL)
309    if self is None:
310        condition = opset9.nonzero(g, condition)
311        return symbolic_helper._unbind_helper(
312            g, condition, g.op("Constant", value_t=torch.tensor(1)), _outputs
313        )
314    return g.op("Where", condition, self, other)
315
316
317@_onnx_symbolic("aten::fake_quantize_per_channel_affine")
318@symbolic_helper.parse_args("v", "v", "v", "i", "i", "i")
319def fake_quantize_per_channel_affine(
320    g: jit_utils.GraphContext,
321    inputs,
322    scale,
323    zero_point,
324    axis,
325    quant_min=-128,
326    quant_max=127,
327):
328    # NOTE: (0, 127) is allowed as special case. PyTorch restricts activations to be in the range (0, 127).
329    #   https://github.com/pytorch/pytorch/blob/b34b192d6b97325c9f78e5995c48c8498ede34bd/torch/ao/quantization/observer.py#L1422
330    if (quant_min, quant_max) not in [(0, 255), (-128, 127), (0, 127)]:
331        raise errors.SymbolicValueError(
332            "For (quant_min, quant_max), ONNX allows only (0, 127), (0, 255) and (-128, 127). "
333            f"Got ({quant_min}, {quant_max})",
334            inputs,
335        )
336    # ONNX defines zero_point to be int8 or uint8
337    if quant_min == 0:
338        zero_point = g.op("Cast", zero_point, to_i=_C_onnx.TensorProtoDataType.UINT8)
339    else:
340        zero_point = g.op("Cast", zero_point, to_i=_C_onnx.TensorProtoDataType.INT8)
341    quantized = g.op("QuantizeLinear", inputs, scale, zero_point, axis_i=axis)
342    if (quant_min, quant_max) == (0, 127):
343        quantized = g.op(
344            "Clip",
345            quantized,
346            opset9.unused(g),
347            g.op("Constant", value_t=torch.tensor(127, dtype=torch.uint8)),
348        )
349    return g.op("DequantizeLinear", quantized, scale, zero_point, axis_i=axis)
350
351
352@_onnx_symbolic("aten::fake_quantize_per_tensor_affine")
353@symbolic_helper.parse_args("v", "v", "v", "i", "i")
354def fake_quantize_per_tensor_affine(
355    g: jit_utils.GraphContext,
356    inputs,
357    scale,
358    zero_point,
359    quant_min=-128,
360    quant_max=127,
361):
362    # NOTE: (0, 127) is allowed as special case. PyTorch restricts activations to be in the range (0, 127).
363    #   https://github.com/pytorch/pytorch/blob/b34b192d6b97325c9f78e5995c48c8498ede34bd/torch/ao/quantization/observer.py#L1422
364    if (quant_min, quant_max) not in [(0, 255), (-128, 127), (0, 127)]:
365        raise errors.SymbolicValueError(
366            "For (quant_min, quant_max), ONNX allows only (0, 127), (0, 255) and (-128, 127). "
367            f"Got ({quant_min}, {quant_max})",
368            inputs,
369        )
370    if quant_min == 0:
371        zero_point = g.op("Cast", zero_point, to_i=_C_onnx.TensorProtoDataType.UINT8)
372    else:
373        zero_point = g.op("Cast", zero_point, to_i=_C_onnx.TensorProtoDataType.INT8)
374    if (
375        _type_utils.JitScalarType.from_value(scale, _type_utils.JitScalarType.UNDEFINED)
376        != _type_utils.JitScalarType.FLOAT
377    ):
378        scale = g.op("Cast", scale, to_i=_C_onnx.TensorProtoDataType.FLOAT)
379    quantized = g.op("QuantizeLinear", inputs, scale, zero_point)
380    if (quant_min, quant_max) == (0, 127):
381        quantized = g.op(
382            "Clip",
383            quantized,
384            opset9.unused(g),
385            g.op("Constant", value_t=torch.tensor(127, dtype=torch.uint8)),
386        )
387    return g.op("DequantizeLinear", quantized, scale, zero_point)
388
389
390def _reduce_op_symbolic(onnx_op_name):
391    def symbolic(g, self, dim=None, keepdim=None):
392        self = symbolic_helper._maybe_cast_reduce_op_input(g, self)
393        if dim is None:
394            # all-reduce path
395            return symbolic_helper._handle_reduce_dim_none(g, self, onnx_op_name)
396        else:
397            keepdim = symbolic_helper._get_const(keepdim, "i", "keepdim")
398            return g.op(onnx_op_name, self, dim, keepdims_i=keepdim)
399
400    return symbolic
401
402
403@_onnx_symbolic(
404    "aten::sum",
405    decorate=[symbolic_helper._apply_params("ReduceSum", "sum")],
406)
407def _reduce_with_dtype(onnx_op, name):
408    symbolic = _reduce_op_symbolic(onnx_op)
409
410    @symbolic_helper._overload_by_arg_count
411    def reduce(g, *args, **kwargs):
412        @symbolic_helper.parse_args("v", "none")
413        def reduce_nodim(g, self, dtype):
414            dtype_onnx = None
415            if dtype.node().kind() == "onnx::Constant":
416                dtype = symbolic_helper._get_const(dtype, "i", "dtype")
417                dtype_onnx = _type_utils.JitScalarType(dtype).onnx_type()
418                self = g.op("Cast", self, to_i=dtype_onnx)
419            elif dtype.node().kind() != "prim::Constant":
420                return symbolic_helper._unimplemented(name, "dtype", dtype)
421            result = symbolic(g, self)
422            if dtype_onnx is not None:
423                result_dtype_onnx = _type_utils.JitScalarType.from_value(
424                    result
425                ).onnx_type()
426                if result_dtype_onnx != dtype_onnx:
427                    result = g.op("Cast", result, to_i=dtype_onnx)
428            return result
429
430        @symbolic_helper.parse_args("v", "v", "i", "none")
431        def reduce_dim(g, self, dim, keepdim, dtype):
432            dtype_onnx = None
433            if dtype.node().kind() == "onnx::Constant":
434                dtype = symbolic_helper._get_const(dtype, "i", "dtype")
435                dtype_onnx = _type_utils.JitScalarType(dtype).onnx_type()
436                self = g.op("Cast", self, to_i=dtype_onnx)
437            elif dtype.node().kind() != "prim::Constant":
438                return symbolic_helper._unimplemented(name, "dtype", dtype)
439            result = symbolic(g, self, dim, keepdim)
440            if dtype_onnx is not None:
441                result_dtype_onnx = _type_utils.JitScalarType.from_value(
442                    result
443                ).onnx_type()
444                if result_dtype_onnx != dtype_onnx:
445                    result = g.op("Cast", result, to_i=dtype_onnx)
446            return result
447
448        return reduce_nodim, reduce_dim
449
450    return reduce
451
452
453# Ported from
454# https://github.com/microsoft/onnxscript/blob/6b1b81700b4523f31d8c6d3321e5d8ef5d42b764/onnxscript/function_libs/torch_aten/ops/core.py#L6097
455# NOTE: Supporting aten::unflatten before opset13 needs helper function to adjust ONNX op changes in Concat, Slice, ...
456@_onnx_symbolic("aten::unflatten")
457def unflatten(g: jit_utils.GraphContext, input, dim, unflattened_size):
458    input_dim = symbolic_helper._get_tensor_rank(input)
459    if input_dim is None:
460        return symbolic_helper._unimplemented(
461            "dim",
462            "ONNX and PyTorch use different strategies to split the input. "
463            "Input rank must be known at export time.",
464        )
465
466    # dim could be negative
467    input_dim = g.op("Constant", value_t=torch.tensor([input_dim], dtype=torch.int64))
468    dim = g.op("Add", input_dim, dim)
469    dim = g.op("Mod", dim, input_dim)
470
471    input_size = g.op("Shape", input)
472
473    head_start_idx = g.op("Constant", value_t=torch.tensor([0], dtype=torch.int64))
474    head_end_idx = g.op(
475        "Reshape", dim, g.op("Constant", value_t=torch.tensor([1], dtype=torch.int64))
476    )
477    head_part_rank = g.op("Slice", input_size, head_start_idx, head_end_idx)
478
479    dim_plus_one = g.op(
480        "Add", dim, g.op("Constant", value_t=torch.tensor([1], dtype=torch.int64))
481    )
482    tail_start_idx = g.op(
483        "Reshape",
484        dim_plus_one,
485        g.op("Constant", value_t=torch.tensor([1], dtype=torch.int64)),
486    )
487    tail_end_idx = g.op(
488        "Constant", value_t=torch.tensor([_constants.INT64_MAX], dtype=torch.int64)
489    )
490    tail_part_rank = g.op("Slice", input_size, tail_start_idx, tail_end_idx)
491
492    final_shape = g.op(
493        "Concat", head_part_rank, unflattened_size, tail_part_rank, axis_i=0
494    )
495
496    return symbolic_helper._reshape_helper(g, input, final_shape)
497
498
499@_onnx_symbolic("aten::unsafe_chunk")
500@symbolic_helper.parse_args("v", "i", "i", "i")
501def unsafe_chunk(g: jit_utils.GraphContext, self, chunks, dim, _outputs=None):
502    if _outputs is None:
503        return g.op(
504            "SplitToSequence",
505            self,
506            g.op("Constant", value_t=torch.tensor(1, dtype=torch.long)),
507            axis_i=dim,
508            keepdims_i=0,
509        )
510
511    size = symbolic_helper._get_tensor_dim_size(self, dim)
512    if size is None:
513        return symbolic_helper._unimplemented("unsafe_chunk", "unknown dimension size")
514    split_size = (size + chunks - 1) // chunks
515    splits = [split_size] * (size // split_size)
516    leftover = size % split_size
517    if leftover:
518        splits.append(leftover)
519
520    # TODO: So far we don"t have a module using this method. We"ll keep
521    # this as a constant unless we see a request of dynamics in any
522    # user's modules.
523    splits = g.op("Constant", value_t=torch.tensor(splits, dtype=torch.long))
524    return g.op("Split", self, splits, axis_i=dim, outputs=_outputs)
525
526
527@_onnx_symbolic("aten::tile")
528def tile(g: jit_utils.GraphContext, self, dims):
529    self_shape = g.op("Shape", self)
530    self_rank = g.op("Size", self_shape)
531    dims_rank = g.op("Size", dims)
532    diff = g.op("Sub", self_rank, dims_rank)
533    const_zero = g.op("Constant", value_t=torch.tensor([0]))
534
535    # 1. If dims is shorter than self.shape pad dims with 1
536    dims_shorter_than_self_shape = g.op("Greater", diff, const_zero)
537    (
538        if_op_greater,
539        (if_context_greater, else_context_greater),
540        _,
541    ) = jit_utils.add_op_with_blocks(
542        g, "If", dims_shorter_than_self_shape, n_blocks=2, outputs=1
543    )
544    const_one = if_context_greater.op("Constant", value_t=torch.LongTensor([1]))
545    diff_1d_greater = if_context_greater.op("Reshape", diff, const_one)
546    exapnd_ones_greater = if_context_greater.op("Expand", const_one, diff_1d_greater)
547    dims_ = if_context_greater.op("Concat", exapnd_ones_greater, dims, axis_i=0)
548    utils._add_output_to_block(if_context_greater.block, dims_)
549    identity_dim = else_context_greater.op("Identity", dims)
550    utils._add_output_to_block(else_context_greater.block, identity_dim)
551    dims_final = if_op_greater.node().output()
552
553    # 2. If dims is longer than self.shape pad self.shape with 1
554    dims_longer_than_self_shape = g.op("Less", diff, const_zero)
555    (
556        if_op_less,
557        (if_context_less, else_context_less),
558        _,
559    ) = jit_utils.add_op_with_blocks(
560        g, "If", dims_longer_than_self_shape, n_blocks=2, outputs=1
561    )
562    const_one = if_context_less.op("Constant", value_t=torch.LongTensor([1]))
563    diff_1d_less = if_context_less.op(
564        "Reshape",
565        if_context_less.op("Abs", diff),
566        const_one,
567    )
568    exapnd_ones_less = if_context_less.op("Expand", const_one, diff_1d_less)
569    self_final_shape = if_context_less.op(
570        "Concat", exapnd_ones_less, self_shape, axis_i=0
571    )
572    self_ = if_context_less.op("Reshape", self, self_final_shape)
573    utils._add_output_to_block(if_context_less.block, self_)
574    identity_self = else_context_less.op("Identity", self)
575    utils._add_output_to_block(else_context_less.block, identity_self)
576    self_final = if_op_less.node().output()
577
578    dims_final = g.op("Cast", dims_final, to_i=_C_onnx.TensorProtoDataType.INT64)
579    return g.op("Tile", self_final, dims_final)
580
581
582@_onnx_symbolic("aten::repeat_interleave")
583def repeat_interleave(
584    g: jit_utils.GraphContext, self, repeats, dim=None, output_size=None
585):
586    repeats_dim = symbolic_helper._get_tensor_rank(repeats)
587    repeats_sizes = symbolic_helper._get_tensor_sizes(repeats)
588    input_sizes = symbolic_helper._get_tensor_sizes(self)
589    if repeats_dim is None:
590        raise errors.SymbolicValueError(
591            "Unsupported: ONNX export of repeat_interleave for unknown repeats rank.",
592            self,
593        )
594    if repeats_sizes is None:
595        raise errors.SymbolicValueError(
596            "Unsupported: ONNX export of repeat_interleave for unknown repeats size.",
597            self,
598        )
599    if input_sizes is None:
600        raise errors.SymbolicValueError(
601            "Unsupported: ONNX export of repeat_interleave for unknown input size.",
602            self,
603        )
604
605    final_dim = dim
606    # if dim is None flatten
607    # By default, use the flattened input array, and return a flat output array
608    if symbolic_helper._is_none(dim):
609        self = symbolic_helper._reshape_helper(
610            g, self, g.op("Constant", value_t=torch.tensor([-1]))
611        )
612        dim = torch.tensor(0, dtype=torch.int64)
613    else:
614        dim = symbolic_helper._maybe_get_scalar(dim)
615
616    # Handle cases where dim is negative
617    if dim < 0:
618        dim += len(input_sizes)
619
620    output_sizes = input_sizes.copy()
621    for idx, input_size in enumerate(input_sizes):
622        if input_size is None:
623            output_sizes[idx], input_sizes[idx] = 0, -1
624
625    # Check if all indices should be repeated the same number of times.
626    if repeats_dim == 0 or (repeats_dim == 1 and repeats_sizes[0] == 1):
627        return symbolic_helper._repeat_interleave_single_value_repeat_helper(
628            g, self, repeats, dim
629        )
630
631    cond_dynamic_repeats = repeats_dim == 1 and repeats_sizes[0] is None
632    # If input size is dynamic or repeats vector is dynamic
633    if output_sizes[dim] == 0 or cond_dynamic_repeats:
634        reps = symbolic_helper._size_helper(g, self, dim)
635        reps = opset11.unsqueeze(g, reps, 0)
636
637        # Check if repeats is dynamic
638        # As repeats is dynamic, we use a where node as a substitute for the if statement
639        # If repests_dim = 1, expand repeats otherwise use original tensor
640        if cond_dynamic_repeats:
641            repeat_dim = symbolic_helper._size_helper(
642                g, repeats, g.op("Constant", value_t=torch.LongTensor([0]))
643            )
644            repeat_cond = g.op(
645                "Equal", repeat_dim, g.op("Constant", value_t=torch.LongTensor([1]))
646            )
647            repeats = where(g, repeat_cond, g.op("Expand", repeats, reps), repeats)
648    # There are cases when the repeats are 1-d tensor with multiple repeats, but dim
649    # provided along one of the dynamic axes provided. A simple example would be
650    # input.shape -> [1, 1, *] where * represents the dynamic axes, and dim = 2
651    # Now, repeat interleaving can be performed in pytorch when the value of * matches
652    # with the number of elements in repeat, for example if * -> 2, number of repeats
653    # should be 2 as well.
654    else:
655        return opset9.repeat_interleave(g, self, repeats, final_dim)
656
657    reps_like = g.op(
658        "ConstantOfShape",
659        g.op("Shape", repeats),
660        value_t=torch.tensor([1], dtype=torch.long),
661    )
662    r_splits = split(g, repeats, reps_like, 0)
663    i_splits = split(g, self, reps_like, dim)
664
665    output_sizes[dim], input_sizes[dim] = -1, 1
666
667    # Create a loop to iterate over each value along the dimension
668    # and perform individual interleaving using the repeats tensor
669    # Loop is of the following pattern
670    # input (trip_count, cond)
671    #   int trip_count = ...;
672    #   bool cond = ...;
673    #   for (int i=0; i < trip_count && cond; ++i) {
674    #     cond = ...;
675    #   }
676
677    # Loop conditions
678    loop_condition = g.op("Constant", value_t=torch.tensor(1))
679    loop_condition = g.op("Cast", loop_condition, to_i=_C_onnx.TensorProtoDataType.BOOL)
680    loop_len = reps
681
682    # Create an empty sequence to store final expansions
683    final_splits = g.op("SequenceEmpty")
684
685    # Loop inputs
686    loop, (loop_context,), _ = jit_utils.add_op_with_blocks(
687        g, "Loop", loop_len, loop_condition, final_splits, n_blocks=1
688    )
689
690    loop_block = loop_context.block
691    block_input_iter = utils._add_input_to_block(loop_block)
692    cond = utils._add_input_to_block(loop_block)
693    final_splits = utils._add_input_to_block(loop_block)
694
695    r_split = loop_context.op("SequenceAt", r_splits, block_input_iter)
696    i_split = loop_context.op("SequenceAt", i_splits, block_input_iter)
697
698    i_split = opset11.unsqueeze(loop_context, i_split, dim + 1)
699    r_concat = [
700        loop_context.op("Constant", value_t=torch.LongTensor(input_sizes[: dim + 1])),
701        r_split,
702        loop_context.op("Constant", value_t=torch.LongTensor(input_sizes[dim + 1 :])),
703    ]
704    r_concat = loop_context.op("Concat", *r_concat, axis_i=0)
705    i_split = opset9.expand(loop_context, i_split, r_concat, None)
706    i_split = symbolic_helper._reshape_helper(
707        loop_context, i_split, g.op("Constant", value_t=torch.LongTensor(output_sizes))
708    )
709    final_splits = loop_context.op("SequenceInsert", final_splits, i_split)
710
711    # Loop outputs
712    cond_out = loop_context.op(
713        "Cast", loop_condition, to_i=_C_onnx.TensorProtoDataType.BOOL
714    )
715    utils._add_output_to_block(loop_block, cond_out)
716    utils._add_output_to_block(loop_block, final_splits)
717
718    loop_out = loop.node().output()
719    loop_out = g.op("ConcatFromSequence", loop_out, axis_i=dim)
720    return loop_out
721
722
723@_onnx_symbolic("aten::diagonal")
724@symbolic_helper.parse_args("v", "i", "i", "i")
725def diagonal(g: jit_utils.GraphContext, self, offset, dim1, dim2):
726    rank = symbolic_helper._get_tensor_rank(self)
727    # Replace negative indexing when rank is known
728    if rank is not None:
729        dim1 = dim1 if dim1 >= 0 else dim1 + rank
730        dim2 = dim2 if dim2 >= 0 else dim2 + rank
731
732    dim1_size = opset9.size(
733        g, self, dim=g.op("Constant", value_t=torch.LongTensor([dim1]))
734    )
735    dim2_size = opset9.size(
736        g, self, dim=g.op("Constant", value_t=torch.LongTensor([dim2]))
737    )
738    # Create appropriate mask
739    mask_shape = g.op("Concat", dim1_size, dim2_size, axis_i=0)
740    mask = opset9.zeros(g, mask_shape, None, None, None)
741    mask = g.op("EyeLike", mask, k_i=offset)
742    # dim1 and dim2 appended as a dimension at the end of the shape
743
744    if rank is not None:
745        axes = list(range(rank))
746        axes.remove(dim1)
747        axes.remove(dim2)
748        self = g.op("Transpose", self, perm_i=axes + [dim1, dim2])
749    else:
750        return symbolic_helper._unimplemented("diagonal", "unknown input rank")
751
752    # Multiply input and mask to calculate values along diagonal
753    # The mask consists of one values where diagonal values are to be calculated
754    # For example:
755    # [[1.1, 1.2, 1.3],   *    [[1, 0, 0]   =   [[1.1, 0, 0],
756    #  [2.1, 2.2, 2.3],         [0, 1, 0]        [0, 2.2, 0],
757    #  [3.1, 3.2, 3.3]]         [0, 0, 1]]       [0, 0, 3.3]]
758    result = g.op("Mul", self, mask)
759    result = symbolic_helper._reducesum_helper(g, result, axes_i=[-1], keepdims_i=0)
760
761    # Calculate gather indices based on offset and dims
762    # If offset is greater than zero, set offset to zero as this aids in
763    # calculation of selection window
764    offset_op = g.op("Constant", value_t=torch.LongTensor([offset]))
765    if offset >= 0:
766        diag_size = g.op(
767            "Max",
768            g.op("Min", dim1_size, g.op("Sub", dim2_size, offset_op)),
769            g.op("Constant", value_t=torch.LongTensor([0])),
770        )
771        offset = 0
772    else:
773        diag_size = g.op(
774            "Max",
775            g.op("Min", g.op("Add", dim1_size, offset_op), dim2_size),
776            g.op("Constant", value_t=torch.LongTensor([0])),
777        )
778    diag_size = g.op("Concat", diag_size, axis_i=0)
779
780    # Calculate which diagonal values to select
781    # For example, in cases with offsets:
782    # [[0, 1.1, 0]
783    #  [0, 0, 2.2]]
784    # we need to select the last two columns, so we create a tensor
785    # with all columns that are to be selected
786    # So in this example, it is [1, 2]
787    select_window_ones_fill = opset9.ones(g, diag_size, 4, None, None)
788    select_window = g.op(
789        "CumSum",
790        select_window_ones_fill,
791        g.op("Constant", value_t=torch.LongTensor([0])),
792    )
793    select_window = g.op(
794        "Add",
795        select_window,
796        g.op("Constant", value_t=torch.LongTensor([abs(offset) - 1])),
797    )
798
799    gather_shape = [
800        opset9.size(g, result, dim=g.op("Constant", value_t=torch.LongTensor([axis])))
801        for axis in list(range(rank))[:-2]
802    ]
803    gather_shape.append(diag_size)
804    gather_shape = g.op("Concat", *gather_shape, axis_i=0)
805    gather_indices = opset9.zeros(g, gather_shape, 4, None, None)
806
807    # There might be cases where offset value is greater than number of rows/columns
808    # and might cause the diagonal to overrun and as a result of this, diag_size would be zero.
809    # For example, if
810    #       offset = 9, dim1_size = 2 (columns), dim2_size = 4 (rows)
811    #       diag_size = max(min(2, (4-9)), 0) = 0, based on calculation above
812    # Cases with diagonal overrun always result in diag_size = max(0, -ve value) = 0
813    # In cases without diagonal overrun, we select the appropriate rows/columns along which we
814    # are calculating diagonal values. In cases with diagonal overrun, we return a tensor which has
815    # the dimension of the row/column where overrun occurred as 0-dim, as we are essentially
816    # returning an empty tensor
817    overrun_cond = g.op(
818        "Not",
819        g.op(
820            "Equal",
821            diag_size,
822            g.op("Constant", value_t=torch.tensor(0, dtype=torch.int64)),
823        ),
824    )
825
826    if_op, (if_context, else_context), _ = jit_utils.add_op_with_blocks(
827        g, "If", overrun_cond, n_blocks=2
828    )
829
830    gather_indices_if_block = if_context.op("Add", gather_indices, select_window)
831    gather_indices_if_block = symbolic_helper._unsqueeze_helper(
832        if_context, gather_indices_if_block, [rank - 1]
833    )
834    final_non_overrun = if_context.op(
835        "GatherND", result, gather_indices_if_block, batch_dims_i=rank - 2
836    )
837    final_overrun = opset9.zeros(else_context, gather_shape, 6, None, None)
838    utils._add_output_to_block(if_context.block, final_non_overrun)
839    utils._add_output_to_block(else_context.block, final_overrun)
840    return if_op
841
842
843# Quantized ops
844
845
846@_onnx_symbolic("quantized::linear")
847def quantized_linear(
848    g: jit_utils.GraphContext, q_input, q_weight, bias, op_scale, op_zero_point
849):
850    input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input)
851    weight, weight_scale, _, axis = symbolic_helper.dequantize_helper(g, q_weight)
852    q_bias = symbolic_helper.requantize_bias_helper(
853        g, bias, input_scale, weight_scale, axis
854    )
855    bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias)
856
857    output = opset9.linear(g, input, weight, bias)
858
859    return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
860
861
862@_onnx_symbolic("quantized::linear_relu")
863def quantized_linear_relu(
864    g: jit_utils.GraphContext, q_input, q_weight, bias, op_scale, op_zero_point
865):
866    input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input)
867    weight, weight_scale, _, axis = symbolic_helper.dequantize_helper(g, q_weight)
868    q_bias = symbolic_helper.requantize_bias_helper(
869        g, bias, input_scale, weight_scale, axis
870    )
871    bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias)
872
873    output = opset9.linear(g, input, weight, bias)
874    output = opset9.relu(g, output)
875
876    return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
877
878
879@_onnx_symbolic("quantized::conv1d_relu")
880def quantized_conv1d_relu(
881    g: jit_utils.GraphContext,
882    q_input,
883    q_weight,
884    bias,
885    stride,
886    padding,
887    dilation,
888    groups,
889    op_scale,
890    op_zero_point,
891):
892    input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input)
893    weight, weight_scale, _, axis = symbolic_helper.dequantize_helper(g, q_weight)
894    q_bias = symbolic_helper.requantize_bias_helper(
895        g, bias, input_scale, weight_scale, axis
896    )
897    bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias)
898
899    output = opset9.conv1d(g, input, weight, bias, stride, padding, dilation, groups)
900    output = opset9.relu(g, output)
901
902    return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
903
904
905@_onnx_symbolic("quantized::conv2d_relu")
906def quantized_conv2d_relu(
907    g: jit_utils.GraphContext,
908    q_input,
909    q_weight,
910    bias,
911    stride,
912    padding,
913    dilation,
914    groups,
915    op_scale,
916    op_zero_point,
917):
918    input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input)
919    weight, weight_scale, _, axis = symbolic_helper.dequantize_helper(g, q_weight)
920    q_bias = symbolic_helper.requantize_bias_helper(
921        g, bias, input_scale, weight_scale, axis
922    )
923    bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias)
924
925    output = opset9.conv2d(g, input, weight, bias, stride, padding, dilation, groups)
926    output = opset9.relu(g, output)
927
928    return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
929
930
931@_onnx_symbolic("quantized::conv3d_relu")
932def quantized_conv3d_relu(
933    g: jit_utils.GraphContext,
934    q_input,
935    q_weight,
936    bias,
937    stride,
938    padding,
939    dilation,
940    groups,
941    op_scale,
942    op_zero_point,
943):
944    input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input)
945    weight, weight_scale, _, axis = symbolic_helper.dequantize_helper(g, q_weight)
946    q_bias = symbolic_helper.requantize_bias_helper(
947        g, bias, input_scale, weight_scale, axis
948    )
949    bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias)
950
951    output = opset9.conv3d(g, input, weight, bias, stride, padding, dilation, groups)
952    output = opset9.relu(g, output)
953
954    return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
955
956
957@_onnx_symbolic("quantized::conv1d")
958def quantized_conv1d(
959    g: jit_utils.GraphContext,
960    q_input,
961    q_weight,
962    bias,
963    stride,
964    padding,
965    dilation,
966    groups,
967    op_scale,
968    op_zero_point,
969):
970    input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input)
971    weight, weight_scale, _, axis = symbolic_helper.dequantize_helper(g, q_weight)
972    q_bias = symbolic_helper.requantize_bias_helper(
973        g, bias, input_scale, weight_scale, axis
974    )
975    bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias)
976
977    output = opset9.conv1d(g, input, weight, bias, stride, padding, dilation, groups)
978
979    return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
980
981
982@_onnx_symbolic("quantized::conv2d")
983def quantized_conv2d(
984    g: jit_utils.GraphContext,
985    q_input,
986    q_weight,
987    bias,
988    stride,
989    padding,
990    dilation,
991    groups,
992    op_scale,
993    op_zero_point,
994):
995    input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input)
996    weight, weight_scale, _, axis = symbolic_helper.dequantize_helper(g, q_weight)
997    q_bias = symbolic_helper.requantize_bias_helper(
998        g, bias, input_scale, weight_scale, axis
999    )
1000    bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias)
1001
1002    output = opset9.conv2d(g, input, weight, bias, stride, padding, dilation, groups)
1003
1004    return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
1005
1006
1007@_onnx_symbolic("quantized::conv3d")
1008def quantized_conv3d(
1009    g: jit_utils.GraphContext,
1010    q_input,
1011    q_weight,
1012    bias,
1013    stride,
1014    padding,
1015    dilation,
1016    groups,
1017    op_scale,
1018    op_zero_point,
1019):
1020    input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input)
1021    weight, weight_scale, _, axis = symbolic_helper.dequantize_helper(g, q_weight)
1022    q_bias = symbolic_helper.requantize_bias_helper(
1023        g, bias, input_scale, weight_scale, axis
1024    )
1025    bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias)
1026
1027    output = opset9.conv3d(g, input, weight, bias, stride, padding, dilation, groups)
1028
1029    return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
1030
1031
1032@_onnx_symbolic("quantized::conv_transpose1d")
1033def quantized_conv_transpose1d(
1034    g: jit_utils.GraphContext,
1035    q_input,
1036    q_weight,
1037    bias,
1038    stride,
1039    padding,
1040    output_padding,
1041    dilation,
1042    groups,
1043    op_scale,
1044    op_zero_point,
1045):
1046    input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input)
1047    weight, weight_scale, _, axis = symbolic_helper.dequantize_helper(g, q_weight)
1048    q_bias = symbolic_helper.requantize_bias_helper(
1049        g, bias, input_scale, weight_scale, axis
1050    )
1051    bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias)
1052
1053    output = opset9.conv_transpose2d(
1054        g, input, weight, bias, stride, padding, output_padding, groups, dilation
1055    )
1056
1057    return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
1058
1059
1060@_onnx_symbolic("quantized::conv_transpose2d")
1061def quantized_conv_transpose2d(
1062    g: jit_utils.GraphContext,
1063    q_input,
1064    q_weight,
1065    bias,
1066    stride,
1067    padding,
1068    output_padding,
1069    dilation,
1070    groups,
1071    op_scale,
1072    op_zero_point,
1073):
1074    input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input)
1075    weight, weight_scale, _, axis = symbolic_helper.dequantize_helper(g, q_weight)
1076    q_bias = symbolic_helper.requantize_bias_helper(
1077        g, bias, input_scale, weight_scale, axis
1078    )
1079    bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias)
1080
1081    output = opset9.conv_transpose2d(
1082        g, input, weight, bias, stride, padding, output_padding, groups, dilation
1083    )
1084
1085    return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
1086
1087
1088@_onnx_symbolic("quantized::conv_transpose3d")
1089def quantized_conv_transpose3d(
1090    g: jit_utils.GraphContext,
1091    q_input,
1092    q_weight,
1093    bias,
1094    stride,
1095    padding,
1096    output_padding,
1097    dilation,
1098    groups,
1099    op_scale,
1100    op_zero_point,
1101):
1102    input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input)
1103    weight, weight_scale, _, axis = symbolic_helper.dequantize_helper(g, q_weight)
1104    q_bias = symbolic_helper.requantize_bias_helper(
1105        g, bias, input_scale, weight_scale, axis
1106    )
1107    bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias)
1108
1109    output = opset9.conv_transpose3d(
1110        g, input, weight, bias, stride, padding, output_padding, groups, dilation
1111    )
1112
1113    return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
1114