xref: /aosp_15_r20/external/pytorch/test/onnx/test_pytorch_onnx_onnxruntime.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: onnx"]
2
3from __future__ import annotations
4
5import functools
6import io
7import itertools
8import os
9import unittest
10from collections import OrderedDict
11from typing import Dict, List, Optional, Tuple, Type, Union
12
13import numpy as np
14
15import onnx
16import onnx_test_common
17import parameterized
18import torchvision
19from model_defs import (
20    lstm_flattening_result,
21    rnn_model_with_packed_sequence,
22    word_language_model,
23)
24from pytorch_test_common import (
25    BATCH_SIZE,
26    RNN_BATCH_SIZE,
27    RNN_HIDDEN_SIZE,
28    RNN_INPUT_SIZE,
29    RNN_SEQUENCE_LENGTH,
30    skipDtypeChecking,
31    skipIfQuantizationBackendQNNPack,
32    skipIfUnsupportedMaxOpsetVersion,
33    skipIfUnsupportedMinOpsetVersion,
34    skipIfUnsupportedOpsetVersion,
35    skipScriptTest,
36    skipShapeChecking,
37    skipTraceTest,
38)
39
40import torch
41from torch import Tensor
42from torch.nn.utils import rnn as rnn_utils
43from torch.onnx import errors, verification
44from torch.testing._internal import common_utils
45from torch.testing._internal.common_utils import skipIfNoLapack
46
47
48def _init_test_generalized_rcnn_transform():
49    min_size = 100
50    max_size = 200
51    image_mean = [0.485, 0.456, 0.406]
52    image_std = [0.229, 0.224, 0.225]
53    transform = torchvision.models.detection.transform.GeneralizedRCNNTransform(
54        min_size, max_size, image_mean, image_std
55    )
56    return transform
57
58
59def _init_test_rpn():
60    anchor_sizes = ((32,), (64,), (128,), (256,), (512,))
61    aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes)
62    rpn_anchor_generator = torchvision.models.detection.rpn.AnchorGenerator(
63        anchor_sizes, aspect_ratios
64    )
65    out_channels = 256
66    rpn_head = torchvision.models.detection.rpn.RPNHead(
67        out_channels, rpn_anchor_generator.num_anchors_per_location()[0]
68    )
69    rpn_fg_iou_thresh = 0.7
70    rpn_bg_iou_thresh = 0.3
71    rpn_batch_size_per_image = 256
72    rpn_positive_fraction = 0.5
73    rpn_pre_nms_top_n = dict(training=2000, testing=1000)
74    rpn_post_nms_top_n = dict(training=2000, testing=1000)
75    rpn_nms_thresh = 0.7
76    rpn_score_thresh = 0.0
77
78    rpn = torchvision.models.detection.rpn.RegionProposalNetwork(
79        rpn_anchor_generator,
80        rpn_head,
81        rpn_fg_iou_thresh,
82        rpn_bg_iou_thresh,
83        rpn_batch_size_per_image,
84        rpn_positive_fraction,
85        rpn_pre_nms_top_n,
86        rpn_post_nms_top_n,
87        rpn_nms_thresh,
88        score_thresh=rpn_score_thresh,
89    )
90    return rpn
91
92
93def _construct_tensor_for_quantization_test(
94    shape: Tuple[int, ...],
95    offset: Optional[Union[int, float]] = None,
96    max_val: Optional[Union[int, float]] = None,
97) -> Tensor:
98    """Helper function to generate weights and test inputs in a deterministic way.
99
100    Due to difference in implementation details between PyTorch and ONNXRuntime, randomly generated
101    test data for quantization tests can be flaky. To help stablize the test, this helper function is
102    used to generate weights and test inputs in a deterministic way.
103
104    Args:
105        shape (Tuple[int]): Shape for tensor to construct.
106        offset (Optional[Union[int, float]]): Offset to be added to the generated tensor.
107        max_val (Optional[Union[int, float]]): If any element within tensor has a larger absolute value than
108            max_val, the tensor will be scaled by max_val / tensor.abs().max(). This step is done after
109            applying offset.
110    """
111    tensor = torch.arange(np.prod(shape), dtype=torch.float).view(shape)
112    if offset is not None:
113        tensor = tensor + offset
114    if max_val is not None and tensor.abs().max() > max_val:
115        tensor = tensor * max_val / tensor.abs().max()
116    return tensor
117
118
119def _parameterized_class_attrs_and_values(
120    min_opset_version: int, max_opset_version: int
121):
122    attrs = ("opset_version", "is_script", "keep_initializers_as_inputs")
123    input_values = []
124    input_values.extend(itertools.product((7, 8), (True, False), (True,)))
125    # Valid opset versions are defined in torch/onnx/_constants.py.
126    # Versions are intentionally set statically, to not be affected by changes elsewhere.
127    if min_opset_version < 9:
128        raise ValueError("min_opset_version must be >= 9")
129    input_values.extend(
130        itertools.product(
131            range(min_opset_version, max_opset_version + 1),
132            (True, False),
133            (True, False),
134        )
135    )
136    return {"attrs": attrs, "input_values": input_values}
137
138
139def _parametrize_rnn_args(arg_name):
140    options = {
141        "layers": {1: "unilayer", 3: "trilayer"},
142        "bidirectional": {True: "bidirectional", False: "forward"},
143        "initial_state": {True: "with_initial_state", False: "no_initial_state"},
144        "packed_sequence": {
145            0: "without_sequence_lengths",
146            1: "with_variable_length_sequences",
147            2: "with_batch_first_sequence_lengths",
148        },
149        "dropout": {0.2: "with_dropout", 0.0: "without_dropout"},
150    }
151
152    return {
153        "arg_str": arg_name,
154        "arg_values": options[arg_name].keys(),
155        "name_fn": lambda val: options[arg_name][val],
156    }
157
158
159@parameterized.parameterized_class(
160    **_parameterized_class_attrs_and_values(
161        onnx_test_common.MIN_ONNX_OPSET_VERSION, onnx_test_common.MAX_ONNX_OPSET_VERSION
162    ),
163    class_name_func=onnx_test_common.parameterize_class_name,
164)
165@common_utils.instantiate_parametrized_tests
166class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
167    def test_fuse_conv_bn1d(self):
168        class Fuse(torch.nn.Module):
169            def __init__(self) -> None:
170                super().__init__()
171                self.conv = torch.nn.Conv1d(16, 33, 3, stride=2)
172                self.bn = torch.nn.BatchNorm1d(33)
173
174            def forward(self, x):
175                out = self.conv(x)
176                return self.bn(out)
177
178        model = Fuse()
179        x = torch.randn(20, 16, 50, requires_grad=True)
180        self.run_test(model, (x,))
181
182    def test_fuse_conv_bn2d(self):
183        class Fuse(torch.nn.Module):
184            def __init__(self) -> None:
185                super().__init__()
186                self.conv = torch.nn.Conv2d(
187                    3, 2, kernel_size=1, stride=2, padding=3, bias=False
188                )
189                self.bn = torch.nn.BatchNorm2d(2)
190
191            def forward(self, x):
192                out = self.conv(x)
193                return self.bn(out)
194
195        model = Fuse()
196        x = torch.randn(2, 3, 2, 2, requires_grad=True)
197        self.run_test(model, (x,))
198
199    def test_fuse_conv_bn3d(self):
200        class Fuse(torch.nn.Module):
201            def __init__(self) -> None:
202                super().__init__()
203                self.conv = torch.nn.Conv3d(
204                    3, 2, (3, 5, 2), stride=(2, 1, 1), padding=(3, 2, 0), bias=False
205                )
206                self.bn = torch.nn.BatchNorm3d(2)
207
208            def forward(self, x):
209                out = self.conv(x)
210                return self.bn(out)
211
212        model = Fuse()
213        x = torch.randn(2, 3, 10, 50, 100, requires_grad=True)
214        self.run_test(model, (x,), rtol=1e-3, atol=1e-6)
215
216    def test_fuse_conv_in_block(self):
217        class Fuse(torch.nn.Module):
218            def __init__(self) -> None:
219                super().__init__()
220                self.conv = torch.nn.Conv1d(
221                    in_channels=5,
222                    out_channels=5,
223                    kernel_size=3,
224                    stride=1,
225                    padding=2,
226                    dilation=1,
227                )
228                self.bn = torch.nn.BatchNorm1d(5)
229
230            def forward(self, x):
231                results_available = True
232
233                if x.sum() > -1:
234                    results_available = False
235
236                if results_available:
237                    x = self.conv(x)
238                    x = self.bn(x)
239
240                return x
241
242        model = Fuse()
243        x = torch.randn(2, 5, 9, requires_grad=True)
244        self.run_test(
245            torch.jit.script(model),
246            (x,),
247            input_names=["x"],
248            dynamic_axes={"x": [0, 2]},
249            rtol=1e-3,
250            atol=1e-6,
251        )
252
253    def test_conv_tbc(self):
254        from torch.nn.modules.utils import _single
255
256        class ConvTBC(torch.nn.Module):
257            def __init__(self, in_channels, out_channels, kernel_size, padding=0):
258                super().__init__()
259                self.in_channels = in_channels
260                self.out_channels = out_channels
261                self.kernel_size = _single(kernel_size)
262                self.padding = _single(padding)
263
264                self.weight = torch.nn.Parameter(
265                    Tensor(self.kernel_size[0], in_channels, out_channels)
266                )
267                self.bias = torch.nn.Parameter(Tensor(out_channels))
268                self.reset_parameters()
269
270            def reset_parameters(self):
271                torch.nn.init.xavier_normal_(self.weight)
272                torch.nn.init.zeros_(self.bias)
273
274            def conv_tbc(self, input):
275                return torch.conv_tbc(
276                    input.contiguous(), self.weight, self.bias, self.padding[0]
277                )
278
279            def forward(self, input):
280                return self.conv_tbc(input)
281
282        in_channels = 3
283        out_channels = 5
284        kernel_size = 5
285        model = ConvTBC(in_channels, out_channels, kernel_size, padding=0)
286        x = torch.randn(10, 7, in_channels, requires_grad=True)
287        self.run_test(model, (x,), atol=1e-5)
288
289    def test_reshape_constant_fold(self):
290        class Reshape(torch.nn.Module):
291            def __init__(
292                self,
293            ):
294                super().__init__()
295                self.weight = torch.nn.Buffer(torch.ones(5))
296
297            def forward(self, x):
298                scale_1 = self.weight.reshape(1, -1, 1, 1)
299                return x * scale_1
300
301        x = torch.randn(4, 5)
302        self.run_test(Reshape(), (x,), rtol=1e-3, atol=1e-5)
303
304    def run_word_language_model(self, model_name):
305        ntokens = 50
306        emsize = 5
307        nhid = 5
308        nlayers = 5
309        dropout = 0.2
310        tied = False
311        batchsize = 5
312        if model_name == "GRU":
313            model = word_language_model.RNNModelWithTensorHidden(
314                model_name, ntokens, emsize, nhid, nlayers, dropout, tied, batchsize
315            )
316        elif model_name == "LSTM":
317            model = word_language_model.RNNModelWithTupleHidden(
318                model_name, ntokens, emsize, nhid, nlayers, dropout, tied, batchsize
319            )
320        else:
321            model = word_language_model.RNNModel(
322                model_name, ntokens, emsize, nhid, nlayers, dropout, tied, batchsize
323            )
324        x = torch.arange(0, ntokens).long().view(-1, batchsize)
325        # Only support CPU version, since tracer is not working in GPU RNN.
326        self.run_test(model, (x, model.hidden))
327
328    def get_image(self, rel_path: str, size: Tuple[int, int]) -> Tensor:
329        from PIL import Image
330        from torchvision import transforms
331
332        data_dir = os.path.join(os.path.dirname(__file__), "assets")
333        path = os.path.join(data_dir, *rel_path.split("/"))
334        image = Image.open(path).convert("RGB").resize(size, Image.BILINEAR)
335
336        return transforms.ToTensor()(image)
337
338    def get_test_images(self) -> Tuple[List[Tensor], List[Tensor]]:
339        return (
340            [self.get_image("grace_hopper_517x606.jpg", (100, 320))],
341            [self.get_image("rgb_pytorch.png", (250, 380))],
342        )
343
344    def test_paste_mask_in_image(self):
345        masks = torch.rand(10, 1, 26, 26)
346        boxes = torch.rand(10, 4)
347        boxes[:, 2:] += torch.rand(10, 2)
348        boxes *= 50
349        o_im_s = (100, 100)
350        from torchvision.models.detection.roi_heads import paste_masks_in_image
351
352        out = paste_masks_in_image(masks, boxes, o_im_s)
353        jit_trace = torch.jit.trace(
354            paste_masks_in_image,
355            (masks, boxes, [torch.tensor(o_im_s[0]), torch.tensor(o_im_s[1])]),
356        )
357        out_trace = jit_trace(
358            masks, boxes, [torch.tensor(o_im_s[0]), torch.tensor(o_im_s[1])]
359        )
360
361        assert torch.all(out.eq(out_trace))
362
363        masks2 = torch.rand(20, 1, 26, 26)
364        boxes2 = torch.rand(20, 4)
365        boxes2[:, 2:] += torch.rand(20, 2)
366        boxes2 *= 100
367        o_im_s2 = (200, 200)
368        from torchvision.models.detection.roi_heads import paste_masks_in_image
369
370        out2 = paste_masks_in_image(masks2, boxes2, o_im_s2)
371        out_trace2 = jit_trace(
372            masks2, boxes2, [torch.tensor(o_im_s2[0]), torch.tensor(o_im_s2[1])]
373        )
374
375        assert torch.all(out2.eq(out_trace2))
376
377    def test_heatmaps_to_keypoints(self):
378        maps = torch.rand(10, 1, 26, 26)
379        rois = torch.rand(10, 4)
380        from torchvision.models.detection.roi_heads import heatmaps_to_keypoints
381
382        out = heatmaps_to_keypoints(maps, rois)
383        jit_trace = torch.jit.trace(heatmaps_to_keypoints, (maps, rois))
384        out_trace = jit_trace(maps, rois)
385
386        assert torch.all(out[0].eq(out_trace[0]))
387        assert torch.all(out[1].eq(out_trace[1]))
388
389        maps2 = torch.rand(20, 2, 21, 21)
390        rois2 = torch.rand(20, 4)
391        from torchvision.models.detection.roi_heads import heatmaps_to_keypoints
392
393        out2 = heatmaps_to_keypoints(maps2, rois2)
394        out_trace2 = jit_trace(maps2, rois2)
395
396        assert torch.all(out2[0].eq(out_trace2[0]))
397        assert torch.all(out2[1].eq(out_trace2[1]))
398
399    def test_word_language_model_RNN_TANH(self):
400        self.run_word_language_model("RNN_TANH")
401
402    def test_word_language_model_RNN_RELU(self):
403        self.run_word_language_model("RNN_RELU")
404
405    @skipScriptTest()  # scripting prim::unchecked_cast prim::setattr
406    def test_word_language_model_LSTM(self):
407        self.run_word_language_model("LSTM")
408
409    def test_word_language_model_GRU(self):
410        self.run_word_language_model("GRU")
411
412    def test_index_1d(self):
413        class MyModel(torch.nn.Module):
414            def forward(self, input):
415                return input[0]
416
417        m1 = torch.randn(3, 4, 5, 6, 7)
418        self.run_test(MyModel(), m1)
419
420    def test_index_2d_1dimslice(self):
421        class MyModel(torch.nn.Module):
422            def forward(self, input):
423                return input[0:1, :]
424
425        m1 = torch.randn(3, 4, 5, 6, 7)
426        self.run_test(MyModel(), m1)
427
428    def test_index_2d_sliceint(self):
429        class MyModel(torch.nn.Module):
430            def forward(self, input):
431                return input[1, :]
432
433        m1 = torch.randn(3, 4, 5, 6, 7)
434        self.run_test(MyModel(), m1)
435
436    def test_index_2d_neg_slice(self):
437        class MyModel(torch.nn.Module):
438            def forward(self, input):
439                return input[0:-1, :]
440
441        m1 = torch.randn(3, 4, 5, 6, 7)
442        self.run_test(MyModel(), m1)
443
444    @skipIfUnsupportedMinOpsetVersion(9)
445    def test_index_mask(self):
446        class MyModel(torch.nn.Module):
447            def forward(self, input):
448                return input[torch.tensor([0, 1, 0], dtype=torch.uint8)]
449
450        m1 = torch.randn(3, 4, 5, 6, 7)
451        self.run_test(MyModel(), m1)
452
453        class MyModel(torch.nn.Module):
454            def forward(self, input):
455                return input[torch.tensor([0, 1, 0], dtype=torch.bool)]
456
457        m1 = torch.randn(3, 4, 5, 6, 7)
458        self.run_test(MyModel(), m1)
459
460    @skipIfUnsupportedMinOpsetVersion(9)
461    def test_data(self):
462        class Data(torch.jit.ScriptModule):
463            @torch.jit.script_method
464            def forward(self, x):
465                return x.new_zeros(x.data.size())
466
467        x = torch.randn(3, 4)
468        self.run_test(Data(), x, input_names=["x"], dynamic_axes={"x": [0, 1]})
469        self.run_test(Data(), x, remained_onnx_input_idx=[])
470
471    @skipIfUnsupportedMinOpsetVersion(11)
472    def test_index_mask_nd(self):
473        class MyModel(torch.nn.Module):
474            def forward(self, input):
475                return input[input > 0]
476
477        m1 = torch.randn(3, 4, 5, 6, 7)
478        self.run_test(MyModel(), m1)
479
480    @skipScriptTest()
481    def test_dict(self):
482        class MyModel(torch.nn.Module):
483            def forward(self, x_in):
484                x_out = {}
485                x_out["test_key_out"] = torch.add(
486                    x_in[list(x_in.keys())[0]],  # noqa: RUF015
487                    list(x_in.keys())[0],  # noqa: RUF015
488                )
489                return x_out
490
491        x = {torch.tensor(1.0): torch.randn(1, 2, 3)}
492        self.run_test(MyModel(), (x,))
493
494    @skipScriptTest()
495    def test_dict_str(self):
496        class MyModel(torch.nn.Module):
497            def forward(self, x_in):
498                x_out = {}
499                x_out["test_key_out"] = torch.add(x_in["test_key_in"], 2.0)
500                return x_out
501
502        x = {"test_key_in": torch.randn(1, 2, 3)}
503        self.run_test(MyModel(), (x,))
504
505    @skipScriptTest()  # User-defined class not supported
506    def test_dict_output(self):
507        class DictModelOutput(OrderedDict):
508            tensor_out: Tensor
509            tuple_out: Optional[Tuple[Tensor]] = None
510            list_out: Optional[List[Tensor]] = None
511
512        class MyModel(torch.nn.Module):
513            def forward(self, a, b, c, d):
514                return DictModelOutput(
515                    tensor_out=a,
516                    tuple_out=(b, c),
517                    list_out=[d],
518                )
519
520        a = torch.randn(2, 3)
521        b = torch.randn(2, 3)
522        c = torch.randn(2, 3)
523        d = torch.randn(2, 3)
524        self.run_test(MyModel(), (a, b, c, d))
525
526    def test_tuple_output(self):
527        class MyModel(torch.nn.Module):
528            def forward(self, a, b, c, d):
529                return a, (b, c), d
530
531        a = torch.randn(2, 3)
532        b = torch.randn(2, 3)
533        c = torch.randn(2, 3)
534        d = torch.randn(2, 3)
535        self.run_test(MyModel(), (a, b, c, d))
536
537    def test_nested_tuple_output(self):
538        class MyModel(torch.nn.Module):
539            def forward(self, a, b, c, d):
540                return a, ((b,), (c, d))
541
542        a = torch.randn(2, 3)
543        b = torch.randn(2, 3)
544        c = torch.randn(2, 3)
545        d = torch.randn(2, 3)
546        self.run_test(MyModel(), (a, b, c, d))
547
548    def test_tuple_input(self):
549        class TupleModel(torch.nn.Module):
550            def forward(self, a: Tuple[Tensor, Tensor]):
551                return a
552
553        x = (torch.randn(3, 4), torch.randn(4, 3))
554        self.run_test(TupleModel(), input_args=(x,))
555
556    def test_tuple_primitive_input(self):
557        class TupleModel(torch.nn.Module):
558            def forward(self, a: Tuple[int, Tensor], b):
559                return a[0], a[1] + b
560
561        x = (3, torch.randn(4, 3))
562        y = torch.randn(4, 3)
563        self.run_test(TupleModel(), input_args=(x, y))
564
565    def test_nested_tuple_input(self):
566        class NestedTupleModel(torch.nn.Module):
567            def forward(self, a, b: Tuple[Tensor, Tuple[Tensor, Tensor]]):
568                return a + b[0] + b[1][0] + b[1][1]
569
570        x = torch.randn(4, 5)
571        y = (torch.randn(4, 5), (torch.randn(1, 5), torch.randn(4, 1)))
572        self.run_test(NestedTupleModel(), input_args=(x, y))
573
574    @skipScriptTest()  # Needs https://github.com/pytorch/rfcs/pull/21
575    @skipIfUnsupportedMinOpsetVersion(15)
576    def test_mixed_optional_default_none(self):
577        class Model(torch.nn.Module):
578            def forward(
579                self,
580                x,
581                y: Optional[Tensor] = None,
582                z: Optional[Tensor] = None,
583            ):
584                if y is not None:
585                    return x + y
586                if z is not None:
587                    return x + z
588                return x
589
590        x = torch.randn(2, 3)
591        y = torch.randn(2, 3)
592        z = torch.randn(2, 3)
593        model = Model()
594        # Without kwargs dict.
595        self.run_test(model, (x, y, None))
596        self.run_test(model, (x, None, z))
597        # With kwargs dict.
598        self.run_test(model, (x,), {"y": y, "z": None})
599        self.run_test(model, (x,), {"y": None, "z": z})
600        self.run_test(model, (x,), {"z": z})
601        self.run_test(model, (x,), {"y": y})
602
603    @skipScriptTest()  # tracing eliminates None inputs so it works differently. See _script version below.
604    @skipIfUnsupportedMinOpsetVersion(15)
605    def test_mixed_optional_default_tensor(self):
606        class Model(torch.nn.Module):
607            def forward(
608                self,
609                x,
610                y: Optional[Tensor] = torch.ones(2, 3),
611                z: Optional[Tensor] = torch.zeros(2, 3),
612            ):
613                if y is not None:
614                    return x + y
615                if z is not None:
616                    return x + z
617                return x
618
619        x = torch.randn(2, 3)
620        y = torch.randn(2, 3)
621        z = torch.randn(2, 3)
622        model = Model()
623
624        self.run_test(model, (x, y, None))
625        self.run_test(model, (x, None, z))
626
627    @skipTraceTest()  # tracing is verified with different set of inputs. See above.
628    @skipIfUnsupportedMinOpsetVersion(15)
629    def test_mixed_optional_default_tensor_script(self):
630        class Model(torch.nn.Module):
631            def forward(
632                self,
633                x,
634                y: Optional[Tensor] = torch.ones(2, 3),
635                z: Optional[Tensor] = torch.zeros(2, 3),
636            ):
637                if y is not None:
638                    return x + y
639                if z is not None:
640                    return x + z
641                return x
642
643        x = torch.randn(2, 3)
644        y = torch.randn(2, 3)
645        z = torch.randn(2, 3)
646        model = torch.jit.script(Model())
647
648        self.run_test(model, (x, y, z), input_names=("x", "y", "z"))
649        self.run_test(model, (x,), {"y": y, "z": z}, input_names=("x", "y", "z"))
650        self.run_test(model, (x,), {"y": y}, input_names=("x", "y"))
651
652        for example_inputs, example_kwargs in (
653            ((x, y, None), {}),
654            ((x, None, z), {}),
655            ((x,), {"y": y, "z": None}),
656            ((x,), {"y": None, "z": z}),
657        ):
658            with self.assertRaisesRegex(
659                ValueError, "args contained 1 None's after flattening."
660            ):
661                self.run_test(
662                    model, example_inputs, example_kwargs, input_names=("x", "y", "z")
663                )
664
665    @skipScriptTest()  # Needs https://github.com/pytorch/rfcs/pull/21
666    @skipIfUnsupportedMinOpsetVersion(15)
667    def test_all_optional_default_none(self):
668        class Model(torch.nn.Module):
669            def forward(self, x: Optional[Tensor] = None, y: Optional[Tensor] = None):
670                if x is not None:
671                    return x
672                if y is not None:
673                    return y
674                else:
675                    return torch.tensor(-1.0)
676
677        x = torch.randn(2, 3)
678        model = Model()
679        self.run_test(model, (x, None))
680        self.run_test(
681            model,
682            (),
683            {"x": x, "y": None},
684            # y disappears in tracing.
685            input_names=("x",),
686        )
687
688    @skipScriptTest()  # tracing eliminates None inputs so it works differently. See _script version below.
689    @skipIfUnsupportedMinOpsetVersion(15)
690    def test_all_optional_default_tensor(self):
691        class Model(torch.nn.Module):
692            def forward(
693                self,
694                x: Optional[Tensor] = torch.ones(2, 3),
695                y: Optional[Tensor] = torch.zeros(2, 3),
696            ):
697                if x is not None:
698                    return x
699                elif y is not None:
700                    return y
701                else:
702                    return torch.tensor(-1.0)
703
704        x = torch.randn(2, 3)
705        y = torch.randn(2, 3)
706        model = Model()
707        self.run_test(model, (x, None))
708        self.run_test(model, (None, y))
709        # tracing means y is never used so it's removed from the exported model inputs,
710        # and we fail when trying to run ORT.
711        with self.assertRaisesRegex(ValueError, "got too many positional inputs"):
712            self.run_test(model, (x, y))
713
714    @skipTraceTest()  # tracing is verified with different set of inputs. See above.
715    @skipIfUnsupportedMinOpsetVersion(15)
716    def test_all_optional_default_tensor_script(self):
717        class Model(torch.nn.Module):
718            def forward(
719                self,
720                x: Optional[Tensor] = torch.ones(2, 3),
721                y: Optional[Tensor] = torch.zeros(2, 3),
722            ):
723                if x is not None:
724                    return x
725                elif y is not None:
726                    return y
727                else:
728                    return torch.tensor(-1.0)
729
730        x = torch.randn(2, 3)
731        y = torch.randn(2, 3)
732        model = torch.jit.script(Model())
733
734        # Optional supports None inputs
735        self.run_test(model, (x,))
736        # NOTE: default value is not supported on ONNX, so torch and ONNX has
737        # different behavior
738        with self.assertRaisesRegex(AssertionError, "Tensor-likes are not close!"):
739            self.run_test(model, (), {"y": y}, input_names=["y"])
740
741        self.run_test(model, (x, y))
742        self.run_test(model, (), {"x": x, "y": y}, input_names=("x", "y"))
743
744    @skipIfUnsupportedMinOpsetVersion(9)
745    def test_logit(self):
746        class Logit(torch.nn.Module):
747            def __init__(self, eps):
748                super().__init__()
749                self.eps = eps
750
751            def forward(self, x):
752                return x.logit(self.eps)
753
754        model = Logit(eps=1e-6)
755        self.run_test(model, torch.randn(1, 3, 640, 640))
756
757    class Atleast1d(torch.nn.Module):
758        def forward(self, t, w, x, y, z):
759            return torch.atleast_1d((t, w, x, y, z))
760
761    class Atleast2d(torch.nn.Module):
762        def forward(self, t, w, x, y, z):
763            return torch.atleast_2d((t, w, x, y, z))
764
765    class Atleast3d(torch.nn.Module):
766        def forward(self, t, w, x, y, z):
767            return torch.atleast_3d((t, w, x, y, z))
768
769    class Atleast1dTensor(torch.nn.Module):
770        def forward(self, x):
771            return torch.atleast_1d(x)
772
773    class Atleast2dTensor(torch.nn.Module):
774        def forward(self, x):
775            return torch.atleast_2d(x)
776
777    class Atleast3dTensor(torch.nn.Module):
778        def forward(self, x):
779            return torch.atleast_3d(x)
780
781    @skipScriptTest()  # tracing uses prim::ListUnpack to avoid onnx::SequenceConstruct
782    @skipIfUnsupportedMinOpsetVersion(11)
783    @common_utils.parametrize("module_class", (Atleast1d, Atleast2d, Atleast3d))
784    def test_atleast_nd_list_input(self, module_class: torch.nn.Module):
785        inputs = (
786            torch.tensor(1.0),
787            torch.randn(2),
788            torch.randn(2, 3),
789            torch.randn(2, 3, 4),
790            torch.randn(2, 3, 4, 5),
791        )
792        self.run_test(module_class(), inputs)
793
794    @skipScriptTest()  # tracing uses prim::ListUnpack to avoid onnx::SequenceConstruct
795    @skipIfUnsupportedMinOpsetVersion(11)
796    @common_utils.parametrize(
797        "module_class", (Atleast1dTensor, Atleast2dTensor, Atleast3dTensor)
798    )
799    @common_utils.parametrize(
800        "inputs",
801        [
802            torch.tensor(1.0),
803            torch.randn(2),
804            torch.randn(2, 3),
805            torch.randn(2, 3, 4),
806            torch.randn(2, 3, 4, 5),
807        ],
808    )
809    def test_atleast_nd_single_tensor_input(
810        self, module_class: torch.nn.Module, inputs: torch.Tensor
811    ):
812        self.run_test(module_class(), inputs)
813
814    @skipScriptTest()  # Needs https://github.com/pytorch/rfcs/pull/21
815    @skipIfUnsupportedMinOpsetVersion(15)
816    def test_mixed_optional(self):
817        class Model(torch.nn.Module):
818            def forward(self, x, y: Optional[Tensor]):
819                if y is not None:
820                    return x + y
821                return x
822
823        x = torch.randn(2, 3)
824        model = Model()
825        self.run_test(model, (x, None))
826        self.run_test(model, (x, x))
827
828    @skipScriptTest()  # Needs https://github.com/pytorch/rfcs/pull/21
829    @skipIfUnsupportedMinOpsetVersion(15)
830    def test_tuple_of_optional(self):
831        class Model(torch.nn.Module):
832            def forward(self, x, y: Tuple[Optional[Tensor], Optional[Tensor]]):
833                if y[0] is not None:
834                    return x + y[0]
835                if y[1] is not None:
836                    return x + y[1]
837                return x
838
839        x = torch.randn(2, 3)
840        y1 = torch.randn(2, 3)
841        self.run_test(Model(), (x, (None, y1)))
842
843    @skipScriptTest()  # tracing eliminates None inputs so it works differently. See _script version below.
844    @skipIfUnsupportedMinOpsetVersion(15)
845    def test_tuple_of_optional_default_tensor(self):
846        class Model(torch.nn.Module):
847            def forward(
848                self,
849                x,
850                y: Tuple[Optional[Tensor], Optional[Tensor]] = (
851                    torch.zeros(2, 3),
852                    torch.zeros(2, 3),
853                ),
854            ):
855                y0, y1 = y
856                if y0 is not None:
857                    return x + y0
858                if y1 is not None:
859                    return x + y1
860                return x
861
862        x = torch.randn(2, 3)
863        y1 = torch.randn(2, 3)
864        self.run_test(Model(), (x, (None, y1)))
865
866    @skipTraceTest()  # tracing is verified with different set of inputs. See above.
867    @skipIfUnsupportedMinOpsetVersion(15)
868    def test_tuple_of_optional_default_tensor_script(self):
869        class Model(torch.nn.Module):
870            def forward(
871                self,
872                x,
873                y: Tuple[Optional[Tensor], Optional[Tensor]] = (
874                    torch.zeros(2, 3),
875                    torch.zeros(2, 3),
876                ),
877            ):
878                y0, y1 = y
879                if y0 is not None:
880                    return x + y0
881                if y1 is not None:
882                    return x + y1
883                return x
884
885        x = torch.randn(2, 3)
886        y0 = torch.randn(2, 3)
887        y1 = torch.randn(2, 3)
888        model = torch.jit.script(Model())
889        with self.assertRaisesRegex(
890            ValueError, "args contained 1 None's after flattening."
891        ):
892            self.run_test(model, (x, (None, y1)))
893        self.run_test(model, (x, (y0, y1)))
894        # export succeeds, but running ORT through run_test would fail because the exported model
895        # has the inputs flattened into 3 inputs.
896        torch.onnx.export(
897            model, (x, {"y": (y0, y1)}), io.BytesIO(), opset_version=self.opset_version
898        )
899
900    def test_primitive_input_integer(self):
901        class Model(torch.nn.Module):
902            def forward(self, x: int, y):
903                return x + y
904
905        x = 3
906        y = torch.randint(10, (2, 3, 4))
907        self.run_test(Model(), (x, y))
908
909    @skipDtypeChecking
910    def test_primitive_input_floating(self):
911        class Model(torch.nn.Module):
912            def forward(self, x: float, y):
913                return x + y
914
915        x = 3.0
916        y = torch.randn(2, 3, 4)
917        self.run_test(Model(), (x, y))
918
919    def test_primitive_input_bool(self):
920        class Model(torch.nn.Module):
921            def forward(self, flag: bool, x, y):
922                if flag:
923                    return x
924                else:
925                    return y
926
927        flag = True
928        x = torch.randn(2, 3, 4)
929        y = torch.randn(2, 3, 4)
930        self.run_test(torch.jit.script(Model()), (flag, x, y))
931
932    @skipIfUnsupportedMinOpsetVersion(9)
933    def test_cste_script(self):
934        class MyModel(torch.jit.ScriptModule):
935            @torch.jit.script_method
936            def forward(self, x):
937                return torch.zeros(x.size(0)), torch.ones(
938                    (x.size(1), x.size(0)), dtype=torch.int64
939                )
940
941        x = torch.randn(3, 4)
942        self.run_test(MyModel(), x, input_names=["x"], dynamic_axes={"x": [0, 1]})
943        self.run_test(MyModel(), x, remained_onnx_input_idx=[])
944
945    def test_scalar_tensor(self):
946        class test(torch.nn.Module):
947            def forward(self, input):
948                return torch.scalar_tensor(input.size(0)), torch.scalar_tensor(
949                    input.size(1), dtype=torch.int64
950                )
951
952        x = torch.randn(2, 3, 4)
953        y = torch.randn(7, 8, 9)
954        model = test()
955        self.run_test(
956            model,
957            x,
958            additional_test_inputs=[y],
959            input_names=["input_1"],
960            dynamic_axes={"input_1": [0, 1, 2]},
961        )
962
963    def test_tensor(self):
964        class ScalarInputModel(torch.jit.ScriptModule):
965            @torch.jit.script_method
966            def forward(self, input):
967                return torch.tensor(input.shape[1])
968
969        x = torch.randn(3, 4)
970        self.run_test(
971            ScalarInputModel(), x, input_names=["x"], dynamic_axes={"x": [0, 1]}
972        )
973        self.run_test(ScalarInputModel(), x, remained_onnx_input_idx=[])
974
975        class TensorInputModel(torch.jit.ScriptModule):
976            @torch.jit.script_method
977            def forward(self, input):
978                return torch.tensor([input.shape[0], input.shape[1]])
979
980        x = torch.randn(3, 4)
981        self.run_test(
982            TensorInputModel(), x, input_names=["x"], dynamic_axes={"x": [0, 1]}
983        )
984        self.run_test(TensorInputModel(), x, remained_onnx_input_idx=[])
985
986        class FloatInputModel(torch.jit.ScriptModule):
987            @torch.jit.script_method
988            def forward(self, input):
989                return torch.tensor([float(input)])
990
991        x = torch.randn(1)
992        self.run_test(FloatInputModel(), x)
993
994        class InputWithDtypeModel(torch.jit.ScriptModule):
995            @torch.jit.script_method
996            def forward(self, input):
997                return torch.tensor(input.shape[1], dtype=torch.long)
998
999        x = torch.randn(3, 4)
1000        self.run_test(
1001            InputWithDtypeModel(), x, input_names=["x"], dynamic_axes={"x": [0, 1]}
1002        )
1003        self.run_test(InputWithDtypeModel(), x, remained_onnx_input_idx=[])
1004
1005        class MixedInputModel(torch.jit.ScriptModule):
1006            @torch.jit.script_method
1007            def forward(self, input):
1008                return torch.tensor([input.shape[0], int(input)])
1009
1010        x = torch.randn(1)
1011        self.run_test(MixedInputModel(), x)
1012
1013    def test_hardtanh(self):
1014        model = torch.nn.Hardtanh(-1.5, 2.5)
1015        x = torch.arange(-5, 5).to(dtype=torch.float32)
1016        self.run_test(model, x)
1017
1018    def test_hardtanh_script_with_default_values(self):
1019        class MyModel(torch.jit.ScriptModule):
1020            @torch.jit.script_method
1021            def forward(self, x):
1022                return torch.nn.functional.hardtanh(x)
1023
1024        x = torch.arange(-5, 5).to(dtype=torch.float32)
1025        self.run_test(MyModel(), x)
1026
1027    def test_hardswish(self):
1028        model = torch.nn.Hardswish()
1029
1030        x = torch.rand(3, 3).to(dtype=torch.float32)
1031        self.run_test(model, x)
1032
1033        # Testing edge cases
1034        x = torch.tensor(3).to(dtype=torch.float32)
1035        self.run_test(model, x)
1036        x = torch.tensor(-3).to(dtype=torch.float32)
1037        self.run_test(model, x)
1038
1039    def test_hardswish_script(self):
1040        class MyModel(torch.jit.ScriptModule):
1041            @torch.jit.script_method
1042            def forward(self, x):
1043                return torch.nn.functional.hardswish(x)
1044
1045        x = torch.rand(3, 3).to(dtype=torch.float32)
1046        self.run_test(MyModel(), x)
1047
1048    def test_hardsigmoid(self):
1049        model = torch.nn.Hardsigmoid()
1050
1051        x = torch.rand(3, 3).to(dtype=torch.float32)
1052        self.run_test(model, x)
1053
1054        # corner cases
1055        x = torch.tensor(3).to(dtype=torch.float32)
1056        self.run_test(model, x)
1057        x = torch.tensor(-3).to(dtype=torch.float32)
1058        self.run_test(model, x)
1059
1060    def test_tanhshrink(self):
1061        model = torch.nn.Tanhshrink()
1062
1063        x = torch.rand(3, 3).to(dtype=torch.float32)
1064        self.run_test(model, x)
1065
1066    @skipIfUnsupportedMinOpsetVersion(9)
1067    def test_hardshrink(self):
1068        model = torch.nn.Hardshrink()
1069
1070        x = torch.rand(3, 3).to(dtype=torch.float32)
1071        self.run_test(model, x)
1072
1073        # Testing edge cases
1074        x = torch.tensor(0.5).to(dtype=torch.float32)
1075        self.run_test(model, x)
1076        x = torch.tensor(-0.5).to(dtype=torch.float32)
1077        self.run_test(model, x)
1078
1079    @skipIfUnsupportedMinOpsetVersion(9)
1080    def test_hardshrink_dtype(self):
1081        x = torch.rand(3, 3).to(dtype=torch.float64)
1082        self.run_test(torch.nn.Hardshrink(), x)
1083
1084    @skipIfUnsupportedMinOpsetVersion(9)
1085    def test_softshrink(self):
1086        model = torch.nn.Softshrink()
1087
1088        x = torch.rand(3, 3).to(dtype=torch.float32)
1089        self.run_test(model, x)
1090
1091        # Testing edge cases
1092        x = torch.tensor(0.5).to(dtype=torch.float32)
1093        self.run_test(model, x)
1094        x = torch.tensor(-0.5).to(dtype=torch.float32)
1095        self.run_test(model, x)
1096
1097    @skipIfUnsupportedMinOpsetVersion(9)
1098    def test_softshrink_dtype(self):
1099        x = torch.rand(3, 3).to(dtype=torch.float64)
1100        self.run_test(torch.nn.Softshrink(), x)
1101
1102    def test_clamp(self):
1103        class ClampModel(torch.nn.Module):
1104            def forward(self, x):
1105                return x.clamp(-0.5, 0.5)
1106
1107        x = torch.randn(3, 4)
1108        self.run_test(ClampModel(), x)
1109
1110        class ClampMinModel(torch.nn.Module):
1111            def forward(self, x):
1112                return x.clamp(min=-0.5)
1113
1114        x = torch.randn(3, 4)
1115        self.run_test(ClampMinModel(), x)
1116
1117        class ClampMaxModel(torch.nn.Module):
1118            def forward(self, x):
1119                return x.clamp(max=0.5)
1120
1121        x = torch.randn(3, 4)
1122        self.run_test(ClampMaxModel(), x)
1123
1124    @skipIfUnsupportedMinOpsetVersion(8)
1125    def test_clamp_dyn(self):
1126        class ClampMaxModel(torch.jit.ScriptModule):
1127            @torch.jit.script_method
1128            def forward(self, x):
1129                return x.clamp(None, x.size(0))
1130
1131        x = torch.arange(16).view(4, 4).float()
1132        self.run_test(ClampMaxModel(), x)
1133
1134        class ClampMinModel(torch.jit.ScriptModule):
1135            @torch.jit.script_method
1136            def forward(self, x):
1137                return x.clamp(x.size(0), None)
1138
1139        x = torch.arange(16).view(4, 4).float()
1140        self.run_test(ClampMinModel(), x)
1141
1142        class ClampMinMaxModel(torch.jit.ScriptModule):
1143            @torch.jit.script_method
1144            def forward(self, x):
1145                return x.clamp(x.size(0), x.size(1))
1146
1147        x = torch.arange(16).view(2, 8).float()
1148        self.run_test(ClampMinMaxModel(), x)
1149
1150        class ClampTensorModel(torch.nn.Module):
1151            def forward(self, x, min, max):
1152                return x.clamp(min, max)
1153
1154        x = torch.randn(3, 4)
1155        y = torch.randn(3, 4)
1156        z = torch.randn(3, 4)
1157        self.run_test(ClampTensorModel(), (x, y, z))
1158
1159        class ClampTensorMinModel(torch.nn.Module):
1160            def forward(self, x, min):
1161                return x.clamp(min=min)
1162
1163        self.run_test(ClampTensorMinModel(), (x, y))
1164
1165        class ClampTensorMaxModel(torch.nn.Module):
1166            def forward(self, x, max):
1167                return x.clamp(max=max)
1168
1169        self.run_test(ClampTensorMaxModel(), (x, z))
1170
1171    @skipIfUnsupportedMinOpsetVersion(9)
1172    def test_full_trace(self):
1173        class FullModel(torch.nn.Module):
1174            def forward(self, x):
1175                return torch.full((3, 4), x, dtype=torch.long)
1176
1177        x = torch.tensor(12)
1178        self.run_test(FullModel(), x)
1179
1180    @skipIfUnsupportedMinOpsetVersion(9)
1181    def test_full_script(self):
1182        class FullModelScripting(torch.jit.ScriptModule):
1183            @torch.jit.script_method
1184            def forward(self, x):
1185                return torch.full((3, 4), x, dtype=torch.long)
1186
1187        x = torch.tensor(12)
1188        self.run_test(FullModelScripting(), x)
1189
1190    def test_fuse_addmm(self):
1191        class AddmmModel(torch.nn.Module):
1192            def forward(self, x):
1193                return torch.mm(x, x) + x
1194
1195        x = torch.ones(3, 3)
1196        self.run_test(AddmmModel(), x)
1197
1198    def test_maxpool(self):
1199        model = torch.nn.MaxPool1d(2, stride=1)
1200        x = torch.randn(20, 16, 50)
1201        self.run_test(model, x)
1202
1203    def test_conv(self):
1204        class TraceModel(torch.nn.Module):
1205            def __init__(self) -> None:
1206                super().__init__()
1207                self.conv1 = torch.nn.Conv1d(16, 33, 3, stride=2)
1208                self.conv2 = torch.nn.Conv2d(
1209                    16, 33, (3, 5), stride=(2, 1), padding=(4, 2), dilation=(3, 1)
1210                )
1211                self.conv3 = torch.nn.Conv3d(
1212                    16, 33, (3, 5, 2), stride=(2, 1, 1), padding=(4, 2, 0)
1213                )
1214
1215            def forward(self, input1, input2, input3):
1216                return self.conv1(input1), self.conv2(input2), self.conv3(input3)
1217
1218        x1 = torch.randn(20, 16, 50)
1219        x2 = torch.randn(20, 16, 50, 50)
1220        x3 = torch.randn(20, 16, 10, 50, 50)
1221
1222        self.run_test(TraceModel(), (x1, x2, x3), atol=10e-5)
1223
1224    def test_conv_str_padding(self):
1225        class TraceModel(torch.nn.Module):
1226            def __init__(self) -> None:
1227                super().__init__()
1228                self.conv1 = torch.nn.Conv1d(16, 33, 3, padding="valid")
1229                self.conv2 = torch.nn.Conv2d(
1230                    16, 33, (3, 5), stride=1, padding="valid", dilation=(3, 1)
1231                )
1232                self.conv3 = torch.nn.Conv3d(
1233                    16, 33, (3, 5, 2), stride=1, padding="same"
1234                )
1235
1236            def forward(self, input1, input2, input3):
1237                return self.conv1(input1), self.conv2(input2), self.conv3(input3)
1238
1239        x1 = torch.randn(20, 16, 50)
1240        x2 = torch.randn(20, 16, 50, 50)
1241        x3 = torch.randn(20, 16, 10, 50, 50)
1242
1243        self.run_test(TraceModel(), (x1, x2, x3), atol=10e-5)
1244
1245    def test_conv_shape_inference(self):
1246        class Model(torch.nn.Module):
1247            def __init__(self) -> None:
1248                super().__init__()
1249                self.conv2 = torch.nn.Conv2d(
1250                    16, 33, (3, 5), stride=(2, 1), padding=(4, 2), dilation=(3, 1)
1251                )
1252
1253            def forward(self, input):
1254                return self.conv2(input) + 2
1255
1256        x = torch.randn(20, 16, 50, 100)
1257        self.run_test(
1258            Model(), x, atol=10e-5, input_names=["x"], dynamic_axes={"x": [0]}
1259        )
1260
1261    def test_conv_transpose(self):
1262        class TraceModel(torch.nn.Module):
1263            def __init__(self) -> None:
1264                super().__init__()
1265                self.conv1 = torch.nn.ConvTranspose1d(16, 33, 3, stride=2)
1266                self.conv2 = torch.nn.ConvTranspose2d(
1267                    16, 33, (3, 5), stride=(2, 1), padding=(4, 2), dilation=(3, 1)
1268                )
1269                self.conv3 = torch.nn.ConvTranspose3d(
1270                    16, 33, (3, 5, 2), stride=(2, 1, 1), padding=(4, 2, 0)
1271                )
1272
1273            def forward(self, input1, input2, input3):
1274                return self.conv1(input1), self.conv2(input2), self.conv3(input3)
1275
1276        x1 = torch.randn(20, 16, 10)
1277        x2 = torch.randn(20, 16, 10, 10)
1278        x3 = torch.randn(20, 16, 10, 10, 10)
1279
1280        self.run_test(TraceModel(), (x1, x2, x3), atol=10e-5)
1281
1282    def test_numpy_T(self):
1283        class NumpyTranspose(torch.nn.Module):
1284            def forward(self, x):
1285                return x.T
1286
1287        self.run_test(NumpyTranspose(), torch.randn(4, 7))
1288
1289    # Conversion of Transpose depends on input shape to be known.
1290    # The following test only works when onnx shape inference is enabled.
1291    def test_transpose_infer_shape(self):
1292        class TransposeModule(torch.jit.ScriptModule):
1293            def __init__(self) -> None:
1294                super().__init__()
1295                self.conv = torch.nn.Conv2d(3, 1, 3, stride=2)
1296
1297            @torch.jit.script_method
1298            def forward(self, x):
1299                x = self.conv(x)
1300                return x.transpose(0, 1)
1301
1302        x = torch.randn(32, 3, 64, 64)
1303        y = torch.randn(16, 3, 8, 64)
1304        self.run_test(
1305            TransposeModule(),
1306            x,
1307            input_names=["x"],
1308            dynamic_axes={"x": [0, 2]},
1309            additional_test_inputs=[y],
1310        )
1311
1312    def squeeze_model_tests(self, d, x1, x2):
1313        class Squeeze(torch.nn.Module):
1314            def __init__(self, d):
1315                super().__init__()
1316                self.d = d
1317
1318            def forward(self, x):
1319                if self.d is not None:
1320                    return torch.squeeze(x, dim=self.d)
1321                else:
1322                    return torch.squeeze(x)
1323
1324        x2 = [] if x2 is None else [x2]
1325        if len(x2) > 0:
1326            self.run_test(
1327                Squeeze(d),
1328                x1,
1329                input_names=["input"],
1330                dynamic_axes={"input": {0: "0", 1: "1", 2: "2"}},
1331                additional_test_inputs=x2,
1332            )
1333        else:
1334            self.run_test(Squeeze(d), x1)
1335
1336    def test_squeeze_without_no_op(self):
1337        x = torch.randn(2, 1, 4)
1338        self.squeeze_model_tests(1, x, None)
1339
1340    @skipIfUnsupportedMinOpsetVersion(11)
1341    def test_squeeze_dynamic(self):
1342        x_squeeze = torch.randn(2, 1, 4)
1343        x_noop = torch.randn(2, 2, 3)
1344        self.squeeze_model_tests(1, x_squeeze, x_noop)
1345
1346    def test_squeeze_neg_without_no_op(self):
1347        x = torch.randn(2, 1, 4)
1348        self.squeeze_model_tests(-2, x, None)
1349
1350    @skipIfUnsupportedMinOpsetVersion(11)
1351    def test_squeeze_neg(self):
1352        x_squeeze = torch.randn(2, 1, 4)
1353        x_noop = torch.randn(2, 2, 3)
1354        self.squeeze_model_tests(-2, x_squeeze, x_noop)
1355
1356    def test_squeeze_all_dims(self):
1357        x_squeeze = torch.randn(2, 1, 4)
1358        x_noop = torch.randn(2, 2, 3)
1359        self.squeeze_model_tests(None, x_squeeze, x_noop)
1360
1361    @skipIfUnsupportedMinOpsetVersion(11)
1362    def test_squeeze_no_op(self):
1363        x_noop = torch.randn(2, 1, 4)
1364        x_squeeze = torch.randn(2, 2, 1)
1365        self.squeeze_model_tests(2, x_noop, x_squeeze)
1366
1367    @skipIfUnsupportedMinOpsetVersion(11)
1368    def test_squeeze_runtime_dim(self):
1369        class Squeeze(torch.nn.Module):
1370            def forward(self, d1, d2):
1371                t = torch.zeros(d1[0], d2[0])
1372                return t.squeeze(0)
1373
1374        d1 = torch.tensor([1])
1375        d3 = torch.tensor([3])
1376        d4 = torch.tensor([4])
1377        self.run_test(Squeeze(), (d1, d4), additional_test_inputs=[(d3, d4)])
1378        self.run_test(Squeeze(), (d3, d4), additional_test_inputs=[(d1, d3)])
1379
1380    def test_squeeze(self):
1381        class Squeeze(torch.nn.Module):
1382            def forward(self, x):
1383                return torch.squeeze(x, dim=-2)
1384
1385        x = torch.randn(2, 1, 4)
1386        self.run_test(Squeeze(), x)
1387
1388    @skipIfUnsupportedMinOpsetVersion(13)
1389    def test_squeeze_dynamic_dim(self):
1390        class Squeeze(torch.nn.Module):
1391            def forward(self, x, dim: int):
1392                return torch.squeeze(x, dim)
1393
1394        x = torch.randn(2, 1, 4)
1395        dim = 1
1396        self.run_test(Squeeze(), (x, dim))
1397
1398    def test_unsqueeze(self):
1399        class Unsqueeze(torch.nn.Module):
1400            def forward(self, x):
1401                return torch.unsqueeze(x, dim=-2)
1402
1403        x = torch.randn(2, 3, 4)
1404        self.run_test(Unsqueeze(), x)
1405
1406    @skipIfUnsupportedMinOpsetVersion(13)
1407    def test_unsqueeze_dynamic_dim(self):
1408        class Unsqueeze(torch.nn.Module):
1409            def forward(self, x, dim: int):
1410                return torch.unsqueeze(x, dim)
1411
1412        x = torch.randn(2, 1, 4)
1413        dim = -1
1414        self.run_test(Unsqueeze(), (x, dim))
1415
1416    def test_maxpool_default_stride(self):
1417        class MaxPoolModel(torch.nn.Module):
1418            def forward(self, x):
1419                return torch.nn.functional.max_pool2d(x, 2)
1420
1421        model = MaxPoolModel()
1422        x = torch.randn(10, 20, 16, 50)
1423        self.run_test(model, x)
1424
1425    @skipIfUnsupportedMinOpsetVersion(8)
1426    def test_maxpool_adaptive(self):
1427        model = torch.nn.AdaptiveMaxPool1d((5), return_indices=False)
1428        x = torch.randn(20, 16, 50, requires_grad=True)
1429        y = torch.randn(32, 16, 50, requires_grad=True)
1430        self.run_test(
1431            model,
1432            x,
1433            input_names=["x"],
1434            dynamic_axes={"x": [0]},
1435            additional_test_inputs=[y],
1436        )
1437
1438    def test_maxpool_2d(self):
1439        model = torch.nn.MaxPool2d(5, padding=(1, 2))
1440        x = torch.randn(1, 20, 16, 50, requires_grad=True)
1441        self.run_test(model, x)
1442
1443    def test_maxpool_1d_ceil(self):
1444        model = torch.nn.MaxPool1d(3, 2, ceil_mode=True)
1445        x = torch.randn(20, 16, 50)
1446        self.run_test(model, x)
1447
1448    def test_maxpool_2d_ceil(self):
1449        model = torch.nn.MaxPool2d(3, 2, ceil_mode=True)
1450        x = torch.randn(20, 16, 50, 32)
1451        self.run_test(model, x)
1452
1453    def test_maxpool_3d_ceil(self):
1454        model = torch.nn.MaxPool3d(3, 2, ceil_mode=True)
1455        x = torch.randn(20, 16, 50, 44, 31)
1456        self.run_test(model, x)
1457
1458    @skipIfUnsupportedMinOpsetVersion(10)
1459    def test_maxpool_dynamic(self):
1460        class test(torch.nn.Module):
1461            def __init__(self, in_channels, out_channels):
1462                super().__init__()
1463                norm_layer = functools.partial(torch.nn.BatchNorm2d, eps=0.0009)
1464                self.avgpool = torch.nn.MaxPool2d((2, 2), stride=2, ceil_mode=True)
1465                self.conv = torch.nn.Conv2d(
1466                    in_channels, out_channels, kernel_size=1, stride=1, bias=False
1467                )
1468                self.norm = norm_layer(out_channels)
1469
1470            def forward(self, x):
1471                return self.norm(self.conv(self.avgpool(x)))
1472
1473        model = test(8, 16)
1474        inputs = torch.randn(2, 8, 64, 64)
1475        self.run_test(
1476            model,
1477            inputs,
1478            input_names=["input_0"],
1479            dynamic_axes={"input_0": {3: "x", 2: "y"}, "output_0": {3: "x", 2: "y"}},
1480            output_names=["output_0"],
1481        )
1482
1483    # TODO: Enable maxpool-ceil family after ONNX 1.15.1+ is bumped
1484    @skipIfUnsupportedMaxOpsetVersion(9)
1485    def test_maxpool_1d_ceil_corner(self):
1486        model = torch.nn.MaxPool1d(
1487            kernel_size=1, dilation=1, stride=2, ceil_mode=True, return_indices=False
1488        )
1489        x = torch.randn(1, 3, 32)
1490        self.run_test(model, x)
1491
1492    @skipIfUnsupportedMaxOpsetVersion(9)
1493    def test_maxpool_2d_ceil_corner(self):
1494        model = torch.nn.MaxPool2d(
1495            kernel_size=[1, 1],
1496            dilation=[1, 1],
1497            stride=[2, 2],
1498            ceil_mode=True,
1499            return_indices=False,
1500        )
1501        x = torch.randn(1, 3, 32, 32)
1502        self.run_test(model, x)
1503
1504    @skipIfUnsupportedMaxOpsetVersion(9)
1505    def test_maxpool_3d_ceil_corner(self):
1506        model = torch.nn.MaxPool3d(
1507            kernel_size=[7, 8, 4],
1508            dilation=[1, 1, 1],
1509            stride=[10, 11, 3],
1510            padding=[2, 2, 2],
1511            ceil_mode=True,
1512            return_indices=False,
1513        )
1514        x = torch.randn(1, 3, 51, 52, 45)
1515        self.run_test(model, x)
1516
1517    @skipIfUnsupportedMaxOpsetVersion(9)
1518    @skipIfUnsupportedMinOpsetVersion(8)
1519    def test_maxpool_1d_ceil_corner_with_indices(self):
1520        model = torch.nn.MaxPool1d(
1521            kernel_size=1, dilation=1, stride=2, ceil_mode=True, return_indices=True
1522        )
1523        x = torch.randn(1, 3, 32)
1524        self.run_test(model, x)
1525
1526    @skipIfUnsupportedMaxOpsetVersion(9)
1527    @skipIfUnsupportedMinOpsetVersion(8)
1528    def test_maxpool_2d_ceil_corner_with_indices(self):
1529        model = torch.nn.MaxPool2d(
1530            kernel_size=[1, 1],
1531            dilation=[1, 1],
1532            stride=[2, 2],
1533            ceil_mode=True,
1534            return_indices=True,
1535        )
1536        x = torch.randn(1, 3, 32, 32)
1537        self.run_test(model, x)
1538
1539    @skipIfUnsupportedMaxOpsetVersion(9)
1540    @skipIfUnsupportedMinOpsetVersion(8)
1541    def test_maxpool_3d_ceil_corner_with_indices(self):
1542        model = torch.nn.MaxPool3d(
1543            kernel_size=[7, 8, 4],
1544            dilation=[1, 1, 1],
1545            stride=[10, 11, 3],
1546            padding=[2, 2, 2],
1547            ceil_mode=True,
1548            return_indices=True,
1549        )
1550        x = torch.randn(1, 3, 51, 52, 45)
1551        self.run_test(model, x)
1552
1553    @skipIfUnsupportedMinOpsetVersion(8)
1554    def test_maxpool_with_indices(self):
1555        model = torch.nn.MaxPool1d(2, stride=1, return_indices=True)
1556        x = torch.randn(20, 16, 50)
1557        self.run_test(model, x)
1558
1559    @skipIfUnsupportedMinOpsetVersion(10)
1560    def test_maxpool_dilation(self):
1561        model = torch.nn.MaxPool1d(2, stride=1, dilation=2)
1562        x = torch.randn(20, 16, 50)
1563        self.run_test(model, x)
1564
1565    def test_avgpool_default_stride(self):
1566        class AvgPoolModel(torch.nn.Module):
1567            def forward(self, x):
1568                return torch.nn.functional.avg_pool2d(x, 2)
1569
1570        model = AvgPoolModel()
1571        x = torch.randn(10, 20, 16, 50)
1572        self.run_test(model, x)
1573
1574    def test_avgpool(self):
1575        model = torch.nn.AvgPool1d(2, stride=1)
1576        x = torch.randn(20, 16, 50)
1577        self.run_test(model, x)
1578
1579    def test_avgpool_1d_ceil(self):
1580        model = torch.nn.AvgPool1d(3, 2, ceil_mode=True)
1581        x = torch.randn(1, 1, 7)
1582        self.run_test(model, x)
1583
1584    # TODO: ceil_mode is not included in the test, because of
1585    # https://github.com/microsoft/onnxruntime/issues/16203
1586    # The ORT and PyTorch has different calculation for ceil_mode (the last value).
1587    @common_utils.parametrize(
1588        "padding",
1589        (0, 1),
1590    )
1591    @common_utils.parametrize(
1592        "count_include_pad",
1593        (True, False),
1594    )
1595    def test_avgpool_2d(self, padding, count_include_pad):
1596        model = torch.nn.AvgPool2d(
1597            3,
1598            3,
1599            padding=padding,
1600            count_include_pad=count_include_pad,
1601        )
1602        x = torch.randn(20, 16, 50, 32)
1603        self.run_test(model, x)
1604
1605    # TODO: ceil_mode is not included in the test, because of
1606    # https://github.com/microsoft/onnxruntime/issues/16203
1607    # The ORT and PyTorch has different calculation for ceil_mode (the last value).
1608    # the issue requires fix in onnx(21) (https://github.com/onnx/onnx/issues/5711)
1609    # a fix in ORT is planned. After the fixes in place, we can add ceil_mode to the test.
1610    @skipIfUnsupportedMinOpsetVersion(21)
1611    def test_avgpool_3d_ceil(self):
1612        model = torch.nn.AvgPool3d(3, 2, ceil_mode=True)
1613        x = torch.randn(20, 16, 50, 44, 31)
1614        y = torch.randn(32, 8, 50, 44, 31)
1615        self.run_test(
1616            model,
1617            x,
1618            input_names=["x"],
1619            dynamic_axes={"x": [0, 1]},
1620            additional_test_inputs=[y],
1621        )
1622
1623    @skipIfUnsupportedMinOpsetVersion(10)
1624    def test_avgpool_dynamic(self):
1625        class test(torch.nn.Module):
1626            def __init__(self, in_channels, out_channels):
1627                super().__init__()
1628                norm_layer = functools.partial(torch.nn.BatchNorm2d, eps=0.0009)
1629                self.avgpool = torch.nn.AvgPool2d(
1630                    (2, 2), stride=2, ceil_mode=True, count_include_pad=False
1631                )
1632                self.conv = torch.nn.Conv2d(
1633                    in_channels, out_channels, kernel_size=1, stride=1, bias=False
1634                )
1635                self.norm = norm_layer(out_channels)
1636
1637            def forward(self, x):
1638                return self.norm(self.conv(self.avgpool(x)))
1639
1640        model = test(8, 16)
1641        inputs = torch.randn(2, 8, 64, 64)
1642        self.run_test(
1643            model,
1644            inputs,
1645            input_names=["input_0"],
1646            dynamic_axes={"input_0": {3: "x", 2: "y"}, "output_0": {3: "x", 2: "y"}},
1647            output_names=["output_0"],
1648        )
1649
1650    @skipIfUnsupportedMinOpsetVersion(9)
1651    def test_floating_point(self):
1652        class FloatingPoint(torch.jit.ScriptModule):
1653            @torch.jit.script_method
1654            def forward(self, x):
1655                if x.is_floating_point():
1656                    return x.new_zeros(x.shape)
1657                return x.new_zeros(x.shape)
1658
1659        x = torch.randn(2, 3, 4)
1660        self.run_test(
1661            FloatingPoint(), x, input_names=["x"], dynamic_axes={"x": [0, 1, 2]}
1662        )
1663        self.run_test(FloatingPoint(), x, remained_onnx_input_idx=[])
1664
1665        class FloatingPoint(torch.jit.ScriptModule):
1666            @torch.jit.script_method
1667            def forward(self, x):
1668                if x.size(0) > 1:
1669                    a = x + 2
1670                    if a.is_floating_point():
1671                        return x + 1
1672                    return x + 1
1673                return x
1674
1675        x = torch.randn(2, 3, 4)
1676        self.run_test(FloatingPoint(), x)
1677
1678    # Operator rank mismatch between outputs of two branches for opsets below 11.
1679    @skipIfUnsupportedMinOpsetVersion(11)
1680    def test_floating_point_infer_dtype(self):
1681        class FloatingPoint(torch.jit.ScriptModule):
1682            @torch.jit.script_method
1683            def forward(self, x):
1684                if x.size(0) > 1:
1685                    a = x + 2
1686                    if a.is_floating_point():
1687                        return x.new_zeros(x.shape[1:])
1688                    return x.new_zeros(x.shape)
1689                return x
1690
1691        x = torch.randn(2, 3, 4)
1692        self.run_test(
1693            FloatingPoint(), x, input_names=["x"], dynamic_axes={"x": [0, 1, 2]}
1694        )
1695        self.run_test(FloatingPoint(), x, remained_onnx_input_idx=[])
1696
1697        class FloatingPoint(torch.jit.ScriptModule):
1698            @torch.jit.script_method
1699            def forward(self, x):
1700                if x.size(0) > 1:
1701                    a = x + 2
1702                    if a.is_floating_point():
1703                        return x + 1
1704                    return x
1705                return x
1706
1707        x = torch.randn(2, 3, 4).to(torch.int32)
1708        self.run_test(FloatingPoint(), x)
1709
1710    @skipIfUnsupportedMinOpsetVersion(12)
1711    def test_prim_min(self):
1712        @torch.jit.script
1713        def list_append(boxes: List[Tensor]):
1714            temp = []
1715            for i, b in enumerate(
1716                boxes
1717            ):  # enumerate is creating a prim::min op in torch graph
1718                temp.append(torch.full_like(b[:, 1], i))
1719            return temp[0]
1720
1721        class Min(torch.nn.Module):
1722            def forward(self, x):
1723                boxes = [x for _ in range(3)]
1724                return list_append(boxes)
1725
1726        x = torch.rand(5, 5)
1727        self.run_test(Min(), (x,))
1728
1729        class M(torch.jit.ScriptModule):
1730            @torch.jit.script_method
1731            def forward(self, x):
1732                i = 3
1733                return min(x[i], i)
1734
1735        x = torch.arange(6, dtype=torch.int64)
1736        self.run_test(M(), (x,))
1737
1738    def test_arithmetic(self):
1739        class ArithmeticModule(torch.nn.Module):
1740            def forward(self, x):
1741                x = x + 2
1742                x = x - 4
1743                x = x * 6
1744                x = x / 8
1745                return x
1746
1747        x = torch.randn(2, 3, 4)
1748        self.run_test(ArithmeticModule(), x)
1749
1750    def test_arithmetic_prim_long(self):
1751        class ArithmeticModule(torch.nn.Module):
1752            def forward(self, x, y: int):
1753                x = x + y
1754                x = x - y
1755                x = x * (y * 3)
1756                x = x / (y * 4)
1757                return x
1758
1759        x = torch.randn(2, 3, 4)
1760        y = 2
1761        self.run_test(ArithmeticModule(), (x, y))
1762
1763        class ArithmeticModule(torch.nn.Module):
1764            def forward(self, x):
1765                x = x + 2
1766                x = x - 3
1767                return x.shape[0]
1768
1769        x = torch.randn(2, 3, 4)
1770        self.run_test(ArithmeticModule(), x, remained_onnx_input_idx=[])
1771
1772    @skipDtypeChecking
1773    def test_arithmetic_prim_float(self):
1774        class ArithmeticModule(torch.nn.Module):
1775            def forward(self, x, y: float):
1776                x = x + y
1777                x = x - y
1778                x = x * (y * 3)
1779                x = x / (y * 4)
1780                return x
1781
1782        x = torch.randn(2, 3, 4)
1783        y = 2.5
1784        self.run_test(ArithmeticModule(), (x, y))
1785
1786        class ArithmeticModule(torch.nn.Module):
1787            def forward(self, x):
1788                x = x + 2
1789                x = x - 3
1790                return x.shape[1] / 2
1791
1792        x = torch.randn(2, 3, 4)
1793        self.run_test(ArithmeticModule(), x, remained_onnx_input_idx=[])
1794
1795    @skipDtypeChecking
1796    def test_arithmetic_prim_bool(self):
1797        class ArithmeticModule(torch.nn.Module):
1798            def forward(self, x, y: int, z: bool, t: float):
1799                x = x + y
1800                x = x - y
1801                if z:
1802                    x = x * (y * 3)
1803                    x = x / (y * 4)
1804                return x / t, z
1805
1806        x = torch.randn(2, 3, 4)
1807        y = 2
1808        z = False
1809        t = 2.5
1810        self.run_test(ArithmeticModule(), (x, y, z, t))
1811
1812        class ArithmeticModule(torch.nn.Module):
1813            def forward(self, x: int, y: int):
1814                return x == y
1815
1816        x = 3
1817        y = 2
1818        self.run_test(ArithmeticModule(), (x, y))
1819
1820    @skipScriptTest(
1821        15,
1822        reason="In trace: Outputs that are always None are removed. \
1823                In script: Outputs that are always None are removed before opset 15. \
1824                After opset 15, we replace the None in output with Optional node.",
1825    )
1826    def test_tuple_with_none_outputs(self):
1827        class TupleModel(torch.nn.Module):
1828            def forward(self, x):
1829                return (x, (x, None, (x, None)))
1830
1831        x = torch.randn(3, 4)
1832        self.run_test(TupleModel(), (x,))
1833
1834    # In scripting the first transpose node do not carry shape and dtype info.
1835    # The following test only works when onnx shape inference is enabled.
1836    def test_arithmetic_infer_dtype(self):
1837        class ArithmeticModule(torch.jit.ScriptModule):
1838            @torch.jit.script_method
1839            def forward(self, x):
1840                x = x.t()
1841                x = x + 2
1842                x = x - 4
1843                x = x * 6
1844                x = x / 8
1845                return x
1846
1847        x = torch.randn(2, 3)
1848        self.run_test(ArithmeticModule(), x)
1849
1850    @unittest.skip("Floor division on ONNX is inconsistent with eager (see #78411)")
1851    def test_floor_div(self):
1852        class FloorDivModule(torch.nn.Module):
1853            def forward(self, x, y):
1854                return (
1855                    x // 3,
1856                    x // 2.0,
1857                    x.to(dtype=torch.float64) // 3,
1858                    x.to(dtype=torch.float64) // 2.0,
1859                    x.to(dtype=torch.int64) // 3,
1860                    x.to(dtype=torch.int64) // 2.0,
1861                    x // (y + 1.0).to(dtype=torch.int64),
1862                    x // y,
1863                    x.to(dtype=torch.float64) // y.to(dtype=torch.int64),
1864                    x.to(dtype=torch.float64) // y.to(dtype=torch.float64),
1865                    x.to(dtype=torch.int64) // y.to(dtype=torch.int64),
1866                    x.to(dtype=torch.int64) // y,
1867                )
1868
1869        x = torch.arange(-2, 4).reshape(2, 3, 1)
1870        y = torch.arange(1, 2 * 3 * 4 + 1).reshape(2, 3, 4)
1871        self.run_test(FloorDivModule(), (x, y))
1872
1873    @unittest.skip("Floor division on ONNX is inconsistent with eager (see #78411)")
1874    def test_floor_div_script(self):
1875        class FloorDivModule(torch.jit.ScriptModule):
1876            @torch.jit.script_method
1877            def forward(self, x, y):
1878                return x // 3, x // 2.0, x // y
1879
1880        x = torch.arange(-2, 4).reshape(2, 3, 1)
1881        y = torch.randn(2, 3, 4)
1882        self.run_test(FloorDivModule(), (x, y))
1883
1884    @unittest.skip("Floor division on ONNX is inconsistent with eager (see #78411)")
1885    @skipIfUnsupportedMinOpsetVersion(9)
1886    def test_floordiv(self):
1887        class FloordivModule(torch.nn.Module):
1888            def forward(self, x):
1889                return x.new_zeros(x.size(2) // x.size(1))
1890
1891        x = torch.randn(2, 3, 4)
1892        self.run_test(
1893            FloordivModule(), x, input_names=["x"], dynamic_axes={"x": [0, 1, 2]}
1894        )
1895        self.run_test(FloordivModule(), (x,), remained_onnx_input_idx=[])
1896
1897    def test_div(self):
1898        class DivModule(torch.nn.Module):
1899            def forward(self, x, y):
1900                return x / y, torch.true_divide(x, y)
1901
1902        x = torch.randn(2, 3, 4).to(torch.int)
1903        y = torch.arange(1, 2 * 3 * 4 + 1).reshape(2, 3, 4).to(torch.int)
1904        self.run_test(DivModule(), (x, y))
1905        self.run_test(DivModule(), (x.float(), y.float()))
1906
1907    # Note: div cannot (generally) be exported via scripting
1908    # since its type promotion logic is dependent on knowing the scalar types
1909    # of the input tensors. That is, the ONNX graph is dependent on the
1910    # data type of the inputs. This makes it appropriate for tracing only.
1911    def test_div_promotion_trace(self):
1912        class DivModule(torch.nn.Module):
1913            def forward(self, x, y):
1914                return x / y, torch.true_divide(x, y)
1915
1916        x = torch.randn(2, 3, 4).to(torch.int)
1917        y = torch.arange(1, 2 * 3 * 4 + 1).reshape(2, 3, 4).to(torch.int)
1918
1919        with common_utils.set_default_dtype(torch.float):
1920            self.run_test(torch.jit.trace(DivModule(), (x, y)), (x, y))
1921
1922        with common_utils.set_default_dtype(torch.double):
1923            self.run_test(torch.jit.trace(DivModule(), (x, y)), (x, y))
1924
1925    # In scripting x, y do not carry shape and dtype info.
1926    # The following test only works when onnx shape inference is enabled.
1927    def test_div_promotion_script(self):
1928        class DivModule(torch.nn.Module):
1929            def forward(self, x, y):
1930                # Add transpose to hide shape/type information
1931                # Otherwise shape and type are still avaiable from input.
1932                x = x.transpose(1, 2)
1933                y = y.transpose(1, 2)
1934                return x / y, torch.true_divide(x, y)
1935
1936        x = torch.randn(2, 3, 4).to(torch.int)
1937        y = torch.arange(1, 2 * 3 * 4 + 1).reshape(2, 3, 4).to(torch.int)
1938
1939        # 1. x,y are int, and output is float.
1940        #    This can be handled by the default case, where both are cast to float.
1941        #    It works even if type of x, y are unknown.
1942        with common_utils.set_default_dtype(torch.float):
1943            self.run_test(torch.jit.script(DivModule()), (x, y))
1944
1945        # 2. x,y are int, and output is double.
1946        #    This can be handled by the default case, where both are cast to double.
1947        #    It works even if type of x, y are unknown.
1948        with common_utils.set_default_dtype(torch.double):
1949            self.run_test(torch.jit.script(DivModule()), (x, y))
1950
1951        # 3. x is int, y is double, and output is double.
1952        #    This can only be handled when both type of x and y are known.
1953        x = torch.randn(2, 3, 4).to(torch.int)
1954        y = torch.arange(1, 2 * 3 * 4 + 1).reshape(2, 3, 4).to(torch.double)
1955        self.run_test(torch.jit.script(DivModule()), (x, y))
1956
1957    @skipDtypeChecking
1958    def test_div_rounding_mode(self):
1959        class TrueDivModule(torch.nn.Module):
1960            def forward(self, x, y):
1961                return (
1962                    x.div(y, rounding_mode=None),
1963                    torch.div(x, y, rounding_mode=None),
1964                )
1965
1966        class TruncDivModule(torch.nn.Module):
1967            def forward(self, x, y):
1968                return (
1969                    x.div(y, rounding_mode="trunc"),
1970                    torch.div(x, y, rounding_mode="trunc"),
1971                )
1972
1973        class FloorDivModule(torch.nn.Module):
1974            def forward(self, x, y):
1975                return (
1976                    x.div(y, rounding_mode="floor"),
1977                    torch.div(x, y, rounding_mode="floor"),
1978                )
1979
1980        modules = [TrueDivModule(), TruncDivModule(), FloorDivModule()]
1981
1982        x = (torch.randn(2, 3, 4) * 100).to(torch.int)
1983        y = torch.arange(1, 2 * 3 * 4 + 1).reshape(2, 3, 4).to(torch.int)
1984
1985        for module in modules:
1986            self.run_test(module, (x, y))
1987            self.run_test(torch.jit.trace(module, (x, y)), (x, y))
1988            self.run_test(torch.jit.script(module), (x, y))
1989
1990        x = torch.randn(2, 3, 4)
1991        y = torch.rand(2, 3, 4) * 10.0 + 0.1
1992
1993        for module in modules:
1994            self.run_test(module, (x, y))
1995            self.run_test(torch.jit.trace(module, (x, y)), (x, y))
1996            self.run_test(torch.jit.script(module), (x, y))
1997
1998    def test_slice_trace(self):
1999        class MyModule(torch.nn.Module):
2000            def forward(self, x):
2001                return x[0:1]
2002
2003        x = torch.randn(3)
2004        self.run_test(MyModule(), x)
2005
2006    def test_slice_neg(self):
2007        class NegSlice(torch.nn.Module):
2008            def forward(self, x):
2009                return x[-1:]
2010
2011        x = torch.randn(3, 4, 5)
2012        self.run_test(NegSlice(), x)
2013
2014    def test_slice_neg_large(self):
2015        class NegSlice(torch.nn.Module):
2016            def forward(self, x):
2017                return x[:, :, -3:-1, :, -1]
2018
2019        x = torch.randn(3, 4, 5, 6, 7)
2020        self.run_test(NegSlice(), x)
2021
2022    def test_slice_neg_large_negone(self):
2023        class NegSlice(torch.nn.Module):
2024            def forward(self, x):
2025                return x[:, :, :, :, -1]
2026
2027        x = torch.randn(3, 4, 5, 6, 7)
2028        self.run_test(NegSlice(), x)
2029
2030    @skipIfUnsupportedMinOpsetVersion(11)
2031    def test_slice_with_input_index(self):
2032        class InputIndexSlice(torch.nn.Module):
2033            def forward(self, x, y):
2034                x[: y.size(0), 0, :] = y
2035                return x
2036
2037        x = torch.zeros((56, 6, 256))
2038        y = torch.rand((22, 256))
2039        self.run_test(InputIndexSlice(), (x, y))
2040
2041    @skipIfUnsupportedMinOpsetVersion(11)
2042    @skipScriptTest()  # Torchscript doesn't support 1d index.
2043    def test_slice_with_1d_input_index(self):
2044        class InputIndexSlice(torch.nn.Module):
2045            def forward(self, x, y):
2046                x[:y, 0, :] = y
2047                return x
2048
2049        x = torch.zeros((56, 6, 256))
2050        y = torch.tensor([5], dtype=torch.int64)
2051        self.run_test(InputIndexSlice(), (x, y))
2052
2053    @skipIfUnsupportedMinOpsetVersion(11)
2054    def test_slice_with_input_step_size(self):
2055        class InputIndexSlice(torch.nn.Module):
2056            def forward(self, x, y, z):
2057                x[:y:z, 0::z, :] = 1
2058                return x
2059
2060        x = torch.zeros((56, 6, 256))
2061        y = torch.tensor(5, dtype=torch.int64)
2062        z = torch.tensor(2, dtype=torch.int64)
2063        self.run_test(InputIndexSlice(), (x, y, z))
2064
2065    @skipIfUnsupportedMinOpsetVersion(10)
2066    @skipScriptTest()  # scripting tuple/list append
2067    def test_slice_dynamic(self):
2068        class DynamicSliceExportMod(torch.nn.Module):
2069            def forward(self, x):
2070                results = []
2071                for i in range(4):
2072                    results.append(x[: x.size(0) - i, i : x.size(2), i:3])
2073                return tuple(results)
2074
2075        x = torch.rand(5, 5, 5)
2076        y = torch.randn(6, 7, 8)
2077        self.run_test(
2078            DynamicSliceExportMod(),
2079            x,
2080            additional_test_inputs=[y],
2081            input_names=["input_1"],
2082            output_names=["output_1"],
2083            dynamic_axes={"input_1": [0, 1, 2], "output_1": [0, 1, 2]},
2084        )
2085
2086    @skipIfUnsupportedMinOpsetVersion(10)
2087    def test_slice_dynamic_script(self):
2088        class DynamicSliceModel(torch.jit.ScriptModule):
2089            @torch.jit.script_method
2090            def forward(self, x):
2091                return x[1 : x.size(1)]
2092
2093        x = torch.rand(1, 2)
2094        self.run_test(DynamicSliceModel(), x)
2095
2096    @skipIfUnsupportedMinOpsetVersion(10)
2097    def test_slice_dynamic_shape_script(self):
2098        class DynamicSliceModel(torch.nn.Module):
2099            def forward(self, x):
2100                return x.new_zeros(x.shape[1 : x.size(2)])
2101
2102        x = torch.rand(1, 2, 3, 4)
2103        self.run_test(
2104            DynamicSliceModel(), x, input_names=["x"], dynamic_axes={"x": [0, 1, 2, 3]}
2105        )
2106        self.run_test(DynamicSliceModel(), x, remained_onnx_input_idx=[])
2107
2108    @skipIfUnsupportedMinOpsetVersion(10)
2109    @skipScriptTest()  # scripting tuple/list append
2110    def test_slice_dynamic_to_end(self):
2111        class DynamicSliceExportMod(torch.nn.Module):
2112            def forward(self, x):
2113                results = []
2114                for i in range(4):
2115                    results.append(x[:, i:, x.size(2) - 5])
2116                return tuple(results)
2117
2118        x = torch.rand(5, 5, 5)
2119        self.run_test(
2120            DynamicSliceExportMod(),
2121            x,
2122            dynamic_axes={"input_1": [0, 1, 2], "output_1": [0, 1, 2]},
2123        )
2124
2125    def test_square(self):
2126        class Square(torch.nn.Module):
2127            def forward(self, x):
2128                return torch.square(x)
2129
2130        x = torch.randn(2, 3, 4)
2131        self.run_test(Square(), x)
2132
2133    @skipIfUnsupportedMinOpsetVersion(9)
2134    def test_arange_dynamic(self):
2135        class ArangeModel(torch.nn.Module):
2136            def forward(self, input):
2137                return (
2138                    torch.arange(input.shape[0]),
2139                    torch.arange(12),
2140                    torch.arange(start=input.shape[0], end=input.shape[0] + 5),
2141                )
2142
2143        x = torch.randn(5, 3, 2)
2144        y = torch.randn(8, 3, 2)
2145        self.run_test(
2146            ArangeModel(),
2147            x,
2148            additional_test_inputs=[y],
2149            input_names=["input_1"],
2150            output_names=["output_1", "output_2", "output_3"],
2151            dynamic_axes={"input_1": [0], "output_1": [0]},
2152        )
2153        self.run_test(
2154            torch.jit.script(ArangeModel()),
2155            x,
2156            additional_test_inputs=[y],
2157            input_names=["input_1"],
2158            output_names=["output_1", "output_2", "output_3"],
2159            dynamic_axes={"input_1": [0], "output_1": [0]},
2160        )
2161
2162    @skipIfUnsupportedMinOpsetVersion(9)
2163    def test_dynamic_arange_out(self):
2164        class ArangeOutModel(torch.nn.Module):
2165            def forward(self, end):
2166                out_t = torch.tensor([1], dtype=torch.int64)
2167                return torch.arange(end, out=out_t)
2168
2169        x = torch.tensor(8)
2170        self.run_test(ArangeOutModel(), (x))
2171
2172    @skipIfUnsupportedMinOpsetVersion(9)
2173    def test_dynamic_arange_start_out(self):
2174        class ArangeStartOutModel(torch.nn.Module):
2175            def forward(self, start, end):
2176                out_t = torch.tensor([1], dtype=torch.int64)
2177                return torch.arange(start.size(0), end, out=out_t)
2178
2179        x = torch.randn(2, 3, 4)
2180        y = torch.tensor(8)
2181        self.run_test(
2182            ArangeStartOutModel(),
2183            (x, y),
2184            input_names=["x", "y"],
2185            dynamic_axes={"x": [0, 1, 2]},
2186        )
2187        self.run_test(ArangeStartOutModel(), (x, y), remained_onnx_input_idx=[1])
2188
2189    @skipIfUnsupportedMinOpsetVersion(9)
2190    def test_linspace(self):
2191        class LinspaceModel(torch.nn.Module):
2192            def forward(self, start, end, steps):
2193                return torch.linspace(start, end, steps)
2194
2195        x = torch.tensor(3, dtype=torch.float)
2196        y = torch.tensor(10, dtype=torch.float)
2197        z = torch.tensor(5, dtype=torch.int)
2198        self.run_test(LinspaceModel(), (x, y, z))
2199
2200    @skipIfUnsupportedMinOpsetVersion(9)
2201    def test_linspace_negative_start(self):
2202        class LinspaceModel(torch.nn.Module):
2203            def forward(self, start, end, steps):
2204                return torch.linspace(start, end, steps)
2205
2206        x = torch.tensor(-1, dtype=torch.float)
2207        y = torch.tensor(1, dtype=torch.float)
2208        z = torch.tensor(6, dtype=torch.int)
2209        self.run_test(LinspaceModel(), (x, y, z))
2210
2211    @skipIfUnsupportedMinOpsetVersion(9)
2212    def test_arange_with_floats_out(self):
2213        class ArangeModelEnd(torch.nn.Module):
2214            def forward(self, end):
2215                out_t = torch.tensor([1], dtype=torch.float)
2216                return torch.arange(end, out=out_t)
2217
2218        y = torch.tensor(8.5, dtype=torch.float)
2219        self.run_test(ArangeModelEnd(), (y))
2220
2221        class ArangeModelStep(torch.nn.Module):
2222            def forward(self, start, end):
2223                out_t = torch.tensor([1], dtype=torch.float)
2224                return torch.arange(start.size(0), end, 1.5, out=out_t)
2225
2226        x = torch.randn(2, 3, 4)
2227        y = torch.tensor(8.5, dtype=torch.float)
2228        self.run_test(
2229            ArangeModelStep(),
2230            (x, y),
2231            input_names=["x", "y"],
2232            dynamic_axes={"x": [0, 1, 2]},
2233        )
2234        self.run_test(ArangeModelStep(), (x, y), remained_onnx_input_idx=[1])
2235
2236    @skipIfUnsupportedMinOpsetVersion(9)
2237    def test_arange_with_floats(self):
2238        class ArangeModelEnd(torch.nn.Module):
2239            def forward(self, end):
2240                return torch.arange(end)
2241
2242        y = torch.tensor(8.5, dtype=torch.float)
2243        self.run_test(ArangeModelEnd(), (y))
2244
2245        class ArangeModelStep(torch.nn.Module):
2246            def forward(self, start, end):
2247                return torch.arange(start.size(0), end, 1.5)
2248
2249        x = torch.randn(2, 3, 4)
2250        y = torch.tensor(8.5, dtype=torch.float)
2251        self.run_test(
2252            ArangeModelStep(),
2253            (x, y),
2254            input_names=["x", "y"],
2255            dynamic_axes={"x": [0, 1, 2]},
2256        )
2257        self.run_test(ArangeModelStep(), (x, y), remained_onnx_input_idx=[1])
2258
2259        class ArangeModelStepNeg(torch.nn.Module):
2260            def forward(self, start, end):
2261                return torch.arange(end, start.size(0), -1.5)
2262
2263        x = torch.randn(2, 3, 4)
2264        y = torch.tensor(8.5, dtype=torch.float)
2265        self.run_test(
2266            ArangeModelStepNeg(),
2267            (x, y),
2268            input_names=["x", "y"],
2269            dynamic_axes={"x": [0, 1, 2]},
2270        )
2271        self.run_test(ArangeModelStepNeg(), (x, y), remained_onnx_input_idx=[1])
2272
2273        class ArangeModelStart(torch.nn.Module):
2274            def forward(self, start, end):
2275                return torch.arange(start.size(0), end)
2276
2277        x = torch.randn(2, 3, 4)
2278        y = torch.tensor(8.5, dtype=torch.float)
2279        self.run_test(
2280            ArangeModelStart(),
2281            (x, y),
2282            input_names=["x", "y"],
2283            dynamic_axes={"x": [0, 1, 2]},
2284        )
2285        self.run_test(ArangeModelStart(), (x, y), remained_onnx_input_idx=[1])
2286
2287    @skipIfUnsupportedMinOpsetVersion(9)
2288    def test_arange_with_floats_override(self):
2289        class ArangeModelEnd(torch.nn.Module):
2290            def forward(self, end):
2291                return torch.arange(end, dtype=torch.int64)
2292
2293        y = torch.tensor(8.5, dtype=torch.float)
2294        self.run_test(ArangeModelEnd(), (y))
2295
2296        class ArangeModelStep(torch.nn.Module):
2297            def forward(self, start, end):
2298                return torch.arange(start.size(0), end, 1.5, dtype=torch.int64)
2299
2300        x = torch.randn(2, 3, 4)
2301        y = torch.tensor(8.5, dtype=torch.float)
2302        self.run_test(
2303            ArangeModelStep(),
2304            (x, y),
2305            input_names=["x", "y"],
2306            dynamic_axes={"x": [0, 1, 2]},
2307        )
2308        self.run_test(ArangeModelStep(), (x, y), remained_onnx_input_idx=[1])
2309
2310    @skipIfUnsupportedMinOpsetVersion(11)
2311    def test_arange_out(self):
2312        class ArangeOutModel(torch.nn.Module):
2313            def forward(self, end):
2314                out_t = torch.tensor([1], dtype=torch.float)
2315                return torch.arange(end, out=out_t)
2316
2317        x = torch.tensor(8.5, dtype=torch.float)
2318        self.run_test(ArangeOutModel(), (x))
2319
2320    @skipIfUnsupportedMinOpsetVersion(11)
2321    def test_arange_start_out(self):
2322        class ArangeStartOutModel(torch.nn.Module):
2323            def forward(self, start, end):
2324                out_t = torch.tensor([1], dtype=torch.float)
2325                return torch.arange(start.size(0), end, out=out_t)
2326
2327        x = torch.randn(2, 3, 4)
2328        y = torch.tensor(8.5, dtype=torch.float)
2329        self.run_test(
2330            ArangeStartOutModel(),
2331            (x, y),
2332            input_names=["x", "y"],
2333            dynamic_axes={"x": [0, 1, 2]},
2334        )
2335        self.run_test(ArangeStartOutModel(), (x, y), remained_onnx_input_idx=[1])
2336
2337    @skipIfUnsupportedMinOpsetVersion(11)
2338    def test_arange_no_type(self):
2339        class ArangeModel(torch.nn.Module):
2340            def forward(self, end):
2341                return torch.arange(end), torch.arange(0, end)
2342
2343        x = torch.tensor(6.2, dtype=torch.float)
2344        self.run_test(ArangeModel(), x)
2345
2346    @skipIfUnsupportedMinOpsetVersion(9)
2347    def test_size(self):
2348        class SizeModel(torch.nn.Module):
2349            def forward(self, input):
2350                return (
2351                    torch.arange(input.size(0)),
2352                    torch.arange(input.size(-1)),
2353                    torch.ones(input.shape),
2354                )
2355
2356        x = torch.randn(5, 3, 2)
2357        self.run_test(SizeModel(), x, input_names=["x"], dynamic_axes={"x": [0, 1, 2]})
2358        self.run_test(SizeModel(), x, remained_onnx_input_idx=[])
2359
2360    @skipIfUnsupportedMinOpsetVersion(9)
2361    @skipScriptTest()  # x.stride() not scriptable
2362    def test_as_strided(self):
2363        class Model(torch.nn.Module):
2364            def forward(self, x):
2365                chunk_size = list(x.size())
2366                chunk_size[1] = chunk_size[1] * 2 - 1
2367                chunk_stride = list(x.stride())
2368                chunk_stride[1] = chunk_stride[1] // 2
2369                return x.as_strided(
2370                    (3, 3, 3), (1, 4, 2), storage_offset=2
2371                ), x.as_strided(chunk_size, chunk_stride)
2372
2373        x = torch.randn(5, 8, 7)
2374        self.run_test(Model(), x)
2375
2376    @skipScriptTest()  # Ellipses followed by tensor indexing not scriptable
2377    def test_tensor_index_advanced_indexing_ellipsis(self):
2378        class MyModel(torch.nn.Module):
2379            def forward(self, input):
2380                return input[..., torch.tensor([2, 1]), torch.tensor([0, 3])]
2381
2382        m1 = torch.randn(3, 4, 5, 6, 7)
2383        self.run_test(MyModel(), (m1,))
2384
2385    def test_tensor_index_advanced_indexing(self):
2386        class MyModel(torch.nn.Module):
2387            def forward(self, input):
2388                return input[
2389                    :,
2390                    torch.tensor([[0, 2], [1, 1]]),
2391                    :,
2392                    torch.tensor([2, 1]),
2393                    torch.tensor([0, 3]),
2394                ]
2395
2396        m1 = torch.randn(3, 4, 5, 6, 7)
2397        self.run_test(MyModel(), (m1,))
2398
2399        class MyModel(torch.nn.Module):
2400            def forward(self, input):
2401                return input[
2402                    :, torch.tensor([0, 2]), None, 2:4, torch.tensor([[1, 3], [4, 0]])
2403                ]
2404
2405        self.run_test(MyModel(), (m1,))
2406
2407        class MyModel(torch.nn.Module):
2408            def forward(self, input):
2409                return input[
2410                    :,
2411                    torch.tensor([0, 2]),
2412                    torch.tensor([1]),
2413                    2:4,
2414                    torch.tensor([[1], [4]]),
2415                ]
2416
2417        self.run_test(MyModel(), (m1,))
2418
2419    def test_tensor_index_advanced_indexing_consecutive(self):
2420        class MyModel(torch.nn.Module):
2421            def forward(self, input):
2422                return input[
2423                    :, torch.tensor([0, 2]), torch.tensor([[1, 3], [4, 0]]), None
2424                ]
2425
2426        m1 = torch.randn(3, 4, 5, 6, 7)
2427        self.run_test(MyModel(), (m1,))
2428
2429    @skipIfUnsupportedMinOpsetVersion(11)
2430    def test_index_put(self):
2431        class IndexPutModel(torch.nn.Module):
2432            def forward(self, x, ind, update):
2433                x[ind] = update
2434                return x
2435
2436        x = torch.randn(3, 4)
2437        ind = torch.tensor([1], dtype=torch.long)
2438        update = torch.ones(4)
2439        self.run_test(IndexPutModel(), (x, ind, update))
2440
2441    @skipIfUnsupportedMinOpsetVersion(11)
2442    def test_index_put_singular(self):
2443        class IndexPutBoolModel(torch.nn.Module):
2444            def forward(self, mask, indices):
2445                mask[indices] = True
2446                return mask
2447
2448        mask = torch.zeros(100, dtype=torch.bool)
2449        indices = (torch.rand(25) * mask.shape[0]).to(torch.int64)
2450        self.run_test(IndexPutBoolModel(), (mask, indices))
2451
2452        class IndexPutFloatModel(torch.nn.Module):
2453            def forward(self, mask, indices):
2454                mask[indices] = torch.tensor(5.5)
2455                return mask
2456
2457        mask = torch.rand(100, dtype=torch.float)
2458        indices = (torch.rand(50) * mask.shape[0]).to(torch.int64)
2459        self.run_test(IndexPutFloatModel(), (mask, indices))
2460
2461    @skipIfUnsupportedMinOpsetVersion(11)
2462    def test_index_put_accumulate(self):
2463        class IndexPutModel(torch.nn.Module):
2464            def forward(self, x, ind, update):
2465                return x.index_put((ind,), update, accumulate=True)
2466
2467        x = torch.randn(3, 4)
2468        ind = torch.tensor([2], dtype=torch.long)
2469        update = torch.ones(4)
2470        self.run_test(IndexPutModel(), (x, ind, update))
2471
2472    @skipIfUnsupportedMinOpsetVersion(11)
2473    def test_index_put_slice_index(self):
2474        class IndexPutModel(torch.nn.Module):
2475            def forward(self, x, update):
2476                x[1:2, 1:3, torch.tensor([1])] += update
2477                return x
2478
2479        x = torch.randn(3, 4, 5)
2480        update = torch.tensor([10, 15]).view(1, 2, 1)
2481        self.run_test(IndexPutModel(), (x, update))
2482
2483        class IndexPutModel2(torch.nn.Module):
2484            def forward(self, x, update):
2485                x[torch.tensor([0, 2]), torch.tensor([1, 2])] += update
2486                return x
2487
2488        x = torch.randn(3, 4, 5)
2489        update = torch.randn(2, 5)
2490        self.run_test(IndexPutModel2(), (x, update))
2491
2492        class IndexPutModel3(torch.nn.Module):
2493            def forward(self, x, update):
2494                x[torch.tensor([0, 2]), 1:2] += update
2495                return x
2496
2497        x = torch.randn(3, 4, 5)
2498        update = torch.tensor([10, 15]).view(2, 1, 1)
2499        self.run_test(IndexPutModel3(), (x, update))
2500
2501        class IndexPutModel4(torch.nn.Module):
2502            def forward(self, x, update):
2503                x[torch.tensor([0, 2]), 2] += update
2504                return x
2505
2506        x = torch.randn(3, 4, 5)
2507        update = torch.tensor([10, 15]).view(2, 1)
2508        self.run_test(IndexPutModel4(), (x, update))
2509
2510        class IndexPutModel5(torch.nn.Module):
2511            def forward(self, x, update):
2512                x[1:3, torch.tensor([0, 2]), 2] += update
2513                return x
2514
2515        x = torch.randn(3, 4, 5)
2516        update = torch.tensor([10, 15]).view(2, 1)
2517        self.run_test(IndexPutModel5(), (x, update))
2518
2519        class IndexPutModel6(torch.nn.Module):
2520            def forward(self, x, update):
2521                x[1:3, 0] = update
2522                return x
2523
2524        x = torch.randn(3, 4, 5)
2525        update = torch.arange(2 * 5).to(torch.float).view(2, 5)
2526        self.run_test(IndexPutModel6(), (x, update))
2527
2528        class IndexPutModel7(torch.nn.Module):
2529            def forward(self, x, update):
2530                x[1:, 0] = update
2531                return x
2532
2533        x = torch.randn(3, 4, 5)
2534        update = torch.arange(2 * 5).to(torch.float).view(2, 5)
2535        self.run_test(IndexPutModel7(), (x, update))
2536
2537        class IndexPutModel8(torch.nn.Module):
2538            def forward(self, x, update):
2539                x[:3, 0] = update
2540                return x
2541
2542        x = torch.randn(3, 4, 5)
2543        update = torch.arange(3 * 5).to(torch.float).view(3, 5)
2544        self.run_test(IndexPutModel8(), (x, update))
2545
2546        class IndexPutModel9(torch.nn.Module):
2547            def forward(self, poses):
2548                w = 32
2549                x = poses[:, :, 0] - (w - 1) // 2
2550                boxes = torch.zeros([poses.shape[0], 17, 4])
2551                boxes[:, :, 0] = x
2552                return boxes
2553
2554        x = torch.zeros([2, 17, 3], dtype=torch.int64)
2555        self.run_test(IndexPutModel9(), (x,))
2556
2557        class IndexPutModel10(torch.nn.Module):
2558            def forward(self, x, ind, update):
2559                x[ind, 1:3] = update.view(1, 1, 1, 5).expand(2, 2, 2, 5)
2560                return x
2561
2562        x = torch.randn(3, 4, 5)
2563        ind = torch.tensor([[0, 2], [1, 1]])
2564        update = torch.randn(5)
2565        self.run_test(IndexPutModel10(), (x, ind, update))
2566
2567    @skipIfUnsupportedMinOpsetVersion(11)
2568    @skipScriptTest()  # Ellipses followed by tensor indexing not scriptable
2569    def test_index_put_ellipsis(self):
2570        class IndexPutModel(torch.nn.Module):
2571            def forward(self, x, update):
2572                x[..., torch.tensor([2, 1, 3]), 2:4] += update
2573                return x
2574
2575        x = torch.randn(3, 4, 5, 6, 7)
2576        update = torch.randn(3, 1, 1, 3, 2)
2577        self.run_test(IndexPutModel(), (x, update))
2578
2579        class IndexPutModel2(torch.nn.Module):
2580            def forward(self, x, update):
2581                x[2, ..., torch.tensor([2, 1, 3]), 2:4] += update
2582                return x
2583
2584        x = torch.randn(3, 4, 5, 6, 7)
2585        update = torch.randn(4, 1, 3, 2)
2586        self.run_test(IndexPutModel2(), (x, update))
2587
2588    @unittest.skip(
2589        "regression in 1.18: https://github.com/microsoft/onnxruntime/issues/20855"
2590    )
2591    @skipIfUnsupportedMinOpsetVersion(11)
2592    def test_index_put_loop(self):
2593        @torch.jit.script
2594        def ngram_attention_bias(
2595            sequence_length: int, ngram: int, device: torch.device, dtype: torch.dtype
2596        ):
2597            bias = torch.ones(
2598                (ngram, sequence_length), device=device, dtype=dtype
2599            ) * float("-inf")
2600            for stream_idx in range(ngram):
2601                for i in range(sequence_length):
2602                    bias = bias * 2
2603                    bias[stream_idx, i] = 5
2604                    bias = bias * 5
2605                    bias[0, 0] = 5
2606
2607            for stream_idx in range(ngram):
2608                for i in range(sequence_length):
2609                    bias[stream_idx, i] = 5
2610                    bias[0, i] = 5
2611            return bias
2612
2613        class ScriptModel(torch.nn.Module):
2614            def __init__(self) -> None:
2615                super().__init__()
2616                self.ngram = 2
2617                self.max_target_positions = 512
2618
2619            def forward(self, hidden_states):
2620                seq_length, batch_size = hidden_states.shape[:2]
2621                predict_causal_mask = ngram_attention_bias(
2622                    self.max_target_positions,
2623                    self.ngram,
2624                    hidden_states.device,
2625                    hidden_states.dtype,
2626                )
2627                predict_causal_mask = predict_causal_mask[:, :seq_length]
2628                return predict_causal_mask
2629
2630        x = torch.randn(6, 2)
2631        y = torch.randn(4, 1)
2632        self.run_test(
2633            ScriptModel(),
2634            x,
2635            input_names=["x"],
2636            dynamic_axes={"x": {0: "seq_length", 1: "batch_size"}},
2637            additional_test_inputs=[y],
2638        )
2639
2640    @skipIfUnsupportedMinOpsetVersion(11)
2641    def test_copy_(self):
2642        class CopyModel(torch.nn.Module):
2643            def forward(self, x, data):
2644                x[1:3] = data
2645                return x
2646
2647        x = torch.randn(3, 4)
2648        update = torch.randn(2, 4)
2649        self.run_test(CopyModel(), (x, update))
2650
2651        # mixed slice and select
2652        class CopyModel2(torch.nn.Module):
2653            def forward(self, x, data):
2654                x[1:3, 0] = data
2655                return x
2656
2657        x = torch.randn(3, 4)
2658        update = torch.tensor([0], dtype=torch.float32)
2659        self.run_test(CopyModel2(), (x, update))
2660
2661        update = torch.tensor([2, 3], dtype=torch.float32)
2662        self.run_test(CopyModel2(), (x, update))
2663
2664        update = torch.randn(2)
2665        self.run_test(CopyModel2(), (x, update))
2666
2667        class CopyModel3(torch.nn.Module):
2668            def forward(self, x, data):
2669                x[1, 1:3] = data
2670                return x
2671
2672        x = torch.randn(3, 4)
2673        update = torch.tensor([0], dtype=torch.float32)
2674        self.run_test(CopyModel3(), (x, update))
2675
2676        update = torch.tensor([2, 3], dtype=torch.float32)
2677        self.run_test(CopyModel3(), (x, update))
2678
2679        update = torch.randn(2)
2680        self.run_test(CopyModel3(), (x, update))
2681
2682        class CopyModel4(torch.nn.Module):
2683            def forward(self, x, ind, data):
2684                x[ind] = data
2685                return x
2686
2687        x = torch.randn(3, 4)
2688        ind = torch.tensor(2)
2689        data = torch.randn(4)
2690        self.run_test(CopyModel4(), (x, ind, data))
2691
2692        class CopyModel5(torch.nn.Module):
2693            def forward(self, x, mask):
2694                if mask is not None:
2695                    x.copy_(mask)
2696                    return x
2697
2698        x = torch.randn(3, 4)
2699        mask = torch.randn(3, 1)
2700        self.run_test(CopyModel5(), (x, mask))
2701
2702    @skipIfUnsupportedMinOpsetVersion(11)
2703    @skipScriptTest()  # Model not scriptable (output with shape doesn't match the broadcast shape)
2704    def test_copy_tracing(self):
2705        class CopyModel(torch.nn.Module):
2706            def forward(self, x, data):
2707                x[1, 1:3] = data
2708                return x
2709
2710        x = torch.randn(3, 4)
2711        update = torch.randn(1, 2)
2712        self.run_test(CopyModel(), (x, update))
2713
2714    @skipIfUnsupportedMinOpsetVersion(11)
2715    def test_copy_ellipsis(self):
2716        class CopyModel(torch.nn.Module):
2717            def forward(self, x, update):
2718                x[..., 1] = update
2719                return x
2720
2721        x = torch.randn(2, 3, 4)
2722        update = torch.ones(1)
2723        self.run_test(CopyModel(), (x, update))
2724
2725        x = torch.randn(2, 3, 4, 5, 6)
2726        update = torch.ones(1)
2727        self.run_test(CopyModel(), (x, update))
2728
2729    @skipIfUnsupportedMinOpsetVersion(11)
2730    def test_copy_ellipsis_script(self):
2731        class CopyModel(torch.nn.Module):
2732            def forward(self, x, update):
2733                # Insert reshape node to ensure no shape/type info for
2734                # x in scripting, without onnx shape inference.
2735                x = x.reshape(4, 3, 5, 6)
2736                x[2, ..., 1:3] = update
2737                return x
2738
2739        x = torch.randn(3, 4, 5, 6)
2740
2741        update = torch.ones(1)
2742        self.run_test(CopyModel(), (x, update))
2743
2744    @skipIfUnsupportedMinOpsetVersion(10)
2745    def test_flip(self):
2746        class MyModule(torch.nn.Module):
2747            def forward(self, x):
2748                return torch.flip(x, dims=[0])
2749
2750        x = torch.tensor(np.arange(6.0).reshape(2, 3))
2751        self.run_test(MyModule(), x)
2752
2753    @skipIfUnsupportedMinOpsetVersion(9)
2754    def test_randint(self):
2755        class RandInt(torch.nn.Module):
2756            def forward(self, x):
2757                randint = torch.randint(1, 10, x.shape)
2758                x = 0 * randint + x
2759                return x
2760
2761        x = torch.randn(2, 3, 4)
2762        self.run_test(RandInt(), x)
2763
2764    @skipIfUnsupportedMinOpsetVersion(9)
2765    def test_randint_value(self):
2766        class RandInt(torch.nn.Module):
2767            def forward(self, x):
2768                # This randint call always returns 3
2769                return torch.randint(3, 4, x.shape) + x
2770
2771        x = torch.randn(2, 3, 4)
2772        self.run_test(RandInt(), x)
2773
2774    @skipIfUnsupportedMinOpsetVersion(9)
2775    def test_randint_like(self):
2776        class RandInt(torch.nn.Module):
2777            def forward(self, x):
2778                # This randint call always returns 3
2779                return torch.randint_like(x, 3, 4) + x
2780
2781        x = torch.randn(2, 3, 4)
2782        self.run_test(RandInt(), x)
2783
2784    def test_randn(self):
2785        class RandN(torch.nn.Module):
2786            def forward(self, x):
2787                return torch.mul(x, (torch.randn(2, 3, 4) + x).size(0))
2788
2789        x = torch.randn(2, 3, 4)
2790        self.run_test(RandN(), x)
2791
2792    def test_rand(self):
2793        class Rand(torch.nn.Module):
2794            def forward(self, x):
2795                return torch.mul(x, (torch.rand(2, 3, 4) + x).size(0))
2796
2797        x = torch.randn(2, 3, 4)
2798        self.run_test(Rand(), x)
2799
2800    def test_randn_dtype(self):
2801        class RandN(torch.nn.Module):
2802            def forward(self, x):
2803                # The resulting node's dtype should be double.
2804                return (
2805                    x.to(torch.float32)
2806                    * torch.randn(2, 3, 4, dtype=torch.double)
2807                    * torch.tensor(0, dtype=torch.float32)
2808                )
2809
2810        x = torch.randn(2, 3, 4)
2811        self.run_test(RandN(), x)
2812
2813    def test_rand_dtype(self):
2814        class Rand(torch.nn.Module):
2815            def forward(self, x):
2816                # The resulting node's dtype should be double.
2817                return (
2818                    x.to(torch.float32)
2819                    * torch.rand(2, 3, 4, dtype=torch.double)
2820                    * torch.tensor(0, dtype=torch.float32)
2821                )
2822
2823        x = torch.randn(2, 3, 4)
2824        self.run_test(Rand(), x)
2825
2826    @skipIfUnsupportedMinOpsetVersion(9)
2827    def test_randn_dynamic_size(self):
2828        class RandN(torch.nn.Module):
2829            def forward(self, x):
2830                return torch.mul(x, torch.randn(x.size()).size(1))
2831
2832        x = torch.randn(2, 3, 4)
2833        self.run_test(RandN(), x)
2834
2835    @skipIfUnsupportedMinOpsetVersion(9)
2836    def test_rand_dynamic_size(self):
2837        class Rand(torch.nn.Module):
2838            def forward(self, x):
2839                return torch.mul(x, torch.rand(x.size()).size(1))
2840
2841        x = torch.randn(2, 3, 4)
2842        self.run_test(Rand(), x)
2843
2844    def test_randn_like(self):
2845        class RandNLike(torch.nn.Module):
2846            def forward(self, x):
2847                return torch.mul(x, torch.randn_like(x).size(0))
2848
2849        x = torch.randn(2, 3, 4)
2850        self.run_test(RandNLike(), x)
2851        self.run_test(torch.jit.script(RandNLike()), x)
2852
2853    def test_rand_like(self):
2854        class RandLike(torch.nn.Module):
2855            def forward(self, x):
2856                return torch.mul(x, torch.rand_like(x).size(0))
2857
2858        x = torch.randn(2, 3, 4)
2859        self.run_test(RandLike(), x)
2860        self.run_test(torch.jit.script(RandLike()), x)
2861
2862    def test_randn_like_dtype(self):
2863        class RandNLike(torch.nn.Module):
2864            def forward(self, x):
2865                # The resulting node's dtype should be double.
2866                return (
2867                    x.to(torch.float32)
2868                    * torch.randn_like(x, dtype=torch.double)
2869                    * torch.tensor(0, dtype=torch.float32)
2870                )
2871
2872        x = torch.randn(2, 3, 4)
2873        self.run_test(RandNLike(), x)
2874
2875    def test_rand_like_dtype(self):
2876        class RandLike(torch.nn.Module):
2877            def forward(self, x):
2878                # The resulting node's dtype should be double.
2879                return (
2880                    x.to(torch.float32)
2881                    * torch.rand_like(x, dtype=torch.double)
2882                    * torch.tensor(0, dtype=torch.float32)
2883                )
2884
2885        x = torch.randn(2, 3, 4)
2886        self.run_test(RandLike(), x)
2887
2888    def test_bernoulli(self):
2889        class Bernoulli(torch.nn.Module):
2890            def forward(self, x):
2891                return torch.mul(x, torch.bernoulli(x).size(0))
2892
2893        x = torch.empty(3, 3).uniform_(0, 1)
2894        self.run_test(Bernoulli(), x)
2895
2896        x = torch.empty(2, 3, 3, dtype=torch.double).uniform_(0, 1)
2897        self.run_test(Bernoulli(), x)
2898
2899    def test_bernoulli_p(self):
2900        class Bernoulli_float(torch.nn.Module):
2901            def forward(self, x):
2902                return torch.mul(x, torch.bernoulli(x, 0.2).size(0))
2903
2904        class Bernoulli_tensor(torch.nn.Module):
2905            def forward(self, x):
2906                return torch.mul(x, torch.rand_like(x).bernoulli_(x).size(0))
2907
2908        x = torch.rand(3, 3)
2909        self.run_test(Bernoulli_float(), x)
2910        self.run_test(Bernoulli_tensor(), x)
2911
2912        x = torch.rand(2, 3, 3, dtype=torch.double)
2913        self.run_test(Bernoulli_float(), x)
2914        self.run_test(Bernoulli_tensor(), x)
2915
2916    @unittest.skip("Bug in ORT, skip test until rel-1.11.")
2917    @skipIfUnsupportedMinOpsetVersion(14)
2918    def test_reshape_allowzero(self):
2919        class ReshapeModel(torch.nn.Module):
2920            def forward(self, x):
2921                x = x.reshape(3, 4, 0)
2922                return x
2923
2924        x = torch.randn(0, 3, 4)
2925        self.run_test(ReshapeModel(), x)
2926
2927    def test_reshape_different_rank(self):
2928        class ReshapeModel(torch.nn.Module):
2929            def forward(self, x):
2930                x = x.reshape(-1, 2, 4, 4, 5, 5)
2931                return x
2932
2933        x = torch.randn(1, 32, 5, 5)
2934        self.run_test(ReshapeModel(), x)
2935
2936    def _interpolate(self, x, mode, use_size, is_upsample, align_corners=False):
2937        class MyModel(torch.nn.Module):
2938            __constants__ = [
2939                "mode",
2940                "use_size",
2941                "is_upsample",
2942                "size",
2943                "scale",
2944                "size_array",
2945                "scale_array",
2946                "align_corners",
2947            ]
2948
2949            def __init__(self, mode, use_size, is_upsample, align_corners):
2950                super().__init__()
2951                self.mode = mode
2952                self.use_size = use_size
2953                self.is_upsample = is_upsample
2954                self.align_corners = align_corners
2955                self.scale = 2.0 if self.is_upsample else 0.5
2956                self.size = 24 if self.is_upsample else 2
2957                if x.dim() == 3:
2958                    self.scale_array = [2.3]
2959                    self.size_array = [16]
2960                elif x.dim() == 4:
2961                    self.scale_array = [2.3, 3.1]
2962                    self.size_array = [16, 32]
2963                else:
2964                    self.scale_array = [2.3, 3.1, 4.6]
2965                    self.size_array = [16, 32, 64]
2966
2967            def forward(self, x):
2968                if self.use_size:
2969                    if self.align_corners:
2970                        return torch.nn.functional.interpolate(
2971                            x, mode=self.mode, size=self.size, align_corners=True
2972                        ), torch.nn.functional.interpolate(
2973                            x, mode=self.mode, size=self.size_array, align_corners=True
2974                        )
2975                    return torch.nn.functional.interpolate(
2976                        x, mode=self.mode, size=self.size
2977                    ), torch.nn.functional.interpolate(
2978                        x, mode=self.mode, size=self.size_array
2979                    )
2980                if self.align_corners:
2981                    return torch.nn.functional.interpolate(
2982                        x,
2983                        mode=self.mode,
2984                        scale_factor=self.scale,
2985                        recompute_scale_factor=False,
2986                    ), torch.nn.functional.interpolate(
2987                        x,
2988                        mode=self.mode,
2989                        scale_factor=self.scale_array,
2990                        recompute_scale_factor=False,
2991                    )
2992                return torch.nn.functional.interpolate(
2993                    x,
2994                    mode=self.mode,
2995                    scale_factor=self.scale,
2996                    recompute_scale_factor=False,
2997                ), torch.nn.functional.interpolate(
2998                    x,
2999                    mode=self.mode,
3000                    scale_factor=self.scale_array,
3001                    recompute_scale_factor=False,
3002                )
3003
3004        model = MyModel(mode, use_size, is_upsample, align_corners)
3005        self.run_test(model, x, atol=1e-6)
3006
3007    def _interpolate_tests(self, is_upsample):
3008        # - cubic mode is not supported for opsets below 11;
3009        # - linear mode does not match for opsets below 11;
3010        modes = ["nearest", "linear", "bicubic"]
3011        if self.opset_version < 11:
3012            modes = ["nearest"]
3013        x = [
3014            torch.randn(1, 2, 6, requires_grad=True),
3015            torch.randn(1, 2, 4, 6, requires_grad=True),
3016            torch.randn(1, 2, 4, 4, 6, requires_grad=True),
3017        ]
3018
3019        for mode in modes:
3020            for xi in x:
3021                mode_i = mode
3022                # TODO: enable bicubic downsample when ORT precision loss fixed
3023                if mode == "bicubic" and xi.dim() != 4:
3024                    continue
3025                elif mode == "linear":
3026                    if xi.dim() == 3:
3027                        # TODO : enable when linear mode is implemented for 1d inputs in ORT
3028                        continue
3029                    elif xi.dim() == 4:
3030                        mode_i = "bilinear"
3031                    elif xi.dim() == 5:
3032                        # TODO : enable when linear mode is implemented for 3d inputs in ORT
3033                        mode_i = "trilinear"
3034                        continue
3035                self._interpolate(xi, mode_i, True, is_upsample)
3036                # test with align_corners if supported
3037                if mode != "nearest":
3038                    self._interpolate(xi, mode_i, True, is_upsample, True)
3039                # the following cases, require dynamic sizes/scales,
3040                # which which is not supported for opset_version < 9
3041                if self.opset_version >= 9:
3042                    self._interpolate(xi, mode_i, True, is_upsample)
3043                    # test with align_corners if supported
3044                    if mode != "nearest":
3045                        self._interpolate(xi, mode_i, False, is_upsample, True)
3046                    self._interpolate(xi, mode_i, False, is_upsample)
3047
3048    # ONNX export failed on interpolate scripting because dynamic size not supported for opsets below 9.
3049    @skipIfUnsupportedMinOpsetVersion(9)
3050    def test_interpolate_upsample(self):
3051        self._interpolate_tests(True)
3052
3053    @skipIfUnsupportedMaxOpsetVersion(8)
3054    @skipScriptTest()  # Scripting supported for opsets > 8. See test_interpolate_upsample
3055    def test_interpolate_upsample_trace(self):
3056        self._interpolate_tests(True)
3057
3058    @skipIfUnsupportedMinOpsetVersion(9)
3059    def test_interpolate_function_substitution(self):
3060        class ScriptModel(torch.jit.ScriptModule):
3061            @torch.jit.script_method
3062            def forward(self, x):
3063                return torch.nn.functional.interpolate(
3064                    x, mode="nearest", scale_factor=2.0
3065                )
3066
3067        class ScriptModule(torch.jit.ScriptModule):
3068            def __init__(self) -> None:
3069                super().__init__()
3070                self.submodule = ScriptModel()
3071
3072            @torch.jit.script_method
3073            def forward(self, input):
3074                return self.submodule(input)
3075
3076        x = torch.randn(1, 2, 4, 4, 6)
3077        self.run_test(ScriptModule(), (x,))
3078
3079        @torch.jit.script
3080        def script_method(x):
3081            return torch.nn.functional.interpolate(x, mode="nearest", scale_factor=2.0)
3082
3083        class TracingModule(torch.nn.Module):
3084            def forward(self, x):
3085                return script_method(x)
3086
3087        self.run_test(TracingModule(), (x,))
3088
3089    @skipIfUnsupportedMinOpsetVersion(10)
3090    def test_interpolate_downsample(self):
3091        self._interpolate_tests(False)
3092
3093    @skipIfUnsupportedMinOpsetVersion(11)
3094    def test_interpolate_half_pixel(self):
3095        # testing whether it uses "half_pixel" or "pytorch_half_pixel"
3096        # see https://github.com/onnx/onnx/blob/main/docs/Operators.md#Resize
3097
3098        class MyModel(torch.nn.Module):
3099            def __init__(self, mode, size):
3100                super().__init__()
3101                self.mode = mode
3102                self.size = size
3103
3104            def forward(self, x):
3105                return torch.nn.functional.interpolate(
3106                    x, mode=self.mode, size=self.size
3107                )
3108
3109        modes = ["linear", "bicubic"]
3110        x = [
3111            torch.randn(1, 2, 6, requires_grad=True),
3112            torch.randn(1, 2, 4, 6, requires_grad=True),
3113            torch.randn(1, 2, 4, 4, 6, requires_grad=True),
3114        ]
3115        for mode in modes:
3116            for xi in x:
3117                mode_i = mode
3118                if mode == "bicubic" and xi.dim() != 4:
3119                    continue
3120                elif mode == "linear":
3121                    if xi.dim() == 4:
3122                        mode_i = "bilinear"
3123                    elif xi.dim() == 5:
3124                        mode_i = "trilinear"
3125                for i in range(xi.dim() - 2):
3126                    size = list(xi.shape[2:])
3127                    size[i] = 1
3128                    self.run_test(MyModel(mode_i, size), xi)
3129
3130    @skipIfUnsupportedMinOpsetVersion(11)
3131    def test_interpolate_no_shape(self):
3132        class MyModel(torch.jit.ScriptModule):
3133            @torch.jit.script_method
3134            def forward(self, x, y):
3135                x = torch.add(x, x)
3136                out1 = torch.nn.functional.interpolate(
3137                    x, mode="bilinear", size=(16, 16), align_corners=False
3138                )
3139                out2 = torch.nn.functional.interpolate(
3140                    x, mode="nearest", size=(int(y.size(0)), int(y.size(1)))
3141                )
3142                return out1, out2
3143
3144        x = torch.randn(1, 2, 4, 4, requires_grad=True)
3145        y = torch.randn(16, 16, requires_grad=True)
3146        self.run_test(
3147            MyModel(),
3148            (x, y),
3149            input_names=["x", "y"],
3150            dynamic_axes={"x": [0, 1, 2, 3], "y": [0, 1]},
3151        )
3152        self.run_test(MyModel(), (x, y), remained_onnx_input_idx=[0])
3153
3154    @skipScriptTest()  # scripting raises OnnxRuntimeError
3155    def test_interpolate_adaptive_pooling_error(self):
3156        x = torch.randn(1, 2, 6, requires_grad=True)
3157        with self.assertRaises(RuntimeError) as cm:
3158            self._interpolate(x, "area", True, True)
3159
3160        with self.assertRaises(RuntimeError) as cm:
3161            self._interpolate(x, "area", False, True)
3162
3163    def test_groupnorm(self):
3164        model = torch.nn.GroupNorm(3, 6, 0.002)
3165        x = torch.randn(4, 6, 36, 36, 18)
3166        self.run_test(model, x)
3167
3168        model = torch.nn.GroupNorm(1, 6, 0.002)
3169        x = torch.randn(4, 6, 180, 180)
3170        self.run_test(model, x)
3171
3172        model = torch.nn.GroupNorm(6, 6, 0.002)
3173        x = torch.randn(4, 6, 180, 180)
3174        self.run_test(model, x)
3175
3176    def test_groupnorm_noaffine(self):
3177        model = torch.nn.GroupNorm(4, 8, 0.002, affine=False)
3178        x = torch.randn(3, 8, 224, 224)
3179        self.run_test(model, x)
3180
3181        model = torch.nn.GroupNorm(1, 6, 0.002, affine=False)
3182        x = torch.randn(4, 6, 180, 180)
3183        self.run_test(model, x)
3184
3185        model = torch.nn.GroupNorm(6, 6, 0.002, affine=False)
3186        x = torch.randn(4, 6, 180, 180)
3187        self.run_test(model, x)
3188
3189    @skipIfUnsupportedMinOpsetVersion(9)
3190    def test_list_unpack_scripted(self):
3191        class ListUnpack(torch.nn.Module):
3192            def forward(self, x):
3193                a, b = x.shape
3194                return x.new_zeros((a, b))
3195
3196        x = torch.randn(2, 3)
3197        self.run_test(
3198            torch.jit.script(ListUnpack()),
3199            x,
3200            input_names=["x"],
3201            dynamic_axes={"x": [0, 1]},
3202        )
3203        self.run_test(torch.jit.script(ListUnpack()), x, remained_onnx_input_idx=[])
3204
3205    @skipIfUnsupportedMinOpsetVersion(9)
3206    def test_list_unpack_scripted_runs_without_error_with_constructed_list_as_input(
3207        self,
3208    ):
3209        class PackUnpack(torch.nn.Module):
3210            """Create and unpack a list of tensors.
3211
3212            When scripted, it should produce a graph similar to
3213
3214            ```
3215            graph(%self : __torch__.PackUnpack,
3216                %a.1 : Tensor,
3217                %b.1 : Tensor):
3218            %packed.1 : Tensor[] = prim::ListConstruct(%a.1, %b.1)
3219            %c.1 : Tensor, %8 : Tensor = prim::ListUnpack(%packed.1)
3220            return (%c.1)
3221            ```
3222            """
3223
3224            def forward(self, a, b):
3225                packed = [a, b]
3226                c, _ = packed
3227                return c
3228
3229        self.run_test(
3230            torch.jit.script(PackUnpack()),
3231            (torch.tensor(0), torch.tensor([42])),
3232            remained_onnx_input_idx=[0],
3233        )
3234
3235    @skipIfUnsupportedMinOpsetVersion(9)
3236    def test_list_unpack_slice_scripted(self):
3237        class ListUnpackSlice(torch.nn.Module):
3238            def forward(self, x):
3239                a, b = x.shape[2:]
3240                return x.new_zeros((a, b))
3241
3242        x = torch.randn(2, 3, 4, 5)
3243        self.run_test(
3244            torch.jit.script(ListUnpackSlice()),
3245            x,
3246            input_names=["x"],
3247            dynamic_axes={"x": [0, 1, 2, 3]},
3248        )
3249        self.run_test(
3250            torch.jit.script(ListUnpackSlice()), x, remained_onnx_input_idx=[]
3251        )
3252
3253    @skipDtypeChecking
3254    def test_pow(self):
3255        class PowModule(torch.nn.Module):
3256            def forward(self, x, y):
3257                return x.pow(y)
3258
3259        x = torch.randn(2, 3, 4)
3260        y = torch.randn(2, 3, 4)
3261        self.run_test(PowModule(), (x, y))
3262
3263        x = torch.randint(10, (2, 3, 4))
3264        y = torch.randint(10, (2, 3, 4)).to(dtype=torch.int32)
3265        self.run_test(PowModule(), (x, y))
3266
3267        x = torch.randint(10, (2, 3, 4))
3268        y = torch.randint(10, (2, 3, 4))
3269        self.run_test(PowModule(), (x, y))
3270
3271        x = torch.randn(2, 3, 4).to(dtype=torch.float64)
3272        y = torch.randint(10, (2, 3, 4))
3273        self.run_test(PowModule(), (x, y))
3274
3275        class PowModule2(torch.nn.Module):
3276            def forward(self, x):
3277                return torch.pow(2, x)
3278
3279        x = torch.randn(1, 10)
3280        self.run_test(PowModule2(), (x,))
3281
3282        x = torch.randint(10, (2, 3, 4))
3283        self.run_test(PowModule2(), (x,))
3284
3285        x = torch.randn(1, 10).to(dtype=torch.float64)
3286        self.run_test(PowModule2(), (x,))
3287
3288        class PowModule3(torch.nn.Module):
3289            def forward(self, x, y):
3290                return y[torch.pow(2, x)]
3291
3292        x = torch.randint(5, (2, 3, 4))
3293        y = torch.rand(100)
3294        self.run_test(PowModule3(), (x, y))
3295
3296    # the arithmeticOps(Add\Sub\Mul\Div\Gemm\Pow\Mod) with low precision include unit8 will be failed in ORT
3297    # add to(dtype=torch.long) to avoid ORT output type does not match expected type.
3298    # will be fixed in ONNX version 14.
3299    @skipIfUnsupportedMaxOpsetVersion(13)
3300    @skipDtypeChecking
3301    def test_arithmeticOps_with_low_precision(self):
3302        class AddModule(torch.nn.Module):
3303            def forward(self, x, y):
3304                return x + y
3305
3306        class SubModule(torch.nn.Module):
3307            def forward(self, x, y):
3308                return x - y
3309
3310        class MulModule(torch.nn.Module):
3311            def forward(self, x, y):
3312                return x * y
3313
3314        class DivModule(torch.nn.Module):
3315            def forward(self, x, y):
3316                return x / y
3317
3318        class PowModule(torch.nn.Module):
3319            def forward(self, x, y):
3320                return x.pow(y)
3321
3322        x = torch.tensor([2, 3, 5], dtype=torch.uint8)
3323        y = torch.tensor([2, 3, 5], dtype=torch.uint8)
3324        z = torch.tensor([1], dtype=torch.uint8)
3325        self.run_test(AddModule(), (x, y))
3326        self.run_test(SubModule(), (x, y))
3327        self.run_test(MulModule(), (x, y))
3328        self.run_test(DivModule(), (x, y))
3329        self.run_test(PowModule(), (x, z))
3330
3331        x = torch.tensor([2, 3, 5], dtype=torch.int8)
3332        y = torch.tensor([2, 3, 5], dtype=torch.int8)
3333        z = torch.tensor([1], dtype=torch.int8)
3334        self.run_test(AddModule(), (x, y))
3335        self.run_test(SubModule(), (x, y))
3336        self.run_test(MulModule(), (x, y))
3337        self.run_test(DivModule(), (x, y))
3338        self.run_test(PowModule(), (x, z))
3339
3340        x = torch.tensor([2, 3, 5], dtype=torch.int16)
3341        y = torch.tensor([2, 3, 5], dtype=torch.int16)
3342        z = torch.tensor([1], dtype=torch.int16)
3343        self.run_test(AddModule(), (x, y))
3344        self.run_test(SubModule(), (x, y))
3345        self.run_test(MulModule(), (x, y))
3346        self.run_test(DivModule(), (x, y))
3347        self.run_test(PowModule(), (x, z))
3348
3349        x = torch.tensor([2, 3, 5], dtype=torch.uint8)
3350        y = torch.tensor([2, 3, 5], dtype=torch.float32)
3351        z = torch.tensor([1], dtype=torch.float64)
3352        self.run_test(AddModule(), (x, y))
3353        self.run_test(SubModule(), (x, y))
3354        self.run_test(MulModule(), (x, y))
3355        self.run_test(DivModule(), (x, y))
3356        self.run_test(PowModule(), (x, z))
3357
3358        x = torch.tensor([2, 3, 5], dtype=torch.uint8)
3359        y = torch.tensor([2, 3, 5], dtype=torch.int64)
3360        z = torch.tensor([1], dtype=torch.int32)
3361        self.run_test(AddModule(), (x, y))
3362        self.run_test(SubModule(), (x, y))
3363        self.run_test(MulModule(), (x, y))
3364        self.run_test(DivModule(), (x, y))
3365        self.run_test(PowModule(), (x, z))
3366
3367    def test_mul_bool(self):
3368        class MyModel(torch.nn.Module):
3369            def forward(self, x, y):
3370                return torch.mul(x, y)
3371
3372        x_t = torch.tensor([True, False, True, False])
3373        y_t = torch.tensor([True, True, False, False])
3374        z_t = torch.tensor([1.0, 2.0, 3.0, 0.0])
3375        self.run_test(MyModel(), (x_t, y_t))
3376        self.run_test(MyModel(), (x_t, z_t))
3377        self.run_test(MyModel(), (z_t, y_t))
3378
3379    # fmod was added in version 10
3380    @skipIfUnsupportedMinOpsetVersion(10)
3381    @skipIfUnsupportedMaxOpsetVersion(13)
3382    def test_mod_with_low_precision(self):
3383        class ModModule(torch.nn.Module):
3384            def forward(self, x, y):
3385                return torch.fmod(x, y).to(dtype=torch.long)
3386
3387        x = torch.tensor([2, 3, 5], dtype=torch.uint8)
3388        y = torch.tensor([2, 3, 5], dtype=torch.uint8)
3389        self.run_test(ModModule(), (x, y))
3390
3391        x = torch.tensor([2, 3, 5], dtype=torch.int8)
3392        y = torch.tensor([2, 3, 5], dtype=torch.int8)
3393        self.run_test(ModModule(), (x, y))
3394
3395        x = torch.tensor([2, 3, 5], dtype=torch.int16)
3396        y = torch.tensor([2, 3, 5], dtype=torch.int16)
3397        self.run_test(ModModule(), (x, y))
3398
3399        x = torch.tensor([2, 3, 5], dtype=torch.uint8)
3400        y = torch.tensor([2, 3, 5], dtype=torch.int32)
3401        self.run_test(ModModule(), (x, y))
3402
3403        x = torch.tensor([2, 3, 5], dtype=torch.uint8)
3404        y = torch.tensor([2, 3, 5], dtype=torch.float64)
3405        self.run_test(ModModule(), (x, y))
3406
3407    @skipIfUnsupportedMinOpsetVersion(9)
3408    def test_empty_constant_shape(self):
3409        class Zeros(torch.nn.Module):
3410            def forward(self, x):
3411                y = torch.zeros(())
3412                y += x
3413                return y
3414
3415        x = torch.tensor(42.0)
3416        self.run_test(Zeros(), x)
3417
3418        class Ones(torch.nn.Module):
3419            def forward(self, x):
3420                y = torch.ones(())
3421                y += x
3422                return y
3423
3424        x = torch.tensor(42.0)
3425        self.run_test(Ones(), x)
3426
3427        class Full(torch.nn.Module):
3428            def forward(self, x):
3429                y = torch.full((), 1.0)
3430                y += x
3431                return y
3432
3433        x = torch.tensor(42.0)
3434        self.run_test(Full(), x)
3435
3436        class Empty(torch.nn.Module):
3437            def forward(self, x):
3438                y = torch.empty(()).fill_(0)
3439                y += x
3440                return y
3441
3442        x = torch.tensor(42.0)
3443        self.run_test(Empty(), x)
3444
3445    def test_std(self):
3446        class StandardDeviation(torch.nn.Module):
3447            def forward(self, input):
3448                return torch.std(input, unbiased=False)
3449
3450        x = torch.randn(2, 3, 4)
3451        model = StandardDeviation()
3452        self.run_test(model, x)
3453
3454        class StandardDeviationUnbiased(torch.nn.Module):
3455            def forward(self, input):
3456                return torch.std(input, unbiased=True)
3457
3458        model = StandardDeviationUnbiased()
3459        self.run_test(model, x)
3460
3461    def test_std_along_dims(self):
3462        class StandardDeviation(torch.nn.Module):
3463            def forward(self, input):
3464                return torch.std(input, dim=(0, 1), unbiased=False)
3465
3466        x = torch.randn(2, 3, 4)
3467        model = StandardDeviation()
3468        self.run_test(model, x)
3469
3470        class StandardDeviationUnbiased(torch.nn.Module):
3471            def forward(self, input):
3472                return torch.std(input, dim=(0, 1), unbiased=True)
3473
3474        x = torch.randn(2, 3, 4)
3475        model = StandardDeviationUnbiased()
3476        self.run_test(model, x)
3477
3478    def test_std_keepdim(self):
3479        class StandardDeviation(torch.nn.Module):
3480            def forward(self, input):
3481                return torch.std(input, dim=(0, 1), unbiased=False, keepdim=True)
3482
3483        x = torch.randn(2, 3, 4)
3484        model = StandardDeviation()
3485        self.run_test(model, x)
3486
3487        class StandardDeviationUnbiased(torch.nn.Module):
3488            def forward(self, input):
3489                return torch.std(input, dim=(0, 1), unbiased=True, keepdim=True)
3490
3491        x = torch.randn(2, 3, 4)
3492        model = StandardDeviationUnbiased()
3493        self.run_test(model, x)
3494
3495    def test_std_correction(self):
3496        class StandardDeviation(torch.nn.Module):
3497            def forward(self, input):
3498                return torch.std(input, dim=(0, 1), correction=3, keepdim=True)
3499
3500        x = torch.randn(2, 3, 4)
3501        model = StandardDeviation()
3502        self.run_test(model, x)
3503
3504    def test_var(self):
3505        class Variance(torch.nn.Module):
3506            def forward(self, input):
3507                return torch.var(input, unbiased=False)
3508
3509        x = torch.randn(2, 3, 4)
3510        model = Variance()
3511        self.run_test(model, x)
3512
3513        class VarianceUnbiased(torch.nn.Module):
3514            def forward(self, input):
3515                return torch.var(input, unbiased=True)
3516
3517        model = VarianceUnbiased()
3518        self.run_test(model, x)
3519
3520        class VarianceSqrt(torch.nn.Module):
3521            def forward(self, input):
3522                y = torch.var(input, 1)
3523                return torch.sqrt(y + 1e-8)
3524
3525        x = torch.randn(1, 2, 3, 300, 300)
3526        model = VarianceSqrt()
3527        self.run_test(model, x)
3528
3529    def test_var_along_dims(self):
3530        class Variance(torch.nn.Module):
3531            def forward(self, input):
3532                return torch.var(input, dim=(0, 1), unbiased=False)
3533
3534        x = torch.randn(2, 3, 4)
3535        model = Variance()
3536        self.run_test(model, x)
3537
3538        class VarianceUnbiased(torch.nn.Module):
3539            def forward(self, input):
3540                return torch.var(input, dim=(0, 1), unbiased=True)
3541
3542        x = torch.randn(2, 3, 4)
3543        model = VarianceUnbiased()
3544        self.run_test(model, x)
3545
3546    def test_var_keepdim(self):
3547        class Variance(torch.nn.Module):
3548            def forward(self, input):
3549                return torch.var(input, dim=(0, 1), unbiased=False, keepdim=True)
3550
3551        x = torch.randn(2, 3, 4)
3552        model = Variance()
3553        self.run_test(model, x)
3554
3555        class VarianceUnbiased(torch.nn.Module):
3556            def forward(self, input):
3557                return torch.var(input, dim=(0, 1), unbiased=True, keepdim=True)
3558
3559        x = torch.randn(2, 3, 4)
3560        model = VarianceUnbiased()
3561        self.run_test(model, x)
3562
3563    def test_var_correction(self):
3564        class Variance(torch.nn.Module):
3565            def forward(self, input):
3566                return torch.var(input, dim=(0, 1), correction=3, keepdim=True)
3567
3568        x = torch.randn(2, 3, 4)
3569        model = Variance()
3570        self.run_test(model, x)
3571
3572    def test_var_mean(self):
3573        class Variance(torch.nn.Module):
3574            def forward(self, input):
3575                return torch.var_mean(input, unbiased=False)
3576
3577        x = torch.randn(2, 3, 4)
3578        model = Variance()
3579        self.run_test(model, x)
3580
3581        class VarianceUnbiased(torch.nn.Module):
3582            def forward(self, input):
3583                return torch.var_mean(input, unbiased=True)
3584
3585        model = VarianceUnbiased()
3586        self.run_test(model, x)
3587
3588    def test_var_mean_along_dims(self):
3589        class Variance(torch.nn.Module):
3590            def forward(self, input):
3591                return torch.var_mean(input, dim=(0, 1), unbiased=False)
3592
3593        x = torch.randn(2, 3, 4)
3594        model = Variance()
3595        self.run_test(model, x)
3596
3597        class VarianceUnbiased(torch.nn.Module):
3598            def forward(self, input):
3599                return torch.var_mean(input, dim=(0, 1), unbiased=True)
3600
3601        x = torch.randn(2, 3, 4)
3602        model = VarianceUnbiased()
3603        self.run_test(model, x)
3604
3605    def test_var_mean_mixed_dims(self):
3606        class ReverseDims(torch.nn.Module):
3607            def forward(self, input):
3608                return torch.var_mean(input, dim=(2, 1), unbiased=False)
3609
3610        x = torch.randn(2, 3, 4)
3611        model = ReverseDims()
3612        self.run_test(model, x)
3613
3614        class SkipDims(torch.nn.Module):
3615            def forward(self, input):
3616                return torch.var_mean(input, dim=(0, 2), unbiased=False)
3617
3618        x = torch.randn(2, 3, 4)
3619        model = SkipDims()
3620        self.run_test(model, x)
3621
3622        class NonZeroDims(torch.nn.Module):
3623            def forward(self, input):
3624                return torch.var_mean(input, dim=(1, 2), unbiased=False)
3625
3626        x = torch.randn(2, 3, 4)
3627        model = NonZeroDims()
3628        self.run_test(model, x)
3629
3630    def test_var_mean_keepdim(self):
3631        class Variance(torch.nn.Module):
3632            def forward(self, input):
3633                return torch.var_mean(input, dim=(0, 1), unbiased=False, keepdim=True)
3634
3635        x = torch.randn(2, 3, 4)
3636        model = Variance()
3637        self.run_test(model, x)
3638
3639        class VarianceUnbiased(torch.nn.Module):
3640            def forward(self, input):
3641                return torch.var_mean(input, dim=(0, 1), unbiased=True, keepdim=True)
3642
3643        x = torch.randn(2, 3, 4)
3644        model = VarianceUnbiased()
3645        self.run_test(model, x)
3646
3647    def test_var_mean_correction(self):
3648        class Variance(torch.nn.Module):
3649            def forward(self, input):
3650                return torch.var_mean(input, dim=(0, 1), correction=3, keepdim=True)
3651
3652        x = torch.randn(2, 3, 4)
3653        model = Variance()
3654        self.run_test(model, x)
3655
3656    def test_std_mean(self):
3657        class StandardDeviation(torch.nn.Module):
3658            def forward(self, input):
3659                return torch.std_mean(input, unbiased=False)
3660
3661        x = torch.randn(2, 3, 4)
3662        model = StandardDeviation()
3663        self.run_test(model, x)
3664
3665        class StandardDeviationUnbiased(torch.nn.Module):
3666            def forward(self, input):
3667                return torch.std_mean(input, unbiased=True)
3668
3669        model = StandardDeviationUnbiased()
3670        self.run_test(model, x)
3671
3672    def test_std_mean_along_dims(self):
3673        class StandardDeviation(torch.nn.Module):
3674            def forward(self, input):
3675                return torch.std_mean(input, dim=(0, 1), unbiased=False)
3676
3677        x = torch.randn(2, 3, 4)
3678        model = StandardDeviation()
3679        self.run_test(model, x)
3680
3681        class VarianceUnbiased(torch.nn.Module):
3682            def forward(self, input):
3683                return torch.std_mean(input, dim=(0, 1), unbiased=True)
3684
3685        x = torch.randn(2, 3, 4)
3686        model = VarianceUnbiased()
3687        self.run_test(model, x)
3688
3689    def test_std_mean_keepdim(self):
3690        class StandardDeviation(torch.nn.Module):
3691            def forward(self, input):
3692                return torch.std_mean(input, dim=(0, 1), unbiased=False, keepdim=True)
3693
3694        x = torch.randn(2, 3, 4)
3695        model = StandardDeviation()
3696        self.run_test(model, x)
3697
3698        class StandardDeviationUnbiased(torch.nn.Module):
3699            def forward(self, input):
3700                return torch.std_mean(input, dim=(0, 1), unbiased=True, keepdim=True)
3701
3702        x = torch.randn(2, 3, 4)
3703        model = StandardDeviationUnbiased()
3704        self.run_test(model, x)
3705
3706    def test_std_mean_correction(self):
3707        class StandardDeviation(torch.nn.Module):
3708            def forward(self, input):
3709                return torch.var_mean(input, dim=(0, 1), correction=3, keepdim=True)
3710
3711        x = torch.randn(2, 3, 4)
3712        model = StandardDeviation()
3713        self.run_test(model, x)
3714
3715    def test_bitshift(self):
3716        class BitshiftModel(torch.nn.Module):
3717            def forward(self, input):
3718                return (
3719                    input >> 1,
3720                    input << 3,
3721                    input >> torch.tensor([1, 2]),
3722                    input << 4,
3723                )
3724
3725        input = torch.arange(24, dtype=torch.int64).reshape(3, 4, 2)
3726        self.run_test(BitshiftModel(), input)
3727
3728    @skipIfUnsupportedMinOpsetVersion(18)
3729    def test_bitwise_and(self):
3730        class BitwiseAndModel(torch.nn.Module):
3731            def forward(self, input, other):
3732                return (
3733                    input & 20,
3734                    torch.bitwise_and(input, other),
3735                    other & torch.tensor([1, 2], dtype=torch.int32),
3736                )
3737
3738        input = torch.randint(0, 255, (3, 4, 2), dtype=torch.uint8)
3739        other = torch.randint(-128, 127, (3, 4, 2), dtype=torch.int8)
3740        self.run_test(BitwiseAndModel(), (input, other))
3741
3742    # uint8 not implemented in ORT for Mul used in
3743    # exporting bitshift for opset_version < 10
3744    @skipIfUnsupportedMinOpsetVersion(11)
3745    def test_bitshift_uint8(self):
3746        class BitshiftModel(torch.nn.Module):
3747            def forward(self, input, input2):
3748                return (
3749                    input >> 1,
3750                    input << 3,
3751                    input2 >> torch.tensor([1, 2], dtype=torch.uint8),
3752                    input2 << 4,
3753                )
3754
3755        input = torch.arange(24, dtype=torch.uint8).reshape(3, 4, 2)
3756        input2 = torch.arange(24, dtype=torch.uint8).reshape(3, 4, 2)
3757        self.run_test(BitshiftModel(), (input, input2))
3758
3759    def test_narrow(self):
3760        class NarrowModel(torch.nn.Module):
3761            def forward(self, input):
3762                return torch.narrow(input, 0, 0, 2)
3763
3764        x = torch.randn(3, 3, requires_grad=True)
3765        self.run_test(NarrowModel(), x)
3766
3767    @skipIfUnsupportedMinOpsetVersion(11)
3768    def test_narrow_dynamic(self):
3769        class NarrowModel(torch.nn.Module):
3770            def forward(self, input):
3771                return torch.narrow(input, 0, 0, input.shape[0] - 1)
3772
3773        x = torch.randn(3, 3, requires_grad=True)
3774        self.run_test(NarrowModel(), x)
3775
3776    @skipIfUnsupportedMinOpsetVersion(9)
3777    def test_index_fill(self):
3778        class IndexFillModel(torch.nn.Module):
3779            def forward(self, input):
3780                index = torch.tensor([2, 0])
3781                return input.index_fill(2, index, -1)
3782
3783        x = torch.randn(3, 4, 5, requires_grad=True)
3784        self.run_test(IndexFillModel(), x)
3785
3786    @skipIfUnsupportedMinOpsetVersion(9)
3787    def test_index_copy(self):
3788        class IndexCopyModel(torch.nn.Module):
3789            def __init__(self, dim):
3790                super().__init__()
3791                self.dim = dim
3792
3793            def forward(self, input):
3794                index = torch.tensor([2, 0])
3795                source = torch.ones(3, 2, 5)
3796                return input.index_copy(self.dim, index, source)
3797
3798        x = torch.randn(3, 4, 5, requires_grad=True)
3799        for dim in (1, -2):
3800            self.run_test(IndexCopyModel(dim), x)
3801
3802    def test_select(self):
3803        class Select(torch.nn.Module):
3804            def forward(self, x):
3805                return x[:, 1]
3806
3807        x = torch.randn(3, 4)
3808        self.run_test(Select(), x)
3809
3810    def test_select_negative_index(self):
3811        class Select(torch.nn.Module):
3812            def forward(self, x):
3813                return x[:, -1]
3814
3815        x = torch.randn(3, 4)
3816        self.run_test(Select(), x)
3817
3818    def test_index_select_constant_scaler_index(self):
3819        class IndexSelectScalerIndexModel(torch.nn.Module):
3820            def forward(self, x):
3821                index = 2
3822                return torch.index_select(x, 1, torch.tensor(index))
3823
3824        x = torch.randn(3, 4)
3825        self.run_test(IndexSelectScalerIndexModel(), x)
3826
3827    def test_index_select_scaler_index(self):
3828        class IndexSelectScalerIndexModel(torch.nn.Module):
3829            def __init__(self, index_base):
3830                super().__init__()
3831                self.index_base = torch.tensor(index_base)
3832
3833            def forward(self, x, index_offset):
3834                index = self.index_base + index_offset
3835                return torch.index_select(x, 1, index)
3836
3837        x = torch.randn(3, 4)
3838        offset = 2
3839        index_offset = torch.tensor(offset)
3840        base = 1
3841        self.run_test(IndexSelectScalerIndexModel(base), (x, index_offset))
3842
3843    def test_take(self):
3844        class TakeModel(torch.nn.Module):
3845            def forward(self, x, y):
3846                return torch.take(x, y)
3847
3848        x = torch.randn(6, 4, 3, 3)
3849        y = torch.tensor([4, 1, 7, 15, 63])
3850        self.run_test(TakeModel(), (x, y))
3851
3852    def test_topk(self):
3853        class MyModule(torch.nn.Module):
3854            def forward(self, x):
3855                return torch.topk(x, 3)
3856
3857        x = torch.arange(1.0, 6.0, requires_grad=True)
3858        self.run_test(MyModule(), x)
3859
3860    @skipIfUnsupportedMinOpsetVersion(10)
3861    def test_topk_int32_k(self):
3862        class Model(torch.nn.Module):
3863            def forward(self, x, k):
3864                return torch.topk(x, k)
3865
3866        x = torch.arange(1.0, 6.0)
3867        k = torch.tensor(3, dtype=torch.int32)
3868        self.run_test(Model(), (x, k))
3869
3870    @skipIfUnsupportedMinOpsetVersion(11)
3871    def test_topk_smallest_unsorted(self):
3872        class MyModule(torch.nn.Module):
3873            def forward(self, x, k):
3874                # When sorted=False, order of elements in the outout tensors
3875                # are not expected to match between PyTorch and ORT
3876                topk_unsorted = torch.topk(x, k, largest=False, sorted=False)
3877                topk_sorted = torch.topk(x, k, largest=False, sorted=True)
3878                return topk_sorted, torch.sort(topk_unsorted.values).values
3879
3880        x = torch.arange(1.0, 6.0, requires_grad=True)
3881        k = torch.tensor(3)
3882        self.run_test(MyModule(), (x, k))
3883
3884    @skipIfUnsupportedMinOpsetVersion(10)
3885    def test_topk_script(self):
3886        class MyModuleDynamic(torch.jit.ScriptModule):
3887            @torch.jit.script_method
3888            def forward(self, x, k):
3889                return torch.topk(x, k)
3890
3891        x = torch.arange(1.0, 6.0, requires_grad=True)
3892        k = torch.tensor(3)
3893        self.run_test(MyModuleDynamic(), (x, k))
3894
3895    @skipScriptTest()  # Python builtin apply of FunctionMeta object is currently not supported in Torchscript.
3896    @skipIfUnsupportedMinOpsetVersion(11)  # Clip op min is an input since opset 11.
3897    def test_auto_grad(self):
3898        class MyClip(torch.autograd.Function):
3899            @staticmethod
3900            def forward(ctx, input, scalar):
3901                ctx.save_for_backward(input)
3902                return input.clamp(min=scalar)
3903
3904        class MyRelu(torch.autograd.Function):
3905            @staticmethod
3906            def forward(ctx, input):
3907                ctx.save_for_backward(input)
3908                return input.clamp(min=0)
3909
3910        def symbolic_python_op(g, *args, **kwargs):
3911            name = kwargs["name"]
3912            if name == "MyClip":
3913                return g.op("Clip", args[0], args[1])
3914            elif name == "MyRelu":
3915                return g.op("Relu", args[0])
3916            else:
3917                # TODO(justinchuby): Remove reference to internal names in symbolic_helper
3918                return torch.onnx.symbolic_helper._unimplemented(
3919                    "prim::PythonOp", "unknown node kind: " + name
3920                )
3921
3922        torch.onnx.register_custom_op_symbolic("prim::PythonOp", symbolic_python_op, 1)
3923        self.addCleanup(torch.onnx.unregister_custom_op_symbolic, "prim::PythonOp", 1)
3924
3925        class MyClipModule(torch.nn.Module):
3926            def forward(self, x, min):
3927                return MyClip.apply(x, min)
3928
3929        x = torch.randn(3, 3)
3930        min = torch.tensor([0.0])
3931        self.run_test(MyClipModule(), (x, min))
3932
3933        class MyReluModule(torch.nn.Module):
3934            def forward(self, x):
3935                return MyRelu.apply(x)
3936
3937        x = torch.randn(3, 3)
3938        self.run_test(MyReluModule(), x)
3939
3940    def test_clip_int(self):
3941        class MyClipInt(torch.nn.Module):
3942            def forward(self, x):
3943                return torch.clamp(x, 0, 1)
3944
3945        self.run_test(MyClipInt(), torch.randn(3, 3).to(torch.int64))
3946
3947    def test_relu_int(self):
3948        self.run_test(torch.nn.ReLU(), torch.randn(3, 3).to(torch.int32))
3949
3950    def test_pad_int(self):
3951        class MyPadInt(torch.nn.Module):
3952            def forward(self, x):
3953                return torch.nn.functional.pad(x, (1, 1))
3954
3955        self.run_test(MyPadInt(), torch.randn(3, 3).to(torch.int32))
3956
3957    def test_min_int(self):
3958        class MyMinInt(torch.nn.Module):
3959            def forward(self, x):
3960                return torch.min(x, x + 1)
3961
3962        self.run_test(MyMinInt(), torch.randn(3, 3).to(torch.int32))
3963
3964    def test_max_int(self):
3965        class MyMaxnInt(torch.nn.Module):
3966            def forward(self, x):
3967                return torch.max(x, x + 1)
3968
3969        self.run_test(MyMaxnInt(), torch.randn(3, 3).to(torch.int32))
3970
3971    @skipIfUnsupportedOpsetVersion([7])
3972    def test_normalize(self):
3973        class Model(torch.nn.Module):
3974            def forward(self, x):
3975                return torch.nn.functional.normalize(x)
3976
3977        x = torch.randn(3, 3)
3978        self.run_test(Model(), x)
3979
3980    def test_norm_with_dtype(self):
3981        class Model(torch.nn.Module):
3982            def forward(self, x):
3983                # TODO(bowbao): There is a slight gap in today's test infrastructure
3984                # to directly test aten ops. OpInfo `torch.norm`` in `common_methods_invocations.py`
3985                # will not decompose to below aten op.
3986                return torch.ops.aten.norm(
3987                    x, p=2, dim=[1], keepdim=True, dtype=torch.float64
3988                )
3989
3990        x = torch.randn(3, 3)
3991        self.run_test(Model(), x)
3992
3993    def test_layer_norm(self):
3994        # As layer_norm works on the last D dimension, please keep
3995        # this test case at least three dimension to prevent the
3996        # situation of axis=2 mapping to the same axis as axis=-2
3997        for elementwise_affine in (True, False):
3998            for bias in (True, False):
3999                model = torch.nn.LayerNorm(
4000                    [10, 10, 10], elementwise_affine=elementwise_affine, bias=bias
4001                )
4002                x = torch.randn(20, 5, 10, 10, 10)
4003                self.run_test(model, x)
4004
4005    def test_batchnorm1d(self):
4006        x = torch.randn(10, 10)
4007        model = torch.nn.BatchNorm1d(10, affine=True)
4008        self.run_test(model, x)
4009
4010        x = torch.randn(10, 10, 128)
4011        self.run_test(model, x)
4012
4013    def test_batchnorm1d_noaffine(self):
4014        x = torch.randn(10, 10)
4015        model = torch.nn.BatchNorm1d(10, affine=False)
4016        self.run_test(model, x)
4017
4018        x = torch.randn(10, 10, 128)
4019        self.run_test(model, x)
4020
4021    def test_batchnorm1d_norunningstats(self):
4022        x = torch.randn(10, 10)
4023        model = torch.nn.BatchNorm1d(10, track_running_stats=False)
4024        self.run_test(model, x)
4025
4026        x = torch.randn(10, 10, 128)
4027        self.run_test(model, x)
4028
4029    def test_batchnorm2d(self):
4030        x = torch.randn(10, 3, 128, 128)
4031        model = torch.nn.BatchNorm2d(3, affine=True)
4032        self.run_test(model, x)
4033
4034    def test_batchnorm2d_noaffine(self):
4035        x = torch.randn(10, 3, 128, 128)
4036        model = torch.nn.BatchNorm2d(3, affine=False)
4037        self.run_test(model, x)
4038
4039    def test_batchnorm2d_norunningstats(self):
4040        x = torch.randn(10, 3, 128, 128)
4041        model = torch.nn.BatchNorm2d(3, track_running_stats=False)
4042        self.run_test(model, x)
4043
4044    def test_batchnorm3d(self):
4045        x = torch.randn(10, 3, 64, 64, 64)
4046        model = torch.nn.BatchNorm3d(3, affine=True)
4047        self.run_test(model, x)
4048
4049    def test_batchnorm3d_noaffine(self):
4050        x = torch.randn(10, 3, 64, 64, 64)
4051        model = torch.nn.BatchNorm3d(3, affine=False)
4052        self.run_test(model, x)
4053
4054    @skipIfUnsupportedMinOpsetVersion(
4055        9
4056    )  # Because ConstantOfShape op is not supported for opset < 9
4057    def test_instancenorm1d_runningstats(self):
4058        x = torch.randn(10, 5, 128)
4059        model = torch.nn.InstanceNorm1d(5, affine=True, track_running_stats=True)
4060        self.run_test(model, x)
4061
4062        model = torch.nn.InstanceNorm1d(5, affine=False, track_running_stats=True)
4063        self.run_test(model, x)
4064
4065    def test_instancenorm1d_norunningstats(self):
4066        x = torch.randn(10, 5, 128)
4067        model = torch.nn.InstanceNorm1d(5, affine=True, track_running_stats=False)
4068        self.run_test(model, x)
4069
4070        model = torch.nn.InstanceNorm1d(5, affine=False, track_running_stats=False)
4071        self.run_test(model, x)
4072
4073    @skipIfUnsupportedMinOpsetVersion(
4074        9
4075    )  # Because ConstantOfShape op is not supported for opset < 9
4076    def test_instancenorm2d_runningstats(self):
4077        x = torch.randn(10, 3, 128, 128)
4078        model = torch.nn.InstanceNorm2d(3, affine=True, track_running_stats=True)
4079        self.run_test(model, x)
4080
4081        model = torch.nn.InstanceNorm2d(3, affine=False, track_running_stats=True)
4082        self.run_test(model, x)
4083
4084    def test_instancenorm2d_norunningstats(self):
4085        x = torch.randn(10, 3, 128, 128)
4086        model = torch.nn.InstanceNorm2d(3, affine=True, track_running_stats=False)
4087        self.run_test(model, x)
4088
4089        model = torch.nn.InstanceNorm2d(3, affine=False, track_running_stats=False)
4090        self.run_test(model, x)
4091
4092    @skipIfUnsupportedMinOpsetVersion(
4093        9
4094    )  # Because ConstantOfShape op is not supported for opset < 9
4095    def test_instancenorm3d_runningstats(self):
4096        x = torch.randn(10, 3, 64, 64, 64)
4097        model = torch.nn.InstanceNorm3d(3, affine=True, track_running_stats=True)
4098        self.run_test(model, x)
4099
4100        model = torch.nn.InstanceNorm3d(3, affine=False, track_running_stats=True)
4101        self.run_test(model, x)
4102
4103    def test_instancenorm3d_norunningstats(self):
4104        x = torch.randn(10, 3, 64, 64, 64)
4105        model = torch.nn.InstanceNorm3d(3, affine=True, track_running_stats=False)
4106        self.run_test(model, x)
4107
4108        model = torch.nn.InstanceNorm3d(3, affine=False, track_running_stats=False)
4109        self.run_test(model, x)
4110
4111    @skipIfUnsupportedMinOpsetVersion(9)
4112    def test_scatter_with_scalar(self):
4113        class ScatterModel(torch.nn.Module):
4114            def forward(self, input, indices):
4115                values = 1.0
4116                return input.scatter(1, indices, values)
4117
4118        input = torch.tensor(
4119            [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], dtype=torch.float64
4120        )
4121        indices = torch.tensor([[1, 0], [0, 1], [0, 1]], dtype=torch.int64)
4122        self.run_test(ScatterModel(), input_args=(input, indices))
4123
4124    @skipIfUnsupportedMinOpsetVersion(9)
4125    def test_scatter_with_scalar_different_types(self):
4126        # Tests the case when scalar src (updates values) type is different
4127        # from self type. Happens only with scalar src - PyTorch does not
4128        # allow this when src is a tensor.
4129        class ScatterModel(torch.nn.Module):
4130            def forward(self, input, indices):
4131                values = 1.0
4132                return input.scatter(1, indices, values)
4133
4134        input = torch.tensor(
4135            [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], dtype=torch.float32
4136        )
4137        indices = torch.tensor([[1, 0], [0, 1], [0, 1]], dtype=torch.int64)
4138        self.run_test(ScatterModel(), input_args=(input, indices))
4139
4140    @skipIfUnsupportedMinOpsetVersion(9)
4141    def test_scatter(self):
4142        class ScatterModel(torch.nn.Module):
4143            def forward(self, input, indices, values):
4144                return input.scatter(1, indices, values)
4145
4146        input = torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]])
4147        indices = torch.tensor([[1, 0], [0, 1], [0, 1]], dtype=torch.int64)
4148        values = torch.tensor([[1.0, 1.1], [2.0, 2.1], [3.0, 3.1]])
4149        self.run_test(ScatterModel(), input_args=(input, indices, values))
4150
4151        input = torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]])
4152        indices = torch.tensor([[1, 0], [0, 2], [0, 1]], dtype=torch.int64)
4153        values = torch.tensor([[1.0, 1.1], [2.0, 2.1], [3.0, 3.1]])
4154        self.run_test(ScatterModel(), (input, indices, values))
4155
4156        input = torch.zeros(3, 4, 5, 6)
4157        indices = torch.tensor([[1, 0], [0, 2], [0, 1]], dtype=torch.int64)
4158        indices = indices.view(3, 2, 1, 1).expand(3, 2, 5, 6)
4159        values = torch.arange(3 * 2 * 5 * 6, dtype=torch.float32).view(3, 2, 5, 6)
4160        self.run_test(ScatterModel(), (input, indices, values))
4161
4162        input = torch.zeros(3, 4, 2)
4163        indices = torch.tensor([[[1, 0], [0, 2]], [[1, 1], [0, 1]], [[2, 1], [2, 2]]])
4164        values = torch.arange(3 * 2 * 2, dtype=torch.float32).view(3, 2, 2)
4165        self.run_test(ScatterModel(), (input, indices, values))
4166
4167    @skipIfUnsupportedMinOpsetVersion(9)
4168    def test_scatter_add(self):
4169        class ScatterModel(torch.nn.Module):
4170            def forward(self, input, indices, values):
4171                return input.scatter_add(1, indices, values)
4172
4173        input = torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]])
4174        indices = torch.tensor([[1, 0], [0, 1], [0, 1]], dtype=torch.int64)
4175        values = torch.tensor([[1.0, 1.1], [2.0, 2.1], [3.0, 3.1]])
4176        self.run_test(ScatterModel(), input_args=(input, indices, values))
4177
4178        @torch.jit.script
4179        def scatter_sum(src: Tensor, index: Tensor):
4180            size = src.size()
4181            out = torch.zeros(size, dtype=src.dtype)
4182            return out.scatter_add_(1, index, src)
4183
4184        class ScatterModel(torch.nn.Module):
4185            def forward(self, src, index):
4186                return scatter_sum(src, index)
4187
4188        src = torch.rand(3, 2)
4189        index = torch.tensor([[0, 1], [0, 1], [0, 1]], dtype=torch.int64)
4190        self.run_test(ScatterModel(), (src, index))
4191
4192    @skipIfUnsupportedMinOpsetVersion(16)
4193    def test_scatter_add_index_not_unique(self):
4194        class ScatterModel(torch.nn.Module):
4195            def forward(self, input, indices, values):
4196                return input.scatter_add(1, indices, values)
4197
4198        input = torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]])
4199        indices = torch.tensor([[0, 0], [1, 1], [2, 2]], dtype=torch.int64)
4200        values = torch.tensor([[1.0, 1.1], [2.0, 2.1], [3.0, 3.1]])
4201        self.run_test(ScatterModel(), input_args=(input, indices, values))
4202
4203        @torch.jit.script
4204        def scatter_sum(src: Tensor, index: Tensor):
4205            size = src.size()
4206            out = torch.zeros(size, dtype=src.dtype)
4207            return out.scatter_add_(1, index, src)
4208
4209        class ScatterModel(torch.nn.Module):
4210            def forward(self, src, index):
4211                return scatter_sum(src, index)
4212
4213        src = torch.rand(3, 2)
4214        index = torch.tensor([[0, 0], [1, 1], [0, 1]], dtype=torch.int64)
4215        self.run_test(ScatterModel(), (src, index))
4216
4217    @skipIfUnsupportedMinOpsetVersion(16)
4218    def test_scatter_add_different_size_index_src(self):
4219        class ScatterModel(torch.nn.Module):
4220            def forward(self, input, indices, src):
4221                return input.scatter_add(0, indices, src)
4222
4223        src = torch.ones((2, 5))
4224        input = torch.zeros(3, 5, dtype=src.dtype)
4225        indices = torch.tensor([[0, 1, 2, 0, 0]])
4226        self.run_test(ScatterModel(), input_args=(input, indices, src))
4227
4228    @common_utils.parametrize(
4229        "src, indices",
4230        [
4231            common_utils.subtest(
4232                [torch.ones((1, 5)), torch.tensor([[0, 1, 2, 0, 0]])],
4233                name="src_indices_dynamic_combination1",
4234            ),
4235            common_utils.subtest(
4236                [torch.ones((2, 5)), torch.tensor([[0, 1, 2, 0, 0], [1, 0, 2, 1, 2]])],
4237                name="src_indices_dynamic_combination2",
4238            ),
4239            common_utils.subtest(
4240                [torch.ones((3, 5)), torch.tensor([[0, 1, 2, 0, 0], [1, 0, 2, 1, 2]])],
4241                name="src_indices_dynamic_combination3",
4242            ),
4243            common_utils.subtest(
4244                [torch.ones((3, 5)), torch.tensor([[0, 1, 2, 0], [1, 0, 2, 1]])],
4245                name="src_indices_dynamic_combination4",
4246            ),
4247        ],
4248    )
4249    @skipIfUnsupportedMinOpsetVersion(16)
4250    def test_scatter_add_dynamic_index(self, src, indices):
4251        class ScatterModel(torch.nn.Module):
4252            def forward(self, input, indices, src):
4253                return input.scatter_add(0, indices, src)
4254
4255        input = torch.zeros(3, 5, dtype=src.dtype)
4256        self.run_test(
4257            ScatterModel(),
4258            input_args=(input, indices, src),
4259            input_names=["input", "indices", "src"],
4260            dynamic_axes={"indices": {0: "a", 1: "b"}, "src": {0: "c", 1: "d"}},
4261        )
4262
4263    @skipIfUnsupportedMinOpsetVersion(16)
4264    def test_scatter_reduce(self):
4265        class Model(torch.nn.Module):
4266            def __init__(self) -> None:
4267                super().__init__()
4268
4269            def forward(self, x, index, input):
4270                y_max = input.scatter_reduce(0, index, x, reduce="amax")
4271                y_sum = input.scatter_reduce(0, index, x, reduce="sum")
4272                y_min = input.scatter_reduce(0, index, x, reduce="amin")
4273                y_mul = input.scatter_reduce(0, index, x, reduce="prod")
4274                return y_max, y_sum, y_min, y_mul
4275
4276        model = Model()
4277        model.eval()
4278
4279        src = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
4280        index = torch.tensor([0, 1, 0, 1, 2, 1])
4281        input = torch.tensor([1.0, 2.0, 3.0, 8.0])
4282
4283        self.run_test(model, (src, index, input))
4284
4285    @skipIfUnsupportedMinOpsetVersion(16)
4286    def test_scatter_reduce_self_rank_zero(self):
4287        class Model(torch.nn.Module):
4288            def __init__(self) -> None:
4289                super().__init__()
4290
4291            def forward(self, x, index, input):
4292                y_max = input.scatter_reduce(0, index, x, reduce="amax")
4293                y_sum = input.scatter_reduce(0, index, x, reduce="sum")
4294                y_min = input.scatter_reduce(0, index, x, reduce="amin")
4295                y_mul = input.scatter_reduce(0, index, x, reduce="prod")
4296                return y_max, y_sum, y_min, y_mul
4297
4298        model = Model()
4299        model.eval()
4300
4301        empty_tensor = torch.tensor([])
4302        empty_idx = torch.tensor([], dtype=torch.int64)
4303
4304        self.run_test(model, (empty_tensor, empty_idx, empty_tensor))
4305
4306    @skipIfUnsupportedMinOpsetVersion(9)
4307    def test_bucketize(self):
4308        class BucketModel(torch.nn.Module):
4309            def forward(self, input, boundaries):
4310                return torch.bucketize(input, boundaries), torch.bucketize(
4311                    input, boundaries, right=True
4312                )
4313
4314        input = torch.tensor([[2, 5, 10], [6, 8, 3]])
4315        boundaries = torch.tensor([1, 5, 7, 8, 10])
4316        self.run_test(BucketModel(), (input, boundaries))
4317
4318    @skipIfUnsupportedMinOpsetVersion(9)
4319    def test_one_hot(self):
4320        class OneHot(torch.nn.Module):
4321            def __init__(self, num_classes):
4322                super().__init__()
4323                self.num_classes = num_classes
4324
4325            def forward(self, x):
4326                return torch.nn.functional.one_hot(x, self.num_classes)
4327
4328        x = torch.arange(10)
4329        self.run_test(OneHot(15), (x))
4330
4331        class OneHot(torch.nn.Module):
4332            def forward(self, x, num_classes):
4333                num_classes = num_classes.to(torch.int32)
4334                return torch.nn.functional.one_hot(x, num_classes[0])
4335
4336        x = torch.arange(10)
4337        num_classes = 15 * torch.ones(1)
4338        self.run_test(OneHot(), (x, num_classes))
4339
4340    @skipIfUnsupportedMinOpsetVersion(9)
4341    def test_gather(self):
4342        class GatherModel(torch.nn.Module):
4343            def forward(self, input, indices):
4344                return input.gather(1, indices)
4345
4346        input = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]])
4347        indices = torch.tensor([[1, 0], [0, 1], [0, 1]], dtype=torch.int64)
4348        self.run_test(GatherModel(), input_args=(input, indices))
4349
4350    @skipScriptTest()  # Scripting error: Cannot instantiate nn module
4351    def test_gather_constant_fold(self):
4352        class GatherModule(torch.nn.Module):
4353            def __init__(self) -> None:
4354                super().__init__()
4355                self.weight = torch.nn.Buffer(torch.ones(5))
4356                # torch.nn.Embedding is converted to ONNX::Gather.
4357                # Constant folding will be triggerred for constant inputs.
4358                # This pattern is common for constant mask inputs in transformer models.
4359                self.embed = torch.nn.Embedding(8, 3)
4360
4361            def forward(self, x):
4362                # shape is of rank 0
4363                shape = self.weight.shape[0]
4364                m = 5 - shape
4365                y = torch.ones(1, 4, dtype=torch.long)
4366                return x.clamp(min=m), self.embed(y)
4367
4368        x = torch.randn(1)
4369        self.run_test(GatherModule(), (x,))
4370
4371        class GatherModule(torch.nn.Module):
4372            def __init__(self) -> None:
4373                super().__init__()
4374                self.weight = torch.nn.Buffer(torch.ones(2))
4375
4376            def forward(self, x):
4377                # shape is of rank 0
4378                shape = self.weight.shape[0]
4379                pad = [1, shape, shape, shape]
4380                zero_pad = torch.nn.ZeroPad2d(pad)
4381                return zero_pad(x)
4382
4383        x = torch.randn(1, 3, 2)
4384        self.run_test(GatherModule(), (x,))
4385
4386        class GatherModule(torch.nn.Module):
4387            def __init__(self) -> None:
4388                super().__init__()
4389                self.rb = torch.nn.Buffer(torch.randn(1, 1, 3, 1, 1))
4390
4391            def forward(self, x):
4392                x += self.rb[0]
4393                return x
4394
4395        x = torch.randn(1, 3, 224, 224)
4396        self.run_test(
4397            GatherModule(),
4398            (x,),
4399            dynamic_axes={
4400                "input": {0: "batch", 2: "height", 3: "width"},
4401                "output": {0: "batch", 1: "class", 2: "height", 3: "width"},
4402            },
4403            input_names=["input"],
4404            output_names=["output"],
4405        )
4406
4407    @skipIfUnsupportedOpsetVersion([13])
4408    @skipIfUnsupportedMinOpsetVersion(9)
4409    def test_expand(self):
4410        class ExpandModel(torch.nn.Module):
4411            def forward(self, input):
4412                return input.expand(2, 3, -1)
4413
4414        input = torch.randn(2, 1, 4)
4415        self.run_test(ExpandModel(), input_args=(input))
4416
4417        class ExpandInferDimModel(torch.nn.Module):
4418            def forward(self, input):
4419                return input.expand(-1, input.size(0))
4420
4421        input = torch.randn(3, 1)
4422        self.run_test(ExpandInferDimModel(), input_args=(input))
4423
4424        class ExpandTensorSizeModel(torch.nn.Module):
4425            def forward(self, input, size):
4426                return input.expand(size)
4427
4428        input = torch.randn(
4429            3,
4430        )
4431        size = torch.tensor(-1)
4432        self.run_test(ExpandTensorSizeModel(), input_args=(input, size))
4433
4434    @skipIfUnsupportedMinOpsetVersion(11)  # index_put is supported in opsets >= 11
4435    def test_dynamic_expand_as(self):
4436        class Model(torch.nn.Module):
4437            def forward(self, x):
4438                x[:, x.size(0) :] = 0
4439                return x
4440
4441        x = torch.ones(2, 5)
4442        x2 = torch.randn(3, 4)
4443        self.run_test(
4444            Model(),
4445            (x,),
4446            input_names=["x"],
4447            dynamic_axes={"x": [0, 1]},
4448            additional_test_inputs=[x2],
4449        )
4450
4451        class Model(torch.nn.Module):
4452            def forward(self, x):
4453                x[:, x.size(0) :] = torch.tensor([1, 2, 3])
4454                return x
4455
4456        x = torch.ones(2, 5, 3)
4457        x2 = torch.randn(3, 4, 3)
4458        self.run_test(
4459            Model(),
4460            (x,),
4461            input_names=["x"],
4462            dynamic_axes={"x": [0, 1, 2]},
4463            additional_test_inputs=[x2],
4464        )
4465
4466        class Model(torch.nn.Module):
4467            def forward(self, x):
4468                aa = torch.tensor([[0], [1], [2]])
4469                return aa.expand_as(x)
4470
4471        x = torch.ones(3, 2)
4472        x2 = torch.randn(3, 5)
4473        self.run_test(
4474            Model(),
4475            (x,),
4476            input_names=["x"],
4477            dynamic_axes={"x": [0, 1]},
4478            additional_test_inputs=[x2],
4479        )
4480
4481    def test_multinomial(self):
4482        class Multinomial(torch.nn.Module):
4483            def forward(self, weight):
4484                return torch.multinomial(weight, 3, replacement=True)
4485
4486        class MultinomialNoReplacement(torch.nn.Module):
4487            def forward(self, weight):
4488                return torch.multinomial(weight, 1)
4489
4490        weight = torch.tensor([[0, 10, 0, 0], [0, 0, 100, 0]], dtype=torch.float)
4491        self.run_test(Multinomial(), (weight,))
4492        self.run_test(MultinomialNoReplacement(), (weight,))
4493
4494    def _test_reduced_ops(self, op):
4495        class ReducedOpModule(torch.nn.Module):
4496            def forward(self, input):
4497                return op(input, dim=-1)
4498
4499        if op != torch.mean:  # torch.mean only supports float types
4500            x = torch.randint(10, (4, 4), dtype=torch.uint8)
4501            self.run_test(ReducedOpModule(), x)
4502
4503            x = torch.randint(10, (4, 4), dtype=torch.int8)
4504            self.run_test(ReducedOpModule(), x)
4505
4506            x = torch.randint(10, (4, 4), dtype=torch.int16)
4507            self.run_test(ReducedOpModule(), x)
4508
4509            x = torch.randint(10, (4, 4), dtype=torch.int32)
4510            self.run_test(ReducedOpModule(), x)
4511
4512            x = torch.randint(10, (4, 4), dtype=torch.int64)
4513            self.run_test(ReducedOpModule(), x)
4514
4515        # torch.mean only supports float types
4516        # ORT does not support double ReduceProd for double
4517        if op != torch.prod and op != torch.mean:
4518            x = torch.randn(4, 5, dtype=torch.double)
4519            self.run_test(ReducedOpModule(), x)
4520
4521        if op != torch.prod:  # torch.prod not implemented for Half
4522            x = torch.randn(4, 4, dtype=torch.half)
4523            self.run_test(ReducedOpModule(), x)
4524
4525        x = torch.randn(4, 5, dtype=torch.float)
4526        self.run_test(ReducedOpModule(), x)
4527
4528    def test_reduced_sum(self):
4529        return self._test_reduced_ops(op=torch.sum)
4530
4531    def test_reduced_mean(self):
4532        return self._test_reduced_ops(op=torch.mean)
4533
4534    def test_reduced_prod(self):
4535        return self._test_reduced_ops(op=torch.prod)
4536
4537    def test_reduced_sum_dtypes(self):
4538        class NoDimModel(torch.nn.Module):
4539            def forward(self, input):
4540                return input.sum(dtype=torch.float)
4541
4542        class DimModel(torch.nn.Module):
4543            def forward(self, input):
4544                return input.sum(dim=-1, dtype=torch.float)
4545
4546        input = torch.randn((4, 4), dtype=torch.half)
4547        self.run_test(NoDimModel(), input)
4548        self.run_test(DimModel(), input)
4549
4550    def test_reduced_min_max(self):
4551        class ReducedMinMaxModule(torch.nn.Module):
4552            def forward(self, input):
4553                return torch.min(input, dim=-1)[0], torch.max(input, dim=0)[0]
4554
4555        x = torch.randint(10, (4, 4), dtype=torch.int32)
4556        self.run_test(ReducedMinMaxModule(), x)
4557
4558        x = torch.randint(10, (4, 4), dtype=torch.int64)
4559        self.run_test(ReducedMinMaxModule(), x)
4560
4561        x = torch.randn(4, 5, dtype=torch.float)
4562        self.run_test(ReducedMinMaxModule(), x)
4563
4564    def test_reduce_log_sum_exp(self):
4565        class ReduceLogSumExpModel(torch.nn.Module):
4566            def forward(self, input):
4567                a = torch.logsumexp(input, dim=0)
4568                b = torch.logsumexp(input, dim=(0, 1))
4569                return a + b
4570
4571        x = torch.randn(4, 4, requires_grad=True)
4572        self.run_test(ReduceLogSumExpModel(), x)
4573
4574    def test_softmax(self):
4575        for i in range(-4, 3):
4576            model = torch.nn.Softmax(dim=i)
4577            input = torch.randn(3, 4, 5, 6)
4578            self.run_test(model, input)
4579
4580            class SoftmaxUnknownRank(torch.nn.Module):
4581                def __init__(self, i):
4582                    super().__init__()
4583                    self.softmax = torch.nn.Softmax(dim=i)
4584
4585                def forward(self, x):
4586                    return self.softmax(x.reshape(3, 4, 5, 6))
4587
4588            model = torch.jit.script(SoftmaxUnknownRank(i))
4589            self.run_test(model, input)
4590
4591    def test_softmax_large_values(self):
4592        input = torch.tensor(
4593            [[-1e12, -1e12, -1e12], [1e12, 0.0, -5.0], [3.0, 4.0, 5.0]]
4594        )
4595        for i in range(-2, 1):
4596            model = torch.nn.Softmax(dim=i)
4597            self.run_test(model, input)
4598
4599            class SoftmaxUnknownRank(torch.nn.Module):
4600                def __init__(self, i):
4601                    super().__init__()
4602                    self.softmax = torch.nn.Softmax(dim=i)
4603
4604                def forward(self, x):
4605                    return self.softmax(x.reshape(3, 3))
4606
4607            model = torch.jit.script(SoftmaxUnknownRank(i))
4608            self.run_test(model, input)
4609
4610    def test_logsoftmax(self):
4611        for i in range(7)[2:]:
4612            model = torch.nn.LogSoftmax(dim=i - 1)
4613            dims = [2] * (i - 2) + [3, 4]
4614            input = torch.ones(*dims, requires_grad=True)
4615            self.run_test(model, input)
4616
4617    def test_logsoftmax_dim(self):
4618        for i in range(-4, 3):
4619            model = torch.nn.LogSoftmax(dim=i)
4620            input = torch.randn(3, 4, 5, 6)
4621            self.run_test(model, input)
4622
4623    def test_logsoftmax_dtype(self):
4624        class Model(torch.nn.Module):
4625            def forward(self, x):
4626                return torch.nn.functional.log_softmax(x, dim=1, dtype=torch.float64)
4627
4628        x = torch.randn(3, 4, 5, requires_grad=True)
4629        self.run_test(Model(), x)
4630
4631    def test_softplus(self):
4632        class BetaOneModel(torch.nn.Module):
4633            def forward(self, x):
4634                return torch.nn.functional.softplus(x)
4635
4636        x = torch.randn(3, 4, 5, requires_grad=True)
4637        self.run_test(BetaOneModel(), x)
4638
4639        class BetaModel(torch.nn.Module):
4640            def forward(self, x):
4641                return torch.nn.functional.softplus(x, beta=2)
4642
4643        x = torch.randn(3, 4, 5, requires_grad=True)
4644        self.run_test(BetaModel(), x)
4645
4646        class BetaFloatModel(torch.nn.Module):
4647            def forward(self, x):
4648                return torch.nn.functional.softplus(x, beta=1.7)
4649
4650        x = torch.randn(3, 4, 5, requires_grad=True)
4651        self.run_test(BetaFloatModel(), x)
4652
4653    @skipIfUnsupportedMinOpsetVersion(9)
4654    def test_lstm_no_hidden(self):
4655        class LSTMModel(torch.nn.Module):
4656            def __init__(self) -> None:
4657                super().__init__()
4658                self.rnn = torch.nn.LSTM(input_size=16, hidden_size=16)
4659
4660            def forward(self, x):
4661                return self.rnn(x)
4662
4663        input = torch.randn((10, 16, 16))
4664        self.run_test(LSTMModel(), (input,))
4665
4666    @skipIfUnsupportedMinOpsetVersion(9)
4667    def test_lstm_proj_no_hidden(self):
4668        class LSTMModel(torch.nn.Module):
4669            def __init__(self) -> None:
4670                super().__init__()
4671                self.rnn = torch.nn.LSTM(input_size=16, hidden_size=16, proj_size=8)
4672
4673            def forward(self, x):
4674                return self.rnn(x)
4675
4676        input = torch.randn((10, 16, 16))
4677        with self.assertRaises(RuntimeError):
4678            self.run_test(LSTMModel(), (input,))
4679
4680    @skipIfUnsupportedMinOpsetVersion(9)
4681    def test_lstm(self):
4682        class LSTMModel(torch.nn.Module):
4683            def __init__(self) -> None:
4684                super().__init__()
4685                self.rnn = torch.nn.LSTM(
4686                    RNN_INPUT_SIZE, RNN_HIDDEN_SIZE, 1, bidirectional=False
4687                )
4688
4689            def forward(self, x, h0, c0):
4690                return self.rnn(x, (h0, c0))
4691
4692        input = torch.randn(RNN_SEQUENCE_LENGTH, BATCH_SIZE, RNN_INPUT_SIZE)
4693        h0 = torch.randn(1, BATCH_SIZE, RNN_HIDDEN_SIZE)
4694        c0 = torch.randn(1, BATCH_SIZE, RNN_HIDDEN_SIZE)
4695        self.run_test(LSTMModel(), (input, h0, c0))
4696
4697    @skipIfUnsupportedMinOpsetVersion(9)
4698    def test_lstm_cell(self):
4699        class LSTMCellModel(torch.nn.Module):
4700            def __init__(self, bias):
4701                super().__init__()
4702                self.lstm_cell = torch.nn.LSTMCell(
4703                    RNN_INPUT_SIZE, RNN_HIDDEN_SIZE, bias=bias
4704                )
4705
4706            def forward(self, x, h0, c0):
4707                return self.lstm_cell(x, (h0, c0))
4708
4709        input = torch.randn(BATCH_SIZE, RNN_INPUT_SIZE)
4710        h0 = torch.randn(BATCH_SIZE, RNN_HIDDEN_SIZE)
4711        c0 = torch.randn(BATCH_SIZE, RNN_HIDDEN_SIZE)
4712        for bias in [True, False]:
4713            self.run_test(LSTMCellModel(bias), (input, h0, c0))
4714
4715    @skipIfUnsupportedMinOpsetVersion(9)
4716    def test_lstm_default_init_state(self):
4717        class LSTMModel(torch.nn.Module):
4718            def __init__(self) -> None:
4719                super().__init__()
4720                self.rnn = torch.nn.LSTM(
4721                    RNN_INPUT_SIZE, RNN_HIDDEN_SIZE, 1, bidirectional=False
4722                )
4723
4724            def forward(self, x):
4725                return self.rnn(x)
4726
4727        input = torch.randn(RNN_SEQUENCE_LENGTH, BATCH_SIZE, RNN_INPUT_SIZE)
4728        self.run_test(LSTMModel(), input)
4729
4730    @skipIfUnsupportedMinOpsetVersion(9)
4731    def test_lstm_fixed_batch_size(self):
4732        class LSTMModel(torch.nn.Module):
4733            def __init__(self) -> None:
4734                super().__init__()
4735                self.lstm = torch.nn.LSTM(
4736                    RNN_INPUT_SIZE, RNN_HIDDEN_SIZE, 1, bidirectional=False
4737                )
4738                self.RNN_HIDDEN_SIZE = RNN_HIDDEN_SIZE
4739
4740            def forward(self, input):
4741                batch_size = input.size()[1]
4742                h0 = torch.ones([1, batch_size, self.RNN_HIDDEN_SIZE])
4743                c0 = torch.ones([1, batch_size, self.RNN_HIDDEN_SIZE])
4744                return self.lstm(input, (h0, c0))
4745
4746        input = torch.randn(RNN_SEQUENCE_LENGTH, BATCH_SIZE, RNN_INPUT_SIZE)
4747        # verify with different input of same batch size
4748        input2 = torch.randn(RNN_SEQUENCE_LENGTH, BATCH_SIZE, RNN_INPUT_SIZE)
4749        self.run_test(
4750            LSTMModel(), input, fixed_batch_size=True, additional_test_inputs=[input2]
4751        )
4752
4753    @skipIfUnsupportedMinOpsetVersion(9)
4754    def test_lstm_post_fix_init_state(self):
4755        class LSTMModel(torch.nn.Module):
4756            def __init__(self) -> None:
4757                super().__init__()
4758                self.lstm = torch.nn.LSTM(
4759                    RNN_INPUT_SIZE, RNN_HIDDEN_SIZE, 1, bidirectional=False
4760                )
4761                self.RNN_HIDDEN_SIZE = RNN_HIDDEN_SIZE
4762
4763            def forward(self, input):
4764                batch_size = input.size()[1]
4765                h0 = torch.ones([1, batch_size, self.RNN_HIDDEN_SIZE])
4766                c0 = torch.ones([1, batch_size, self.RNN_HIDDEN_SIZE])
4767                return self.lstm(input, (h0, c0))
4768
4769        model = LSTMModel()
4770        input = torch.randn(RNN_SEQUENCE_LENGTH, 1, RNN_INPUT_SIZE)
4771        # verify with different input of different batch size
4772        input2 = torch.randn(RNN_SEQUENCE_LENGTH, BATCH_SIZE, RNN_INPUT_SIZE)
4773        self.run_test(
4774            model,
4775            input,
4776            input_names=["input.1"],
4777            dynamic_axes={"input.1": {0: "seq", 1: "batch"}},
4778            additional_test_inputs=[input2],
4779        )
4780
4781    def test_lstm_constant_folding(self):
4782        class LstmNet(torch.nn.Module):
4783            def __init__(self, input_size, hidden_size, num_layers, bidirectional):
4784                super().__init__()
4785                self.lstm = torch.nn.LSTM(
4786                    input_size, hidden_size, num_layers, bidirectional=bidirectional
4787                )
4788
4789            def forward(self, input, initial_state: Tuple[Tensor, Tensor]):
4790                return self.lstm(input, initial_state)
4791
4792        def get_LstmNet_model_and_inputs(
4793            input_size, hidden_size, num_layers, batch_size, seq_len, bidirectional
4794        ):
4795            num_directions = 2 if bidirectional else 1
4796            model = LstmNet(input_size, hidden_size, num_layers, bidirectional)
4797            input = torch.randn(seq_len, batch_size, input_size)
4798            h0 = torch.randn(num_layers * num_directions, batch_size, hidden_size)
4799            c0 = torch.randn(num_layers * num_directions, batch_size, hidden_size)
4800            return model, (input, (h0, c0))
4801
4802        batch_size1 = 3
4803        model1, input1 = get_LstmNet_model_and_inputs(7, 3, 2, batch_size1, 5, True)
4804        self.run_test(model1, input1, do_constant_folding=True)
4805
4806        batch_size2 = 4
4807        model2, input2 = get_LstmNet_model_and_inputs(5, 4, 3, batch_size2, 7, False)
4808        self.run_test(model2, input2, do_constant_folding=True)
4809
4810    @skipIfUnsupportedMinOpsetVersion(9)
4811    def test_lstm_no_bias(self):
4812        class LstmNet(torch.nn.Module):
4813            def __init__(self, num_layers, bidirectional):
4814                super().__init__()
4815                self.lstm = torch.nn.LSTM(
4816                    RNN_INPUT_SIZE,
4817                    RNN_HIDDEN_SIZE,
4818                    num_layers,
4819                    bias=False,
4820                    bidirectional=bidirectional,
4821                )
4822
4823            def forward(self, input, initial_state: Tuple[Tensor, Tensor]):
4824                return self.lstm(input, initial_state)
4825
4826        def get_LstmNet_model_and_inputs(num_layers, bidirectional):
4827            input = torch.randn(RNN_SEQUENCE_LENGTH, BATCH_SIZE, RNN_INPUT_SIZE)
4828            num_directions = 2 if bidirectional else 1
4829            model = LstmNet(num_layers, bidirectional)
4830            h0 = torch.randn(num_layers * num_directions, BATCH_SIZE, RNN_HIDDEN_SIZE)
4831            c0 = torch.randn(num_layers * num_directions, BATCH_SIZE, RNN_HIDDEN_SIZE)
4832            return model, (input, (h0, c0))
4833
4834        num_layers = [1, 1, 2, 3]
4835        bidirectional = [True, False, True, False]
4836        models_and_inputs = [
4837            get_LstmNet_model_and_inputs(n, b)
4838            for n, b in zip(num_layers, bidirectional)
4839        ]
4840        for model, input in models_and_inputs:
4841            self.run_test(model, input)
4842
4843    @skipIfUnsupportedMinOpsetVersion(9)
4844    def test_lstm_sequence(self):
4845        class LstmNet(torch.nn.Module):
4846            def __init__(self) -> None:
4847                super().__init__()
4848                self.rnn1 = torch.nn.LSTM(8, 8, bidirectional=True, batch_first=True)
4849                self.linear1 = torch.nn.Linear(8 * 2, 8)
4850                self.rnn2 = torch.nn.LSTM(8, 8, bidirectional=True, batch_first=True)
4851                self.linear2 = torch.nn.Linear(8 * 2, 8)
4852
4853            def forward(self, input):
4854                rnn_output1, _ = self.rnn1(input)
4855                linear_output1 = self.linear1(rnn_output1)
4856                rnn_output2, _ = self.rnn2(linear_output1)
4857                linear_output2 = self.linear2(rnn_output2)
4858                return linear_output2
4859
4860        input = torch.zeros((1, 100, 8), dtype=torch.float32)
4861        self.run_test(
4862            LstmNet(),
4863            input,
4864            input_names=["input"],
4865            output_names=["output"],
4866            dynamic_axes={
4867                "input": {0: "batch_size", 1: "w", 2: "h"},
4868                "output": {0: "batch_size", 1: "w", 2: "h"},
4869            },
4870        )
4871
4872    @skipScriptTest()
4873    def test_rnn_no_bias(self):
4874        def make_model(layers, packed_sequence):
4875            batch_first = True if packed_sequence == 2 else False
4876            model = torch.nn.RNN(
4877                RNN_INPUT_SIZE,
4878                RNN_HIDDEN_SIZE,
4879                layers,
4880                bidirectional=False,
4881                batch_first=batch_first,
4882                bias=False,
4883            )
4884
4885            if packed_sequence == 1:
4886                model = rnn_model_with_packed_sequence.RnnModelWithPackedSequence(
4887                    model, False
4888                )
4889            if packed_sequence == 2:
4890                model = rnn_model_with_packed_sequence.RnnModelWithPackedSequence(
4891                    model, True
4892                )
4893            return model
4894
4895        def make_input(batch_size, layers, packed_sequence):
4896            batch_first = True if packed_sequence == 2 else False
4897            seq_lengths = np.random.randint(1, RNN_SEQUENCE_LENGTH + 1, size=batch_size)
4898            seq_lengths = sorted(map(int, seq_lengths), reverse=True)
4899            inputs = [torch.randn(l, RNN_INPUT_SIZE) for l in seq_lengths]
4900            inputs = rnn_utils.pad_sequence(inputs, batch_first=batch_first)
4901            inputs = [inputs]
4902
4903            h0 = torch.randn(layers, batch_size, RNN_HIDDEN_SIZE)
4904            inputs.append(h0)
4905            if packed_sequence != 0:
4906                inputs.append(torch.IntTensor(seq_lengths))
4907            if len(inputs) == 1:
4908                input = inputs[0]
4909            else:
4910                input = tuple(inputs)
4911            return input
4912
4913        layers = [1, 3, 1, 3, 1, 3]
4914        packed_sequence = [0, 0, 1, 1, 2, 2]
4915        models = [make_model(l, p) for l, p in zip(layers, packed_sequence)]
4916        inputs = [
4917            make_input(RNN_BATCH_SIZE, l, p) for l, p in zip(layers, packed_sequence)
4918        ]
4919
4920        for model, input in zip(models, inputs):
4921            self.run_test(model, input)
4922
4923    def test_gru_no_bias(self):
4924        class GruNet(torch.nn.Module):
4925            def __init__(self, input_size, hidden_size, num_layers, bidirectional):
4926                super().__init__()
4927                self.mygru = torch.nn.GRU(
4928                    input_size,
4929                    hidden_size,
4930                    num_layers,
4931                    bidirectional=bidirectional,
4932                    bias=False,
4933                )
4934
4935            def forward(self, input, initial_state):
4936                out = self.mygru(input, initial_state)
4937                return out
4938
4939        def get_GruNet_model_and_inputs(
4940            input_size, hidden_size, num_layers, batch_size, seq_len, bidirectional
4941        ):
4942            num_directions = 2 if bidirectional else 1
4943            model = GruNet(input_size, hidden_size, num_layers, bidirectional)
4944            input = torch.randn(seq_len, batch_size, input_size)
4945            h0 = torch.randn(num_layers * num_directions, batch_size, hidden_size)
4946            return model, (input, h0)
4947
4948        input_size = [7, 5]
4949        hidden_size = [3, 4]
4950        num_layers = [2, 3]
4951        batch_size = [3, 4]
4952        seq_len = [5, 7]
4953        bidirectional = [True, False]
4954        models_and_inputs = [
4955            get_GruNet_model_and_inputs(i, h, n, b, s, bi)
4956            for i, h, n, b, s, bi in zip(
4957                input_size, hidden_size, num_layers, batch_size, seq_len, bidirectional
4958            )
4959        ]
4960        for model, input in models_and_inputs:
4961            self.run_test(model, input, do_constant_folding=True)
4962
4963    def test_gru_constant_folding(self):
4964        class GruNet(torch.nn.Module):
4965            def __init__(self, input_size, hidden_size, num_layers, bidirectional):
4966                super().__init__()
4967                self.mygru = torch.nn.GRU(
4968                    input_size, hidden_size, num_layers, bidirectional=bidirectional
4969                )
4970
4971            def forward(self, input, initial_state):
4972                out = self.mygru(input, initial_state)
4973                return out
4974
4975        def get_GruNet_model_and_inputs(
4976            input_size, hidden_size, num_layers, batch_size, seq_len, bidirectional
4977        ):
4978            num_directions = 2 if bidirectional else 1
4979            model = GruNet(input_size, hidden_size, num_layers, bidirectional)
4980            input = torch.randn(seq_len, batch_size, input_size)
4981            h0 = torch.randn(num_layers * num_directions, batch_size, hidden_size)
4982            return model, (input, h0)
4983
4984        batch_size1 = 3
4985        model1, input1 = get_GruNet_model_and_inputs(7, 3, 2, batch_size1, 5, True)
4986        self.run_test(model1, input1, do_constant_folding=True)
4987
4988        batch_size2 = 4
4989        model2, input2 = get_GruNet_model_and_inputs(5, 4, 3, batch_size2, 7, False)
4990        self.run_test(model2, input2, do_constant_folding=True)
4991
4992    @skipIfUnsupportedMinOpsetVersion(8)
4993    def test_max_tensors(self):
4994        class MaxModel(torch.nn.Module):
4995            def forward(self, input, other):
4996                return torch.max(input, other)
4997
4998        model = MaxModel()
4999        x = torch.randn(4, 4, requires_grad=True)
5000        y = torch.randn(4, 1, requires_grad=True)
5001        self.run_test(model, (x, y))
5002
5003    def test_amax_amin(self):
5004        class Model(torch.nn.Module):
5005            def forward(self, x):
5006                return torch.amax(x, dim=0, keepdim=True), torch.amin(
5007                    x, dim=[0, 1], keepdim=False
5008                )
5009
5010        model = Model()
5011        x = torch.randn(4, 4)
5012        self.run_test(model, x)
5013
5014    def test_aminmax(self):
5015        class Model(torch.nn.Module):
5016            def forward(self, x):
5017                return torch.aminmax(x, dim=1, keepdim=True), torch.aminmax(
5018                    x, keepdim=False
5019                )
5020
5021        model = Model()
5022        x = torch.randn(3, 4)
5023        self.run_test(model, x)
5024
5025    @skipIfUnsupportedMinOpsetVersion(9)
5026    def test_arange_end(self):
5027        class ArangeScript(torch.jit.ScriptModule):
5028            @torch.jit.script_method
5029            def forward(self, a):
5030                return torch.arange(a.size(0), dtype=torch.float).view(-1, 1) + a
5031
5032        x = torch.randn(3, 4, requires_grad=True)
5033        outputs = ArangeScript()(x)
5034        self.run_test(ArangeScript(), x)
5035
5036        class ArangeModel(torch.nn.Module):
5037            def forward(self, a):
5038                return torch.arange(a.size(0), dtype=torch.float).view(-1, 1) + a
5039
5040        self.run_test(ArangeModel(), x)
5041
5042    @skipIfUnsupportedMinOpsetVersion(11)
5043    def test_arange_end_notype(self):
5044        class ArangeScript(torch.jit.ScriptModule):
5045            @torch.jit.script_method
5046            def forward(self, a):
5047                return torch.arange(a.size(0))
5048
5049        x = torch.randn(3, 4, requires_grad=True)
5050        outputs = ArangeScript()(x)
5051        self.run_test(ArangeScript(), x, input_names=["x"], dynamic_axes={"x": [0, 1]})
5052        self.run_test(ArangeScript(), x, remained_onnx_input_idx=[])
5053
5054        class ArangeModel(torch.nn.Module):
5055            def forward(self, a):
5056                return torch.arange(a.size(0))
5057
5058        self.run_test(ArangeModel(), x, input_names=["x"], dynamic_axes={"x": [0, 1]})
5059        self.run_test(ArangeModel(), x, remained_onnx_input_idx=[])
5060
5061    @skipIfUnsupportedMinOpsetVersion(9)
5062    def test_arange_start_end(self):
5063        class ArangeScript(torch.jit.ScriptModule):
5064            @torch.jit.script_method
5065            def forward(self, a):
5066                return torch.arange(2, a.size(0) + 2, dtype=torch.float).view(-1, 1) + a
5067
5068        x = torch.randn(3, 4, requires_grad=True)
5069        self.run_test(ArangeScript(), x)
5070
5071        class ArangeModel(torch.nn.Module):
5072            def forward(self, a):
5073                return torch.arange(2, a.size(0) + 2, dtype=torch.float).view(-1, 1) + a
5074
5075        self.run_test(ArangeModel(), x)
5076
5077    @skipIfUnsupportedMinOpsetVersion(11)
5078    def test_arange_start_end_notype(self):
5079        class ArangeScript(torch.jit.ScriptModule):
5080            @torch.jit.script_method
5081            def forward(self, a):
5082                return torch.arange(2.7, a.size(0) + 2).view(-1, 1) + a
5083
5084        x = torch.randn(3, 4, requires_grad=True)
5085        self.run_test(ArangeScript(), x)
5086
5087        class ArangeModel(torch.nn.Module):
5088            def forward(self, a):
5089                return torch.arange(2.7, a.size(0) + 2).view(-1, 1) + a
5090
5091        self.run_test(ArangeModel(), x)
5092
5093    @skipIfUnsupportedMinOpsetVersion(9)
5094    def test_arange_start_end_step(self):
5095        class ArangeScript(torch.jit.ScriptModule):
5096            @torch.jit.script_method
5097            def forward(self, a):
5098                return (
5099                    torch.arange(
5100                        2, a.size(0) * a.size(1) + 2, a.size(1), dtype=torch.float
5101                    ).view(-1, 1)
5102                    + a
5103                )
5104
5105        x = torch.randn(3, 4, requires_grad=True)
5106        self.run_test(ArangeScript(), x)
5107
5108        class ArangeModel(torch.nn.Module):
5109            def forward(self, a):
5110                return (
5111                    torch.arange(
5112                        2, a.size(0) * a.size(1) + 2, a.size(1), dtype=torch.float
5113                    ).view(-1, 1)
5114                    + a
5115                )
5116
5117        self.run_test(ArangeModel(), x)
5118
5119    @skipIfUnsupportedMinOpsetVersion(11)
5120    def test_arange_start_end_step_notype(self):
5121        class ArangeScript(torch.jit.ScriptModule):
5122            @torch.jit.script_method
5123            def forward(self, a):
5124                return (
5125                    torch.arange(2.7, a.size(0) * a.size(1) + 2, a.size(1)).view(-1, 1)
5126                    + a
5127                )
5128
5129        x = torch.randn(3, 4, requires_grad=True)
5130        self.run_test(ArangeScript(), x)
5131
5132        class ArangeModel(torch.nn.Module):
5133            def forward(self, a):
5134                return (
5135                    torch.arange(2.7, a.size(0) * a.size(1) + 2, a.size(1)).view(-1, 1)
5136                    + a
5137                )
5138
5139        self.run_test(ArangeModel(), x)
5140
5141    @skipIfUnsupportedMinOpsetVersion(9)
5142    def test__dim_arange(self):
5143        class DimArange(torch.nn.Module):
5144            def forward(self, input):
5145                return torch._dim_arange(input, 1)
5146
5147        x = torch.ones(5, 6)
5148        self.run_test(DimArange(), x, input_names=["x"], dynamic_axes={"x": [0, 1]})
5149        remained_onnx_input_idx = None if self.opset_version < 11 else []
5150        self.run_test(DimArange(), x, remained_onnx_input_idx=remained_onnx_input_idx)
5151
5152    def _test_compare_ops(self, model, num_inputs):
5153        x_float = torch.randn(1, 2, 3, 4, requires_grad=True)
5154        x_int = torch.randint(10, (3, 4), dtype=torch.int32)
5155        if num_inputs > 1:
5156            y_float = torch.randn(1, 2, 3, 4, requires_grad=True)
5157            y_int = torch.randint(10, (3, 4), dtype=torch.int32)
5158            self.run_test(model, (x_float, y_float))
5159            self.run_test(model, (x_float, y_int))
5160            self.run_test(model, (x_int, y_float))
5161            self.run_test(model, (x_int, y_int))
5162        else:
5163            self.run_test(model, x_float)
5164            self.run_test(model, x_int)
5165
5166    @skipIfUnsupportedMinOpsetVersion(9)
5167    def test_and_or_xor(self):
5168        class MyModel(torch.nn.Module):
5169            def forward(self, x, y):
5170                return x ^ y, x | y, x & y, ~x
5171
5172        x = torch.randint(0, 2, (5, 5), dtype=torch.bool)
5173        y = torch.randint(0, 2, (5, 5), dtype=torch.bool)
5174        self.run_test(MyModel(), input_args=(x, y))
5175
5176    @skipIfUnsupportedMinOpsetVersion(9)
5177    def test_logical_and(self):
5178        class AndModel(torch.nn.Module):
5179            def forward(self, x, y):
5180                return torch.logical_and(x, y)
5181
5182        x = torch.randint(0, 2, (5, 5), dtype=torch.bool)
5183        y = torch.randint(0, 2, (5, 5), dtype=torch.bool)
5184        self.run_test(AndModel(), input_args=(x, y))
5185
5186        x = torch.randint(10, (5, 5), dtype=torch.int32)
5187        y = torch.randint(10, (5, 5), dtype=torch.int32)
5188        self.run_test(AndModel(), input_args=(x, y))
5189
5190        x = torch.randint(10, (5, 5), dtype=torch.double)
5191        y = torch.randint(10, (5, 5), dtype=torch.double)
5192        self.run_test(AndModel(), input_args=(x, y))
5193
5194        x = torch.randint(10, (2, 3, 5), dtype=torch.float32)
5195        y = torch.randint(10, (2, 3, 5), dtype=torch.long)
5196        self.run_test(AndModel(), input_args=(x, y))
5197
5198    @skipIfUnsupportedMinOpsetVersion(9)
5199    def test_logical_or(self):
5200        class OrModel(torch.nn.Module):
5201            def forward(self, x, y):
5202                return torch.logical_or(x, y)
5203
5204        x = torch.randint(0, 2, (5, 5), dtype=torch.bool)
5205        y = torch.randint(0, 2, (5, 5), dtype=torch.bool)
5206        self.run_test(OrModel(), input_args=(x, y))
5207
5208        x = torch.randint(10, (5, 5), dtype=torch.int32)
5209        y = torch.randint(10, (5, 5), dtype=torch.int32)
5210        self.run_test(OrModel(), input_args=(x, y))
5211
5212        x = torch.randint(10, (5, 5), dtype=torch.double)
5213        y = torch.randint(10, (5, 5), dtype=torch.double)
5214        self.run_test(OrModel(), input_args=(x, y))
5215
5216        x = torch.randint(10, (2, 3, 5), dtype=torch.float32)
5217        y = torch.randint(10, (2, 3, 5), dtype=torch.long)
5218        self.run_test(OrModel(), input_args=(x, y))
5219
5220    @skipIfUnsupportedMinOpsetVersion(9)
5221    def test_logical_xor(self):
5222        class XorModel(torch.nn.Module):
5223            def forward(self, x, y):
5224                return torch.logical_xor(x, y)
5225
5226        x = torch.randint(0, 2, (5, 5), dtype=torch.bool)
5227        y = torch.randint(0, 2, (5, 5), dtype=torch.bool)
5228        self.run_test(XorModel(), input_args=(x, y))
5229
5230        x = torch.randint(10, (5, 5), dtype=torch.int32)
5231        y = torch.randint(10, (5, 5), dtype=torch.int32)
5232        self.run_test(XorModel(), input_args=(x, y))
5233
5234        x = torch.randint(10, (5, 5), dtype=torch.double)
5235        y = torch.randint(10, (5, 5), dtype=torch.double)
5236        self.run_test(XorModel(), input_args=(x, y))
5237
5238        x = torch.randint(10, (2, 3, 5), dtype=torch.float32)
5239        y = torch.randint(10, (2, 3, 5), dtype=torch.long)
5240        self.run_test(XorModel(), input_args=(x, y))
5241
5242    @skipIfUnsupportedMinOpsetVersion(9)
5243    def test_logical_not(self):
5244        class NotModel(torch.nn.Module):
5245            def forward(self, x):
5246                return torch.logical_not(x)
5247
5248        x = torch.randint(0, 2, (5, 5), dtype=torch.bool)
5249        self.run_test(NotModel(), input_args=(x,))
5250
5251        x = torch.randint(10, (5, 5), dtype=torch.int32)
5252        self.run_test(NotModel(), input_args=(x,))
5253
5254        x = torch.randint(10, (5, 5), dtype=torch.double)
5255        self.run_test(NotModel(), input_args=(x,))
5256
5257        x = torch.randint(10, (2, 3, 5), dtype=torch.float32)
5258        self.run_test(NotModel(), input_args=(x,))
5259
5260    @skipIfUnsupportedMinOpsetVersion(11)  # float equal added after opset 11
5261    def test_eq(self):
5262        class EqualModel(torch.nn.Module):
5263            def forward(self, input, other):
5264                return input == other
5265
5266        self._test_compare_ops(EqualModel(), 2)
5267
5268    def test_gt(self):
5269        class GreaterModel(torch.nn.Module):
5270            def forward(self, input, other):
5271                return input > other
5272
5273        self._test_compare_ops(GreaterModel(), 2)
5274
5275    @skipIfUnsupportedMinOpsetVersion(9)
5276    def test_ge(self):
5277        class GreaterOrEqualModel(torch.nn.Module):
5278            def forward(self, input, other):
5279                return input >= other
5280
5281        self._test_compare_ops(GreaterOrEqualModel(), 2)
5282
5283    def test_gt_scalar(self):
5284        class GreaterModel(torch.nn.Module):
5285            def forward(self, input):
5286                return input > 1
5287
5288        self._test_compare_ops(GreaterModel(), 1)
5289
5290    def test_gt_primitive(self):
5291        class GreaterModel(torch.nn.Module):
5292            def __init__(self) -> None:
5293                super().__init__()
5294                self.y: int = 2
5295
5296            def forward(self, x: int):
5297                return self.y > x
5298
5299        x = 3
5300        self.run_test(GreaterModel(), (x,))
5301
5302    @skipIfUnsupportedMinOpsetVersion(9)
5303    def test_ge_scalar(self):
5304        class GreaterOrEqualModel(torch.nn.Module):
5305            def forward(self, input):
5306                return input >= 1
5307
5308        self._test_compare_ops(GreaterOrEqualModel(), 1)
5309
5310    def test_lt(self):
5311        class LessModel(torch.nn.Module):
5312            def forward(self, input, other):
5313                return input > other
5314
5315        self._test_compare_ops(LessModel(), 2)
5316
5317    @skipIfUnsupportedMinOpsetVersion(9)
5318    def test_le(self):
5319        class LessOrEqualModel(torch.nn.Module):
5320            def forward(self, input, other):
5321                return input <= other
5322
5323        self._test_compare_ops(LessOrEqualModel(), 2)
5324
5325    def test_lt_scalar(self):
5326        class LessModel(torch.nn.Module):
5327            def forward(self, input):
5328                return input < 1
5329
5330        self._test_compare_ops(LessModel(), 1)
5331
5332    @skipIfUnsupportedMinOpsetVersion(9)
5333    def test_le_scalar(self):
5334        class LessOrEqualModel(torch.nn.Module):
5335            def forward(self, input):
5336                return input <= 1
5337
5338        self._test_compare_ops(LessOrEqualModel(), 1)
5339
5340    def test_matmul(self):
5341        class MatmulModel(torch.nn.Module):
5342            def forward(self, input, other):
5343                return torch.matmul(input, other)
5344
5345        x = torch.randn(3, 4, requires_grad=True)
5346        y = torch.randn(4, 5, requires_grad=True)
5347        self.run_test(MatmulModel(), (x, y))
5348
5349        x = torch.randint(10, (3, 4))
5350        y = torch.randint(10, (4, 5))
5351        self.run_test(MatmulModel(), (x, y))
5352
5353    def test_matmul_batch(self):
5354        class MatmulModel(torch.nn.Module):
5355            def forward(self, input, other):
5356                return torch.matmul(input, other)
5357
5358        x = torch.randn(2, 3, 4, requires_grad=True)
5359        y = torch.randn(2, 4, 5, requires_grad=True)
5360        self.run_test(MatmulModel(), (x, y))
5361
5362        x = torch.randint(10, (2, 3, 4))
5363        y = torch.randint(10, (2, 4, 5))
5364        self.run_test(MatmulModel(), (x, y))
5365
5366    def _argmin_argmax_model(self, input):
5367        class ArgminArgmaxModel(torch.nn.Module):
5368            def forward(self, input):
5369                return (
5370                    torch.argmin(input),
5371                    torch.argmax(input),
5372                    torch.argmin(input, keepdim=True),
5373                    torch.argmax(input, keepdim=True),
5374                    torch.argmin(input, dim=0, keepdim=True),
5375                    torch.argmax(input, dim=1, keepdim=True),
5376                )
5377
5378        self.run_test(ArgminArgmaxModel(), input)
5379
5380    @skipIfUnsupportedMinOpsetVersion(9)
5381    def test_argmin_argmax(self):
5382        input = torch.randn(7, 3, 5)
5383        self._argmin_argmax_model(input)
5384
5385    # Argmin and Argmax with "select_last_index" is not supprted before opset 12
5386    # "select_last_index" was added in opset 12 to deal with corner case where the
5387    # same value appears multiple times in the tensor
5388    @skipIfUnsupportedMinOpsetVersion(12)
5389    def test_argmin_argmax_select_last_index(self):
5390        input = torch.tensor([[1.0, 2.0, 3.0], [1.0, 1.0, 2.0]])
5391        self._argmin_argmax_model(input)
5392
5393        input = torch.ones(7, 3, 5)
5394        self._argmin_argmax_model(input)
5395
5396    def test_repeat(self):
5397        class RepeatModel(torch.nn.Module):
5398            def forward(self, x, y):
5399                x2 = x.repeat(y.shape[0], 1)
5400                y1 = y.view(-1, 1)
5401                return x2 + y1
5402
5403        x = torch.tensor([1, 2, 3])
5404        y = torch.tensor([4, 5, 8, 9])
5405        self.run_test(RepeatModel(), (x, y))
5406
5407    @skipIfUnsupportedMinOpsetVersion(9)
5408    def test_repeat_interleave(self):
5409        class FlattenModel(torch.nn.Module):
5410            def forward(self, x):
5411                return x.repeat_interleave(2)
5412
5413        for shape in ([3], [3, 4], [2, 3, 4]):
5414            x = torch.randn(shape)
5415            self.run_test(FlattenModel(), (x,))
5416
5417        class DimsModel(torch.nn.Module):
5418            def forward(self, x):
5419                return x.repeat_interleave(4, dim=1)
5420
5421        x = torch.tensor([[1, 2], [3, 4]])
5422        self.run_test(DimsModel(), (x,))
5423
5424        class DimsModel2(torch.nn.Module):
5425            def forward(self, x):
5426                repeats = torch.tensor([4])
5427                return torch.repeat_interleave(x, repeats, dim=1)
5428
5429        x = torch.tensor([[1, 2], [3, 4]])
5430        self.run_test(DimsModel2(), (x,))
5431
5432        class RepeatsDimsModel(torch.nn.Module):
5433            def forward(self, x):
5434                repeats = torch.tensor([1, 2])
5435                return torch.repeat_interleave(x, repeats, dim=0)
5436
5437        x = torch.tensor([[1, 2], [3, 4]])
5438        self.run_test(RepeatsDimsModel(), (x,))
5439
5440        class RepeatsDimsModel2(torch.nn.Module):
5441            def forward(self, x):
5442                repeats = torch.tensor([1, 2])
5443                return torch.repeat_interleave(x, repeats, dim=1)
5444
5445        x = torch.tensor([[1, 2], [3, 4]])
5446        self.run_test(RepeatsDimsModel2(), (x,))
5447
5448    @skipIfUnsupportedMinOpsetVersion(9)
5449    def test_repeat_interleave_noop(self):
5450        class Model(torch.nn.Module):
5451            def forward(self, x):
5452                return x.repeat_interleave(1, dim=1)
5453
5454        x = torch.randn(4, 1, 8)
5455        self.run_test(Model(), (x,))
5456
5457    @skipIfUnsupportedMinOpsetVersion(13)
5458    def test_dynamic_repeat_interleave(self):
5459        class SingleDynamicModel(torch.nn.Module):
5460            def forward(self, x):
5461                repeats = torch.tensor(4)
5462                return torch.repeat_interleave(x, repeats, dim=1)
5463
5464        x = torch.tensor([[1, 2, 4], [3, 4, 7]])
5465        another_x = torch.tensor([[7, 8], [5, 6]])
5466        self.run_test(
5467            SingleDynamicModel(),
5468            x,
5469            additional_test_inputs=[another_x],
5470            input_names=["input_1"],
5471            dynamic_axes={"input_1": {1: "w"}},
5472        )
5473
5474        class NegDynamicModel(torch.nn.Module):
5475            def forward(self, x):
5476                repeats = torch.tensor(4)
5477                return torch.repeat_interleave(x, repeats, dim=-1)
5478
5479        x = torch.tensor([[1, 2, 4], [3, 4, 7]])
5480        another_x = torch.tensor([[7, 8], [5, 6]])
5481        self.run_test(
5482            NegDynamicModel(),
5483            x,
5484            additional_test_inputs=[another_x],
5485            input_names=["input_1"],
5486            dynamic_axes={"input_1": {1: "w"}},
5487        )
5488
5489        class SingleDynamicModelFloat(torch.nn.Module):
5490            def forward(self, x):
5491                repeats = torch.tensor([4])
5492                return torch.repeat_interleave(x, repeats, dim=0)
5493
5494        x = torch.tensor([[1.1, 2.1], [3.1, 4.1]])
5495        another_x = torch.tensor([[7.1, 8.1], [5.1, 6.1]])
5496        self.run_test(
5497            SingleDynamicModelFloat(),
5498            x,
5499            additional_test_inputs=[another_x],
5500            input_names=["input_1"],
5501            dynamic_axes={"input_1": {0: "h"}},
5502        )
5503
5504        class DynamicRepeatsModel(torch.nn.Module):
5505            def forward(self, x, repeats):
5506                return torch.repeat_interleave(x, repeats, dim=1)
5507
5508        x = torch.tensor([[1, 2, 4], [3, 4, 7]])
5509        another_x = torch.tensor([[7, 8], [5, 6]])
5510        repeats = torch.tensor([2])
5511        another_repeats = torch.tensor([4])
5512        self.run_test(
5513            DynamicRepeatsModel(),
5514            (x, repeats),
5515            additional_test_inputs=[(another_x, another_repeats)],
5516            input_names=["input_1", "repeats_1"],
5517            dynamic_axes={"input_1": {1: "w"}, "repeats_1": {0: "r"}},
5518        )
5519
5520        class DynamicRepeatsModel2(torch.nn.Module):
5521            def forward(self, x, repeats):
5522                return torch.repeat_interleave(x, repeats, dim=1)
5523
5524        x = torch.tensor([[1, 2, 4], [3, 4, 7]])
5525        repeats = torch.tensor([2])
5526        another_repeats = torch.tensor([4])
5527        self.run_test(
5528            DynamicRepeatsModel2(),
5529            (x, repeats),
5530            additional_test_inputs=[(x, another_repeats)],
5531            input_names=["input_1", "repeats_1"],
5532            dynamic_axes={"repeats_1": {0: "r"}},
5533        )
5534
5535        class DynamicFlattenModel(torch.nn.Module):
5536            def forward(self, x):
5537                return x.repeat_interleave(2)
5538
5539        x = torch.tensor([1, 2, 3])
5540        self.run_test(
5541            DynamicFlattenModel(),
5542            x,
5543            input_names=["input_1"],
5544            dynamic_axes={"input_1": {0: "w"}},
5545        )
5546
5547    @skipIfUnsupportedMinOpsetVersion(13)
5548    def test_multiple_dynamic_repeat_interleave(self):
5549        class DynamicRepeatsModel(torch.nn.Module):
5550            def forward(self, x, repeats):
5551                return torch.repeat_interleave(x, repeats, dim=1)
5552
5553        x = torch.tensor([[1, 2, 4], [3, 4, 7]])
5554        repeats = torch.tensor([2, 3, 4])
5555        another_repeats = torch.tensor([4, 3, 2])
5556        self.run_test(
5557            DynamicRepeatsModel(),
5558            (x, repeats),
5559            additional_test_inputs=[(x, another_repeats)],
5560            input_names=["input_1", "repeats_1"],
5561            dynamic_axes={"repeats_1": {0: "r"}},
5562        )
5563
5564        class DynamicRepeatsModel2(torch.nn.Module):
5565            def forward(self, x, repeats):
5566                return torch.repeat_interleave(x, repeats, dim=0)
5567
5568        x = torch.tensor([[1, 2, 4], [3, 4, 7]])
5569        repeats = torch.tensor([2, 3])
5570        another_repeats = torch.tensor([4, 3])
5571        self.run_test(
5572            DynamicRepeatsModel2(),
5573            (x, repeats),
5574            additional_test_inputs=[(x, another_repeats)],
5575            input_names=["input_1", "repeats_1"],
5576            dynamic_axes={"repeats_1": {0: "r"}},
5577        )
5578
5579    def test_view(self):
5580        class ViewModel(torch.nn.Module):
5581            def forward(self, input):
5582                return input.view(4, 24)
5583
5584        x = torch.randint(10, (4, 2, 3, 4), dtype=torch.int32)
5585        self.run_test(ViewModel(), x)
5586
5587    def test_view_dynamic(self):
5588        class ViewModel(torch.nn.Module):
5589            def forward(self, input, other):
5590                return input.view(other.shape)
5591
5592        x = torch.randn(2, 3, 4)
5593        shape = torch.randn(6, 4)
5594        self.run_test(
5595            ViewModel(),
5596            (x, shape),
5597            input_names=["x", "shape"],
5598            dynamic_axes={"x": [0, 1, 2], "shape": [0, 1]},
5599        )
5600        self.run_test(ViewModel(), (x, shape), remained_onnx_input_idx=[0])
5601
5602    def test_view_dynamic_zero_dim(self):
5603        class ViewModel(torch.nn.Module):
5604            def forward(self, input):
5605                input = input.view(-1, 2)
5606                return input.view(1, -1)
5607
5608        x = torch.ones(2)
5609        another_x = torch.empty((0,))
5610        self.run_test(
5611            ViewModel(),
5612            x,
5613            additional_test_inputs=[another_x],
5614            input_names=["input_1"],
5615            dynamic_axes={
5616                "input_1": [
5617                    0,
5618                ]
5619            },
5620        )
5621
5622    def test_view_as(self):
5623        class ViewModel(torch.nn.Module):
5624            def forward(self, input, other):
5625                return input.view_as(other)
5626
5627        x = torch.randn(2, 3, 4)
5628        y = torch.randn(6, 4)
5629        self.run_test(ViewModel(), (x, y))
5630
5631    def test_linear(self):
5632        class LinearModel(torch.nn.Module):
5633            def __init__(self) -> None:
5634                super().__init__()
5635                self.fc = torch.nn.Linear(16, 16)
5636
5637            def forward(self, x):
5638                out = self.fc(x)
5639                out = self.fc(out)
5640                return out
5641
5642        x = torch.randn(3, 16)
5643        self.run_test(LinearModel(), (x,))
5644
5645        class LinearModel(torch.nn.Module):
5646            def forward(self, input, weight, bias):
5647                return torch.nn.functional.linear(input, weight, bias)
5648
5649        # input of rank 2
5650        x = torch.randn(2, 2)
5651        y = torch.randn(2, 2)
5652        z = torch.randn(1)
5653        self.run_test(LinearModel(), (x, y, z))
5654
5655        # input of rank 3
5656        x = torch.randn(3, 3, 3)
5657        y = torch.randn(3, 3)
5658        z = torch.randn(1)
5659        self.run_test(LinearModel(), (x, y, z))
5660
5661    @skipScriptTest()
5662    def test_weight_norm(self):
5663        # addmm for 3-d inputs converts to onnx::MatMul
5664        model = torch.nn.utils.weight_norm(torch.nn.Linear(5, 10), dim=1)
5665        x = torch.randn(3, 4, 5, requires_grad=True)
5666        self.run_test(model, x)
5667
5668        # addmm for 2-d inputs converts to onnx::Gemm
5669        model = torch.nn.utils.weight_norm(torch.nn.Linear(5, 10), dim=1)
5670        x = torch.randn(4, 5, requires_grad=True)
5671        self.run_test(model, x)
5672
5673        model = torch.nn.utils.weight_norm(torch.nn.Conv1d(1, 1, 3))
5674        x = torch.randn(1, 1, 5, requires_grad=True)
5675        self.run_test(model, x)
5676
5677        model = torch.nn.utils.weight_norm(torch.nn.Conv1d(1, 1, 3), dim=-2)
5678        x = torch.randn(1, 1, 5, requires_grad=True)
5679        self.run_test(model, x)
5680
5681        model = torch.nn.utils.weight_norm(torch.nn.Conv1d(3, 6, 3), name="weight")
5682        x = torch.randn(3, 3, 5, requires_grad=True)
5683        self.run_test(model, x)
5684
5685    @skipScriptTest()
5686    def test_weight_norm_nodim(self):
5687        # addmm for 3-d inputs converts to onnx::MatMul
5688        model = torch.nn.utils.weight_norm(torch.nn.Linear(5, 10), dim=None)
5689        x = torch.randn(3, 4, 5, requires_grad=True)
5690        self.run_test(model, x)
5691
5692        # addmm for 2-d inputs converts to onnx::Gemm
5693        model = torch.nn.utils.weight_norm(torch.nn.Linear(5, 10), dim=None)
5694        x = torch.randn(4, 5, requires_grad=True)
5695        self.run_test(model, x)
5696
5697    def test_flatten(self):
5698        class FlattenModel(torch.nn.Module):
5699            def forward(self, input):
5700                return torch.flatten(input)
5701
5702        model = FlattenModel()
5703
5704        # flatten with 4d input
5705        x = torch.randint(10, (1, 2, 3, 4))
5706        self.run_test(model, x)
5707
5708        # flatten with 0d input
5709        x = torch.randn([])
5710        self.run_test(model, x)
5711
5712        # flatten with 1d input
5713        x = torch.randn(4)
5714        self.run_test(model, x)
5715
5716    def test_flatten2d(self):
5717        class FlattenModel(torch.nn.Module):
5718            def forward(self, input):
5719                return torch.flatten(input, 1)
5720
5721        x = torch.randint(10, (1, 2, 3, 4))
5722        self.run_test(FlattenModel(), x)
5723
5724    def test_flatten2d_neg(self):
5725        class FlattenModel(torch.nn.Module):
5726            def forward(self, x):
5727                return (
5728                    torch.flatten(x, 1, -1),
5729                    torch.flatten(x, 0, -2),
5730                    torch.flatten(x, 1, -2),
5731                )
5732
5733        x = torch.randint(10, (1, 2, 3, 4))
5734        self.run_test(FlattenModel(), x)
5735
5736    @skipIfUnsupportedMinOpsetVersion(9)
5737    def test_flatten_dynamic_axes(self):
5738        class MyModule(torch.nn.Module):
5739            def forward(self, x):
5740                return torch.flatten(x, start_dim=2, end_dim=3)
5741
5742        batch_size = 3
5743        x = torch.randn(batch_size, 5, 4, 5)
5744        y = torch.randn(5, 5, 4, 5)
5745        model = MyModule()
5746        self.run_test(
5747            model,
5748            x,
5749            additional_test_inputs=[y],
5750            input_names=["input"],
5751            output_names=["output"],
5752            dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}},
5753        )
5754
5755    @skipIfUnsupportedMinOpsetVersion(11)
5756    def test_getitem(self):
5757        class GetItemModel(torch.jit.ScriptModule):
5758            @torch.jit.script_method
5759            def forward(self, x, y, z, ind):
5760                # this will create prim::ListConstruct(x, y, z) + aten::__getitem__
5761                arr = [x, y, z]
5762                return arr[ind]
5763
5764        x = torch.randn(3, 4, 5)
5765        y = torch.randn(1, 4, 5)
5766        z = torch.randn(2, 4, 5)
5767        ind = torch.tensor(1, dtype=torch.long)
5768        self.run_test(GetItemModel(), (x, y, z, ind))
5769
5770        ind = torch.tensor(-2, dtype=torch.long)
5771        self.run_test(GetItemModel(), (x, y, z, ind))
5772
5773    @skipDtypeChecking
5774    def test_item(self):
5775        class M(torch.nn.Module):
5776            def forward(self, x, y, i: int):
5777                return int(x[y[i]].item())
5778
5779        x = torch.arange(6, dtype=torch.float)
5780        y = torch.tensor([0, 1, 2, 3, 4], dtype=torch.long)
5781        i = 3
5782        self.run_test(torch.jit.script(M()), (x, y, i))
5783
5784    @skipScriptTest()  # torch.nonzero(x, as_tuple=True) is not scriptable.
5785    @skipIfUnsupportedMinOpsetVersion(9)
5786    def test_nonzero(self):
5787        class NonzeroModel(torch.nn.Module):
5788            def forward(self, x):
5789                return x.nonzero(), x.nonzero(as_tuple=True)
5790
5791        x = torch.randn(60).index_fill_(0, torch.randint(0, 60, (20,)), 0).view(3, 4, 5)
5792        self.run_test(NonzeroModel(), (x,))
5793
5794    def test_unbind(self):
5795        class UnbindModel(torch.nn.Module):
5796            def forward(self, input):
5797                _, out, _ = input.unbind()
5798                return out
5799
5800        x = torch.randn(3, 4, 5)
5801        self.run_test(UnbindModel(), x)
5802
5803        class UnbindModel2(torch.nn.Module):
5804            def forward(self, input):
5805                _, out, _, _ = input.unbind(1)
5806                return out
5807
5808        x = torch.randn(3, 4, 5)
5809        self.run_test(UnbindModel2(), x)
5810
5811        class UnbindModel3(torch.nn.Module):
5812            def forward(self, input):
5813                _, out, _, _ = input.unbind(-2)
5814                return out
5815
5816        x = torch.randn(3, 4, 5)
5817        self.run_test(UnbindModel3(), x)
5818
5819    @skipIfUnsupportedMinOpsetVersion(11)
5820    def test_len(self):
5821        class LenModel(torch.jit.ScriptModule):
5822            @torch.jit.script_method
5823            def forward(self, input):
5824                return len(input.unbind()) + input
5825
5826        x = torch.randn(4, 5)
5827        self.run_test(
5828            LenModel(),
5829            x,
5830            input_names=["input"],
5831            dynamic_axes={"input": {0: "seq"}},
5832            additional_test_inputs=(torch.randn(5, 5),),
5833        )
5834
5835    @skipIfUnsupportedMinOpsetVersion(9)
5836    def test_len_list(self):
5837        class LenListModel(torch.jit.ScriptModule):
5838            @torch.jit.script_method
5839            def forward(self, input):
5840                return torch.ones(len(input.shape))
5841
5842        x = torch.randn(4, 5)
5843        self.run_test(LenListModel(), x, remained_onnx_input_idx=[])
5844
5845    @skipIfUnsupportedMinOpsetVersion(11)
5846    def test_unbind_dynamic(self):
5847        class UnbindModel(torch.jit.ScriptModule):
5848            @torch.jit.script_method
5849            def forward(self, input):
5850                return input.unbind()[1]
5851
5852        x = torch.randn(3, 4, 5)
5853        self.run_test(UnbindModel(), x)
5854
5855        class UnbindModel2(torch.jit.ScriptModule):
5856            @torch.jit.script_method
5857            def forward(self, input):
5858                return input.unbind(-1)[1]
5859
5860        x = torch.randn(3, 4, 5)
5861        self.run_test(UnbindModel2(), x)
5862
5863    @skipScriptTest()  # scripting tests run for opsets > 11. See: test_split_script
5864    def test_split(self):
5865        class SplitModel(torch.nn.Module):
5866            def forward(self, input):
5867                return input.split([2, 1, 2]), input.split([3, 2])[0]
5868
5869        x = torch.randn(5, 4, 3)
5870        self.run_test(SplitModel(), x)
5871
5872        class SplitModel2(torch.nn.Module):
5873            def forward(self, input):
5874                return input.split([2, 1, 1], -2), input.split([2, 2], -2)[-1]
5875
5876        x = torch.randn(5, 4, 3)
5877        self.run_test(SplitModel2(), x)
5878
5879        class SplitModel3(torch.nn.Module):
5880            def forward(self, input):
5881                return input.split([2, 1, 2])
5882
5883        x = torch.randn(5, 4, 3)
5884        self.run_test(SplitModel3(), x)
5885
5886    @skipIfUnsupportedMinOpsetVersion(11)
5887    def test_split_script(self):
5888        class SplitModel(torch.nn.Module):
5889            def forward(self, input):
5890                return input.split([2, 1, 2]), input.split([3, 2])[0]
5891
5892        x = torch.randn(5, 4, 3)
5893        self.run_test(SplitModel(), x)
5894
5895        class SplitModel2(torch.nn.Module):
5896            def forward(self, input):
5897                return input.split([2, 1, 1], -2), input.split([2, 2], -2)[-1]
5898
5899        x = torch.randn(5, 4, 3)
5900        self.run_test(SplitModel2(), x)
5901
5902        class SplitModel3(torch.nn.Module):
5903            def forward(self, input):
5904                return input.split([2, 1, 2])
5905
5906        x = torch.randn(5, 4, 3)
5907        self.run_test(SplitModel3(), x)
5908
5909    @skipIfUnsupportedMinOpsetVersion(11)
5910    @skipScriptTest()
5911    def test_split_size_as_list(self):
5912        class SplitModel(torch.nn.Module):
5913            def forward(self, input, split_sizes: List[int]):
5914                out = []
5915                split_list: List[Tensor] = input.split(split_sizes)
5916
5917                for ob in split_list:
5918                    out.append(ob)  # noqa: PERF402
5919                return torch.cat(out, dim=0)
5920
5921        x = torch.randn(6, 4, 3)
5922        split_sizes = [torch.tensor(2), torch.tensor(4)]
5923        self.run_test(SplitModel(), (x, split_sizes))
5924
5925    @skipIfUnsupportedMinOpsetVersion(11)
5926    def test_split_size_with_slice(self):
5927        class SplitModule(torch.nn.Module):
5928            def forward(self, x, y, t):
5929                splits = (x.size(1), y.size(1))
5930                out, out2 = torch.split(t, splits, dim=1)
5931                return out, out2
5932
5933        x = torch.randn(2, 3)
5934        y = torch.randn(2, 4)
5935        t = torch.randn(2, 7)
5936        self.run_test(
5937            SplitModule(),
5938            (x, y, t),
5939            input_names=["x", "y", "t"],
5940            dynamic_axes={"x": [0, 1], "y": [0, 1], "t": [0, 1]},
5941        )
5942        self.run_test(SplitModule(), (x, y, t), remained_onnx_input_idx=[2])
5943
5944    @skipIfUnsupportedMinOpsetVersion(11)
5945    def test_split_dynamic(self):
5946        class SplitModel(torch.jit.ScriptModule):
5947            @torch.jit.script_method
5948            def forward(self, input):
5949                return input.split(2)[1]
5950
5951        x = torch.randn(5, 4, 3)
5952        self.run_test(SplitModel(), x)
5953
5954        class SplitModel2(torch.jit.ScriptModule):
5955            @torch.jit.script_method
5956            def forward(self, input):
5957                return input.split(2, -3)[1]
5958
5959        x = torch.randn(5, 4, 3)
5960        self.run_test(SplitModel2(), x)
5961
5962    @skipIfUnsupportedMinOpsetVersion(11)
5963    def test_split_dynamic_axes(self):
5964        class Split(torch.nn.Module):
5965            def forward(self, x):
5966                return x.split(1, dim=-1)
5967
5968        x = torch.randn(4, 384, 2)
5969        input_names = ["logits"]
5970        self.run_test(
5971            Split(),
5972            x,
5973            input_names=input_names,
5974            dynamic_axes={input_names[0]: {0: "batch"}},
5975        )
5976
5977    @skipIfUnsupportedMinOpsetVersion(11)
5978    def test_chunk(self):
5979        class ChunkModel(torch.nn.Module):
5980            def __init__(self, dim=1):
5981                super().__init__()
5982                self.dim = dim
5983
5984            def forward(self, x):
5985                return torch.chunk(x, 3, dim=self.dim)
5986
5987        model = ChunkModel()
5988        model.eval()
5989        model_neg_dim = ChunkModel(-1)
5990        model_neg_dim.eval()
5991        x = torch.randn(1, 18)
5992
5993        for dim_size_ in range(13, 16):
5994            y = torch.randn(1, dim_size_)
5995            self.run_test(
5996                model,
5997                x,
5998                additional_test_inputs=[y],
5999                input_names=["x"],
6000                dynamic_axes={"x": {0: "batch_size", 1: "dims"}},
6001            )
6002
6003            self.run_test(
6004                model_neg_dim,
6005                x,
6006                additional_test_inputs=[y],
6007                input_names=["x"],
6008                dynamic_axes={"x": {0: "batch_size", 1: "dims"}},
6009            )
6010
6011    @skipIfUnsupportedMinOpsetVersion(11)
6012    def test_dynamic_chunk(self):
6013        class ChunkModel(torch.nn.Module):
6014            def __init__(self, dim=1):
6015                super().__init__()
6016                self.dim = dim
6017
6018            def forward(self, x):
6019                return torch.chunk(x, x.size(0), dim=self.dim)
6020
6021        model = ChunkModel()
6022        model.eval()
6023        model_neg_dim = ChunkModel(-1)
6024        model_neg_dim.eval()
6025        x = torch.randn(3, 18)
6026
6027        for dim_size_ in range(13, 16):
6028            y = torch.randn(3, dim_size_)
6029            self.run_test(
6030                model,
6031                x,
6032                additional_test_inputs=[y],
6033                input_names=["x"],
6034                dynamic_axes={"x": {0: "batch_size", 1: "dims"}},
6035            )
6036
6037            self.run_test(
6038                model_neg_dim,
6039                x,
6040                additional_test_inputs=[y],
6041                input_names=["x"],
6042                dynamic_axes={"x": {0: "batch_size", 1: "dims"}},
6043            )
6044
6045    def test_concat(self):
6046        class ConcatModel(torch.nn.Module):
6047            def forward(self, x, y, z):
6048                return torch.cat((x, y, z))
6049
6050        x = torch.randn(3, 4, 5)
6051        y = torch.randn(1, 4, 5)
6052        z = torch.randn(2, 4, 5)
6053        self.run_test(ConcatModel(), (x, y, z))
6054
6055    @skipIfUnsupportedMinOpsetVersion(11)
6056    def test_concat_dynamic(self):
6057        class ConcatDynamicModel(torch.jit.ScriptModule):
6058            @torch.jit.script_method
6059            def forward(self, x):
6060                return torch.cat(x.unbind())
6061
6062        x = torch.randn(4, 5, 6)
6063        self.run_test(ConcatDynamicModel(), x)
6064
6065    def test_stack(self):
6066        class StackModel(torch.nn.Module):
6067            def forward(self, x, y, z):
6068                return torch.stack((x, y, z), 1)
6069
6070        x = torch.randn(3, 4, 5)
6071        y = torch.randn(3, 4, 5)
6072        z = torch.randn(3, 4, 5)
6073        self.run_test(StackModel(), (x, y, z))
6074
6075    @skipIfUnsupportedMinOpsetVersion(11)
6076    def test_stack_dynamic(self):
6077        class StackDynamicModel(torch.jit.ScriptModule):
6078            @torch.jit.script_method
6079            def forward(self, x):
6080                return torch.stack(x.unbind(), 1)
6081
6082        x = torch.randn(4, 5, 6)
6083        self.run_test(StackDynamicModel(), x)
6084
6085    def test_loop_dynamic(self):
6086        class LoopModel(torch.jit.ScriptModule):
6087            @torch.jit.script_method
6088            def forward(self, x):
6089                for i in range(x.size(2)):
6090                    x = x + i
6091                return x
6092
6093        model = LoopModel()
6094        inputs = torch.zeros(1, 2, 3, dtype=torch.long)
6095        self.run_test(model, inputs)
6096
6097    @skipIfUnsupportedMinOpsetVersion(9)
6098    def test_loop_nested(self):
6099        class NestedLoopsModel(torch.jit.ScriptModule):
6100            @torch.jit.script_method
6101            def forward(self, x):
6102                for i in range(5):
6103                    a = 0
6104                    while a < 4:
6105                        a += 1
6106                    x = x + a
6107                return x
6108
6109        model = NestedLoopsModel()
6110        inputs = torch.zeros(1, 2, 3, dtype=torch.long)
6111        self.run_test(model, inputs)
6112
6113    @skipIfUnsupportedMinOpsetVersion(11)
6114    def test_loop_with_list(self):
6115        class ListLoopModel(torch.jit.ScriptModule):
6116            @torch.jit.script_method
6117            def forward(self, x):
6118                res = []
6119                res1 = []
6120                arr = x.split([3, 4, 1, 1, 2, 3, 2], 0)
6121                res2 = torch.zeros(3, 4, dtype=torch.long)
6122                res3 = []
6123                res4 = []
6124                for i in range(len(arr)):
6125                    res.append(arr[i].sum(0, False))
6126                    res1.append(arr[-1 - i].sum(0, False))
6127                    res2 += 1
6128                    res3 = res3 + [arr[i].sum(0, False)]
6129                    res4 += [arr[-1 - i].sum(0, False)]
6130                return res, res1, res2, torch.stack(res3), torch.stack(res4)
6131
6132        model = ListLoopModel()
6133        inputs = torch.randn(16)
6134        self.run_test(model, inputs)
6135
6136    @skipIfUnsupportedMinOpsetVersion(11)
6137    def test_loop_transpose(self):
6138        class LoopModel(torch.nn.Module):
6139            def forward(self, x):
6140                res = torch.zeros_like(x[0])
6141                for i in range(x.size(0)):
6142                    res += x[0].transpose(0, 1)
6143                return res
6144
6145        model = torch.jit.script(LoopModel())
6146        x = torch.randn(5, 3, 3)
6147        self.run_test(model, x)
6148
6149    @skipIfUnsupportedMinOpsetVersion(11)
6150    def test_loop_multi_dim(self):
6151        class LoopMultiDimModel(torch.jit.ScriptModule):
6152            @torch.jit.script_method
6153            def forward(self, x, y):
6154                for x_ in torch.flip(x.narrow(0, 0, 7), [0]):
6155                    y = x_[0][y]
6156                return y
6157
6158        model = LoopMultiDimModel()
6159        x = torch.randint(0, 5, (8, 1, 17), dtype=torch.long)
6160        y = torch.ones(1, dtype=torch.long)
6161        self.run_test(model, (x, y))
6162
6163    @skipIfUnsupportedMinOpsetVersion(11)
6164    def test_list(self):
6165        class ListModel(torch.jit.ScriptModule):
6166            @torch.jit.script_method
6167            def forward(self, x):
6168                tensors = x.unbind()
6169                res = []
6170                res.append(tensors[0])
6171                res.append(tensors[1])
6172                res.pop(1)
6173
6174                res.insert(0, tensors[1])
6175                res.append(tensors[2])
6176                res += [tensors[3], tensors[4]]
6177                res = res + [tensors[5]]
6178                return torch.ones(len(res))
6179
6180        model = ListModel()
6181        inputs = torch.randn(16, 1)
6182        self.run_test(model, inputs)
6183
6184    @skipIfUnsupportedMinOpsetVersion(11)
6185    def test_list_append(self):
6186        class ListModel(torch.nn.Module):
6187            def forward(self, x, y):
6188                res = []
6189                for i in range(x.size(0)):
6190                    res += [torch.matmul(x[i], y)]
6191                return res
6192
6193        model = torch.jit.script(ListModel())
6194        x = torch.randn(16, 3, 4)
6195        y = torch.randn(4, 5)
6196        self.run_test(model, (x, y))
6197
6198    @skipIfUnsupportedMinOpsetVersion(13)
6199    def test_list_append_nested(self):
6200        class ListModel(torch.nn.Module):
6201            def forward(self, x, y):
6202                res = []
6203                for i in range(x.size(0)):
6204                    for j in range(x.size(1)):
6205                        res += [torch.matmul(x[i][j], y)]
6206                return res
6207
6208        model = torch.jit.script(ListModel())
6209        x = torch.randn(4, 4, 3, 4)
6210        y = torch.randn(4, 5)
6211        self.run_test(model, (x, y))
6212
6213    @skipIfUnsupportedMinOpsetVersion(14)  # Need onnx::Identity of sequence in opset 14
6214    def test_list_append_nested_2(self):
6215        class ListModel(torch.nn.Module):
6216            def forward(self, x):
6217                res = []
6218                res_replicate = []
6219                for i in range(x.size(0)):
6220                    if len(res) > 2:
6221                        for j in range(x.size(1)):
6222                            res.append(x[i][j])
6223                        res_replicate.append(res[-1])
6224                        res.append(res_replicate[-1])
6225                return res, res_replicate
6226
6227        model = torch.jit.script(ListModel())
6228        x = torch.randn(4, 4, 3, 4)
6229        self.run_test(model, (x,))
6230
6231    @skipIfUnsupportedMinOpsetVersion(13)
6232    def test_list_append_nested_mixed_dtype(self):
6233        class ListModel(torch.nn.Module):
6234            def forward(self, x, y):
6235                res = []
6236                for i in range(x.size(0)):
6237                    for j in range(x.size(1)):
6238                        if i == j:
6239                            res.append(x == y)
6240                        else:
6241                            res.append(x != y)
6242                return res
6243
6244        model = torch.jit.script(ListModel())
6245        x = torch.randn(4, 4, 3, 4)
6246        y = torch.randn(3, 4)
6247        self.run_test(model, (x, y))
6248
6249    @skipIfUnsupportedMinOpsetVersion(11)
6250    def test_list_pop(self):
6251        class ListModel(torch.nn.Module):
6252            def forward(self, x, y):
6253                res = []
6254                for i in range(x.size(0)):
6255                    res += [torch.matmul(x[i], y)]
6256                res.pop()
6257                return res
6258
6259        model = torch.jit.script(ListModel())
6260        x = torch.randn(16, 3, 4)
6261        y = torch.randn(4, 5)
6262        self.run_test(model, (x, y))
6263
6264    @skipIfUnsupportedMinOpsetVersion(13)
6265    def test_list_pop_nested(self):
6266        class ListModel(torch.nn.Module):
6267            def forward(self, x, y):
6268                res = []
6269                for i in range(x.size(0)):
6270                    for j in range(x.size(1)):
6271                        res += [torch.matmul(x[i][j], y)]
6272                        res.pop()
6273                    res += [torch.matmul(x[i][0], y)]
6274                return res
6275
6276        model = torch.jit.script(ListModel())
6277        x = torch.randn(4, 4, 3, 4)
6278        y = torch.randn(4, 5)
6279        self.run_test(model, (x, y))
6280
6281    @skipIfUnsupportedMinOpsetVersion(11)
6282    def test_list_del(self):
6283        class ListModel(torch.nn.Module):
6284            def forward(self, x, y):
6285                res = []
6286                for i in range(x.size(0)):
6287                    res += [torch.matmul(x[i], y)]
6288                del res[2]
6289                return res
6290
6291        model = torch.jit.script(ListModel())
6292        x = torch.randn(16, 3, 4)
6293        y = torch.randn(4, 5)
6294        self.run_test(model, (x, y))
6295
6296    @skipIfUnsupportedMinOpsetVersion(13)
6297    def test_list_del_nested(self):
6298        class ListModel(torch.nn.Module):
6299            def forward(self, x, y):
6300                res = []
6301                for i in range(x.size(0)):
6302                    for j in range(x.size(1)):
6303                        res += [torch.matmul(x[i][j], y)]
6304                        del res[i]
6305                    res += [torch.matmul(x[i][0], y)]
6306                return res
6307
6308        model = torch.jit.script(ListModel())
6309        x = torch.randn(4, 4, 3, 4)
6310        y = torch.randn(4, 5)
6311        self.run_test(model, (x, y))
6312
6313    @skipIfUnsupportedMinOpsetVersion(11)
6314    def test_list_set(self):
6315        class ListModel(torch.nn.Module):
6316            def forward(self, x, y):
6317                res = []
6318                for i in range(x.size(0)):
6319                    res.append(x[i])
6320                res[y] = x[y]
6321                return res
6322
6323        model = torch.jit.script(ListModel())
6324        x = torch.randn(12, 4)
6325        y = torch.tensor(2, dtype=torch.long)
6326        self.run_test(model, (x, y))
6327
6328    @skipIfUnsupportedMinOpsetVersion(13)
6329    def test_list_idx_sum(self):
6330        class ListModel(torch.nn.Module):
6331            def forward(self, x, y):
6332                indices = torch.arange(x.size(0))
6333                res = []
6334                for i in range(x.size(0)):
6335                    res.append(x[i])
6336                return res[torch.sum(indices[:y])]
6337
6338        model = torch.jit.script(ListModel())
6339        x = torch.randn(12, 4)
6340        y = torch.tensor(2, dtype=torch.long)
6341        self.run_test(model, (x, y))
6342
6343    @skipIfUnsupportedMinOpsetVersion(9)
6344    def test_tensor_factories(self):
6345        class TensorFactory(torch.nn.Module):
6346            def forward(self, x):
6347                return torch.zeros(x.size()) + torch.ones(x.size())
6348
6349        x = torch.randn(2, 3, 4)
6350        self.run_test(
6351            TensorFactory(), x, input_names=["x"], dynamic_axes={"x": [0, 1, 2]}
6352        )
6353        self.run_test(TensorFactory(), x, remained_onnx_input_idx=[])
6354
6355    @skipIfUnsupportedMinOpsetVersion(9)
6356    def test_tensor_factories_script(self):
6357        class TensorFactory(torch.jit.ScriptModule):
6358            @torch.jit.script_method
6359            def forward(self, x):
6360                return torch.zeros(x.shape, dtype=torch.float) + torch.ones(
6361                    x.shape, dtype=torch.float
6362                )
6363
6364        x = torch.randn(2, 3, 4)
6365        self.run_test(
6366            TensorFactory(), x, input_names=["x"], dynamic_axes={"x": [0, 1, 2]}
6367        )
6368        self.run_test(TensorFactory(), x, remained_onnx_input_idx=[])
6369
6370    @skipIfUnsupportedMinOpsetVersion(9)
6371    def test_tensor_like_factories_script(self):
6372        class TensorFactory(torch.jit.ScriptModule):
6373            @torch.jit.script_method
6374            def forward(self, x):
6375                zeros = torch.zeros_like(
6376                    x,
6377                    dtype=torch.float,
6378                    layout=torch.strided,
6379                    device=torch.device("cpu"),
6380                )
6381                ones = torch.ones_like(
6382                    x,
6383                    dtype=torch.float,
6384                    layout=torch.strided,
6385                    device=torch.device("cpu"),
6386                )
6387                return zeros + ones
6388
6389        x = torch.randn(2, 3, 4)
6390        self.run_test(
6391            TensorFactory(), x, input_names=["x"], dynamic_axes={"x": [0, 1, 2]}
6392        )
6393        self.run_test(TensorFactory(), x, remained_onnx_input_idx=[])
6394
6395    @skipIfUnsupportedMinOpsetVersion(13)
6396    def test_tensor_split(self):
6397        class TensorSplitModel(torch.nn.Module):
6398            def forward(self, input):
6399                return (
6400                    input.tensor_split([1, 3]),
6401                    # test with output indexing.
6402                    input.tensor_split([2, 4])[0],
6403                    # test split on specific dim.
6404                    input.tensor_split([1, 3, 4], dim=-2),
6405                    # test split on specific dim and output indexing.
6406                    input.tensor_split([0, 2], dim=-2)[-1],
6407                    # test with out of bound end index (5).
6408                    input.tensor_split([2, 3, 5]),
6409                )
6410
6411        self.run_test(TensorSplitModel(), torch.randn(5, 4, 3))
6412
6413    @skipIfUnsupportedMinOpsetVersion(13)
6414    def test_tensor_split_scalar(self):
6415        class TensorSplitModel(torch.nn.Module):
6416            def forward(self, x):
6417                return torch.tensor_split(x, x.size(1))
6418
6419        self.run_test(TensorSplitModel(), torch.randn(1, 2, 3))
6420
6421    @skipIfUnsupportedMinOpsetVersion(13)
6422    def test_tensor_split_dynamic_axes(self):
6423        class TensorSplitModel(torch.nn.Module):
6424            def forward(self, x):
6425                return x.tensor_split(1, dim=-1)
6426
6427        x = torch.randn(4, 384, 2)
6428        input_names = ["logits"]
6429        self.run_test(
6430            TensorSplitModel(),
6431            x,
6432            input_names=input_names,
6433            dynamic_axes={input_names[0]: {0: "batch"}},
6434        )
6435
6436    @skipIfUnsupportedMinOpsetVersion(9)
6437    def test_eye(self):
6438        class TensorFactory(torch.nn.Module):
6439            def forward(self, x):
6440                return (
6441                    torch.eye(x.size()[1], 3),
6442                    torch.eye(4, 4, dtype=torch.long),
6443                    torch.eye(x.size()[1], 2, dtype=torch.long),
6444                    torch.eye(x.shape[0]),
6445                    torch.eye(x.shape[0], dtype=torch.float64),
6446                )
6447
6448        x = torch.randn(2, 3, 4)
6449        another_x = torch.randn(5, 6, 7)
6450        self.run_test(
6451            TensorFactory(),
6452            x,
6453            additional_test_inputs=[another_x],
6454            input_names=["input_1"],
6455            dynamic_axes={"input_1": [0, 1, 2]},
6456        )
6457
6458    @skipIfUnsupportedMinOpsetVersion(13)
6459    def test_diagonal(self):
6460        class DiagonalModel(torch.nn.Module):
6461            def forward(self, x):
6462                return torch.diagonal(x)
6463
6464        x = torch.randn(2, 4, 5, 2)
6465        # Other test inputs to test dynamic behavior
6466        another_x = torch.randn(5, 6, 7, 8)
6467        self.run_test(
6468            DiagonalModel(),
6469            x,
6470            additional_test_inputs=[another_x],
6471            input_names=["input_1"],
6472            dynamic_axes={"input_1": [0, 1, 2, 3]},
6473        )
6474
6475        class DiagonalModelNegOffset(torch.nn.Module):
6476            def forward(self, x):
6477                return torch.diagonal(x, offset=-1)
6478
6479        x = torch.randn(2, 4, 5, 2)
6480        # Other test inputs to test dynamic behavior
6481        another_x = torch.randn(5, 6, 7, 8)
6482        self.run_test(
6483            DiagonalModelNegOffset(),
6484            x,
6485            additional_test_inputs=[another_x],
6486            input_names=["input_1"],
6487            dynamic_axes={"input_1": [0, 1, 2, 3]},
6488        )
6489
6490        class DiagonalModelPosOffset(torch.nn.Module):
6491            def forward(self, x):
6492                return torch.diagonal(x, offset=1)
6493
6494        x = torch.randn(2, 4, 5, 2)
6495        # Other test inputs to test dynamic behavior
6496        another_x = torch.randn(5, 6, 7, 8)
6497        self.run_test(
6498            DiagonalModelPosOffset(),
6499            x,
6500            additional_test_inputs=[another_x],
6501            input_names=["input_1"],
6502            dynamic_axes={"input_1": [0, 1, 2, 3]},
6503        )
6504
6505        class DiagonalModelWithDims(torch.nn.Module):
6506            def forward(self, x):
6507                return torch.diagonal(x, offset=-1, dim1=1, dim2=2)
6508
6509        x = torch.randn(2, 4, 5, 2)
6510        # Other test inputs to test dynamic behavior
6511        another_x = torch.randn(5, 6, 7, 8)
6512        self.run_test(
6513            DiagonalModelWithDims(),
6514            x,
6515            additional_test_inputs=[another_x],
6516            input_names=["input_1"],
6517            dynamic_axes={"input_1": [0, 1, 2, 3]},
6518        )
6519
6520        class DiagonalModelWithNegativeDims(torch.nn.Module):
6521            def forward(self, x):
6522                return torch.diagonal(x, offset=0, dim1=-2, dim2=-1)
6523
6524        x = torch.randn(2, 4, 5, 2)
6525        # Other test inputs to test dynamic behavior
6526        another_x = torch.randn(5, 6, 7, 8)
6527        self.run_test(
6528            DiagonalModelWithNegativeDims(),
6529            x,
6530            additional_test_inputs=[another_x],
6531            input_names=["input_1"],
6532            dynamic_axes={"input_1": [0, 1, 2, 3]},
6533        )
6534
6535        class DiagonalModelOffsetOverrun(torch.nn.Module):
6536            def forward(self, x):
6537                return torch.diagonal(x, offset=-2), torch.diagonal(x, offset=5)
6538
6539        x = torch.randn(2, 4, 5, 2)
6540        # Other test inputs to test dynamic behavior
6541        another_x = torch.randn(5, 6, 7, 8)
6542        self.run_test(
6543            DiagonalModelOffsetOverrun(),
6544            x,
6545            additional_test_inputs=[another_x],
6546            input_names=["input_1"],
6547            dynamic_axes={"input_1": [0, 1, 2, 3]},
6548        )
6549
6550    @skipIfUnsupportedMinOpsetVersion(9)
6551    def test_inplace_zero(self):
6552        class Zero_(torch.nn.Module):
6553            def forward(self, x):
6554                return x.zero_(), x
6555
6556        x = torch.randn(2, 3, 4)
6557        self.run_test(Zero_(), x, input_names=["x"], dynamic_axes={"x": [0, 1, 2]})
6558        self.run_test(Zero_(), x, remained_onnx_input_idx=[])
6559
6560    @skipIfUnsupportedMinOpsetVersion(11)
6561    def test_inplace_zero_qkv(self):
6562        class Zero_(torch.nn.Module):
6563            def forward(self, x):
6564                return x[2:4].zero_()
6565
6566        x = torch.randn(24, 3, 4)
6567        self.run_test(Zero_(), x, input_names=["x"], dynamic_axes={"x": [0, 1, 2]})
6568
6569    @skipIfUnsupportedMinOpsetVersion(9)
6570    def test_new_zeros(self):
6571        class Zero_(torch.nn.Module):
6572            def forward(self, x):
6573                return x.new_zeros(x.shape[1:2]), x.new_zeros(
6574                    x.shape[2:], dtype=torch.long
6575                )
6576
6577        x = torch.randn(2, 3, 4)
6578        self.run_test(Zero_(), x, input_names=["x"], dynamic_axes={"x": [0, 1, 2]})
6579        self.run_test(Zero_(), x, remained_onnx_input_idx=[])
6580
6581    @skipIfUnsupportedMinOpsetVersion(9)
6582    def test_new_zeros_with_dtype(self):
6583        class MyModel(torch.nn.Module):
6584            def __init__(self) -> None:
6585                super().__init__()
6586                self.emb = torch.nn.Embedding(50, 64)
6587
6588            def forward(self, x):
6589                inp = x.new_zeros(x.shape)
6590                return self.emb(inp)
6591
6592        model = MyModel()
6593        x = torch.Tensor([[2, 5, 6], [3, 2, 5]]).to(torch.int64)
6594        self.run_test(model, x, input_names=["x"], dynamic_axes={"x": [0, 1]})
6595
6596    @skipIfUnsupportedMinOpsetVersion(9)
6597    def test_new_ones(self):
6598        class OnesModel(torch.nn.Module):
6599            def forward(self, x):
6600                return x.new_ones(x.shape[1:2]), x.new_ones(
6601                    x.shape[2:], dtype=torch.long
6602                )
6603
6604        x = torch.randn(2, 3, 4)
6605        self.run_test(OnesModel(), x, input_names=["x"], dynamic_axes={"x": [0, 1, 2]})
6606        self.run_test(OnesModel(), x, remained_onnx_input_idx=[])
6607
6608    @skipIfUnsupportedMinOpsetVersion(9)
6609    @skipScriptTest()  # torch.zeros/torch.ones with size tensor of dim != 0 not scriptable.
6610    def test_zeros_ones_with_tensor_input(self):
6611        class ZeroAndOnes(torch.nn.Module):
6612            def forward(self, x):
6613                return torch.zeros(x, 1), torch.ones(x, 1)
6614
6615        x = torch.tensor([2])
6616        self.run_test(ZeroAndOnes(), (x,))
6617
6618    @skipIfUnsupportedMinOpsetVersion(9)
6619    @skipShapeChecking
6620    def test_tolist(self):
6621        class List(torch.jit.ScriptModule):
6622            @torch.jit.script_method
6623            def forward(self, input):
6624                res: List[int] = input.tolist()
6625                return res
6626
6627        self.run_test(List(), (torch.randint(100, (1,)),))
6628
6629    @skipIfUnsupportedMinOpsetVersion(9)
6630    def test_list_pass(self):
6631        class Slice(torch.nn.Module):
6632            def forward(self, x, y):
6633                return x.new_zeros(x.shape[2:] + y.shape[1:])
6634
6635        x = torch.randn(2, 3, 4, 5)
6636        y = torch.randn(1, 2, 3, 4)
6637        self.run_test(
6638            Slice(),
6639            (x, y),
6640            input_names=["x", "y"],
6641            dynamic_axes={"x": [0, 1, 2, 3], "y": [0, 1, 2, 3]},
6642        )
6643        self.run_test(Slice(), (x, y), remained_onnx_input_idx=[])
6644
6645        class Size(torch.nn.Module):
6646            def forward(self, x, y):
6647                return x.new_zeros(x.shape + y.shape)
6648
6649        x = torch.randn(2, 3, 4)
6650        y = torch.randn(1, 2, 3)
6651        self.run_test(
6652            Size(),
6653            (x, y),
6654            input_names=["x", "y"],
6655            dynamic_axes={"x": [0, 1, 2], "y": [0, 1, 2]},
6656        )
6657        self.run_test(Size(), (x, y), remained_onnx_input_idx=[])
6658
6659        class Array(torch.nn.Module):
6660            def forward(self, x, y):
6661                arr1 = [x.shape[0], x.shape[1], 2]
6662                arr2 = [y.shape[0], y.shape[1]]
6663                return x.new_zeros(arr1 + arr2)
6664
6665        x = torch.randn(2, 3, 4)
6666        y = torch.randn(1, 2, 3)
6667        self.run_test(
6668            Array(),
6669            (x, y),
6670            input_names=["x", "y"],
6671            dynamic_axes={"x": [0, 1, 2], "y": [0, 1, 2]},
6672        )
6673        self.run_test(Array(), (x, y), remained_onnx_input_idx=[])
6674
6675        class List(torch.nn.Module):
6676            def forward(self, x, y):
6677                l1 = list(x.shape)
6678                l2 = list(y.shape)
6679                return x.new_zeros(l1 + l2)
6680
6681        x = torch.randn(2, 3, 4)
6682        y = torch.randn(1, 2, 3)
6683        self.run_test(
6684            List(),
6685            (x, y),
6686            input_names=["x", "y"],
6687            dynamic_axes={"x": [0, 1, 2], "y": [0, 1, 2]},
6688        )
6689        self.run_test(List(), (x, y), remained_onnx_input_idx=[])
6690
6691    @skipIfUnsupportedMinOpsetVersion(9)
6692    def test_new_empty(self):
6693        class Emtpy(torch.nn.Module):
6694            def forward(self, x):
6695                return (
6696                    x.new_empty(x.shape[0]).fill_(0),
6697                    x.new_empty(x.shape[0], dtype=torch.long) * 0,
6698                )
6699
6700        x = torch.randn(2, 3, 4)
6701        self.run_test(Emtpy(), x, input_names=["x"], dynamic_axes={"x": [0, 1, 2]})
6702        self.run_test(Emtpy(), x, remained_onnx_input_idx=[])
6703
6704    @skipIfUnsupportedMinOpsetVersion(9)
6705    def test_new_full(self):
6706        class Full(torch.nn.Module):
6707            def forward(self, x):
6708                return x.new_full(x.shape[1:2], 5), x.new_full(
6709                    x.shape[0:1], 1.3, dtype=torch.long
6710                )
6711
6712        x = torch.randn(2, 3, 4)
6713        self.run_test(Full(), x, input_names=["x"], dynamic_axes={"x": [0, 1, 2]})
6714        self.run_test(Full(), x, remained_onnx_input_idx=[])
6715
6716    @skipIfUnsupportedMinOpsetVersion(9)
6717    def test_inplace_list(self):
6718        class Arithmetic(torch.jit.ScriptModule):
6719            @torch.jit.script_method
6720            def forward(self, x, y):
6721                return torch.cat([x.add_(3), y.fill_(0)])
6722
6723        x = torch.randn(2, 3)
6724        y = torch.randn(2, 3)
6725        self.run_test(
6726            Arithmetic(),
6727            (x, y),
6728            input_names=["x", "y"],
6729            dynamic_axes={"x": [0, 1], "y": [0, 1]},
6730        )
6731        self.run_test(Arithmetic(), (x, y), remained_onnx_input_idx=[0])
6732
6733    @skipIfUnsupportedMinOpsetVersion(9)
6734    def test_inplace_fill(self):
6735        class Fill_(torch.nn.Module):
6736            def forward(self, x):
6737                return x.fill_(3), x
6738
6739        x = torch.randn(2, 3, 4)
6740        self.run_test(Fill_(), x, input_names=["x"], dynamic_axes={"x": [0, 1, 2]})
6741        self.run_test(Fill_(), x, remained_onnx_input_idx=[])
6742
6743    def test_inplace_arithmetic(self):
6744        class Arithmetic(torch.jit.ScriptModule):
6745            @torch.jit.script_method
6746            def forward(self, x, y):
6747                x.add_(3)
6748                y.mul_(x)
6749                return x, y
6750
6751        x = torch.randn(2, 3, 4)
6752        y = torch.randn(2, 3, 4)
6753        self.run_test(Arithmetic(), (x, y))
6754
6755    def test_inplace_arithmetic_half(self):
6756        class InplaceAddModel(torch.nn.Module):
6757            def forward(self, x, y):
6758                return x.add_(y)
6759
6760        class InplaceMulModel(torch.nn.Module):
6761            def forward(self, x, y):
6762                return x.mul_(y)
6763
6764        x = torch.randn(2, 2, dtype=torch.half)
6765        y = torch.randn(2, 2, dtype=torch.float)
6766        self.run_test(InplaceAddModel(), (x, y), rtol=1e-2, atol=1e-2)
6767        self.run_test(InplaceMulModel(), (x, y), rtol=1e-2, atol=1e-2)
6768
6769    @skipIfUnsupportedMinOpsetVersion(9)
6770    def test_inplace_with_loop(self):
6771        class M(torch.nn.Module):
6772            def forward(self, x):
6773                a = torch.ones(
6774                    12,
6775                )
6776                for i in range(10):
6777                    a.add_(
6778                        torch.ones(
6779                            12,
6780                        )
6781                    )
6782                return a + x
6783
6784        m = M()
6785        x = torch.randn(
6786            12,
6787        )
6788        self.run_test(torch.jit.script(M()), (x))
6789
6790    @skipIfUnsupportedMinOpsetVersion(9)
6791    def test_inplace_with_loop_2(self):
6792        class M(torch.nn.Module):
6793            def forward(self, x):
6794                _bias = torch.ones(
6795                    12,
6796                )
6797                a = torch.ones(
6798                    12,
6799                )  # used in loop, altered.
6800                a_ref = a  # not used in loop, should be altered.
6801                b = x.clone()  # used in loop, not be altered.
6802                b_ref = b  # not used in loop, should not be altered.
6803                for i in range(10):
6804                    if i == 3:
6805                        for j in range(5):
6806                            a += _bias
6807                            _bias.add_(
6808                                torch.ones(
6809                                    12,
6810                                )
6811                            )
6812                            b = b + torch.ones(
6813                                12,
6814                            )
6815
6816                    _bias.add_(
6817                        torch.ones(
6818                            12,
6819                        )
6820                    )
6821                    a += _bias
6822                # TODO: value for a_ref is incorrect.
6823                # a_ref += torch.ones(12,)
6824                b_ref += torch.ones(
6825                    12,
6826                )
6827                return _bias + x, a, b, b_ref
6828
6829        m = M()
6830        x = torch.zeros(
6831            12,
6832        )
6833        self.run_test(torch.jit.script(M()), (x))
6834
6835    @skipIfUnsupportedMinOpsetVersion(11)
6836    def test_inplace_attr_with_loop(self):
6837        class M(torch.nn.Module):
6838            def __init__(self) -> None:
6839                super().__init__()
6840                self._bias = torch.arange(
6841                    12,
6842                )
6843
6844            def forward(self, x):
6845                self._bias = torch.arange(
6846                    12,
6847                )
6848                for i in range(10):
6849                    if i == 3:
6850                        for j in range(5):
6851                            self._bias += torch.arange(
6852                                12,
6853                            )
6854                return self._bias + x
6855
6856        m = M()
6857        x = torch.zeros(
6858            12,
6859        )
6860        self.run_test(torch.jit.script(M()), (x))
6861
6862    @skipIfUnsupportedMinOpsetVersion(11)
6863    def test_inplace_attr_copy_with_loop(self):
6864        class M(torch.nn.Module):
6865            def __init__(self) -> None:
6866                super().__init__()
6867                self._bias = torch.arange(
6868                    12,
6869                )
6870
6871            def forward(self, x):
6872                self._bias = torch.arange(
6873                    12,
6874                )
6875                for i in range(10):
6876                    if i == 3:
6877                        for j in range(5):
6878                            self._bias.copy_(
6879                                torch.arange(
6880                                    12,
6881                                )
6882                            )
6883                        self._bias.copy_(
6884                            self._bias
6885                            + torch.arange(
6886                                12,
6887                            )
6888                        )
6889
6890                    self._bias.copy_(
6891                        self._bias
6892                        + torch.arange(
6893                            12,
6894                        )
6895                    )
6896                return self._bias + x
6897
6898        m = M()
6899        x = torch.zeros(
6900            12,
6901        )
6902        self.run_test(torch.jit.script(M()), (x))
6903
6904    @skipIfUnsupportedMinOpsetVersion(14)  # Need onnx::Identity of sequence in opset 14
6905    def test_inplace_sequence_with_loop(self):
6906        class M(torch.nn.Module):
6907            def process(self, beam_hyps: List[Tensor], done: Tensor, x):
6908                batch_size = x.shape[0]
6909                for i in range(batch_size):
6910                    if done[i]:
6911                        continue
6912
6913                    beam_idx = 0
6914                    for _, token in enumerate(x[i]):
6915                        beam_hyps.append(token)
6916                        beam_idx += 1
6917
6918                        if beam_idx == 6:
6919                            break
6920
6921                    done[i] = len(beam_hyps) > 4
6922
6923                return beam_hyps, done
6924
6925            def forward(self, x):
6926                beam_hyps: List[Tensor] = []
6927                batch_size = x.shape[0]
6928                cur_len = 0
6929                max_len = x.shape[1]
6930                done = torch.zeros(batch_size, dtype=torch.bool)
6931                while cur_len < max_len:
6932                    beam_hyps, done = self.process(beam_hyps, done, x[:, 0, :])
6933                    cur_len = cur_len + 1
6934
6935                return beam_hyps
6936
6937        m = torch.jit.script(M())
6938        x = torch.randn(8, 4, 3)
6939        self.run_test(torch.jit.script(M()), (x))
6940
6941    @skipScriptTest()  # Sort with dynamic dim not supported in ONNX
6942    def test_sort(self):
6943        class SortModel(torch.nn.Module):
6944            def forward(self, x):
6945                out = []
6946                for i in range(-2, 2):
6947                    out.append(torch.sort(x, dim=i, descending=True))
6948                return out
6949
6950        x = torch.randn(3, 4)
6951        self.run_test(SortModel(), x)
6952
6953    @skipIfUnsupportedMinOpsetVersion(11)
6954    @skipScriptTest()  # Sort with dynamic dim not supported in ONNX
6955    def test_sort_ascending(self):
6956        class SortModel(torch.nn.Module):
6957            def forward(self, x):
6958                out = []
6959                for i in range(-2, 2):
6960                    out.append(torch.sort(x, dim=i, descending=False))
6961                return out
6962
6963        x = torch.randn(3, 4)
6964        self.run_test(SortModel(), x)
6965
6966    @skipIfUnsupportedMinOpsetVersion(11)
6967    def test_argsort(self):
6968        class ArgSortModel(torch.nn.Module):
6969            def forward(self, x):
6970                return torch.argsort(x, dim=1, descending=False)
6971
6972        x = torch.randn(3, 4)
6973        self.run_test(ArgSortModel(), x)
6974
6975    @skipIfUnsupportedMinOpsetVersion(9)
6976    def test_masked_fill(self):
6977        class MaskedFillModel(torch.nn.Module):
6978            def forward(self, x):
6979                mask = torch.tensor([[0, 0, 1], [1, 1, 0]], dtype=torch.bool)
6980                return x.masked_fill(mask, 2)
6981
6982        x = torch.zeros(4, 2, 3, requires_grad=True)
6983        self.run_test(MaskedFillModel(), x)
6984
6985        class MaskedFillModel2(torch.nn.Module):
6986            def forward(self, x):
6987                return x.masked_fill(x > 3, -1)
6988
6989        x = torch.arange(16).view(2, 2, 4).to(torch.float32)
6990        self.run_test(MaskedFillModel2(), x)
6991
6992    @skipIfUnsupportedMinOpsetVersion(9)
6993    def test_masked_fill_inplace(self):
6994        class MaskedFillModel(torch.jit.ScriptModule):
6995            @torch.jit.script_method
6996            def forward(self, x):
6997                mask = torch.tensor([[0, 0, 1], [1, 1, 0]], dtype=torch.bool)
6998                x.masked_fill_(mask, 2)
6999                return x
7000
7001        x = torch.zeros(4, 2, 3, requires_grad=True)
7002        self.run_test(MaskedFillModel(), x)
7003
7004        class MaskedFillModel2(torch.jit.ScriptModule):
7005            @torch.jit.script_method
7006            def forward(self, x):
7007                x.masked_fill_(x > 3, -1)
7008                return x
7009
7010        x = torch.arange(16).view(2, 2, 4).to(torch.float32)
7011        self.run_test(MaskedFillModel2(), x)
7012
7013    @skipIfUnsupportedMinOpsetVersion(11)
7014    def test_masked_scatter(self):
7015        class MaskedScatterModel(torch.nn.Module):
7016            def forward(self, x):
7017                return torch.masked_scatter(x, x.ge(0.5), torch.ones(100, 100) * 5)
7018
7019        x = torch.randn(3, 4, 5, requires_grad=True)
7020        self.run_test(MaskedScatterModel(), x)
7021
7022    @skipIfUnsupportedMinOpsetVersion(11)
7023    def test_masked_select(self):
7024        class MaskedSelectModel(torch.nn.Module):
7025            def forward(self, x):
7026                return torch.masked_select(x, x.ge(0.5))
7027
7028        x = torch.randn(3, 4, 5, requires_grad=True)
7029        self.run_test(MaskedSelectModel(), x)
7030
7031    @skipIfUnsupportedMinOpsetVersion(11)
7032    def test_index_put_to_masked_fill(self):
7033        class MaskedFillModel(torch.nn.Module):
7034            def forward(self, input_mask, some_const):
7035                mask = input_mask.clone()
7036                mask[mask != some_const] = 1
7037                mask[mask == some_const] = 0
7038                return mask
7039
7040        mask = torch.randn(2, 2, 2, requires_grad=True)
7041        constant = torch.tensor(5, dtype=torch.float)
7042        self.run_test(MaskedFillModel(), (mask, constant))
7043
7044    @skipIfUnsupportedMinOpsetVersion(11)
7045    def test_index_put_to_masked_scatter(self):
7046        class MaskedScatterModel(torch.nn.Module):
7047            def forward(self, input_mask, some_const):
7048                mask = input_mask.clone()
7049                mask[mask != some_const] = torch.ones(8)
7050                return mask
7051
7052        mask = torch.randn(2, 2, 2, requires_grad=True)
7053        constant = torch.tensor(5, dtype=torch.float)
7054        self.run_test(MaskedScatterModel(), (mask, constant))
7055
7056    @skipIfUnsupportedMinOpsetVersion(11)
7057    def test_index_put_with_1d_mask_to_masked_scatter(self):
7058        class MaskedScatterModel(torch.nn.Module):
7059            def forward(self, tensor, mask, some_const):
7060                tensor[mask] = some_const
7061                return tensor
7062
7063        mask = torch.tensor([0, 1, 0, 1, 0, 1, 0, 1], dtype=torch.bool)
7064        tensor = torch.randn(8, 4, 5, requires_grad=True)
7065        some_const = torch.randn(4, 4, 5, dtype=torch.float)
7066        self.run_test(MaskedScatterModel(), (tensor, mask, some_const))
7067
7068    @skipIfUnsupportedMinOpsetVersion(9)
7069    def test_pixel_shuffle(self):
7070        class PixelShuffle(torch.nn.Module):
7071            def forward(self, x):
7072                return torch.pixel_shuffle(x, upscale_factor=2)
7073
7074        x = torch.randn(2, 16, 4, 3, requires_grad=True)
7075        y = torch.randn(4, 32, 8, 4, requires_grad=True)
7076        self.run_test(PixelShuffle(), x)
7077        self.run_test(
7078            PixelShuffle(),
7079            x,
7080            input_names=["x"],
7081            dynamic_axes={"x": [0, 1, 2, 3]},
7082            additional_test_inputs=[y],
7083        )
7084
7085    @skipIfUnsupportedMinOpsetVersion(9)
7086    def test_pixel_unshuffle(self):
7087        class PixelUnshuffle(torch.nn.Module):
7088            def forward(self, x):
7089                return torch.pixel_unshuffle(x, downscale_factor=2)
7090
7091        x = torch.randn(2, 16, 4, 6, requires_grad=True)
7092        y = torch.randn(4, 32, 8, 4, requires_grad=True)
7093        self.run_test(PixelUnshuffle(), x)
7094        self.run_test(
7095            PixelUnshuffle(),
7096            x,
7097            input_names=["x"],
7098            dynamic_axes={"x": [0, 1, 2, 3]},
7099            additional_test_inputs=[y],
7100        )
7101
7102    @skipIfUnsupportedMinOpsetVersion(9)
7103    def test_reciprocal(self):
7104        class ReciprocalModel(torch.nn.Module):
7105            def forward(self, x):
7106                return torch.reciprocal(x)
7107
7108        model = ReciprocalModel()
7109        x = torch.tensor([2, 4])
7110        self.run_test(model, x.to(torch.long))
7111        self.run_test(model, x.to(torch.float))
7112        self.run_test(model, x.to(torch.double))
7113
7114    @skipIfUnsupportedMinOpsetVersion(9)
7115    def test_scalar_type(self):
7116        class ArithmeticModel(torch.nn.Module):
7117            def forward(self, x):
7118                return x.size(0) * 2 * x, 2 - x
7119
7120        x = torch.ones(2, 3, dtype=torch.float32)
7121        self.run_test(ArithmeticModel(), x)
7122
7123        class ComparisonModel(torch.nn.Module):
7124            def forward(self, x, y):
7125                a = torch.tensor([12.0])
7126                return x.lt(1.5) & y.le(2) & x.le(1), x.gt(y), x.lt(y), a.ge(x.size(0))
7127
7128        x = torch.ones(2, 3, dtype=torch.int32)
7129        y = torch.ones(2, 3, dtype=torch.float32)
7130        self.run_test(ComparisonModel(), (x, y))
7131
7132        class MatMulModel(torch.nn.Module):
7133            def forward(self, x):
7134                return torch.mm(x, x) + x + torch.mm(x, x) + x
7135
7136        x = torch.ones(3, 3)
7137        self.run_test(MatMulModel(), x)
7138
7139        class AddMMModel(torch.nn.Module):
7140            def forward(self, x):
7141                return torch.mm(x, x) + x
7142
7143        x = torch.ones(3, 3)
7144        self.run_test(AddMMModel(), x)
7145
7146        class FullModel(torch.nn.Module):
7147            # add is used for exporting full
7148            def forward(self, x):
7149                return torch.full((3, 4), x)
7150
7151        x = torch.tensor(12.0)
7152        self.run_test(FullModel(), x)
7153
7154        class CatModel(torch.nn.Module):
7155            def forward(self, fp16, fp32):
7156                return torch.cat([fp16, fp32])
7157
7158        fp16 = Tensor([0.5])
7159        fp16 = fp16.half()
7160        fp32 = Tensor([1.5])
7161        self.run_test(CatModel(), (fp16, fp32))
7162
7163    @skipIfUnsupportedMinOpsetVersion(9)
7164    def test_scalar_type_does_not_trigger_upcast_type_promotion(self):
7165        class DoNotUpcastModel(torch.nn.Module):
7166            def forward(self, x):
7167                scale = x.size()[-1] ** -0.5
7168                # 'scale' is exported as onnx float32 rank 0 tensor.
7169                # The following 'Mul' should NOT be promoted to float32.
7170                return x * scale
7171
7172        x = torch.ones(2, 3, dtype=torch.float16)
7173        self.run_test(DoNotUpcastModel(), x)
7174
7175    @skipIfUnsupportedMinOpsetVersion(9)
7176    def test_scalar_type_promotion_onnx_where_two_prim_const(self):
7177        class TwoPrimConstCastWhereModel(torch.nn.Module):
7178            def forward(self, c):
7179                return torch.where(c, 0, 1.0)
7180
7181        c = torch.ones(8, dtype=torch.bool)
7182        self.run_test(TwoPrimConstCastWhereModel(), (c))
7183
7184    @skipIfUnsupportedMinOpsetVersion(9)
7185    def test_scalar_type_promotion_onnx_where_one_prim_const(self):
7186        class OnePrimConstCastWhereModel(torch.nn.Module):
7187            def forward(self, c, x):
7188                return torch.where(c, x, 1.0)
7189
7190        c = torch.ones(8, dtype=torch.bool)
7191        x = torch.ones(8, dtype=torch.float16)
7192        self.run_test(OnePrimConstCastWhereModel(), (c, x))
7193
7194    @skipIfUnsupportedMinOpsetVersion(9)
7195    def test_scalar_type_promotion_onnx_where_one_tensor_const(self):
7196        class OneTensorConstCastWhereModel(torch.nn.Module):
7197            def forward(self, c, x):
7198                return torch.where(c, x, torch.ones(size=(), dtype=torch.float64))
7199
7200        c = torch.ones(8, dtype=torch.bool)
7201        x = torch.ones(8, dtype=torch.float16)
7202        self.run_test(OneTensorConstCastWhereModel(), (c, x))
7203
7204    @skipIfUnsupportedMinOpsetVersion(9)
7205    def test_scalar_type_upcast_type_promotion_onnx_where_no_const(self):
7206        class OnnxWhereUpcastModel(torch.nn.Module):
7207            def forward(self, c, x, y):
7208                return torch.where(c, x, y)
7209
7210        c = torch.ones(8, dtype=torch.bool)
7211        x = torch.ones(8, dtype=torch.float16)
7212        y = torch.ones(8, dtype=torch.float32)
7213
7214        self.run_test(OnnxWhereUpcastModel(), (c, x, y))
7215
7216    @skipIfUnsupportedMinOpsetVersion(9)
7217    def test_full_like(self):
7218        class FullLikeModel(torch.nn.Module):
7219            def forward(self, x):
7220                return torch.full_like(x, 1.3, dtype=torch.int)
7221
7222        x = torch.tensor(12)
7223        self.run_test(FullLikeModel(), x)
7224
7225    @skipIfUnsupportedMinOpsetVersion(9)
7226    @skipDtypeChecking
7227    def test_full_like_value(self):
7228        class FullLikeModel(torch.nn.Module):
7229            def forward(self, x, y):
7230                out = y + 2
7231                return torch.full_like(x, out)
7232
7233        x = torch.tensor(12)
7234        y = torch.tensor(2)
7235        self.run_test(FullLikeModel(), (x, y))
7236
7237    def test_l1_norm(self):
7238        class NormModel(torch.nn.Module):
7239            def forward(self, x):
7240                return torch.norm(x, p=1, dim=-1, keepdim=False)
7241
7242        x = torch.randn(4, 2, 3, requires_grad=True)
7243        self.run_test(NormModel(), x)
7244
7245    def test_l2_norm(self):
7246        class NormModel(torch.nn.Module):
7247            def forward(self, x):
7248                return torch.norm(x, p=2, dim=-2, keepdim=False)
7249
7250        x = torch.randn(4, 2, 3, requires_grad=True)
7251        self.run_test(NormModel(), x)
7252
7253    def test_frobenius_norm(self):
7254        class NormModel(torch.nn.Module):
7255            def forward(self, x):
7256                return torch.norm(x, p="fro", dim=0, keepdim=False)
7257
7258        x = torch.randn(4, 2, 3, requires_grad=True)
7259        self.run_test(NormModel(), x)
7260
7261    def test_frobenius_norm_keepdim(self):
7262        class NormModel(torch.nn.Module):
7263            def forward(self, x):
7264                return torch.norm(x, p="fro", dim=(0, 1), keepdim=True)
7265
7266        x = torch.randn(4, 2, 3, requires_grad=True)
7267        self.run_test(NormModel(), x)
7268
7269    def test_unfold(self):
7270        class UnfoldModel(torch.nn.Module):
7271            def forward(self, x):
7272                return x.unfold(dimension=2, size=2, step=2)
7273
7274        x = torch.randn(4, 2, 3, requires_grad=True)
7275        y = torch.randn(2, 1, 3, requires_grad=True)
7276        self.run_test(
7277            UnfoldModel(),
7278            x,
7279            dynamic_axes={"x": [0, 1]},
7280            input_names=["x"],
7281            additional_test_inputs=[y],
7282        )
7283
7284    def test_unfold_infer_shape(self):
7285        class UnfoldModule(torch.jit.ScriptModule):
7286            def __init__(self) -> None:
7287                super().__init__()
7288                self.conv = torch.nn.Conv1d(3, 1, 3, stride=2)
7289
7290            @torch.jit.script_method
7291            def forward(self, x):
7292                x = self.conv(x)
7293                return x.unfold(dimension=2, size=2, step=2)
7294
7295        x = torch.randn(32, 3, 64)
7296        self.run_test(UnfoldModule(), x)
7297
7298    @skipIfUnsupportedMinOpsetVersion(12)
7299    def test_unfold_dynamic_inputs(self):
7300        class UnfoldModel(torch.nn.Module):
7301            def forward(self, x):
7302                return x.unfold(dimension=2, size=x.shape[1], step=x.shape[1] - 1)
7303
7304        x = torch.randn(4, 2, 4, requires_grad=True)
7305        self.run_test(UnfoldModel(), x)
7306
7307        class UnfoldModel(torch.nn.Module):
7308            def forward(self, x):
7309                return x.unfold(dimension=2, size=x.shape[1], step=1)
7310
7311        x = torch.randn(4, 2, 4, requires_grad=True)
7312        self.run_test(UnfoldModel(), x)
7313
7314    @skipIfUnsupportedMinOpsetVersion(9)  # MatMul long inputs is added in ONNX opset 9.
7315    def test_mv(self):
7316        class MatmulModel(torch.nn.Module):
7317            def forward(self, input, other):
7318                return torch.mv(input, other)
7319
7320        x = torch.randn(4, 5, requires_grad=True)
7321        y = torch.randn(5, requires_grad=True)
7322        self.run_test(MatmulModel(), (x, y))
7323
7324        x = torch.randint(10, (4, 5))
7325        y = torch.randint(10, (5,))
7326        self.run_test(MatmulModel(), (x, y))
7327
7328    @skipIfUnsupportedMinOpsetVersion(9)  # MatMul long inputs is added in ONNX opset 9.
7329    def test_dot(self):
7330        class MatmulModel(torch.nn.Module):
7331            def forward(self, input, other):
7332                return torch.dot(input, other)
7333
7334        x = torch.randn(5, requires_grad=True)
7335        y = torch.randn(5, requires_grad=True)
7336        self.run_test(MatmulModel(), (x, y))
7337
7338        x = torch.randint(10, (5,))
7339        y = torch.randint(10, (5,))
7340        self.run_test(MatmulModel(), (x, y))
7341
7342    @skipScriptTest()  # SpectralNorm not TorchScript compatible.
7343    def test_spectral_norm(self):
7344        m = torch.nn.utils.spectral_norm(torch.nn.Linear(2, 4))
7345
7346        x = torch.randn(6, 2)
7347        self.run_test(m, (x,))
7348
7349    def test_prelu(self):
7350        class PReluModel(torch.nn.Module):
7351            def __init__(self) -> None:
7352                super().__init__()
7353                self.prelu = torch.nn.PReLU()
7354
7355            def forward(self, x):
7356                return self.prelu(x)
7357
7358        x = torch.randn(2, 3, 4)
7359        y = torch.randn(2, 4, 5)
7360        self.run_test(
7361            PReluModel(),
7362            x,
7363            input_names=["x"],
7364            dynamic_axes={"x": [1, 2]},
7365            additional_test_inputs=[y],
7366        )
7367
7368    def test_prelu_scalar(self):
7369        x = torch.scalar_tensor(1.0)
7370        self.run_test(torch.nn.PReLU(), x, input_names=["x"])
7371
7372    def test_relu6(self):
7373        class Relu6Model(torch.nn.Module):
7374            def __init__(self) -> None:
7375                super().__init__()
7376                self.relu6 = torch.nn.ReLU6()
7377
7378            def forward(self, x):
7379                return self.relu6(x)
7380
7381        x = torch.randn(2, 3, 4) * 100.0
7382        y = torch.randn(2, 4, 5) * 100.0
7383        self.run_test(
7384            Relu6Model(),
7385            x,
7386            input_names=["x"],
7387            dynamic_axes={"x": [1, 2]},
7388            additional_test_inputs=[y],
7389        )
7390
7391    def test_silu(self):
7392        class SiLUModel(torch.nn.Module):
7393            def __init__(self) -> None:
7394                super().__init__()
7395                self.silu = torch.nn.SiLU()
7396
7397            def forward(self, x):
7398                return self.silu(x)
7399
7400        x = torch.randn(2, 3, 4)
7401        self.run_test(SiLUModel(), (x))
7402
7403    @skipIfUnsupportedMinOpsetVersion(14)
7404    def test_tril(self):
7405        class trilModel(torch.nn.Module):
7406            def forward(self, x):
7407                return torch.tril(x)
7408
7409        x = torch.randn(2, 3, 4)
7410        self.run_test(trilModel(), (x))
7411
7412        class trilModelwithDiagonal(torch.nn.Module):
7413            def forward(self, x):
7414                return torch.tril(x, diagonal=1)
7415
7416        x = torch.randn(2, 3, 4)
7417        self.run_test(trilModelwithDiagonal(), (x))
7418
7419        class trilModelwithNegDiagonal(torch.nn.Module):
7420            def forward(self, x):
7421                return torch.tril(x, diagonal=-1)
7422
7423        x = torch.randn(2, 3, 4)
7424        self.run_test(trilModelwithNegDiagonal(), (x))
7425
7426        class trilModelWithDiagonalInput(torch.nn.Module):
7427            def forward(self, x, diagnonal: int):
7428                return torch.tril(x, diagonal=diagnonal)
7429
7430        x = torch.randn(2, 3, 4)
7431        self.run_test(trilModelWithDiagonalInput(), (x, 5))
7432
7433    @skipIfUnsupportedMinOpsetVersion(14)
7434    def test_triu(self):
7435        class triuModel(torch.nn.Module):
7436            def forward(self, x):
7437                return torch.triu(x)
7438
7439        x = torch.randn(2, 3, 4)
7440        self.run_test(triuModel(), (x))
7441
7442        class triuModelwithDiagonal(torch.nn.Module):
7443            def forward(self, x):
7444                return torch.triu(x, diagonal=1)
7445
7446        x = torch.randn(2, 3, 4)
7447        self.run_test(triuModelwithDiagonal(), (x))
7448
7449        class triuModelwithNegDiagonal(torch.nn.Module):
7450            def forward(self, x):
7451                return torch.triu(x, diagonal=-1)
7452
7453        x = torch.randn(2, 3, 4)
7454        self.run_test(triuModelwithNegDiagonal(), (x))
7455
7456        class triuModelWithDiagonalInput(torch.nn.Module):
7457            def forward(self, x, diagnonal: int):
7458                return torch.triu(x, diagonal=diagnonal)
7459
7460        x = torch.randn(2, 3, 4)
7461        self.run_test(triuModelWithDiagonalInput(), (x, 5))
7462
7463    def test_mish(self):
7464        class MishModel(torch.nn.Module):
7465            def __init__(self) -> None:
7466                super().__init__()
7467                self.mish = torch.nn.Mish()
7468
7469            def forward(self, x):
7470                return self.mish(x)
7471
7472        x = torch.randn(2, 3, 4)
7473        self.run_test(MishModel(), (x))
7474
7475    def test_remainder(self):
7476        class RemainderModel(torch.nn.Module):
7477            def forward(self, input, other):
7478                return torch.remainder(input, other)
7479
7480        x = torch.randn(4, 2, 3)
7481        y = torch.randn(1, 2, 1)
7482        self.run_test(RemainderModel(), (x, y))
7483
7484        x = torch.tensor([7, 6, -7, -6], dtype=torch.long)
7485        y = torch.tensor([2], dtype=torch.long)
7486        self.run_test(RemainderModel(), (x, y))
7487
7488        x = x.to(torch.float)
7489        self.run_test(RemainderModel(), (x, y))
7490
7491        y = y.to(torch.float)
7492        self.run_test(RemainderModel(), (x, y))
7493
7494        x = x.to(torch.int32)
7495        self.run_test(RemainderModel(), (x, y))
7496
7497    def test_remainder_scalar(self):
7498        class RemainderModel(torch.nn.Module):
7499            def __init__(self, scalar=2.55):
7500                super().__init__()
7501                self.scalar = scalar
7502
7503            def forward(self, input):
7504                return torch.remainder(input, self.scalar)
7505
7506        x = torch.randint(10, (2, 3))
7507        self.run_test(RemainderModel(), x)
7508
7509        x = torch.tensor([7, 6, -7, -6], dtype=torch.long)
7510        self.run_test(RemainderModel(2), x)
7511
7512    @skipIfUnsupportedMinOpsetVersion(10)
7513    def test_fmod(self):
7514        class FModModel(torch.nn.Module):
7515            def forward(self, input, other):
7516                return torch.fmod(input, other)
7517
7518        x = torch.randn(4, 2, 3)
7519        y = torch.randn(1, 2, 1)
7520        self.run_test(FModModel(), (x, y))
7521
7522    @skipIfUnsupportedMinOpsetVersion(10)
7523    def test_fmod_scalar(self):
7524        class FModModel(torch.nn.Module):
7525            def forward(self, input):
7526                return torch.fmod(input, 2.55)
7527
7528        x = torch.randint(10, (2, 3))
7529        self.run_test(FModModel(), x)
7530
7531    @skipIfUnsupportedMinOpsetVersion(9)
7532    def test_glu(self):
7533        class GluModel(torch.nn.Module):
7534            def forward(self, x):
7535                return torch.nn.functional.glu(x)
7536
7537        x = torch.randn(2, 4, 5, 6, requires_grad=True)
7538        self.run_test(GluModel(), x)
7539
7540    @skipIfUnsupportedMinOpsetVersion(9)
7541    def test_gelu(self):
7542        class GeluModel(torch.nn.Module):
7543            def forward(self, x):
7544                return torch.nn.functional.gelu(x, approximate="none")
7545
7546        x = torch.randn(2, 4, 5, 6, requires_grad=True)
7547        self.run_test(GeluModel(), x)
7548
7549    @skipIfUnsupportedMinOpsetVersion(9)
7550    def test_tanh_gelu(self):
7551        class GeluModel(torch.nn.Module):
7552            def forward(self, x):
7553                return torch.nn.functional.gelu(x, approximate="tanh")
7554
7555        x = torch.randn(2, 4, 5, 6, requires_grad=True)
7556        self.run_test(GeluModel(), x)
7557
7558    def test_add_inplace(self):
7559        class InplaceAddModel(torch.nn.Module):
7560            def forward(self, x):
7561                x += 12
7562                return x
7563
7564        x = torch.randn(4, 2, 3, requires_grad=True)
7565        self.run_test(InplaceAddModel(), x)
7566
7567    def test_addcmul(self):
7568        class AddcmulModel(torch.nn.Module):
7569            def forward(self, x, t1, t2):
7570                return torch.addcmul(x, t1, t2), torch.addcmul(x, t1, t2, value=2.2)
7571
7572        x = torch.randn(1, 3)
7573        t1 = torch.randn(3, 1)
7574        t2 = torch.randn(1, 3)
7575        self.run_test(AddcmulModel(), (x, t1, t2))
7576
7577    def test_rsqrt(self):
7578        class RsqrtModel(torch.nn.Module):
7579            def forward(self, x):
7580                return x.rsqrt()
7581
7582        x = torch.randn(4, 2, 3, requires_grad=True, dtype=torch.float64)
7583        self.run_test(RsqrtModel(), x)
7584
7585    def test_rsqrt_zeros(self):
7586        class RsqrtModel(torch.nn.Module):
7587            def forward(self, x):
7588                return x.rsqrt()
7589
7590        x = torch.zeros(4, 2, 3, requires_grad=True, dtype=torch.float64)
7591        self.run_test(RsqrtModel(), x)
7592
7593    @skipIfUnsupportedMinOpsetVersion(11)
7594    def test_unique(self):
7595        class UniqueModel(torch.nn.Module):
7596            def forward(self, x):
7597                return torch.unique(
7598                    x, sorted=True, return_inverse=False, return_counts=True
7599                )
7600
7601        x = torch.tensor([1, 3, 2, 3], dtype=torch.long)
7602        self.run_test(UniqueModel(), x)
7603
7604    @skipIfUnsupportedMinOpsetVersion(11)
7605    def test_unique_along_dim(self):
7606        class UniqueModel(torch.nn.Module):
7607            def forward(self, x):
7608                return torch.unique(
7609                    x, dim=0, sorted=True, return_inverse=True, return_counts=False
7610                )
7611
7612        x = torch.tensor([1, 3, 2, 3], dtype=torch.long)
7613        self.run_test(UniqueModel(), x)
7614
7615    @skipIfUnsupportedMinOpsetVersion(11)
7616    def test_cumsum(self):
7617        class CumSum(torch.nn.Module):
7618            def forward(self, input):
7619                return torch.cumsum(input, dim=0)
7620
7621        x = torch.randn(2, 3, 4)
7622        model = CumSum()
7623        self.run_test(model, x)
7624
7625    @skipIfUnsupportedMinOpsetVersion(11)
7626    def test_cumsum_with_cast(self):
7627        class CumSum(torch.nn.Module):
7628            def forward(self, input):
7629                return torch.cumsum(input, dim=0, dtype=torch.float32)
7630
7631        model = CumSum()
7632        x = torch.tensor([2, 3, 4], dtype=torch.int32)
7633        self.run_test(model, x)
7634        x = torch.tensor([False, True, True])
7635        self.run_test(model, x)
7636
7637    @skipScriptTest()  # error in propagate as assign input shape
7638    @skipIfUnsupportedMinOpsetVersion(10)
7639    def test_embedding_bag(self):
7640        model = torch.nn.EmbeddingBag(10, 5, mode="sum", scale_grad_by_freq=True)
7641        input = torch.randint(10, (7,))
7642        offset = torch.tensor([0, 2, 5, 6])
7643        self.run_test(model, (input, offset))
7644
7645        model = torch.nn.EmbeddingBag(10, 5, mode="sum", include_last_offset=True)
7646        input = torch.randint(10, (7,))
7647        offset = torch.tensor([0, 2, 5, 6])
7648        self.run_test(model, (input, offset))
7649
7650        model = torch.nn.EmbeddingBag(10, 5, mode="max")
7651        input = torch.randint(10, (7, 5))
7652        self.run_test(model, (input))
7653
7654    @skipIfUnsupportedMinOpsetVersion(11)
7655    def test_embedding_bag_1d_per_sample_weights(self):
7656        class EmbeddingModel(torch.nn.Module):
7657            def forward(self, embedding_matrix, input, offset, weights):
7658                return torch.nn.functional.embedding_bag(
7659                    input,
7660                    embedding_matrix,
7661                    offsets=offset,
7662                    mode="sum",
7663                    per_sample_weights=weights,
7664                )
7665
7666        model = EmbeddingModel()
7667        x = torch.randint(7, (6,))
7668        w = torch.randn(
7669            6,
7670        )
7671        offset = torch.tensor([0, 2, 5])
7672        embedding_matrix = torch.rand(10, 15)
7673        self.run_test(model, (embedding_matrix, x, offset, w))
7674
7675    @skipIfUnsupportedMinOpsetVersion(11)
7676    @unittest.skip(
7677        "This test is broken with ONNXRuntime(17): "
7678        "when running with onnxruntime 1.17.0 this test fails with the following error:"
7679        "FAIL : Non-zero status code returned while running If node. "
7680        "Name:'/If' Status Message: if.cc:253 Compute "
7681        "If nodes condition input must have exactly one element"
7682        "https://github.com/pytorch/pytorch/issues/119442"
7683    )
7684    def test_embedding_bag_2d_per_sample_weights(self):
7685        class EmbeddingModel(torch.nn.Module):
7686            def forward(self, embedding_matrix, input, weights):
7687                return torch.nn.functional.embedding_bag(
7688                    input, embedding_matrix, mode="sum", per_sample_weights=weights
7689                )
7690
7691        embedding_matrix = torch.rand(10, 15)
7692        model = EmbeddingModel()
7693        x = torch.randint(7, (2, 3))
7694        w = torch.randn(2, 3)
7695
7696        x2 = torch.randint(7, (4, 3))
7697        w2 = torch.randn(4, 3)
7698        self.run_test(
7699            model,
7700            (embedding_matrix, x, w),
7701            input_names=["embed", "x", "w"],
7702            dynamic_axes={"x": [0], "w": [0]},
7703            additional_test_inputs=[(embedding_matrix, x2, w2)],
7704        )
7705
7706    @skipScriptTest()  # scripting prim::Uninitialized, prim::dtype, prim::unchecked_cast
7707    @skipIfUnsupportedMinOpsetVersion(11)
7708    @unittest.skip(
7709        "Due to ONNX Loop shape inference issue. "
7710        "https://msdata.visualstudio.com/Vienna/_workitems/edit/1352001"
7711    )
7712    def test_embedding_bag_dynamic_input(self):
7713        class EmbeddingModel1D(torch.nn.Module):
7714            def forward(self, embedding_matrix, input, weights, offsets):
7715                return torch.nn.functional.embedding_bag(
7716                    input,
7717                    embedding_matrix,
7718                    offsets=offsets,
7719                    mode="sum",
7720                    per_sample_weights=weights,
7721                )
7722
7723        model = EmbeddingModel1D()
7724        x = torch.randint(7, (6,))
7725        w = torch.randn(
7726            6,
7727        )
7728        offsets = torch.tensor([0, 2, 5], dtype=torch.long)
7729        embedding_matrix = torch.rand(10, 15)
7730        x2 = torch.randint(7, (2,))
7731        w2 = torch.randn(
7732            2,
7733        )
7734        embedding_matrix2 = torch.rand(12, 25)
7735        offsets2 = torch.tensor(
7736            [
7737                0,
7738            ],
7739            dtype=torch.long,
7740        )
7741        self.run_test(
7742            model,
7743            (embedding_matrix, x, w, offsets),
7744            additional_test_inputs=[(embedding_matrix2, x2, w2, offsets2)],
7745            input_names=["embedding_matrix", "x", "offsets", "w"],
7746            dynamic_axes={
7747                "embedding_matrix": [0, 1],
7748                "x": [0],
7749                "offsets": [0],
7750                "w": [0],
7751            },
7752        )
7753
7754        class EmbeddingModel2D(torch.nn.Module):
7755            def forward(self, embedding_matrix, input, weights):
7756                return torch.nn.functional.embedding_bag(
7757                    input, embedding_matrix, mode="sum", per_sample_weights=weights
7758                )
7759
7760        model = EmbeddingModel2D()
7761        x = torch.randint(7, (2, 3))
7762        w = torch.randn(2, 3)
7763        embedding_matrix = torch.rand(10, 15)
7764        x2 = torch.randint(7, (3, 5))
7765        w2 = torch.randn(3, 5)
7766        embedding_matrix2 = torch.rand(12, 25)
7767        self.run_test(
7768            model,
7769            (embedding_matrix, x, w),
7770            additional_test_inputs=[(embedding_matrix2, x2, w2)],
7771            input_names=["embedding_matrix", "x", "w"],
7772            dynamic_axes={"embedding_matrix": [0, 1], "x": [0, 1], "w": [0, 1]},
7773        )
7774
7775    @skipIfUnsupportedMinOpsetVersion(8)
7776    def test_meshgrid(self):
7777        class Meshgrid(torch.nn.Module):
7778            def forward(self, x, y, z):
7779                output1, output2, output3 = torch.meshgrid(x, y, z)
7780                return output1, output2, output3
7781
7782        x = torch.randn(3, requires_grad=True)
7783        y = torch.zeros(4, requires_grad=True)
7784        z = torch.randn(5, requires_grad=True)
7785        self.run_test(Meshgrid(), (x, y, z))
7786
7787    @skipIfUnsupportedMinOpsetVersion(8)
7788    def test_meshgrid_indexing(self):
7789        class Meshgrid(torch.nn.Module):
7790            def __init__(self, indexing):
7791                super().__init__()
7792                self.indexing = indexing
7793
7794            def forward(self, x, y, z):
7795                output1, output2, output3 = torch.meshgrid(
7796                    x, y, z, indexing=self.indexing
7797                )
7798                return output1, output2, output3
7799
7800        x = torch.randn(5, requires_grad=True)
7801        y = torch.zeros(6, requires_grad=True)
7802        z = torch.randn(7, requires_grad=True)
7803        for indexing in ("xy", "ij"):
7804            self.run_test(Meshgrid(indexing), (x, y, z))
7805
7806    @skipIfUnsupportedMinOpsetVersion(8)
7807    def test_meshgrid_scalar(self):
7808        class Meshgrid(torch.nn.Module):
7809            def forward(self, x, y, z):
7810                output1, output2, output3 = torch.meshgrid(x, y, z)
7811                return output1, output2, output3
7812
7813        x = torch.ones(3, requires_grad=True)
7814        y = torch.zeros(4, requires_grad=True)
7815        z = torch.tensor(2.0)
7816        self.run_test(Meshgrid(), (x, y, z))
7817
7818    def test_baddbmm(self):
7819        class MyModule(torch.nn.Module):
7820            def forward(self, input, batch1, batch2):
7821                return torch.baddbmm(
7822                    input, batch1, batch2, alpha=torch.tensor(5), beta=3.5
7823                )
7824
7825        x = torch.randn(10, 3, 5)
7826        batch1 = torch.randn(10, 3, 4)
7827        batch2 = torch.randn(10, 4, 5)
7828        model = MyModule()
7829        self.run_test(model, (x, batch1, batch2))
7830
7831    def test_baddbmm_dynamic(self):
7832        class MyModule(torch.nn.Module):
7833            def forward(self, input, batch1, batch2, alpha, beta):
7834                return torch.baddbmm(input, batch1, batch2, alpha=alpha, beta=beta)
7835
7836        x = torch.randn(10, 3, 5)
7837        batch1 = torch.randn(10, 3, 4)
7838        batch2 = torch.randn(10, 4, 5)
7839        alpha = torch.tensor(5)
7840        beta = torch.tensor(3.5)
7841        model = MyModule()
7842        self.run_test(model, (x, batch1, batch2, alpha, beta))
7843
7844    def test_numel(self):
7845        class MyModule(torch.nn.Module):
7846            def forward(self, input):
7847                return input.numel() * input
7848
7849        x = torch.randn(2, 3, 5)
7850        x2 = torch.randn(4, 5, 6)
7851        model = MyModule()
7852        self.run_test(
7853            model,
7854            (x,),
7855            input_names=["x"],
7856            dynamic_axes={"x": [0, 1, 2]},
7857            additional_test_inputs=[(x2,)],
7858        )
7859
7860    def test_numel_empty(self):
7861        class MyModule(torch.nn.Module):
7862            def forward(self, input):
7863                return input.numel() * input
7864
7865        x = torch.randn(0)
7866        x2 = torch.randn(4)
7867        model = MyModule()
7868        self.run_test(
7869            model,
7870            (x,),
7871            input_names=["x"],
7872            dynamic_axes={"x": [0]},
7873            additional_test_inputs=[(x2,)],
7874        )
7875
7876    def test_dtype(self):
7877        class MyModel(torch.jit.ScriptModule):
7878            @torch.jit.script_method
7879            def forward(self, input, other):
7880                return input.to(dtype=other.dtype) + other
7881
7882        x = torch.randn(2, 3)
7883        y = torch.randn(2, 3)
7884        self.run_test(MyModel(), (x, y))
7885
7886    def test_dtype_eq(self):
7887        class MyModel(torch.jit.ScriptModule):
7888            @torch.jit.script_method
7889            def forward(self, input, other):
7890                if input.dtype == other.dtype:
7891                    return input + other
7892                return input
7893
7894        x = torch.randn(2, 3)
7895        y = torch.randn(2, 3)
7896        self.run_test(MyModel(), (x, y))
7897
7898    def test_cast_to(self):
7899        class MyModule(torch.jit.ScriptModule):
7900            @torch.jit.script_method
7901            def forward(self, input, other):
7902                return input.to(other) + other
7903
7904        x = torch.randn(2, 3, 4)
7905        y = torch.tensor([1], dtype=torch.int64)
7906        model = MyModule()
7907        self.run_test(model, (x, y))
7908
7909    def test_cast_to_bool(self):
7910        class MyModule(torch.nn.Module):
7911            def forward(self, input, other):
7912                return torch.cat((input.to(other), other), 0)
7913
7914        x = torch.randn(2, 3, 4)
7915        y = torch.zeros([2, 3, 4], dtype=torch.bool)
7916        model = MyModule()
7917        self.run_test(model, (x, y))
7918
7919    # ONNX supports bfloat16 for opsets >= 13
7920    @skipIfUnsupportedMinOpsetVersion(13)
7921    def test_cast_type_as_with_bfloat16(self):
7922        class MyModule(torch.nn.Module):
7923            def forward(self, x):
7924                y = torch.ones((3, 4), dtype=torch.bfloat16)
7925                x = x.type_as(y)
7926                return x.to(dtype=torch.float16)
7927
7928        x = torch.ones(3, 4, dtype=torch.float16)
7929        model = MyModule()
7930        self.run_test(model, x)
7931
7932    @skipIfUnsupportedMinOpsetVersion(9)
7933    def test_type_as(self):
7934        class MyModule(torch.nn.Module):
7935            def forward(self, x):
7936                y = torch.tensor([1.0])
7937                return x.type_as(y)
7938
7939        a = torch.tensor([True, False], dtype=torch.bool)
7940        b = torch.randn(3, 4, dtype=torch.double)
7941        c = torch.ones((2, 2), dtype=torch.int64)
7942        model = MyModule()
7943        self.run_test(model, a)
7944        self.run_test(model, b)
7945        self.run_test(model, c)
7946
7947    @skipIfUnsupportedMinOpsetVersion(9)
7948    def test_ones_bool(self):
7949        class MyModule(torch.nn.Module):
7950            def forward(self, input):
7951                true = torch.ones(input.shape, dtype=torch.bool)
7952                return input.to(true) & true
7953
7954        x = torch.randn(2, 3, 4)
7955        model = MyModule()
7956        self.run_test(model, x)
7957
7958    def test_log(self):
7959        class Log(torch.nn.Module):
7960            def forward(self, input):
7961                return torch.log(input)
7962
7963        x = torch.rand(2, 3, 4)
7964        model = Log()
7965        self.run_test(model, x)
7966
7967    def test_log1p(self):
7968        class Log1p(torch.nn.Module):
7969            def forward(self, input):
7970                return torch.log1p(input)
7971
7972        x = torch.rand(2, 3, 4)
7973        model = Log1p()
7974        self.run_test(model, x)
7975
7976    def test_log10(self):
7977        class Log10(torch.nn.Module):
7978            def forward(self, input):
7979                return torch.log10(input)
7980
7981        x = torch.rand(2, 3, 4)
7982        model = Log10()
7983        self.run_test(model, x)
7984
7985    def test_log2(self):
7986        class Log2(torch.nn.Module):
7987            def forward(self, input):
7988                return torch.log2(input)
7989
7990        x = torch.tensor(1.0)
7991        model = Log2()
7992        self.run_test(model, x)
7993
7994    @skipIfUnsupportedMinOpsetVersion(11)
7995    def test_round(self):
7996        class Round(torch.nn.Module):
7997            def forward(self, x):
7998                return torch.round(x)
7999
8000        x = torch.tensor([0.9920, -1.0362, -1.5000, 3.5000], requires_grad=True)
8001        self.run_test(Round(), x)
8002
8003        int_x = torch.tensor([9920, 1036, -1500, 35], dtype=torch.int32)
8004        self.run_test(Round(), int_x)
8005
8006    @skipIfUnsupportedMinOpsetVersion(11)
8007    def test_round_with_decimals(self):
8008        class Round(torch.nn.Module):
8009            def __init__(self, decimals):
8010                super().__init__()
8011                self.decimals = decimals
8012
8013            def forward(self, x):
8014                return torch.round(x, decimals=self.decimals)
8015
8016        x = torch.tensor([0.9920, -1234.0362, -1.58960, 3.5000])
8017        for decimals in (0, -2, 3):
8018            self.run_test(Round(decimals), x)
8019
8020    @skipIfUnsupportedMinOpsetVersion(17)
8021    def test_stft_default(self):
8022        class STFT(torch.nn.Module):
8023            def forward(self, x):
8024                n_fft = 16
8025                return torch.stft(x, n_fft=n_fft, center=False, return_complex=False)
8026
8027        x = torch.randn((1, 32), requires_grad=True)
8028        self.run_test(STFT(), x, atol=1e-6)
8029
8030    @skipIfUnsupportedMinOpsetVersion(17)
8031    def test_stft_hop_length(self):
8032        class STFT(torch.nn.Module):
8033            def forward(self, x):
8034                n_fft = 16
8035                hop_length = 4
8036                return torch.stft(
8037                    x,
8038                    n_fft=n_fft,
8039                    center=False,
8040                    hop_length=hop_length,
8041                    return_complex=False,
8042                )
8043
8044        x = torch.randn((1, 32), requires_grad=True)
8045        self.run_test(STFT(), x, atol=1e-6)
8046
8047    @skipIfUnsupportedMinOpsetVersion(17)
8048    def test_stft_non_divisible_hop_length(self):
8049        class STFT(torch.nn.Module):
8050            def forward(self, x):
8051                n_fft = 16
8052                hop_length = 5
8053                return torch.stft(
8054                    x,
8055                    n_fft=n_fft,
8056                    center=False,
8057                    hop_length=hop_length,
8058                    return_complex=False,
8059                )
8060
8061        x = torch.randn((1, 32), requires_grad=True)
8062        self.run_test(STFT(), x, atol=1e-6)
8063
8064    @skipIfUnsupportedMinOpsetVersion(17)
8065    def test_stft_window_int_same_size(self):
8066        class STFT(torch.nn.Module):
8067            def forward(self, x):
8068                n_fft = 16
8069                win_length = 16
8070                return torch.stft(
8071                    x,
8072                    n_fft=n_fft,
8073                    center=False,
8074                    win_length=win_length,
8075                    return_complex=False,
8076                )
8077
8078        x = torch.randn((1, 32), requires_grad=True)
8079        self.run_test(STFT(), x, atol=1e-6)
8080
8081    @skipIfUnsupportedMinOpsetVersion(17)
8082    def test_stft_window_int_different_size(self):
8083        class STFT(torch.nn.Module):
8084            def forward(self, x):
8085                n_fft = 16
8086                win_length = 9
8087                return torch.stft(
8088                    x,
8089                    n_fft=n_fft,
8090                    center=False,
8091                    win_length=win_length,
8092                    return_complex=False,
8093                )
8094
8095        x = torch.randn((1, 32), requires_grad=True)
8096        self.run_test(STFT(), x, atol=1e-6)
8097
8098    @skipIfUnsupportedMinOpsetVersion(17)
8099    def test_stft_window_custom(self):
8100        class STFT(torch.nn.Module):
8101            def forward(self, x):
8102                n_fft = 16
8103                window = torch.hann_window(16)
8104                return torch.stft(
8105                    x,
8106                    n_fft=n_fft,
8107                    center=False,
8108                    window=window,
8109                    return_complex=False,
8110                )
8111
8112        x = torch.randn((1, 32), requires_grad=True)
8113        self.run_test(STFT(), x, atol=1e-6)
8114
8115    @skipIfUnsupportedMinOpsetVersion(17)
8116    def test_stft_wrong_custom_window_size(self):
8117        class STFT(torch.nn.Module):
8118            def forward(self, x):
8119                n_fft = 16
8120                window = torch.hann_window(10)
8121                return torch.stft(
8122                    x, n_fft=n_fft, window=window, center=False, return_complex=False
8123                )
8124
8125        x = torch.randn((1, 32), requires_grad=True)
8126        with self.assertRaises((AssertionError, RuntimeError)):
8127            self.run_test(STFT(), x)
8128
8129    @skipIfUnsupportedMinOpsetVersion(17)
8130    def test_stft_wrong_window_length(self):
8131        class STFT(torch.nn.Module):
8132            def forward(self, x):
8133                n_fft = 16
8134                win_len = 17
8135                return torch.stft(
8136                    x,
8137                    n_fft=n_fft,
8138                    win_length=win_len,
8139                    center=False,
8140                    return_complex=False,
8141                )
8142
8143        x = torch.randn((1, 32), requires_grad=True)
8144        with self.assertRaises(RuntimeError):
8145            self.run_test(STFT(), x)
8146
8147    @skipIfUnsupportedMinOpsetVersion(17)
8148    def test_stft_window_size_with_win_len(self):
8149        class STFT(torch.nn.Module):
8150            def forward(self, x):
8151                n_fft = 16
8152                window = torch.hann_window(10)
8153                win_len = 10
8154                return torch.stft(
8155                    x,
8156                    n_fft=n_fft,
8157                    window=window,
8158                    win_length=win_len,
8159                    center=False,
8160                    return_complex=False,
8161                )
8162
8163        x = torch.randn((1, 32), requires_grad=True)
8164        self.run_test(STFT(), x, atol=1e-6)
8165
8166    @skipIfUnsupportedMinOpsetVersion(17)
8167    def test_stft_one_dimension(self):
8168        class STFT(torch.nn.Module):
8169            def forward(self, x):
8170                n_fft = 16
8171                return torch.stft(
8172                    x,
8173                    n_fft=n_fft,
8174                    center=False,
8175                    return_complex=False,
8176                )
8177
8178        x = torch.randn((32), requires_grad=True)
8179        self.run_test(STFT(), x, atol=1e-6)
8180
8181    @skipIfUnsupportedMinOpsetVersion(17)
8182    def test_stft_wrong_input_size(self):
8183        class STFT(torch.nn.Module):
8184            def forward(self, x):
8185                n_fft = 16
8186                return torch.stft(x, n_fft=n_fft, center=False, return_complex=False)
8187
8188        x = torch.randn((1, 1, 32), requires_grad=True)
8189        with self.assertRaises(RuntimeError):
8190            self.run_test(STFT(), x)
8191
8192    @skipIfUnsupportedMinOpsetVersion(17)
8193    def test_stft_wrong_return_complex(self):
8194        class STFT(torch.nn.Module):
8195            def forward(self, x):
8196                n_fft = 16
8197                return torch.stft(x, n_fft=n_fft, center=False, return_complex=True)
8198
8199        x = torch.randn((1, 32), requires_grad=True)
8200        with self.assertRaises(errors.SymbolicValueError):
8201            self.run_test(STFT(), x)
8202
8203    @skipIfUnsupportedMinOpsetVersion(17)
8204    def test_stft_normalize(self):
8205        class STFT(torch.nn.Module):
8206            def forward(self, x):
8207                n_fft = 16
8208                return torch.stft(
8209                    x,
8210                    n_fft=n_fft,
8211                    center=False,
8212                    normalized=True,
8213                    return_complex=False,
8214                )
8215
8216        x = torch.randn((32), requires_grad=True)
8217        self.run_test(STFT(), x, atol=1e-6)
8218
8219    @skipIfUnsupportedMinOpsetVersion(17)
8220    def test_stft_not_onesided(self):
8221        class STFT(torch.nn.Module):
8222            def forward(self, x):
8223                n_fft = 16
8224                return torch.stft(
8225                    x,
8226                    n_fft=n_fft,
8227                    center=False,
8228                    onesided=False,
8229                    return_complex=False,
8230                )
8231
8232        x = torch.randn((32), requires_grad=True)
8233        self.run_test(STFT(), x, atol=1e-6)
8234
8235    def test_constant_pad(self):
8236        model = torch.nn.ConstantPad1d(2, 3.5)
8237        x = torch.randn(2, 4, 4)
8238        self.run_test(model, x)
8239
8240        model = torch.nn.ConstantPad2d((3, 0, 2, 1), 3.5)
8241        x = torch.randn(2, 2, 4, 4)
8242        self.run_test(model, x)
8243
8244    @common_utils.parametrize(
8245        "pad",
8246        [
8247            common_utils.subtest([2, 4], name="scalar_list"),
8248            common_utils.subtest(
8249                [
8250                    torch.tensor(2, dtype=torch.int64),
8251                    torch.tensor(4, dtype=torch.int64),
8252                ],
8253                name="scalar_tensor_list",
8254            ),
8255        ],
8256    )
8257    @skipIfUnsupportedMinOpsetVersion(11)  # Dynamic padding is added in opset 11
8258    def test_pad_types(self, pad):
8259        # Test for different pad integer types
8260        class Pad(torch.nn.Module):
8261            def forward(self, x, pad: List[int]):
8262                return torch.nn.functional.pad(x, pad)
8263
8264        x = torch.randn(2, 2, 4, 4)
8265        self.run_test(Pad(), (x, pad))
8266
8267    @skipIfUnsupportedMinOpsetVersion(11)
8268    def test_pad_circular(self):
8269        class PadModel(torch.nn.Module):
8270            def forward(self, x):
8271                out = torch.nn.functional.pad(x, (1, 2, 1, 2), mode="circular")
8272                return out
8273
8274        x = torch.randn(2, 3, 3, 4)
8275        self.run_test(PadModel(), (x))
8276
8277    @skipIfUnsupportedMinOpsetVersion(11)
8278    def test_pad_circular_negative(self):
8279        # Test for different pad integer types
8280        class PadModel(torch.nn.Module):
8281            def forward(self, x):
8282                out = torch.nn.functional.pad(x, (-1, -2), mode="circular")
8283                return out
8284
8285        x = torch.randn(2, 3, 6)
8286        self.run_test(PadModel(), (x))
8287
8288    @skipIfUnsupportedMinOpsetVersion(11)
8289    def test_pad_circular_dynamic_axes(self):
8290        class PadModel(torch.nn.Module):
8291            def forward(self, x):
8292                out = torch.nn.functional.pad(x, (2, 1, 2, 1), mode="circular")
8293                return out
8294
8295        x = torch.randn(4, 3, 5, 6)
8296        self.run_test(
8297            PadModel(),
8298            x,
8299            input_names=["input_1"],
8300            dynamic_axes={"input_1": [0, 1, 2, 3]},
8301        )
8302
8303    @skipIfUnsupportedMaxOpsetVersion(10)
8304    @skipScriptTest()  # TODO: the logic in symbolic_opset9 doesn't handle script
8305    def test_unsupported_pad(self):
8306        class Pad(torch.nn.Module):
8307            def forward(self, x, pad: List[int]):
8308                return torch.nn.functional.pad(x, pad)
8309
8310        x = torch.randn(2, 2, 4, 4)
8311        y = [2, 4]
8312
8313        with self.assertRaisesRegex(
8314            RuntimeError,
8315            (
8316                "Unsupported: ONNX export of Pad.*"
8317                + "The sizes of the padding must be constant"
8318            ),
8319        ):
8320            self.run_test(Pad(), (x, y))
8321
8322    @skipIfUnsupportedMinOpsetVersion(9)
8323    def test_if_fold(self):
8324        class IfFoldModel(torch.nn.Module):
8325            def forward(self, y):
8326                if y.dim() == 2:
8327                    y = y + 4
8328                    y = y + 2
8329                else:
8330                    y = y - 1
8331                return y
8332
8333        x = torch.ones((3, 4), dtype=torch.int)
8334        self.run_test(IfFoldModel(), x)
8335
8336        class IfFoldModel(torch.nn.Module):
8337            def forward(self, y):
8338                if y.numel() > 1:
8339                    y = y + 4
8340                else:
8341                    y = y + 2
8342                return y
8343
8344        x = torch.ones((3, 4), dtype=torch.int)
8345        self.run_test(IfFoldModel(), x)
8346
8347        class IfFoldModel(torch.nn.Module):
8348            def forward(self, y):
8349                if y.dim() != 3:
8350                    y = y + 4
8351                    y = y + 2
8352                else:
8353                    return y
8354                return y
8355
8356        x = torch.ones((3, 4), dtype=torch.int)
8357        self.run_test(IfFoldModel(), x)
8358
8359        class IfFoldModel(torch.nn.Module):
8360            def forward(self, y):
8361                if y.dim() >= 1:
8362                    y = y + 4
8363                else:
8364                    y = y - 1
8365                return y
8366
8367        x = torch.ones((3, 4), dtype=torch.int)
8368        self.run_test(IfFoldModel(), x)
8369
8370        class IfFoldModel(torch.nn.Module):
8371            def forward(self, y):
8372                if y.dim() <= 1:
8373                    y = y + 4
8374                else:
8375                    y = y + 2
8376                return y
8377
8378        x = torch.ones((3, 4), dtype=torch.int)
8379        self.run_test(IfFoldModel(), x)
8380
8381        class IfFoldModel(torch.nn.Module):
8382            def forward(self, y):
8383                if y.dim() < 3 and y.dtype == torch.int:
8384                    y = y + 4
8385                    y = y + 2
8386                else:
8387                    return y
8388                return y
8389
8390        x = torch.ones((3, 4), dtype=torch.int)
8391        self.run_test(IfFoldModel(), x)
8392
8393        class IfFoldModel(torch.nn.Module):
8394            def forward(self, y):
8395                if y.dim() == 3 and y.dtype == torch.int:
8396                    y = y + 4
8397                    y = y + 2
8398                else:
8399                    y = y + 1
8400                return y
8401
8402        x = torch.ones((3, 4), dtype=torch.int)
8403        self.run_test(IfFoldModel(), x)
8404
8405        class IfFoldModel(torch.nn.Module):
8406            def forward(self, y):
8407                if y.numel() != 0 and y.dim() == 2:
8408                    y = y + 4
8409                    y = y + 2
8410                else:
8411                    return y
8412                return y
8413
8414        x = torch.ones((3, 4), dtype=torch.int)
8415        self.run_test(IfFoldModel(), x)
8416
8417        class IfFoldModel(torch.nn.Module):
8418            def forward(self, x, y):
8419                if x.numel() == y.numel():
8420                    y = x + y
8421                else:
8422                    y = y - x
8423                return y
8424
8425        x = torch.ones((3, 4), dtype=torch.int)
8426        y = torch.ones((3, 4), dtype=torch.int)
8427        self.run_test(IfFoldModel(), (x, y))
8428
8429        class IfFoldModel(torch.nn.Module):
8430            def forward(self, x, y):
8431                if x.numel() != y.numel():
8432                    y = x + y
8433                else:
8434                    y = y - x
8435                return y
8436
8437        x = torch.ones((3, 4), dtype=torch.int)
8438        y = torch.ones((3, 4), dtype=torch.int)
8439        self.run_test(IfFoldModel(), (x, y))
8440
8441    @skipIfUnsupportedMinOpsetVersion(11)
8442    def test_uninitialized(self):
8443        class UninitializedModel(torch.nn.Module):
8444            def forward(self, y):
8445                if y.shape[1] < 5:
8446                    if y.size(0) == 1:
8447                        y = y + 4
8448                    else:
8449                        return y
8450                return y
8451
8452        x = torch.ones((3, 4), dtype=torch.int)
8453        self.run_test(UninitializedModel(), x)
8454
8455    @skipIfUnsupportedMinOpsetVersion(11)
8456    def test_uninitialized_dynamic(self):
8457        class UninitializedModel(torch.nn.Module):
8458            def forward(self, y):
8459                if y.shape[1] < 5:
8460                    if y.size(0) == 1:
8461                        y = y + 4
8462                    else:
8463                        return y
8464                return y
8465
8466        x = torch.ones((3, 4), dtype=torch.int)
8467        y = torch.ones((6, 7), dtype=torch.int)
8468        self.run_test(
8469            UninitializedModel(),
8470            x,
8471            additional_test_inputs=[y],
8472            input_names=["input_1"],
8473            dynamic_axes={"input_1": [0, 1]},
8474        )
8475
8476    # onnx::Identity of sequence supported for ONNX opset >= 14
8477    @skipIfUnsupportedMinOpsetVersion(14)
8478    def test_uninitialized_tensorList(self):
8479        class UninitializedTensorListModel(torch.nn.Module):
8480            def forward(self, x):
8481                if x[0].shape[0] < 5:
8482                    if x.size(0) == 1:
8483                        x = x + 4
8484                    else:
8485                        return [x]
8486                return [x]
8487
8488        x = torch.ones((3, 4), dtype=torch.int)
8489        self.run_test(torch.jit.script(UninitializedTensorListModel()), x)
8490
8491    # onnx::Identity of sequence supported for ONNX opset >= 14
8492    @skipIfUnsupportedMinOpsetVersion(14)
8493    def test_uninitialized_tensorList_dynamic(self):
8494        class UninitializedTensorListModel(torch.nn.Module):
8495            def forward(self, x):
8496                if x[0].shape[0] < 5:
8497                    if x.size(0) == 1:
8498                        x += x
8499                    else:
8500                        return list(x)
8501                return list(x)
8502
8503        x = torch.ones((3, 4), dtype=torch.double)
8504        self.run_test(
8505            torch.jit.script(UninitializedTensorListModel()),
8506            x,
8507            input_names=["input_1"],
8508            dynamic_axes={"input_1": [0, 1]},
8509        )
8510
8511    # onnx::Identity of sequence supported for ONNX opset >= 14
8512    @skipIfUnsupportedMinOpsetVersion(14)
8513    def test_uninitialized_intList(self):
8514        class UninitializedListModel(torch.nn.Module):
8515            def forward(self, x):
8516                y = list(range(x.size(0)))
8517                if y[0] < 5:
8518                    # if x.size(0) != 3, ORT will throw type error.
8519                    if x.size(0) == 3:
8520                        y.append(10)
8521                    else:
8522                        return y
8523                return y
8524
8525        x = torch.ones((3, 4), dtype=torch.int)
8526        self.run_test(
8527            torch.jit.script(UninitializedListModel()),
8528            x,
8529            input_names=["input_1"],
8530            dynamic_axes={"input_1": [0, 1]},
8531        )
8532
8533    # onnx::Identity of sequence supported for ONNX opset >= 14
8534    @skipIfUnsupportedMinOpsetVersion(14)
8535    def test_uninitialized_tensorList_shape(self):
8536        class UninitializedModel(torch.nn.Module):
8537            def forward(self, x):
8538                if x.shape[1] < 5:
8539                    if x.size(0) == 1:
8540                        x = x + 4
8541                    else:
8542                        x_list = list(x)
8543                        x_list.append(x)
8544                        return x_list
8545                return [x, x]
8546
8547        x = torch.ones((3, 4), dtype=torch.int)
8548        y = torch.ones((4, 6), dtype=torch.int)
8549        self.run_test(
8550            torch.jit.script(UninitializedModel()),
8551            x,
8552            additional_test_inputs=[y],
8553            input_names=["input_1"],
8554            dynamic_axes={"input_1": [0, 1]},
8555        )
8556
8557    # Sequence type as loop-carried dependencies only supported for ONNX opset >= 13
8558    @skipIfUnsupportedMinOpsetVersion(13)
8559    def test_sequance_loopcarried(self):
8560        class SequanceLoopModel(torch.nn.Module):
8561            def forward(self, x):
8562                outputs = []
8563                for i in range(3):
8564                    outputs += [x]
8565                return torch.stack(outputs).transpose(0, 1)
8566
8567        x = torch.ones((3, 4), dtype=torch.int)
8568        self.run_test(torch.jit.script(SequanceLoopModel()), x)
8569
8570    def test_reflection_pad(self):
8571        model = torch.nn.ReflectionPad1d(2)
8572        x = torch.randn(2, 4, 4)
8573        self.run_test(model, x)
8574
8575        model = torch.nn.ReflectionPad2d((3, 0, 2, 1))
8576        x = torch.randn(2, 2, 4, 4)
8577        self.run_test(model, x)
8578
8579    def test_replication_pad(self):
8580        model = torch.nn.ReplicationPad1d(2)
8581        x = torch.randn(2, 4, 4)
8582        self.run_test(model, x)
8583
8584        model = torch.nn.ReplicationPad2d((3, 0, 2, 1))
8585        x = torch.randn(2, 2, 4, 4)
8586        self.run_test(model, x)
8587
8588    @skipIfUnsupportedMinOpsetVersion(11)
8589    def test_im2col(self):
8590        class Unfold(torch.nn.Module):
8591            def forward(self, input):
8592                return (
8593                    torch.nn.functional.unfold(
8594                        input, kernel_size=(10, 15), dilation=2, padding=5, stride=3
8595                    ),
8596                    torch.nn.functional.unfold(
8597                        input, kernel_size=(2, 2), dilation=1, padding=0, stride=3
8598                    ),
8599                    torch.nn.functional.unfold(
8600                        input, kernel_size=(1, 1), dilation=5, padding=2, stride=3
8601                    ),
8602                )
8603
8604        x = torch.rand(1, 1, 200, 100)
8605        self.run_test(Unfold(), x)
8606
8607    @skipIfNoLapack
8608    @skipIfUnsupportedMinOpsetVersion(11)
8609    def test_det(self):
8610        class Det(torch.nn.Module):
8611            def forward(self, x):
8612                return torch.linalg.det(x)
8613
8614        x = torch.randn(2, 3, 5, 5)
8615        self.run_test(Det(), x)
8616
8617    def test_linalg_norm(self):
8618        class LinalgSingleDimModel(torch.nn.Module):
8619            def __init__(self, ord_val):
8620                super().__init__()
8621                self.ord = ord_val
8622
8623            def forward(self, x):
8624                return torch.linalg.norm(x, ord=self.ord, dim=1)
8625
8626        x = torch.randn(2, 3, 5, 5)
8627        self.run_test(LinalgSingleDimModel(None), x)
8628        self.run_test(LinalgSingleDimModel(2), x)
8629        self.run_test(LinalgSingleDimModel(float("inf")), x)
8630        self.run_test(LinalgSingleDimModel(-float("inf")), x)
8631        self.run_test(LinalgSingleDimModel(-4), x)
8632        self.run_test(LinalgSingleDimModel(1.5), x)
8633
8634        class LinalgMultiDimModel(torch.nn.Module):
8635            def __init__(self, ord_val):
8636                super().__init__()
8637                self.ord = ord_val
8638
8639            def forward(self, x):
8640                return torch.linalg.norm(x, ord=self.ord, dim=(0, 2))
8641
8642        x = torch.randn(2, 3, 5, 5)
8643        self.run_test(LinalgMultiDimModel("fro"), x)
8644        self.run_test(LinalgMultiDimModel(float("inf")), x)
8645        self.run_test(LinalgMultiDimModel(-float("inf")), x)
8646        self.run_test(LinalgMultiDimModel(1), x)
8647        self.run_test(LinalgMultiDimModel(-1), x)
8648
8649        class LinalgNoDimNoOrdModel(torch.nn.Module):
8650            def forward(self, x):
8651                return torch.linalg.norm(x)
8652
8653        x = torch.randn(2, 3, 5, 5)
8654        self.run_test(LinalgNoDimNoOrdModel(), x)
8655        y = torch.randn(2, 3)
8656        self.run_test(LinalgNoDimNoOrdModel(), y)
8657        z = torch.randn(2)
8658        self.run_test(LinalgNoDimNoOrdModel(), z)
8659
8660        class LinalgNoDim1DModel(torch.nn.Module):
8661            def __init__(self, ord_val):
8662                super().__init__()
8663                self.ord = ord_val
8664
8665            def forward(self, x):
8666                return torch.linalg.norm(x, ord=self.ord)
8667
8668        x = torch.randn(2)
8669        self.run_test(LinalgNoDim1DModel(None), x)
8670        self.run_test(LinalgNoDim1DModel(2), x)
8671        self.run_test(LinalgNoDim1DModel(float("inf")), x)
8672        self.run_test(LinalgNoDim1DModel(-float("inf")), x)
8673        self.run_test(LinalgNoDim1DModel(-4), x)
8674        self.run_test(LinalgNoDim1DModel(1.5), x)
8675
8676        class LinalgNoDim2DModel(torch.nn.Module):
8677            def __init__(self, ord_val):
8678                super().__init__()
8679                self.ord = ord_val
8680
8681            def forward(self, x):
8682                return torch.linalg.norm(x, ord=self.ord)
8683
8684        x = torch.randn(2, 3)
8685        self.run_test(LinalgNoDim2DModel("fro"), x)
8686        self.run_test(LinalgNoDim2DModel(float("inf")), x)
8687        self.run_test(LinalgNoDim2DModel(-float("inf")), x)
8688        self.run_test(LinalgNoDim2DModel(1), x)
8689        self.run_test(LinalgNoDim2DModel(-1), x)
8690
8691    @skipIfUnsupportedMinOpsetVersion(11)
8692    def test_linalg_vector_norm_zero(self):
8693        class LinalgVectorNormModel(torch.nn.Module):
8694            def __init__(self, ord_val):
8695                super().__init__()
8696                self.ord = ord_val
8697
8698            def forward(self, x):
8699                return torch.linalg.vector_norm(x, ord=self.ord)
8700
8701        x = torch.randn(2, 3, 5, 5)
8702        self.run_test(LinalgVectorNormModel(0), x)
8703
8704    def test_linalg_vector_norm(self):
8705        class LinalgVectorNormModel(torch.nn.Module):
8706            def __init__(self, ord_val, dim_info):
8707                super().__init__()
8708                self.ord = ord_val
8709                self.dim, self.keepdim = dim_info
8710
8711            def forward(self, x):
8712                return torch.linalg.vector_norm(
8713                    x, ord=self.ord, dim=self.dim, keepdim=self.keepdim
8714                )
8715
8716        x = torch.randn(2, 3, 5, 5)
8717        ord_options = [2, float("inf"), -float("inf"), -4, 1.5]
8718        dim_options = [(None, False), (1, False), ((1, 2), False), ((1, 2), True)]
8719        for ord_val in ord_options:
8720            for dim_info in dim_options:
8721                self.run_test(LinalgVectorNormModel(ord_val, dim_info), x)
8722
8723    def test_linalg_matrix_norm(self):
8724        class LinalgMatrixNormModel(torch.nn.Module):
8725            def __init__(self, ord_val, dim_val=(-2, -1), keepdim_val=False):
8726                super().__init__()
8727                self.ord = ord_val
8728                self.dim = dim_val
8729                self.keepdim = keepdim_val
8730
8731            def forward(self, x):
8732                return torch.linalg.matrix_norm(
8733                    x, ord=self.ord, dim=self.dim, keepdim=self.keepdim
8734                )
8735
8736        x = torch.randn(2, 3, 5, 5)
8737        ord_options = ["fro", float("inf"), -float("inf"), 1, -1]
8738        for ord_val in ord_options:
8739            self.run_test(LinalgMatrixNormModel(ord_val), x)
8740            self.run_test(LinalgMatrixNormModel(ord_val, (0, 2)), x)
8741            self.run_test(LinalgMatrixNormModel(ord_val, (0, 2), True), x)
8742
8743    @skipIfUnsupportedMinOpsetVersion(9)
8744    def test_linalg_cross(self):
8745        class Cross(torch.nn.Module):
8746            def forward(self, x, y):
8747                return torch.linalg.cross(x, y, dim=1), torch.linalg.cross(x, y)
8748
8749        x = torch.randn(5, 3, 2, 3)
8750        y = torch.randn(1, 3, 1, 3)
8751        self.run_test(Cross(), input_args=(x, y))
8752
8753    # This test checks output scalar type in the ONNX graph should not be null
8754    # https://github.com/pytorch/pytorch/issues/28607
8755    @skipIfUnsupportedMinOpsetVersion(10)
8756    def test_trace_script(self):
8757        @torch.jit.script
8758        def center_slice_helper(input, h_offset):
8759            return input[:, h_offset:]
8760
8761        class CenterCrop(torch.nn.Module):
8762            def forward(self, input):
8763                return center_slice_helper(input, torch.tensor(input.shape[1] - 1))
8764
8765        x = torch.randn(3, 4)
8766        self.run_test(CenterCrop(), x)
8767
8768    @skipIfNoLapack
8769    @skipIfUnsupportedMinOpsetVersion(11)
8770    def test_logdet(self):
8771        class LogDet(torch.nn.Module):
8772            def forward(self, x):
8773                return torch.logdet(x)
8774
8775        x = torch.randn(2, 3, 5, 5)
8776        self.run_test(LogDet(), x)
8777
8778    def test_dim(self):
8779        class DimModel(torch.jit.ScriptModule):
8780            @torch.jit.script_method
8781            def forward(self, input):
8782                out = input * 2
8783                out *= out.dim()
8784                return out
8785
8786        empty_input = torch.randn(0, requires_grad=True)
8787        multi_dim_input = torch.randn(1, 2, 3, requires_grad=True)
8788        self.run_test(DimModel(), empty_input)
8789        self.run_test(DimModel(), multi_dim_input)
8790
8791    @skipIfUnsupportedMinOpsetVersion(11)
8792    def test_dim_1(self):
8793        class M(torch.jit.ScriptModule):
8794            @torch.jit.script_method
8795            def forward(self, poses):
8796                boxes = torch.zeros([poses.shape[0], 2, 4])
8797                batch_boxes = []
8798                for kp_boxes in boxes:
8799                    kp_boxes = torchvision.ops.clip_boxes_to_image(kp_boxes, (2, 3))
8800                    batch_boxes.append(kp_boxes)
8801                return batch_boxes
8802
8803        dummy_inputs = torch.rand(2, 2, 3)
8804        self.run_test(M(), (dummy_inputs,), input_names=["x"], dynamic_axes={"x": [0]})
8805
8806    @skipIfUnsupportedMinOpsetVersion(12)
8807    @skipDtypeChecking
8808    def test_outer(self):
8809        class Outer(torch.nn.Module):
8810            def forward(self, x, y):
8811                return torch.outer(x, y)
8812
8813        x = torch.arange(1, 5)
8814        y = torch.arange(1, 4)
8815        self.run_test(Outer(), input_args=(x, y))
8816
8817        x = torch.arange(1, 6).to(dtype=torch.float32)
8818        y = torch.arange(1, 4).to(dtype=torch.long)
8819        self.run_test(Outer(), input_args=(x, y))
8820
8821        x = torch.arange(2, 5).to(dtype=torch.float32)
8822        y = torch.arange(2, 4).to(dtype=torch.float64)
8823        self.run_test(Outer(), input_args=(x, y))
8824
8825        x = torch.arange(3, 6).to(dtype=torch.int32)
8826        y = torch.arange(4, 7).to(dtype=torch.long)
8827        self.run_test(Outer(), input_args=(x, y))
8828
8829    @skipIfUnsupportedMinOpsetVersion(9)
8830    def test_movedim(self):
8831        class MovedimModel(torch.nn.Module):
8832            def forward(self, x):
8833                return (
8834                    x.movedim(1, 3),
8835                    x.movedim(2, 0),
8836                    x.movedim(1, 1),
8837                    x.movedim((1, 2, 3), (3, 0, 1)),
8838                    x.movedim((0, 1, 2), (1, 2, 3)),
8839                    x.movedim((1, 3, 2), (1, 3, 2)),
8840                )
8841
8842        x = torch.randn(5, 3, 4, 2)
8843
8844        self.run_test(MovedimModel(), x)
8845
8846    @skipIfUnsupportedMinOpsetVersion(9)
8847    def test_moveaxis(self):
8848        # moveaxis is an alias of movedim; thus, mostly copied from `test_movedim`.
8849        class MoveaxisModel(torch.nn.Module):
8850            def forward(self, x):
8851                return (
8852                    x.moveaxis(1, 3),
8853                    x.moveaxis(2, 0),
8854                    x.moveaxis(1, 1),
8855                    x.moveaxis((1, 2, 3), (3, 0, 1)),
8856                    x.moveaxis((0, 1, 2), (1, 2, 3)),
8857                    x.moveaxis((1, 3, 2), (1, 3, 2)),
8858                )
8859
8860        x = torch.randn(5, 3, 4, 2)
8861
8862        self.run_test(MoveaxisModel(), x)
8863
8864    @skipIfUnsupportedMinOpsetVersion(12)
8865    def test_einsum(self):
8866        class EinsumModelBatchDiagonal(torch.nn.Module):
8867            def forward(self, x):
8868                eqn = "...ii ->...i"
8869                return torch.einsum(eqn, x)
8870
8871        for x in [torch.randn(3, 5, 5), torch.randn(3, 5, 5).to(dtype=torch.bool)]:
8872            self.run_test(EinsumModelBatchDiagonal(), input_args=(x,))
8873
8874        class EinsumModelBatchMatmul(torch.nn.Module):
8875            def forward(self, x, y):
8876                eqn = "bij, bjk -> bik"
8877                return torch.einsum(eqn, x, y)
8878
8879        x = torch.randn(5, 2, 3)
8880        y = torch.randn(5, 3, 4)
8881        self.run_test(EinsumModelBatchMatmul(), input_args=(x, y))
8882
8883        class EinsumModelInnerProd(torch.nn.Module):
8884            def forward(self, x, y):
8885                eqn = "i,i"
8886                return torch.einsum(eqn, x, y)
8887
8888        x = torch.randn(5)
8889        y = torch.randn(5)
8890        self.run_test(EinsumModelInnerProd(), input_args=(x, y))
8891
8892        class EinsumModelTranspose(torch.nn.Module):
8893            def forward(self, x):
8894                eqn = "ij->ji"
8895                return torch.einsum(eqn, x)
8896
8897        for x in [torch.randn(3, 4), torch.randn(3, 4).to(dtype=torch.bool)]:
8898            self.run_test(EinsumModelTranspose(), input_args=(x,))
8899
8900    @skipIfUnsupportedMinOpsetVersion(9)
8901    def test_cosine_similarity(self):
8902        x = torch.randn(5, 3, 2)
8903        y = torch.randn(5, 3, 2)
8904        self.run_test(torch.nn.CosineSimilarity(dim=2), input_args=(x, y))
8905
8906    @skipIfUnsupportedMinOpsetVersion(9)
8907    def test_pairwise_distance(self):
8908        x = torch.randn(5, 3, 2)
8909        y = torch.randn(5, 3, 2)
8910        self.run_test(torch.nn.PairwiseDistance(p=2.0), input_args=(x, y))
8911
8912    @skipIfUnsupportedMinOpsetVersion(9)
8913    def test_cross(self):
8914        class Cross(torch.nn.Module):
8915            def forward(self, x, y):
8916                return torch.cross(x, y, dim=3), torch.cross(x, y)
8917
8918        x = torch.randn(5, 3, 2, 3)
8919        y = torch.randn(5, 3, 2, 3)
8920        self.run_test(Cross(), input_args=(x, y))
8921
8922    @skipIfUnsupportedMinOpsetVersion(9)
8923    def test_cdist(self):
8924        class Cdist(torch.nn.Module):
8925            def forward(self, x, y):
8926                return torch.cdist(x, y)
8927
8928        x = torch.randn(5, 3, 3)
8929        y = torch.randn(5, 2, 3)
8930        self.run_test(Cdist(), input_args=(x, y))
8931
8932    @skipIfUnsupportedMinOpsetVersion(12)
8933    def test_crossentropyloss(self):
8934        for ignore_index in [-100, 1]:
8935            x = torch.randn(3, 5)
8936            y = torch.empty(3, dtype=torch.long).random_(5)
8937            y[y == 1] = ignore_index
8938
8939            self._crossentropyloss(x, y, ignore_index)
8940
8941            x = torch.randn(3, 5, 2)
8942            y = torch.empty(3, 2, dtype=torch.long).random_(5)
8943            y[y == 1] = ignore_index
8944            self._crossentropyloss(x, y, ignore_index)
8945
8946            x = torch.randn(3, 5, 2, 7)
8947            y = torch.empty(3, 2, 7, dtype=torch.long).random_(5)
8948            y[y == 1] = ignore_index
8949            self._crossentropyloss(x, y, ignore_index)
8950
8951    def _crossentropyloss(self, x, y, ignore_index):
8952        class CrossEntropyLossNone(torch.nn.Module):
8953            def __init__(self, ignore_index):
8954                super().__init__()
8955                if ignore_index == -100:
8956                    self.loss = torch.nn.CrossEntropyLoss(reduction="none")
8957                else:
8958                    self.loss = torch.nn.CrossEntropyLoss(
8959                        reduction="none", ignore_index=ignore_index
8960                    )
8961
8962            def forward(self, input, target):
8963                return self.loss(input, target)
8964
8965        self.run_test(CrossEntropyLossNone(ignore_index), input_args=(x, y))
8966
8967        class CrossEntropyLossNoneWeight(torch.nn.Module):
8968            def __init__(self, ignore_index):
8969                super().__init__()
8970                if ignore_index == -100:
8971                    self.loss = torch.nn.CrossEntropyLoss(
8972                        reduction="none", weight=torch.randn(5)
8973                    )
8974                else:
8975                    self.loss = torch.nn.CrossEntropyLoss(
8976                        reduction="none",
8977                        weight=torch.randn(5),
8978                        ignore_index=ignore_index,
8979                    )
8980
8981            def forward(self, input, target):
8982                return self.loss(input, target)
8983
8984        self.run_test(CrossEntropyLossNoneWeight(ignore_index), input_args=(x, y))
8985
8986        class CrossEntropyLossSum(torch.nn.Module):
8987            def __init__(self, ignore_index):
8988                super().__init__()
8989                if ignore_index == -100:
8990                    self.loss = torch.nn.CrossEntropyLoss(reduction="sum")
8991                else:
8992                    self.loss = torch.nn.CrossEntropyLoss(
8993                        reduction="sum", ignore_index=ignore_index
8994                    )
8995
8996            def forward(self, input, target):
8997                return self.loss(input, target)
8998
8999        self.run_test(CrossEntropyLossSum(ignore_index), input_args=(x, y))
9000
9001        class CrossEntropyLossSumWeight(torch.nn.Module):
9002            def __init__(self, ignore_index):
9003                super().__init__()
9004                if ignore_index == -100:
9005                    self.loss = torch.nn.CrossEntropyLoss(
9006                        reduction="sum", weight=torch.randn(5)
9007                    )
9008                else:
9009                    self.loss = torch.nn.CrossEntropyLoss(
9010                        reduction="sum",
9011                        weight=torch.randn(5),
9012                        ignore_index=ignore_index,
9013                    )
9014
9015            def forward(self, input, target):
9016                return self.loss(input, target)
9017
9018        self.run_test(CrossEntropyLossSumWeight(ignore_index), input_args=(x, y))
9019
9020        class CrossEntropyLossMean(torch.nn.Module):
9021            def __init__(self, ignore_index):
9022                super().__init__()
9023                if ignore_index == -100:
9024                    self.loss = torch.nn.CrossEntropyLoss()
9025                else:
9026                    self.loss = torch.nn.CrossEntropyLoss(ignore_index=ignore_index)
9027
9028            def forward(self, input, target):
9029                return self.loss(input, target)
9030
9031        self.run_test(CrossEntropyLossMean(ignore_index), input_args=(x, y))
9032
9033        class CrossEntropyLossMeanWeight(torch.nn.Module):
9034            def __init__(self, ignore_index):
9035                super().__init__()
9036                if ignore_index == -100:
9037                    self.loss = torch.nn.CrossEntropyLoss(weight=torch.randn(5))
9038                else:
9039                    self.loss = torch.nn.CrossEntropyLoss(
9040                        weight=torch.randn(5), ignore_index=ignore_index
9041                    )
9042
9043            def forward(self, input, target):
9044                return self.loss(input, target)
9045
9046        self.run_test(CrossEntropyLossMeanWeight(ignore_index), input_args=(x, y))
9047
9048    @skipIfUnsupportedMinOpsetVersion(9)
9049    def test_MSELoss(self):
9050        class MSELoss(torch.nn.Module):
9051            def __init__(self) -> None:
9052                super().__init__()
9053                self.loss1 = torch.nn.MSELoss(reduction="none")
9054                self.loss2 = torch.nn.MSELoss(reduction="sum")
9055                self.loss3 = torch.nn.MSELoss(reduction="mean")
9056
9057            def forward(self, input, target):
9058                return (
9059                    self.loss1(input, target),
9060                    self.loss2(input, target),
9061                    self.loss3(input, target),
9062                )
9063
9064        x = torch.randn(2, 3, 5)
9065        y = torch.randn(2, 3, 5)
9066        self.run_test(MSELoss(), input_args=(x, y))
9067
9068    @skipIfUnsupportedMinOpsetVersion(9)
9069    def test_kldiv_loss(self):
9070        x = torch.rand(5).log()
9071        y = torch.rand(5)
9072        self._kldiv_loss(x, y)
9073
9074        x = torch.rand(2, 3, 5).log()
9075        y = torch.rand(2, 3, 5)
9076        self._kldiv_loss(x, y)
9077
9078        x = torch.rand(2, 3, 5, 7).log()
9079        y = torch.rand(2, 3, 5, 7)
9080        self._kldiv_loss(x, y)
9081
9082    def _kldiv_loss(self, x, y):
9083        class KLDivLossNone(torch.nn.Module):
9084            def __init__(self) -> None:
9085                super().__init__()
9086                self.loss = torch.nn.KLDivLoss(reduction="none", log_target=True)
9087
9088            def forward(self, input, target):
9089                return self.loss(input, target.log())
9090
9091        self.run_test(KLDivLossNone(), input_args=(x, y))
9092
9093        class KLDivLossMean(torch.nn.Module):
9094            def __init__(self) -> None:
9095                super().__init__()
9096                self.loss = torch.nn.KLDivLoss(reduction="mean", log_target=False)
9097
9098            def forward(self, input, target):
9099                return self.loss(input, target)
9100
9101        self.run_test(KLDivLossMean(), input_args=(x, y))
9102
9103        class KLDivLossSum(torch.nn.Module):
9104            def __init__(self) -> None:
9105                super().__init__()
9106                self.loss = torch.nn.KLDivLoss(reduction="sum", log_target=True)
9107
9108            def forward(self, input, target):
9109                return self.loss(input, target.log())
9110
9111        self.run_test(KLDivLossSum(), input_args=(x, y))
9112
9113        class KLDivLossBatchMean(torch.nn.Module):
9114            def __init__(self) -> None:
9115                super().__init__()
9116                self.loss = torch.nn.KLDivLoss(reduction="batchmean", log_target=False)
9117
9118            def forward(self, input, target):
9119                return self.loss(input, target)
9120
9121        self.run_test(KLDivLossBatchMean(), input_args=(x, y))
9122
9123        class KLDivLossMiniBatchMean(torch.nn.Module):
9124            def __init__(self) -> None:
9125                super().__init__()
9126                self.loss = torch.nn.KLDivLoss(
9127                    reduction="batchmean", size_average=False, log_target=True
9128                )
9129
9130            def forward(self, input, target):
9131                return self.loss(input, target.log())
9132
9133        self.run_test(KLDivLossMiniBatchMean(), input_args=(x, y))
9134
9135    @skipIfUnsupportedMinOpsetVersion(12)
9136    def test_nllloss(self):
9137        class NLLModel(torch.nn.Module):
9138            def __init__(self) -> None:
9139                super().__init__()
9140                self.loss = torch.nn.NLLLoss(reduction="none")
9141                self.m = torch.nn.LogSoftmax(dim=1)
9142
9143            def forward(self, input, target):
9144                output = self.loss(self.m(2 * input), target)
9145                return output
9146
9147        N, C = 5, 4
9148        input = torch.randn(N, 16)
9149        target = torch.empty(N, dtype=torch.long).random_(0, C)
9150
9151        # using test data containing default ignore_index=-100
9152        target[target == 1] = -100
9153        self.run_test(NLLModel(), (input, target))
9154
9155    @skipIfUnsupportedMinOpsetVersion(12)
9156    def test_nllloss_2d_none(self):
9157        class NLLModel(torch.nn.Module):
9158            def __init__(self) -> None:
9159                super().__init__()
9160                self.loss = torch.nn.NLLLoss(reduction="none")
9161                self.conv = torch.nn.Conv2d(16, C, (3, 3))
9162                self.m = torch.nn.LogSoftmax(dim=1)
9163
9164            def forward(self, input, target):
9165                output = self.loss(self.m(self.conv(input)), target)
9166                return output
9167
9168        N, C = 5, 4
9169        input = torch.randn(N, 16, 10, 10)
9170        target = torch.empty(N, 8, 8, dtype=torch.long).random_(0, C)
9171
9172        # using test data containing default ignore_index=-100
9173        target[target == 1] = -100
9174        self.run_test(NLLModel(), (input, target))
9175
9176    @skipIfUnsupportedMinOpsetVersion(12)
9177    def test_nllloss_2d_mean(self):
9178        class NLLModel(torch.nn.Module):
9179            def __init__(self) -> None:
9180                super().__init__()
9181                self.loss = torch.nn.NLLLoss(reduction="mean")
9182                self.conv = torch.nn.Conv2d(16, C, (3, 3))
9183                self.m = torch.nn.LogSoftmax(dim=1)
9184
9185            def forward(self, input, target):
9186                output = self.loss(self.m(self.conv(input)), target)
9187                return output
9188
9189        N, C = 5, 4
9190        input = torch.randn(N, 16, 10, 10)
9191        target = torch.empty(N, 8, 8, dtype=torch.long).random_(0, C)
9192
9193        # using test data containing default ignore_index=-100
9194        target[target == 1] = -100
9195        self.run_test(NLLModel(), (input, target))
9196
9197    @skipIfUnsupportedMinOpsetVersion(12)
9198    def test_nllloss_2d_sum(self):
9199        class NLLModel(torch.nn.Module):
9200            def __init__(self) -> None:
9201                super().__init__()
9202                self.loss = torch.nn.NLLLoss(reduction="sum")
9203                self.conv = torch.nn.Conv2d(16, C, (3, 3))
9204                self.m = torch.nn.LogSoftmax(dim=1)
9205
9206            def forward(self, input, target):
9207                output = self.loss(self.m(self.conv(input)), target)
9208                return output
9209
9210        N, C = 5, 4
9211        input = torch.randn(N, 16, 10, 10)
9212        target = torch.empty(N, 8, 8, dtype=torch.long).random_(0, C)
9213
9214        # using test data containing default ignore_index=-100
9215        target[target == 1] = -100
9216        self.run_test(NLLModel(), (input, target))
9217
9218    @skipIfUnsupportedMinOpsetVersion(12)
9219    def test_nllloss_2d_mean_weights(self):
9220        class NLLModel(torch.nn.Module):
9221            def __init__(self) -> None:
9222                super().__init__()
9223                self.loss = torch.nn.NLLLoss(reduction="mean", weight=torch.randn(C))
9224                self.conv = torch.nn.Conv2d(16, C, (3, 3))
9225                self.m = torch.nn.LogSoftmax(dim=1)
9226
9227            def forward(self, input, target):
9228                output = self.loss(self.m(self.conv(input)), target)
9229                return output
9230
9231        N, C = 5, 4
9232        input = torch.randn(N, 16, 10, 10)
9233        target = torch.empty(N, 8, 8, dtype=torch.long).random_(0, C)
9234
9235        # using test data containing default ignore_index=-100
9236        target[target == 1] = -100
9237        self.run_test(NLLModel(), (input, target))
9238
9239    @skipIfUnsupportedMinOpsetVersion(12)
9240    def test_nllloss_2d_mean_ignore_index(self):
9241        class NLLModel(torch.nn.Module):
9242            def __init__(self) -> None:
9243                super().__init__()
9244                self.loss = torch.nn.NLLLoss(reduction="mean", ignore_index=1)
9245                self.conv = torch.nn.Conv2d(16, C, (3, 3))
9246                self.m = torch.nn.LogSoftmax(dim=1)
9247
9248            def forward(self, input, target):
9249                output = self.loss(self.m(self.conv(input)), target)
9250                return output
9251
9252        N, C = 5, 4
9253        input = torch.randn(N, 16, 10, 10)
9254        target = torch.empty(N, 8, 8, dtype=torch.long).random_(0, C)
9255        self.run_test(NLLModel(), (input, target))
9256
9257    @skipIfUnsupportedMinOpsetVersion(12)
9258    def test_nllloss_dynamic_ignore_index(self):
9259        import torch.nn.functional as F
9260
9261        def linear_combination(x, y, epsilon):
9262            return epsilon * x + (1 - epsilon) * y
9263
9264        def reduce_loss(loss, reduction="mean"):
9265            return (
9266                loss.mean()
9267                if reduction == "mean"
9268                else loss.sum()
9269                if reduction == "sum"
9270                else loss
9271            )
9272
9273        class LabelSmoothingCrossEntropy(torch.nn.Module):
9274            def __init__(self, epsilon: float = 0.1, reduction="mean"):
9275                super().__init__()
9276                self.epsilon = epsilon
9277                self.reduction = reduction
9278
9279            def forward(self, preds, target, start_position):
9280                n = preds.size()[-1]
9281                log_preds = F.log_softmax(preds, dim=-1)
9282                ignore_index = start_position.size(1)
9283                nll = F.nll_loss(
9284                    log_preds,
9285                    target,
9286                    reduction=self.reduction,
9287                    ignore_index=ignore_index,
9288                )
9289                return nll + start_position.float()
9290
9291        N = 5
9292        preds = torch.randn(N, 16)
9293        target = torch.randint(5, (N,))
9294        start_position = torch.randint(10, (N, N))
9295        self.run_test(LabelSmoothingCrossEntropy(), (preds, target, start_position))
9296
9297    @skipIfUnsupportedMinOpsetVersion(12)
9298    def test_nllloss_2d_mean_ignore_index_weights(self):
9299        class NLLModel(torch.nn.Module):
9300            def __init__(self) -> None:
9301                super().__init__()
9302                self.loss = torch.nn.NLLLoss(
9303                    reduction="mean", weight=torch.randn(C), ignore_index=1
9304                )
9305                self.conv = torch.nn.Conv2d(16, C, (3, 3))
9306                self.m = torch.nn.LogSoftmax(dim=1)
9307
9308            def forward(self, input, target):
9309                output = self.loss(self.m(self.conv(input)), target)
9310                return output
9311
9312        N, C = 5, 4
9313        input = torch.randn(N, 16, 10, 10)
9314        target = torch.empty(N, 8, 8, dtype=torch.long).random_(0, C)
9315        self.run_test(NLLModel(), (input, target))
9316
9317    @skipIfUnsupportedMinOpsetVersion(12)
9318    def test_binary_cross_entropy_with_logits(self):
9319        x = torch.randn(5)
9320        y = torch.empty(5).random_(2)
9321        self._bce_logits(x, y)
9322
9323        x = torch.randn(3, 4)
9324        y = torch.empty(3, 4).random_(2)
9325        weight = torch.tensor([3])
9326        self._bce_logits_wegiht(x, y, weight)
9327
9328        x = torch.randn(3, 2, 4)
9329        y = torch.empty(3, 2, 4).random_(2)
9330        pos_weight = torch.empty([2, 4]).random_(2)
9331        self._bce_logits_posweight(x, y, pos_weight)
9332
9333        x = torch.randn(3, 3, 4)
9334        y = torch.empty(3, 3, 4).random_(2)
9335        weight = torch.tensor([3])
9336        pos_weight = torch.empty([3, 4]).random_(2)
9337        self._bce_logits_loss_weight_posweight(x, y, weight, pos_weight)
9338
9339    def _bce_logits(self, x, y):
9340        class BCEWithLogitsLossNone(torch.nn.Module):
9341            def forward(self, input, target):
9342                return torch.nn.functional.binary_cross_entropy_with_logits(
9343                    input, target, reduction="none"
9344                )
9345
9346        self.run_test(BCEWithLogitsLossNone(), input_args=(x, y))
9347
9348        class BCEWithLogitsLossMean(torch.nn.Module):
9349            def forward(self, input, target):
9350                return torch.nn.functional.binary_cross_entropy_with_logits(
9351                    input, target, reduction="mean"
9352                )
9353
9354        self.run_test(BCEWithLogitsLossMean(), input_args=(x, y))
9355
9356        class BCEWithLogitsLossSum(torch.nn.Module):
9357            def forward(self, input, target):
9358                return torch.nn.functional.binary_cross_entropy_with_logits(
9359                    input, target, reduction="sum"
9360                )
9361
9362        self.run_test(BCEWithLogitsLossSum(), input_args=(x, y))
9363
9364    def _bce_logits_wegiht(self, x, y, weight):
9365        class BCEWithLogitsLossWegihtNone(torch.nn.Module):
9366            def forward(self, input, target, weight):
9367                return torch.nn.functional.binary_cross_entropy_with_logits(
9368                    input, target, weight=weight, reduction="none"
9369                )
9370
9371        self.run_test(BCEWithLogitsLossWegihtNone(), input_args=(x, y, weight))
9372
9373        class BCEWithLogitsLossWegihtMean(torch.nn.Module):
9374            def forward(self, input, target, weight):
9375                return torch.nn.functional.binary_cross_entropy_with_logits(
9376                    input, target, weight=weight, reduction="mean"
9377                )
9378
9379        self.run_test(BCEWithLogitsLossWegihtMean(), input_args=(x, y, weight))
9380
9381        class BCEWithLogitsLossWegihtSum(torch.nn.Module):
9382            def forward(self, input, target, weight):
9383                return torch.nn.functional.binary_cross_entropy_with_logits(
9384                    input, target, weight=weight, reduction="sum"
9385                )
9386
9387        self.run_test(BCEWithLogitsLossWegihtSum(), input_args=(x, y, weight))
9388
9389    def _bce_logits_posweight(self, x, y, pos_weight):
9390        class BCEWithLogitsLossPosWegihtNone(torch.nn.Module):
9391            def forward(self, input, target, pos_weight):
9392                return torch.nn.functional.binary_cross_entropy_with_logits(
9393                    input, target, pos_weight=pos_weight, reduction="none"
9394                )
9395
9396        self.run_test(BCEWithLogitsLossPosWegihtNone(), input_args=(x, y, pos_weight))
9397
9398        class BCEWithLogitsLossPosWegihtMean(torch.nn.Module):
9399            def forward(self, input, target, pos_weight):
9400                return torch.nn.functional.binary_cross_entropy_with_logits(
9401                    input, target, pos_weight=pos_weight, reduction="mean"
9402                )
9403
9404        self.run_test(BCEWithLogitsLossPosWegihtMean(), input_args=(x, y, pos_weight))
9405
9406        class BCEWithLogitsLossPosWegihtSum(torch.nn.Module):
9407            def forward(self, input, target, pos_weight):
9408                return torch.nn.functional.binary_cross_entropy_with_logits(
9409                    input, target, pos_weight=pos_weight, reduction="sum"
9410                )
9411
9412        self.run_test(BCEWithLogitsLossPosWegihtSum(), input_args=(x, y, pos_weight))
9413
9414    def _bce_logits_loss_weight_posweight(self, x, y, weight, pos_weight):
9415        class BCEWithLogitsLossWeightPosweightNone(torch.nn.Module):
9416            def forward(self, input, target, weight, pos_weight):
9417                return torch.nn.functional.binary_cross_entropy_with_logits(
9418                    input,
9419                    target,
9420                    weight=weight,
9421                    pos_weight=pos_weight,
9422                    reduction="none",
9423                )
9424
9425        self.run_test(
9426            BCEWithLogitsLossWeightPosweightNone(),
9427            input_args=(x, y, weight, pos_weight),
9428        )
9429
9430        class BCEWithLogitsLossWeightPosweightMean(torch.nn.Module):
9431            def forward(self, input, target, weight, pos_weight):
9432                return torch.nn.functional.binary_cross_entropy_with_logits(
9433                    input,
9434                    target,
9435                    weight=weight,
9436                    pos_weight=pos_weight,
9437                    reduction="mean",
9438                )
9439
9440        self.run_test(
9441            BCEWithLogitsLossWeightPosweightMean(),
9442            input_args=(x, y, weight, pos_weight),
9443        )
9444
9445        class BCEWithLogitsLossWeightPosweightSum(torch.nn.Module):
9446            def forward(self, input, target, weight, pos_weight):
9447                return torch.nn.functional.binary_cross_entropy_with_logits(
9448                    input, target, weight=weight, pos_weight=pos_weight, reduction="sum"
9449                )
9450
9451        self.run_test(
9452            BCEWithLogitsLossWeightPosweightSum(), input_args=(x, y, weight, pos_weight)
9453        )
9454
9455    def test_torch_mm(self):
9456        class M(torch.nn.Module):
9457            def forward(self, mat1, mat2):
9458                mm = torch.mm(mat1, mat2)
9459                return mm
9460
9461        mat1 = torch.randn(2, 3)
9462        mat2 = torch.randn(3, 3)
9463        self.run_test(M(), input_args=(mat1, mat2))
9464
9465    @skipIfUnsupportedMinOpsetVersion(
9466        9
9467    )  # Because where op is not supported for opset < 9.
9468    def test_where_with_bool_tensor(self):
9469        class M(torch.nn.Module):
9470            def forward(self, mat1, mat2):
9471                out = torch.where(mat1 > 0, mat1, mat2)
9472                return out
9473
9474        mat1 = torch.randn(2, 3)
9475        mat2 = torch.ones(2, 3)
9476        self.run_test(M(), input_args=(mat1, mat2))
9477
9478    @skipIfUnsupportedMinOpsetVersion(
9479        9
9480    )  # Because where op is not supported for opset < 9.
9481    def test_where_with_byte_tensor(self):
9482        class M(torch.nn.Module):
9483            def forward(self, cond, mat1, mat2):
9484                out = torch.where(cond, mat1, mat2)
9485                return out
9486
9487        cond = torch.ones(2, 3, dtype=torch.uint8)
9488        cond[1, 2] = 0
9489        mat1 = torch.randn(2, 3)
9490        mat2 = torch.ones(2, 3)
9491        self.run_test(M(), input_args=(cond, mat1, mat2))
9492
9493    @skipIfUnsupportedMinOpsetVersion(10)  # ONNX IsInf op is added in opset 10.
9494    def test_isinf(self):
9495        class M(torch.nn.Module):
9496            def forward(self, x):
9497                return x.isinf()
9498
9499        x = torch.tensor([[1, 2, float("inf")], [2, float("nan"), float("inf")]])
9500        self.run_test(M(), (x,))
9501
9502    @skipIfUnsupportedMinOpsetVersion(10)
9503    def test_isfinite(self):
9504        class M(torch.nn.Module):
9505            def forward(self, x):
9506                return x.isfinite()
9507
9508        x = torch.tensor([[1, 2, float("inf")], [2, float("nan"), -float("inf")]])
9509        self.run_test(M(), (x,))
9510
9511    @skipIfUnsupportedMinOpsetVersion(9)  # ONNX IsNaN op is added in opset 9.
9512    def test_isnan(self):
9513        class M(torch.nn.Module):
9514            def forward(self, x):
9515                return x.isnan()
9516
9517        x = torch.tensor([[1, 2, float("inf")], [2, float("nan"), float("inf")]])
9518        self.run_test(M(), (x,))
9519
9520    @skipIfUnsupportedMinOpsetVersion(
9521        10
9522    )  # ONNX IsNaN, IsInf op is added in opset 9, 10 respectively.
9523    def test_nan_to_num(self):
9524        class NoParams(torch.nn.Module):
9525            def forward(self, x):
9526                return x.nan_to_num()
9527
9528        x = torch.tensor([[1, 2, float("inf")], [2, float("nan"), -float("inf")]])
9529        xint = torch.ones((2, 4), dtype=torch.int)
9530        xhalf = torch.ones((2, 4), dtype=torch.half)
9531        self.run_test(NoParams(), (x,))
9532        self.run_test(NoParams(), (xint,))
9533        self.run_test(NoParams(), (xhalf,))
9534
9535        class WithParams(torch.nn.Module):
9536            def forward(self, x):
9537                return x.nan_to_num(nan=2.3, posinf=4.5, neginf=6.7)
9538
9539        x = torch.tensor([[1, 2, float("inf")], [2, float("nan"), -float("inf")]])
9540        self.run_test(WithParams(), (x,))
9541
9542    @skipIfUnsupportedMinOpsetVersion(9)
9543    def test_maximum_minimum(self):
9544        class ModelWithNan(torch.nn.Module):
9545            def forward(self, x, y):
9546                return torch.maximum(x, y), torch.minimum(x, y)
9547
9548        x = torch.tensor([-2, -2, float("nan")])
9549        y = torch.rand(1, 3)
9550        self.run_test(ModelWithNan(), (x, y))
9551
9552    @skipIfUnsupportedMinOpsetVersion(12)
9553    def test_minimum_dtypes(self):
9554        class MinimumModel(torch.nn.Module):
9555            def forward(self, x, y):
9556                return torch.minimum(x, y)
9557
9558        x = torch.randn((5, 5), dtype=torch.float16)
9559        y = torch.randn((5, 5), dtype=torch.float)
9560        self.run_test(MinimumModel(), (x, y))
9561
9562        x = torch.randn((5, 5), dtype=torch.float16)
9563        y = torch.randint(10, (5, 5), dtype=torch.int16)
9564        self.run_test(MinimumModel(), (x, y))
9565
9566        x = torch.randint(10, (5, 5), dtype=torch.int16)
9567        y = torch.randint(10, (5, 5), dtype=torch.int32)
9568        self.run_test(MinimumModel(), (x, y))
9569
9570        x = torch.randint(10, (5, 5), dtype=torch.int)
9571        y = torch.full_like(x, True)
9572        self.run_test(MinimumModel(), (x, y))
9573
9574    @skipIfUnsupportedMinOpsetVersion(12)
9575    def test_maximum_dtypes(self):
9576        class MaximumModel(torch.nn.Module):
9577            def forward(self, x, y):
9578                return torch.maximum(x, y)
9579
9580        x = torch.randn((5, 5), dtype=torch.float16)
9581        y = torch.randn((5, 5), dtype=torch.float)
9582        self.run_test(MaximumModel(), (x, y))
9583
9584        x = torch.randn((5, 5), dtype=torch.float16)
9585        y = torch.randint(10, (5, 5), dtype=torch.int16)
9586        self.run_test(MaximumModel(), (x, y))
9587
9588        x = torch.randint(10, (5, 5), dtype=torch.int16)
9589        y = torch.randint(10, (5, 5), dtype=torch.int32)
9590        self.run_test(MaximumModel(), (x, y))
9591
9592        x = torch.randint(10, (5, 5), dtype=torch.int)
9593        y = torch.full_like(x, True)
9594        self.run_test(MaximumModel(), (x, y))
9595
9596    @skipIfUnsupportedMinOpsetVersion(9)
9597    def test_any(self):
9598        class M(torch.nn.Module):
9599            def forward(self, x):
9600                return x.any()
9601
9602        x = torch.tensor([[True, False], [False, False]])
9603        self.run_test(M(), (x,))
9604
9605        class MDim(torch.nn.Module):
9606            def forward(self, x):
9607                return x.any(dim=1)
9608
9609        x = torch.rand(3, 4).bool()
9610        self.run_test(MDim(), (x,))
9611
9612        class MKeepdim(torch.nn.Module):
9613            def forward(self, x):
9614                return x.any(dim=1, keepdim=True)
9615
9616        x = torch.rand(3, 4).bool()
9617        self.run_test(MKeepdim(), (x,))
9618
9619    @skipIfUnsupportedMinOpsetVersion(9)
9620    def test_all(self):
9621        class M(torch.nn.Module):
9622            def forward(self, x):
9623                return x.all()
9624
9625        x = torch.tensor([[True, False], [False, False]])
9626        self.run_test(M(), (x,))
9627
9628        class MDim(torch.nn.Module):
9629            def forward(self, x):
9630                return x.all(dim=1)
9631
9632        x = torch.rand(3, 4).bool()
9633        self.run_test(MDim(), (x,))
9634
9635        class MKeepdim(torch.nn.Module):
9636            def forward(self, x):
9637                return x.all(dim=1, keepdim=True)
9638
9639        x = torch.rand(3, 4).bool()
9640        self.run_test(MKeepdim(), (x,))
9641
9642    def test_dropout(self):
9643        class M(torch.nn.Module):
9644            def __init__(self) -> None:
9645                super().__init__()
9646                self.dropout = torch.nn.Dropout(0.3)
9647
9648            def forward(self, x):
9649                dropout = self.dropout(x)
9650                return dropout
9651
9652        x = torch.randn(10, 3, 53)
9653        self.run_test(M(), (x))
9654
9655    def test_rrelu_eval(self):
9656        x = torch.tensor([0.5, -0.5])
9657        self.run_test(torch.nn.RReLU(0.1, 0.3).eval(), x)
9658
9659    def test_shape_constant_fold(self):
9660        class ShapeModule(torch.nn.Module):
9661            def __init__(self) -> None:
9662                super().__init__()
9663                self.weight = torch.nn.Buffer(torch.ones(5))
9664
9665            def forward(self, x):
9666                shape = self.weight.shape[0]
9667                return x + shape
9668
9669        x = torch.randn(2, 5)
9670        self.run_test(ShapeModule(), (x,), rtol=1e-3, atol=1e-5)
9671
9672    @skipIfUnsupportedMinOpsetVersion(12)
9673    def test_celu(self):
9674        class Celu(torch.nn.Module):
9675            def __init__(self) -> None:
9676                super().__init__()
9677                self.celu = torch.nn.CELU(alpha=1.0)
9678
9679            def forward(self, input):
9680                return self.celu(input)
9681
9682        input = torch.randn(2)
9683        self.run_test(Celu(), (input,))
9684
9685    @skipIfUnsupportedMinOpsetVersion(12)
9686    def test_celu_default(self):
9687        class Celu(torch.nn.Module):
9688            def __init__(self) -> None:
9689                super().__init__()
9690                self.celu = torch.nn.CELU()
9691
9692            def forward(self, input):
9693                return self.celu(input)
9694
9695        input = torch.randn(2)
9696        self.run_test(Celu(), (input,))
9697
9698    @skipIfUnsupportedMinOpsetVersion(12)
9699    def test_celu_alpha(self):
9700        class Celu(torch.nn.Module):
9701            def __init__(self) -> None:
9702                super().__init__()
9703                self.celu = torch.nn.CELU(alpha=2.0)
9704
9705            def forward(self, input):
9706                return self.celu(input)
9707
9708        input = torch.randn(2)
9709        self.run_test(Celu(), (input,))
9710
9711    @skipIfUnsupportedMinOpsetVersion(12)
9712    def test_celu_cast(self):
9713        class Celu(torch.nn.Module):
9714            def __init__(self) -> None:
9715                super().__init__()
9716                self.celu = torch.nn.CELU()
9717
9718            def forward(self, input):
9719                return self.celu(input)
9720
9721        input = torch.randn(2, 5, 7, dtype=torch.float64)
9722        self.run_test(Celu(), (input,))
9723
9724    def test_lower_tuple(self):
9725        class TupleModule(torch.nn.Module):
9726            def forward(self, input1: Tensor, input2: Tensor, input3: Tensor) -> Tensor:
9727                a = (input1, input2)
9728                b = a
9729                c = (input1, input2, input3)
9730                for i in range(5):
9731                    d = a[0]
9732                    for j in range(2):
9733                        e, f = a
9734                        a = (d, f)
9735                        f = c[2]
9736                        if f.size(0) != input1.size(-1):
9737                            g = b[1]
9738                            b = (g, f)
9739                        else:
9740                            k = c[1:]
9741                            b = (f, k[0])
9742                    m, n = b
9743                    c = (input1, n, m)
9744                p, q, r = c
9745                return p + q + r
9746
9747        input1 = torch.randn(2)
9748        input2 = torch.randn(2)
9749        input3 = torch.randn(2)
9750        self.run_test(TupleModule(), (input1, input2, input3))
9751
9752    def test_lower_tuple_2(self):
9753        class TupleModule(torch.nn.Module):
9754            def forward(self, input1: Tensor, input2: Tensor) -> Tuple[Tensor, Tensor]:
9755                a = (input1, input2)
9756                for x in range(5):
9757                    c, d = a
9758                    a = (c, d)
9759                return a
9760
9761        input1 = torch.randn(2)
9762        input2 = torch.randn(2)
9763        self.run_test(TupleModule(), (input1, input2))
9764
9765    def test_lower_tuple_3(self):
9766        class TupleModule(torch.nn.Module):
9767            def forward(
9768                self,
9769                input1: Tuple[Tensor, Tensor],
9770                input2: Tuple[Tensor, Tensor],
9771            ) -> Tuple[Tuple[Tensor, Tensor], Tuple[Tensor, Tensor]]:
9772                a = input1
9773                b = input2
9774                for x in range(5):
9775                    c, d = a
9776                    e, f = b
9777                    if c.shape[0] == e.shape[0]:
9778                        e = e + c
9779                    else:
9780                        f = f + d
9781                    a = (e, f)
9782                    b = (c, d)
9783                return a, b
9784
9785        input1 = (torch.randn(2), torch.randn(2))
9786        input2 = (torch.randn(2), torch.randn(2))
9787        self.run_test(TupleModule(), (input1, input2))
9788
9789    @skipIfUnsupportedMinOpsetVersion(9)
9790    def test_where(self):
9791        class Model(torch.nn.Module):
9792            def forward(self, cond, input, other):
9793                return torch.where(cond, input, other)
9794
9795        x = torch.randint(0, 1, (2, 3, 4), dtype=torch.bool)
9796        y = torch.randn(2, 1, 4)
9797        z = torch.ones(2, 3, 1)
9798        self.run_test(Model(), (x, y, z))
9799
9800    @skipIfUnsupportedMinOpsetVersion(9)
9801    @skipScriptTest()  # scripting tests run for opsets > 11. See: test_where_condition_script
9802    def test_where_condition(self):
9803        class Model1(torch.nn.Module):
9804            def forward(self, input):
9805                return torch.stack(torch.where(input > 0.5), dim=1)
9806
9807        x = torch.randint(0, 2, (2, 3, 4), dtype=bool)
9808        self.run_test(Model1(), (x))
9809
9810        class Model2(torch.nn.Module):
9811            def forward(self, input, other):
9812                return torch.stack(torch.where(input > other), dim=1)
9813
9814        x = torch.randint(0, 1, (2, 3, 4), dtype=bool)
9815        y = torch.randint(1, 2, (2, 3, 4), dtype=bool)
9816        self.run_test(Model2(), (x, y))
9817
9818    @skipIfUnsupportedOpsetVersion([13])
9819    @skipIfUnsupportedMinOpsetVersion(11)
9820    def test_where_condition_script(self):
9821        class Model1(torch.nn.Module):
9822            def forward(self, input):
9823                return torch.stack(torch.where(input > 0.5), dim=1)
9824
9825        x = torch.randint(0, 2, (2, 3, 4), dtype=bool)
9826        self.run_test(Model1(), (x))
9827
9828        class Model2(torch.nn.Module):
9829            def forward(self, input, other):
9830                return torch.stack(torch.where(input > other), dim=1)
9831
9832        x = torch.randint(0, 1, (2, 3, 4), dtype=bool)
9833        y = torch.randint(1, 2, (2, 3, 4), dtype=bool)
9834        self.run_test(Model2(), (x, y))
9835
9836    def test_empty_branch(self):
9837        class EmptyBranchModel(torch.jit.ScriptModule):
9838            @torch.jit.script_method
9839            def forward(self, input):
9840                out = input + 1
9841                if out.dim() > 2:
9842                    if out.dim() > 3:
9843                        out += 3
9844                    else:
9845                        pass
9846                else:
9847                    pass
9848                return out
9849
9850        x = torch.randn(1, 2, 3, requires_grad=True)
9851        self.run_test(EmptyBranchModel(), x)
9852
9853    @skipIfUnsupportedMinOpsetVersion(11)
9854    def test_derive_index_scripting(self):
9855        class MyModule(torch.nn.Module):
9856            def forward(self, x: Tensor):
9857                j = []
9858                for idx in range(len(x) - 1, -len(x), -2):
9859                    y = x[idx]
9860                    j += [x * y]
9861                return j
9862
9863        x = torch.randn(5, 13)
9864        self.run_test(MyModule(), x)
9865
9866        class MyModule(torch.nn.Module):
9867            def forward(self, x: Tensor):
9868                j = []
9869                for idx in range(-len(x), len(x) - 1, 2):
9870                    y = x[idx]
9871                    j += [x * y]
9872                return j
9873
9874        x = torch.randn(5, 13)
9875        self.run_test(MyModule(), x)
9876
9877        class MyModule(torch.nn.Module):
9878            def forward(self, x: Tensor):
9879                j = []
9880                for idx in range(len(x) - 1, -len(x), -3):
9881                    y = x[idx]
9882                    j += [x * y]
9883                return j
9884
9885        self.run_test(MyModule(), x)
9886
9887        class MyModule(torch.nn.Module):
9888            def forward(self, x: Tensor):
9889                j = []
9890                for idx in range(-len(x), len(x) - 1, 3):
9891                    y = x[idx]
9892                    j += [x * y]
9893                return j
9894
9895        self.run_test(MyModule(), x)
9896
9897    @skipScriptTest()  # Scripting fails for add lists for opsets < 11. Chek test_derive_index_scripting
9898    def test_derive_index(self):
9899        class MyModule(torch.nn.Module):
9900            def forward(self, x: Tensor):
9901                j = []
9902                for idx in range(len(x) - 1, -len(x), -2):
9903                    y = x[idx]
9904                    j += [x * y]
9905                return j
9906
9907        x = torch.randn(5, 13)
9908        self.run_test(MyModule(), x)
9909
9910        class MyModule(torch.nn.Module):
9911            def forward(self, x: Tensor):
9912                j = []
9913                for idx in range(-len(x), len(x) - 1, 2):
9914                    y = x[idx]
9915                    j += [x * y]
9916                return j
9917
9918        x = torch.randn(5, 13)
9919        self.run_test(MyModule(), x)
9920
9921        class MyModule(torch.nn.Module):
9922            def forward(self, x: Tensor):
9923                j = []
9924                for idx in range(len(x) - 1, -len(x), -3):
9925                    y = x[idx]
9926                    j += [x * y]
9927                return j
9928
9929        self.run_test(MyModule(), x)
9930
9931        class MyModule(torch.nn.Module):
9932            def forward(self, x: Tensor):
9933                j = []
9934                for idx in range(-len(x), len(x) - 1, 3):
9935                    y = x[idx]
9936                    j += [x * y]
9937                return j
9938
9939        self.run_test(MyModule(), x)
9940
9941    @skipIfUnsupportedMinOpsetVersion(11)
9942    def test_if_transpose(self):
9943        class IfModel(torch.nn.Module):
9944            def forward(self, x):
9945                x = x.transpose(0, 1)
9946                if x.size(0) == 2:
9947                    return x.transpose(0, 1)
9948                else:
9949                    return x
9950
9951        x = torch.randn(2, 3)
9952        self.run_test(
9953            torch.jit.script(IfModel()),
9954            x,
9955            output_names=["output_1"],
9956            dynamic_axes={"output_1": [0, 1]},
9957        )
9958
9959    @skipIfUnsupportedMinOpsetVersion(13)
9960    def test_if_list(self):
9961        class IfModel(torch.nn.Module):
9962            def forward(self, x, y, cond):
9963                res = []
9964                if cond:
9965                    res = res + [x]
9966                else:
9967                    res = res + [y]
9968                return res
9969
9970        x = torch.randn(2, 3)
9971        y = torch.randn(3, 3)
9972        cond = torch.tensor(1, dtype=torch.bool)
9973        self.run_test(torch.jit.script(IfModel()), (x, y, cond))
9974
9975    @skipIfUnsupportedMinOpsetVersion(13)
9976    def test_if_view(self):
9977        class IfModel(torch.nn.Module):
9978            def forward(self, x, y, cond):
9979                bs, seq = y.shape[:2]
9980                if cond:
9981                    res = x.view(bs, seq, -1)
9982                else:
9983                    res = y
9984                return res.transpose(1, 2)
9985
9986        x = torch.randn(2, 16, 2, 2)
9987        y = torch.randn(2, 16, 8)
9988        cond = torch.tensor(1, dtype=torch.bool)
9989        self.run_test(
9990            torch.jit.script(IfModel()),
9991            (x, y, cond),
9992            output_names=["output_1"],
9993            dynamic_axes={"output_1": [1]},
9994        )
9995
9996    @skipScriptTest(
9997        skip_before_opset_version=11, reason="dynamic split support added in 11"
9998    )
9999    def test_split_tensor_scalar(self):
10000        class SplitModel(torch.nn.Module):
10001            def forward(self, x):
10002                return torch.split(x, x.size(1))
10003
10004        x = torch.randn(1, 2, 3, requires_grad=True)
10005        self.run_test(SplitModel(), x)
10006
10007    def test_split_tensor_multi(self):
10008        class SplitModel(torch.nn.Module):
10009            def forward(self, x):
10010                return torch.split(x, torch.ones(3))
10011
10012        x = torch.randn(1, 2, 3, requires_grad=True)
10013
10014        def run_model():
10015            SplitModel(x)
10016
10017        self.assertRaises(TypeError, run_model)
10018
10019    @skipIfUnsupportedMinOpsetVersion(9)
10020    def test_embedding(self):
10021        class EmbedModel(torch.nn.Module):
10022            def forward(self, input, emb):
10023                return torch.nn.functional.embedding(input, emb, padding_idx=1)
10024
10025        model = EmbedModel()
10026        x = torch.randint(4, (4,))
10027        x[2] = x[0] = 1
10028        embedding_matrix = torch.rand(10, 3)
10029        self.run_test(model, (x, embedding_matrix))
10030
10031        x = torch.randint(4, (4, 3, 2))
10032        x[2] = 1
10033        x[0][1] = 1
10034        self.run_test(model, (x, embedding_matrix))
10035        self.run_test(
10036            model, (x, embedding_matrix), training=torch.onnx.TrainingMode.TRAINING
10037        )
10038
10039        class EmbedModelWithoutPaddingIdx(torch.nn.Module):
10040            def forward(self, input, emb):
10041                return torch.nn.functional.embedding(input, emb)
10042
10043        model = EmbedModelWithoutPaddingIdx()
10044        x = torch.randint(4, (4, 3, 2))
10045        self.run_test(model, (x, embedding_matrix))
10046
10047    @skipIfUnsupportedMinOpsetVersion(9)
10048    def test_embedding_module(self):
10049        class EmbedModel(torch.nn.Module):
10050            def __init__(self) -> None:
10051                super().__init__()
10052                self.emb = torch.nn.Embedding(4, 3, padding_idx=1)
10053                self.emb2 = torch.nn.Embedding(4, 3, padding_idx=1)
10054                with torch.no_grad():
10055                    self.emb2.weight[1] = torch.ones(3)
10056
10057            def forward(self, input):
10058                return self.emb(input), self.emb2(input)
10059
10060        model = EmbedModel()
10061        x = torch.randint(4, (4,))
10062        x[2] = x[0] = 1
10063        self.run_test(model, (x,))
10064
10065        x = torch.randint(4, (4, 3, 2))
10066        x[2] = 1
10067        x[0][1] = 1
10068        self.run_test(model, (x,))
10069
10070        class EmbedModelWithoutPaddingIdx(torch.nn.Module):
10071            def __init__(self) -> None:
10072                super().__init__()
10073                self.emb = torch.nn.Embedding(4, 3)
10074
10075            def forward(self, input):
10076                return self.emb(input)
10077
10078        model = EmbedModelWithoutPaddingIdx()
10079        x = torch.randint(4, (4, 3, 2))
10080        self.run_test(model, (x,))
10081
10082    @skipIfUnsupportedMinOpsetVersion(11)
10083    def test_embedding_renorm(self):
10084        n, d = 7, 5
10085        embedding = torch.nn.Embedding(n, d, max_norm=0.2)
10086        idx = torch.tensor([2, 1])
10087        self.run_test(embedding, idx)
10088
10089        embedding = torch.nn.Embedding(n, d, max_norm=0.5, norm_type=1.0)
10090        idx = torch.tensor([4, 3, 4, 2])
10091        self.run_test(embedding, idx)
10092
10093    def _dispatch_rnn_test(self, name, *args, **kwargs):
10094        if name == "elman":
10095            self._elman_rnn_test(*args, **kwargs)
10096        if name == "lstm":
10097            self._lstm_test(*args, **kwargs)
10098        if name == "gru":
10099            self._gru_test(*args, **kwargs)
10100
10101    def _elman_rnn_test(
10102        self,
10103        layers,
10104        nonlinearity,
10105        bidirectional,
10106        initial_state,
10107        packed_sequence,
10108        dropout,
10109        **extra_kwargs,
10110    ):
10111        class ElmanWithStateModel(torch.nn.Module):
10112            def __init__(self, layers, nonlinearity, bidirect, dropout, batch_first):
10113                super().__init__()
10114
10115                self.batch_first = batch_first
10116                self.inner_model = torch.nn.RNN(
10117                    RNN_INPUT_SIZE,
10118                    RNN_HIDDEN_SIZE,
10119                    layers,
10120                    nonlinearity=nonlinearity,
10121                    bidirectional=bidirectional,
10122                    dropout=dropout,
10123                    batch_first=batch_first,
10124                )
10125
10126            def forward(self, input: rnn_utils.PackedSequence, hx=None):
10127                return self.inner_model(input, hx)
10128
10129        class ElmanWithoutStateModel(torch.nn.Module):
10130            def __init__(self, layers, nonlinearity, bidirect, dropout, batch_first):
10131                super().__init__()
10132                self.batch_first = batch_first
10133                self.inner_model = torch.nn.RNN(
10134                    RNN_INPUT_SIZE,
10135                    RNN_HIDDEN_SIZE,
10136                    layers,
10137                    nonlinearity=nonlinearity,
10138                    bidirectional=bidirectional,
10139                    dropout=dropout,
10140                    batch_first=batch_first,
10141                )
10142
10143            def forward(self, input: rnn_utils.PackedSequence):
10144                return self.inner_model(input)
10145
10146        batch_first = packed_sequence == 2
10147
10148        if initial_state:
10149            model = ElmanWithStateModel(
10150                layers=layers,
10151                bidirect=bidirectional,
10152                nonlinearity=nonlinearity,
10153                dropout=dropout,
10154                batch_first=batch_first,
10155            )
10156            if packed_sequence:
10157                model = (
10158                    rnn_model_with_packed_sequence.RnnModelWithPackedSequenceWithState(
10159                        model, batch_first
10160                    )
10161                )
10162        else:
10163            model = ElmanWithoutStateModel(
10164                layers=layers,
10165                bidirect=bidirectional,
10166                nonlinearity=nonlinearity,
10167                dropout=dropout,
10168                batch_first=batch_first,
10169            )
10170            if packed_sequence:
10171                model = rnn_model_with_packed_sequence.RnnModelWithPackedSequenceWithoutState(
10172                    model, batch_first
10173                )
10174
10175        def make_input(batch_size):
10176            seq_lengths = np.random.randint(1, RNN_SEQUENCE_LENGTH + 1, size=batch_size)
10177            seq_lengths = sorted(map(int, seq_lengths), reverse=True)
10178            inputs = [torch.randn(l, RNN_INPUT_SIZE) for l in seq_lengths]
10179            inputs = rnn_utils.pad_sequence(inputs, batch_first=batch_first)
10180            inputs = [inputs]
10181            input_names = ["input"]
10182
10183            directions = 2 if bidirectional else 1
10184
10185            if initial_state:
10186                h0 = torch.randn(directions * layers, batch_size, RNN_HIDDEN_SIZE)
10187                inputs.append(h0)
10188                input_names.append("h0")
10189            if packed_sequence != 0:
10190                inputs.append(torch.IntTensor(seq_lengths))
10191                input_names.append("seq_lengths")
10192            if len(inputs) == 1:
10193                input = inputs[0]
10194            else:
10195                input = tuple(inputs)
10196            return input, input_names
10197
10198        input, input_names = make_input(RNN_BATCH_SIZE)
10199        dynamic_axes = {"input": [0, 1], "seq_lengths": [0]}
10200        if initial_state:
10201            dynamic_axes.update({"h0": [1]})
10202        export_options = {"input_names": input_names, "dynamic_axes": dynamic_axes}
10203
10204        # test that the model still runs with a different batch size
10205        other_input, _ = make_input(RNN_BATCH_SIZE + 1)
10206        self.run_test(
10207            model, input, additional_test_inputs=[other_input], **export_options
10208        )
10209
10210    def _lstm_test(
10211        self,
10212        layers,
10213        bidirectional,
10214        initial_state,
10215        packed_sequence,
10216        dropout,
10217        **extra_kwargs,
10218    ):
10219        batch_first = packed_sequence == 2
10220
10221        if packed_sequence:
10222            model = lstm_flattening_result.LstmFlatteningResultWithSeqLength(
10223                RNN_INPUT_SIZE,
10224                RNN_HIDDEN_SIZE,
10225                layers,
10226                bidirectional,
10227                dropout,
10228                batch_first,
10229            )
10230            if initial_state:
10231                model = (
10232                    rnn_model_with_packed_sequence.RnnModelWithPackedSequenceWithState(
10233                        model, batch_first
10234                    )
10235                )
10236            else:
10237                model = rnn_model_with_packed_sequence.RnnModelWithPackedSequenceWithoutState(
10238                    model, batch_first
10239                )
10240        else:
10241            model = lstm_flattening_result.LstmFlatteningResultWithoutSeqLength(
10242                RNN_INPUT_SIZE,
10243                RNN_HIDDEN_SIZE,
10244                layers,
10245                bidirectional,
10246                dropout,
10247                batch_first,
10248            )
10249
10250        def make_input(batch_size):
10251            seq_lengths = np.random.randint(1, RNN_SEQUENCE_LENGTH + 1, size=batch_size)
10252            seq_lengths = sorted(map(int, seq_lengths), reverse=True)
10253            inputs = [torch.randn(l, RNN_INPUT_SIZE) for l in seq_lengths]
10254            inputs = rnn_utils.pad_sequence(inputs, batch_first=batch_first)
10255            inputs = [inputs]
10256            input_names = ["input"]
10257            directions = 2 if bidirectional else 1
10258
10259            if initial_state:
10260                h0 = torch.randn(directions * layers, batch_size, RNN_HIDDEN_SIZE)
10261                c0 = torch.randn(directions * layers, batch_size, RNN_HIDDEN_SIZE)
10262                inputs.append((h0, c0))
10263                input_names.append("h0")
10264                input_names.append("c0")
10265            if packed_sequence != 0:
10266                inputs.append(torch.IntTensor(seq_lengths))
10267                input_names.append("seq_lengths")
10268            if len(inputs) == 1:
10269                input = inputs[0]
10270            else:
10271                input = tuple(inputs)
10272            return input, input_names
10273
10274        input, input_names = make_input(RNN_BATCH_SIZE)
10275        dynamic_axes = {"input": [0, 1], "seq_lengths": [0]}
10276        if initial_state:
10277            dynamic_axes.update({"h0": [1], "c0": [1]})
10278        export_options = {"input_names": input_names, "dynamic_axes": dynamic_axes}
10279
10280        # test that the model still runs with a different batch size
10281        other_input, _ = make_input(RNN_BATCH_SIZE + 1)
10282        self.run_test(
10283            model, input, additional_test_inputs=[other_input], **export_options
10284        )
10285
10286    def _gru_test(
10287        self,
10288        layers,
10289        bidirectional,
10290        initial_state,
10291        packed_sequence,
10292        dropout,
10293        **extra_kwargs,
10294    ):
10295        class GRUWithStateModel(torch.nn.Module):
10296            def __init__(self, layers, bidirect, dropout, batch_first):
10297                super().__init__()
10298
10299                self.batch_first = batch_first
10300                self.inner_model = torch.nn.GRU(
10301                    RNN_INPUT_SIZE,
10302                    RNN_HIDDEN_SIZE,
10303                    num_layers=layers,
10304                    bidirectional=bidirectional,
10305                    dropout=dropout,
10306                    batch_first=batch_first,
10307                )
10308
10309            def forward(self, input: rnn_utils.PackedSequence, hx):
10310                return self.inner_model(input, hx)
10311
10312        class GRUWithoutStateModel(torch.nn.Module):
10313            def __init__(self, layers, bidirect, dropout, batch_first):
10314                super().__init__()
10315                self.batch_first = batch_first
10316                self.inner_model = torch.nn.GRU(
10317                    RNN_INPUT_SIZE,
10318                    RNN_HIDDEN_SIZE,
10319                    num_layers=layers,
10320                    bidirectional=bidirectional,
10321                    dropout=dropout,
10322                    batch_first=batch_first,
10323                )
10324
10325            def forward(self, input: rnn_utils.PackedSequence):
10326                return self.inner_model(input)
10327
10328        class GRUNoSeqLengthWithoutStateModel(torch.nn.Module):
10329            def __init__(self, layers, bidirect, dropout, batch_first):
10330                super().__init__()
10331                self.batch_first = batch_first
10332                self.inner_model = torch.nn.GRU(
10333                    RNN_INPUT_SIZE,
10334                    RNN_HIDDEN_SIZE,
10335                    num_layers=layers,
10336                    bidirectional=bidirectional,
10337                    dropout=dropout,
10338                    batch_first=batch_first,
10339                )
10340
10341            def forward(self, input):
10342                return self.inner_model(input)
10343
10344        class GRUNoSeqLengthWithStateModel(torch.nn.Module):
10345            def __init__(self, layers, bidirect, dropout, batch_first):
10346                super().__init__()
10347                self.batch_first = batch_first
10348                self.inner_model = torch.nn.GRU(
10349                    RNN_INPUT_SIZE,
10350                    RNN_HIDDEN_SIZE,
10351                    num_layers=layers,
10352                    bidirectional=bidirectional,
10353                    dropout=dropout,
10354                    batch_first=batch_first,
10355                )
10356
10357            def forward(self, input, hx):
10358                return self.inner_model(input, hx)
10359
10360        batch_first = packed_sequence == 2
10361
10362        if packed_sequence:
10363            if initial_state:
10364                model = GRUWithStateModel(
10365                    layers=layers,
10366                    bidirect=bidirectional,
10367                    dropout=dropout,
10368                    batch_first=batch_first,
10369                )
10370                model = (
10371                    rnn_model_with_packed_sequence.RnnModelWithPackedSequenceWithState(
10372                        model, batch_first
10373                    )
10374                )
10375            else:
10376                model = GRUWithoutStateModel(
10377                    layers=layers,
10378                    bidirect=bidirectional,
10379                    dropout=dropout,
10380                    batch_first=batch_first,
10381                )
10382                model = rnn_model_with_packed_sequence.RnnModelWithPackedSequenceWithoutState(
10383                    model, batch_first
10384                )
10385        else:
10386            if initial_state:
10387                model = GRUNoSeqLengthWithStateModel(
10388                    layers=layers,
10389                    bidirect=bidirectional,
10390                    dropout=dropout,
10391                    batch_first=batch_first,
10392                )
10393            else:
10394                model = GRUNoSeqLengthWithoutStateModel(
10395                    layers=layers,
10396                    bidirect=bidirectional,
10397                    dropout=dropout,
10398                    batch_first=batch_first,
10399                )
10400
10401        def make_input(batch_size):
10402            seq_lengths = np.random.randint(1, RNN_SEQUENCE_LENGTH + 1, size=batch_size)
10403            seq_lengths = sorted(map(int, seq_lengths), reverse=True)
10404            inputs = [torch.randn(l, RNN_INPUT_SIZE) for l in seq_lengths]
10405            inputs = rnn_utils.pad_sequence(inputs, batch_first=batch_first)
10406            inputs = [inputs]
10407            input_names = ["input"]
10408
10409            directions = 2 if bidirectional else 1
10410
10411            if initial_state:
10412                h0 = torch.randn(directions * layers, batch_size, RNN_HIDDEN_SIZE)
10413                inputs.append(h0)
10414                input_names.append("h0")
10415            if packed_sequence != 0:
10416                inputs.append(torch.IntTensor(seq_lengths))
10417                input_names.append("seq_lengths")
10418            if len(inputs) == 1:
10419                input = inputs[0]
10420            else:
10421                input = tuple(inputs)
10422            return input, input_names
10423
10424        input, input_names = make_input(RNN_BATCH_SIZE)
10425        dynamic_axes = {"input": [0, 1], "seq_lengths": [0]}
10426        if initial_state:
10427            dynamic_axes.update({"h0": [1]})
10428        export_options = {"input_names": input_names, "dynamic_axes": dynamic_axes}
10429
10430        # test that the model still runs with a different batch size
10431        other_input, _ = make_input(RNN_BATCH_SIZE + 1)
10432        self.run_test(
10433            model, input, additional_test_inputs=[other_input], **export_options
10434        )
10435
10436    @skipIfUnsupportedMinOpsetVersion(10)
10437    def test_fake_quantize_per_tensor(self):
10438        class FakeQuantizePerTensorModel(torch.nn.Module):
10439            def forward(self, input):
10440                scale = 1.0 / 127
10441                zero_point = 0
10442                quant_min = -128
10443                quant_max = 127
10444                return torch.fake_quantize_per_tensor_affine(
10445                    input, scale, zero_point, quant_min, quant_max
10446                )
10447
10448        x = torch.randn(6, 4, 3, 3)
10449        self.run_test(FakeQuantizePerTensorModel(), (x))
10450
10451    @skipIfUnsupportedMinOpsetVersion(13)
10452    def test_fake_quantize_per_tensor_dynamic_scale_zeropoint(self):
10453        class FakeQuantizePerTensorModel(torch.nn.Module):
10454            def forward(self, input, scale, zero_point):
10455                quant_min = -128
10456                quant_max = 127
10457                return torch.fake_quantize_per_tensor_affine(
10458                    input, scale, zero_point, quant_min, quant_max
10459                )
10460
10461        x = torch.randn(6, 4, 3, 3)
10462        scale = torch.tensor(1.0 / 127)
10463        zero_point = torch.tensor(0)
10464        self.run_test(FakeQuantizePerTensorModel(), (x, scale, zero_point))
10465
10466    @skipIfUnsupportedMinOpsetVersion(13)
10467    def test_fake_quantize_per_channel(self):
10468        class FakeQuantizePerChannelModel(torch.nn.Module):
10469            def forward(self, input):
10470                amax = torch.ones(4)
10471                scale = amax / 127.0
10472                zero_point = torch.zeros_like(amax, dtype=torch.int)
10473                # Quantize twice to test differnet branches
10474                y = torch.fake_quantize_per_channel_affine(
10475                    input, scale, zero_point, 1, 0, 255
10476                )
10477                return torch.fake_quantize_per_channel_affine(
10478                    y, scale, zero_point, 1, -128, 127
10479                )
10480
10481        x = torch.randn(6, 4, 3, 3)
10482        self.run_test(FakeQuantizePerChannelModel(), (x))
10483
10484    @skipIfUnsupportedMinOpsetVersion(13)
10485    # RuntimeError: Can't redefine method:
10486    # forward on class: __torch__.torch.nn.modules.linear.Linear
10487    @skipScriptTest()
10488    def test_fake_quantize_activation(self):
10489        from torch.ao import quantization
10490
10491        m = torch.nn.Linear(1, 1)
10492        m.qconfig = quantization.QConfig(
10493            activation=quantization.default_fake_quant,
10494            weight=quantization.default_per_channel_weight_fake_quant,
10495        )
10496        quantization.prepare_qat(m.train(), inplace=True)
10497        m.apply(quantization.enable_observer)
10498        m.apply(quantization.enable_fake_quant)
10499        for module in m.modules():
10500            if isinstance(module, quantization.FakeQuantize):
10501                module.calculate_qparams()
10502
10503        m.apply(quantization.disable_observer)
10504        m.eval()
10505
10506        # Fake quantize activation is a special case, as it restricts quantized range to be (0, 127),
10507        # while standard 8bit quantization range is (-128, 127) or (0, 255).
10508        # Set fixed weight, bias and inputs to test if ONNX handles the overflow correctly.
10509        m.weight = torch.nn.Parameter(torch.tensor([[1.0], [1.0], [1.0]]))
10510        m.bias = torch.nn.Parameter(torch.tensor([0.0]))
10511        x = torch.tensor([[150.0], [127.0], [-5.0]])
10512        self.run_test(m, x)
10513
10514    def test_batchnorm_training(self):
10515        class MyModule(torch.nn.Module):
10516            def __init__(self) -> None:
10517                super().__init__()
10518                self.bn1 = torch.nn.BatchNorm2d(3, affine=False)
10519                self.cv1 = torch.nn.Conv2d(3, 3, 10)
10520                self.bn2 = torch.nn.BatchNorm2d(3, affine=True)
10521                self.cv2 = torch.nn.Conv2d(3, 3, 10)
10522                self.bn3 = torch.nn.BatchNorm2d(3, affine=False)
10523
10524            def forward(self, x):
10525                x = self.bn1(x)
10526                x = self.cv1(x)
10527                x = self.bn2(x)
10528                x = self.cv2(x)
10529                x = self.bn3(x)
10530                return x
10531
10532        x = torch.randn(10, 3, 20, 20) * 2
10533        model_export = MyModule()
10534        self.run_test(
10535            model_export,
10536            (x,),
10537            training=torch.onnx.TrainingMode.TRAINING,
10538            rtol=1e-3,
10539            atol=1e-5,
10540        )
10541        model_export.train()
10542        self.run_test(
10543            model_export,
10544            (x,),
10545            training=torch.onnx.TrainingMode.PRESERVE,
10546            rtol=1e-3,
10547            atol=1e-5,
10548        )
10549
10550    def test_batchnorm_training_mode_fix_layer(self):
10551        class MyModule(torch.nn.Module):
10552            def __init__(self) -> None:
10553                super().__init__()
10554                self.bn1 = torch.nn.BatchNorm2d(3, affine=True)
10555                self.cv1 = torch.nn.Conv2d(3, 3, 10)
10556                self.bn2 = torch.nn.BatchNorm2d(3, affine=False)
10557                self.cv2 = torch.nn.Conv2d(3, 3, 10)
10558                self.bn3 = torch.nn.BatchNorm2d(3, affine=True)
10559                self.bn3.eval()
10560
10561            def forward(self, x):
10562                x = self.bn1(x)
10563                x = self.cv1(x)
10564                x = self.bn2(x)
10565                x = self.cv2(x)
10566                x = self.bn3(x)
10567                return x
10568
10569        x = torch.randn(10, 3, 128, 128)
10570        model_export = MyModule()
10571        self.run_test(
10572            model_export,
10573            (x,),
10574            training=torch.onnx.TrainingMode.TRAINING,
10575            rtol=1e-3,
10576            atol=1e-5,
10577        )
10578        model_export.train()
10579        self.run_test(
10580            model_export,
10581            (x,),
10582            training=torch.onnx.TrainingMode.PRESERVE,
10583            rtol=1e-3,
10584            atol=1e-5,
10585        )
10586
10587    def test_batchnorm_eval_mode_train_layer(self):
10588        class MyModule(torch.nn.Module):
10589            def __init__(self) -> None:
10590                super().__init__()
10591                self.bn1 = torch.nn.BatchNorm2d(3, affine=True)
10592                self.cv1 = torch.nn.Conv2d(3, 3, 10)
10593                self.bn2 = torch.nn.BatchNorm2d(3, affine=False)
10594                self.cv2 = torch.nn.Conv2d(3, 3, 10)
10595                self.bn3 = torch.nn.BatchNorm2d(3, affine=True)
10596                self.bn3.train()
10597
10598            def forward(self, x):
10599                x = self.bn1(x)
10600                x = self.cv1(x)
10601                x = self.bn2(x)
10602                x = self.cv2(x)
10603                x = self.bn3(x)
10604                return x
10605
10606        x = torch.randn(10, 3, 128, 128)
10607        model_export = MyModule()
10608        self.run_test(
10609            model_export,
10610            (x,),
10611            training=torch.onnx.TrainingMode.EVAL,
10612            rtol=1e-3,
10613            atol=1e-5,
10614        )
10615        model_export.eval()
10616        self.run_test(
10617            model_export,
10618            (x,),
10619            training=torch.onnx.TrainingMode.PRESERVE,
10620            rtol=1e-3,
10621            atol=1e-5,
10622        )
10623
10624    def test_instancenorm_training(self):
10625        class MyModule(torch.nn.Module):
10626            def __init__(self) -> None:
10627                super().__init__()
10628                self.in1 = torch.nn.InstanceNorm2d(3, affine=True)
10629                self.cv1 = torch.nn.Conv2d(3, 3, 10)
10630                self.in2 = torch.nn.InstanceNorm2d(3, affine=False)
10631                self.cv2 = torch.nn.Conv2d(3, 3, 10)
10632                self.in3 = torch.nn.InstanceNorm2d(3, affine=True)
10633
10634            def forward(self, x):
10635                x = self.in1(x)
10636                x = self.cv1(x)
10637                x = self.in2(x)
10638                x = self.cv2(x)
10639                x = self.in3(x)
10640                return x
10641
10642        x = torch.randn(10, 3, 128, 128)
10643        model_export = MyModule()
10644        self.run_test(
10645            model_export,
10646            (x,),
10647            training=torch.onnx.TrainingMode.TRAINING,
10648            rtol=1e-3,
10649            atol=1e-5,
10650        )
10651        model_export.train()
10652        self.run_test(
10653            model_export,
10654            (x,),
10655            training=torch.onnx.TrainingMode.PRESERVE,
10656            rtol=1e-3,
10657            atol=1e-5,
10658        )
10659
10660    def test_instancenorm_training_mode_fix_layer(self):
10661        class MyModule(torch.nn.Module):
10662            def __init__(self) -> None:
10663                super().__init__()
10664                self.in1 = torch.nn.InstanceNorm2d(3, affine=True)
10665                self.cv1 = torch.nn.Conv2d(3, 3, 10)
10666                self.in2 = torch.nn.InstanceNorm2d(3, affine=False)
10667                self.cv2 = torch.nn.Conv2d(3, 3, 10)
10668                self.in3 = torch.nn.InstanceNorm2d(3, affine=True)
10669                self.in3.eval()
10670
10671            def forward(self, x):
10672                x = self.in1(x)
10673                x = self.cv1(x)
10674                x = self.in2(x)
10675                x = self.cv2(x)
10676                x = self.in3(x)
10677                return x
10678
10679        x = torch.randn(10, 3, 128, 128)
10680        model_export = MyModule()
10681        self.run_test(
10682            model_export,
10683            (x,),
10684            training=torch.onnx.TrainingMode.TRAINING,
10685            rtol=1e-3,
10686            atol=1e-5,
10687        )
10688        model_export.train()
10689        self.run_test(
10690            model_export,
10691            (x,),
10692            training=torch.onnx.TrainingMode.PRESERVE,
10693            rtol=1e-3,
10694            atol=1e-5,
10695        )
10696
10697    def test_instancenorm_eval_mode_train_layer(self):
10698        class MyModule(torch.nn.Module):
10699            def __init__(self) -> None:
10700                super().__init__()
10701                self.in1 = torch.nn.InstanceNorm2d(8, affine=True)
10702                self.cv1 = torch.nn.Conv2d(8, 8, 10)
10703                self.in2 = torch.nn.InstanceNorm2d(8, affine=False)
10704                self.cv2 = torch.nn.Conv2d(8, 8, 10)
10705                self.in3 = torch.nn.InstanceNorm2d(8, affine=True)
10706                self.in3.train()
10707
10708            def forward(self, x):
10709                x = self.in1(x)
10710                x = self.cv1(x)
10711                x = self.in2(x)
10712                x = self.cv2(x)
10713                x = self.in3(x)
10714                return x
10715
10716        x = torch.randn(10, 8, 128, 128)
10717        model_export = MyModule()
10718        self.run_test(
10719            model_export,
10720            (x,),
10721            training=torch.onnx.TrainingMode.EVAL,
10722            rtol=1e-3,
10723            atol=1e-5,
10724        )
10725        model_export.eval()
10726        self.run_test(
10727            model_export,
10728            (x,),
10729            training=torch.onnx.TrainingMode.PRESERVE,
10730            rtol=1e-3,
10731            atol=1e-5,
10732        )
10733
10734    @skipIfUnsupportedMinOpsetVersion(12)
10735    def test_dropout_training(self):
10736        class MyModule(torch.nn.Module):
10737            def __init__(self) -> None:
10738                super().__init__()
10739                self.dropout = torch.nn.Dropout(0.4)
10740
10741            def forward(self, x):
10742                dropout = self.dropout(x)
10743                return dropout
10744
10745        model = MyModule()
10746        x = torch.randn(10)
10747        model.train()
10748
10749        model_onnx = io.BytesIO()
10750        torch.onnx.export(
10751            model,
10752            x,
10753            model_onnx,
10754            opset_version=self.opset_version,
10755            do_constant_folding=False,
10756            training=torch.onnx.TrainingMode.TRAINING,
10757        )
10758        ort_sess = verification._ort_session(model_onnx)
10759        ort_outs = verification._run_onnx(ort_sess, (x,))
10760        assert not torch.all(torch.eq(x, torch.from_numpy(ort_outs[0])))
10761
10762        script_model = torch.jit.script(model)
10763        output = model(x)
10764        model_onnx = io.BytesIO()
10765        torch.onnx.export(
10766            model,
10767            x,
10768            model_onnx,
10769            opset_version=self.opset_version,
10770            do_constant_folding=False,
10771            training=torch.onnx.TrainingMode.TRAINING,
10772        )
10773        ort_outs = verification._run_onnx(ort_sess, (x,))
10774        assert not torch.all(torch.eq(x, torch.from_numpy(ort_outs[0])))
10775
10776    @skipIfUnsupportedMinOpsetVersion(12)
10777    def test_dropout_training_zero(self):
10778        class MyModule(torch.nn.Module):
10779            def __init__(self) -> None:
10780                super().__init__()
10781                self.dropout = torch.nn.Dropout(0.5)
10782
10783            def forward(self, x):
10784                dropout = self.dropout(x)
10785                return dropout
10786
10787        model = MyModule()
10788
10789        # ensure there are no zeros in the input
10790        x = torch.randn(10, 3, 128, 128)
10791        y = x.numpy()
10792        y_mask = np.where(y == 0, 1, y)
10793        input = torch.from_numpy(y_mask)
10794        nb_elements = torch.numel(input)
10795
10796        model.train()
10797        model_onnx = io.BytesIO()
10798        torch.onnx.export(
10799            model,
10800            x,
10801            model_onnx,
10802            opset_version=self.opset_version,
10803            do_constant_folding=False,
10804            training=torch.onnx.TrainingMode.TRAINING,
10805        )
10806        ort_sess = verification._ort_session(model_onnx)
10807        ort_outs = verification._run_onnx(ort_sess, (x,))
10808
10809        y = model(input)
10810        output = y.cpu().numpy()
10811        ort_mask = np.where(ort_outs[0] != 0, 1, 0)
10812        pyt_mask = np.where(output != 0, 1, 0)
10813
10814        ratio_pytorch = np.sum(pyt_mask) / nb_elements
10815        ratio_ort = np.sum(ort_mask) / nb_elements
10816
10817        np.testing.assert_allclose(ratio_pytorch, ratio_ort, rtol=0.01, atol=0.01)
10818
10819        script_model = torch.jit.script(model)
10820        y = model(input)
10821        output = y.cpu().numpy()
10822        model_onnx = io.BytesIO()
10823        torch.onnx.export(
10824            model,
10825            x,
10826            model_onnx,
10827            opset_version=self.opset_version,
10828            do_constant_folding=False,
10829            training=torch.onnx.TrainingMode.TRAINING,
10830        )
10831        ort_sess = verification._ort_session(model_onnx)
10832        ort_outs = verification._run_onnx(ort_sess, (x,))
10833        ort_mask = np.where(ort_outs[0] != 0, 1, 0)
10834        pyt_mask = np.where(output != 0, 1, 0)
10835
10836        ratio_pytorch = np.sum(pyt_mask) / nb_elements
10837        ratio_ort = np.sum(ort_mask) / nb_elements
10838
10839        np.testing.assert_allclose(ratio_pytorch, ratio_ort, rtol=0.01, atol=0.01)
10840
10841    def test_conv_bn(self):
10842        class MyModule(torch.nn.Module):
10843            def __init__(self) -> None:
10844                super().__init__()
10845                self.conv = torch.nn.Conv2d(
10846                    3, 16, kernel_size=1, stride=2, padding=3, bias=True
10847                )
10848                self.bn = torch.nn.BatchNorm2d(16, affine=True)
10849
10850            def forward(self, x):
10851                x = self.conv(x)
10852                bn = self.bn(x)
10853                return bn
10854
10855        model_export = MyModule()
10856        x = torch.randn(10, 3, 128, 128)
10857        self.run_test(model_export, (x,), training=torch.onnx.TrainingMode.EVAL)
10858        self.run_test(
10859            model_export,
10860            (x,),
10861            training=torch.onnx.TrainingMode.TRAINING,
10862            rtol=1e-3,
10863            atol=1e-5,
10864        )
10865
10866    def test_multiple_conv_bn(self):
10867        class MyModule(torch.nn.Module):
10868            def __init__(self) -> None:
10869                super().__init__()
10870                self.conv1 = torch.nn.Conv2d(
10871                    3, 64, kernel_size=7, stride=2, padding=3, bias=False
10872                )
10873                self.conv2 = torch.nn.Conv2d(
10874                    64, 2, kernel_size=1, stride=1, padding=0, bias=False
10875                )
10876                self.conv3 = torch.nn.Conv2d(
10877                    2, 2, kernel_size=3, stride=1, padding=1, bias=False
10878                )
10879                self.bn = torch.nn.BatchNorm2d(64)
10880                self.bn2 = torch.nn.BatchNorm2d(2)
10881                self.relu = torch.nn.ReLU(inplace=True)
10882                self.maxpool = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
10883
10884            def forward(self, x):
10885                x = self.conv1(x)
10886                x = self.bn(x)
10887                x = self.relu(x)
10888                x = self.maxpool(x)
10889                x = self.conv2(x)
10890                x = self.bn2(x)
10891                x = self.relu(x)
10892                x = self.conv3(x)
10893                x = self.bn2(x)
10894                x = self.relu(x)
10895                return x
10896
10897        model_export = MyModule()
10898        x = torch.randn(2, 3, 224, 224)
10899        self.run_test(
10900            model_export,
10901            (x,),
10902            training=torch.onnx.TrainingMode.TRAINING,
10903            rtol=1e-3,
10904            atol=1e-5,
10905        )
10906        self.run_test(model_export, (x,), training=torch.onnx.TrainingMode.EVAL)
10907
10908    @skipIfUnsupportedMinOpsetVersion(11)
10909    def test_nms(self):
10910        num_boxes = 100
10911        boxes = torch.rand(num_boxes, 4)
10912        boxes[:, 2:] += boxes[:, :2]
10913        scores = torch.randn(num_boxes)
10914
10915        class Module(torch.nn.Module):
10916            def forward(self, boxes, scores):
10917                return torchvision.ops.nms(boxes, scores, 0.5)
10918
10919        self.run_test(Module(), (boxes, scores))
10920
10921    @skipIfUnsupportedMinOpsetVersion(11)
10922    def test_batched_nms(self):
10923        num_boxes = 100
10924        boxes = torch.rand(num_boxes, 4)
10925        boxes[:, 2:] += boxes[:, :2]
10926        scores = torch.randn(num_boxes)
10927        idxs = torch.randint(0, 5, size=(num_boxes,))
10928
10929        class Module(torch.nn.Module):
10930            def forward(self, boxes, scores, idxs):
10931                return torchvision.ops.batched_nms(boxes, scores, idxs, 0.5)
10932
10933        self.run_test(Module(), (boxes, scores, idxs))
10934
10935    @skipIfUnsupportedMinOpsetVersion(11)
10936    @skipScriptTest()
10937    def test_clip_boxes_to_image(self):
10938        boxes = torch.randn(5, 4) * 500
10939        boxes[:, 2:] += boxes[:, :2]
10940        size = torch.randn(200, 300)
10941
10942        size_2 = torch.randn(300, 400)
10943
10944        class Module(torch.nn.Module):
10945            def forward(self, boxes, size):
10946                shape = (size.shape[0], size.shape[1])
10947                return torchvision.ops.boxes.clip_boxes_to_image(boxes, shape)
10948
10949        self.run_test(
10950            Module(),
10951            (boxes, size),
10952            input_names=["boxes", "size"],
10953            dynamic_axes={"size": [0, 1]},
10954            additional_test_inputs=[(boxes, size), (boxes, size_2)],
10955        )
10956
10957    @skipScriptTest(
10958        reason="Conditioning on input type via prim::isinstance unsupported in ONNX"
10959    )
10960    @skipIfUnsupportedMinOpsetVersion(11)
10961    def test_roi_align(self):
10962        x = torch.rand(1, 1, 10, 10, dtype=torch.float32)
10963        single_roi = torch.tensor([[0, 0, 0, 4, 4]], dtype=torch.float32)
10964        model = torchvision.ops.RoIAlign((5, 5), 1.0, 2)
10965        self.run_test(model, (x, single_roi))
10966
10967    @skipScriptTest(
10968        reason="Conditioning on input type via prim::isinstance unsupported in ONNX"
10969    )
10970    @skipIfUnsupportedMinOpsetVersion(16)
10971    def test_roi_align_aligned(self):
10972        x = torch.rand(1, 1, 10, 10, dtype=torch.float32)
10973        single_roi = torch.tensor([[0, 1.5, 1.5, 3, 3]], dtype=torch.float32)
10974        model1 = torchvision.ops.RoIAlign((5, 5), 1.0, 2, aligned=True)
10975        self.run_test(model1, (x, single_roi))
10976
10977        x = torch.rand(1, 1, 10, 10, dtype=torch.float32)
10978        single_roi = torch.tensor([[0, 0.2, 0.3, 4.5, 3.5]], dtype=torch.float32)
10979        model2 = torchvision.ops.RoIAlign((5, 5), 0.5, 3, aligned=True)
10980        self.run_test(model2, (x, single_roi))
10981
10982        x = torch.rand(1, 1, 10, 10, dtype=torch.float32)
10983        single_roi = torch.tensor([[0, 0.2, 0.3, 4.5, 3.5]], dtype=torch.float32)
10984        model3 = torchvision.ops.RoIAlign((5, 5), 1.8, 2, aligned=True)
10985        self.run_test(model3, (x, single_roi))
10986
10987        x = torch.rand(1, 1, 10, 10, dtype=torch.float32)
10988        single_roi = torch.tensor([[0, 0.2, 0.3, 4.5, 3.5]], dtype=torch.float32)
10989        model4 = torchvision.ops.RoIAlign((2, 2), 2.5, 0, aligned=True)
10990        self.run_test(model4, (x, single_roi))
10991
10992    @skipScriptTest(
10993        reason="Conditioning on input type via prim::isinstance unsupported in ONNX"
10994    )
10995    @skipIfUnsupportedMinOpsetVersion(11)
10996    def test_roi_pool(self):
10997        x = torch.rand(1, 1, 10, 10, dtype=torch.float32)
10998        rois = torch.tensor([[0, 0, 0, 4, 4]], dtype=torch.float32)
10999        pool_h = 5
11000        pool_w = 5
11001        model = torchvision.ops.RoIPool((pool_h, pool_w), 2.0)
11002        self.run_test(model, (x, rois))
11003
11004    @skipIfUnsupportedMinOpsetVersion(11)
11005    def test_resize_images(self):
11006        class TransformModule(torch.nn.Module):
11007            def __init__(self) -> None:
11008                super().__init__()
11009                self.transform = _init_test_generalized_rcnn_transform()
11010
11011            def forward(self, images):
11012                return self.transform.resize(images, None)[0]
11013
11014        input = torch.rand(3, 10, 20)
11015        input_test = torch.rand(3, 100, 150)
11016        self.run_test(
11017            TransformModule(),
11018            (input,),
11019            input_names=["input1"],
11020            dynamic_axes={"input1": [0, 1, 2]},
11021            additional_test_inputs=[(input,), (input_test,)],
11022        )
11023
11024    @skipIfUnsupportedMinOpsetVersion(11)
11025    @skipScriptTest()
11026    def test_transform_images(self):
11027        class TransformModule(torch.nn.Module):
11028            def __init__(self) -> None:
11029                super().__init__()
11030                self.transform = _init_test_generalized_rcnn_transform()
11031
11032            def forward(self, images: List[Tensor]):
11033                return self.transform(images)[0].tensors
11034
11035        input = torch.rand(3, 100, 200), torch.rand(3, 200, 200)
11036        input_test = torch.rand(3, 100, 200), torch.rand(3, 200, 200)
11037        self.run_test(
11038            TransformModule(),
11039            (input,),
11040            additional_test_inputs=[(input,), (input_test,)],
11041        )
11042
11043    def get_features(self, images):
11044        s0, s1 = images.shape[-2:]
11045        features = [
11046            ("0", torch.rand(2, 256, s0 // 4, s1 // 4)),
11047            ("1", torch.rand(2, 256, s0 // 8, s1 // 8)),
11048            ("2", torch.rand(2, 256, s0 // 16, s1 // 16)),
11049            ("3", torch.rand(2, 256, s0 // 32, s1 // 32)),
11050            ("4", torch.rand(2, 256, s0 // 64, s1 // 64)),
11051        ]
11052        features = OrderedDict(features)
11053        return features
11054
11055    @skipIfUnsupportedMinOpsetVersion(11)
11056    @skipScriptTest()
11057    def test_rpn(self):
11058        class RPNModule(torch.nn.Module):
11059            def __init__(self) -> None:
11060                super().__init__()
11061                self.rpn = _init_test_rpn()
11062
11063            def forward(self, images, features: Dict[str, Tensor]):
11064                images_m = torchvision.models.detection.image_list.ImageList(
11065                    images, [(i.shape[-1], i.shape[-2]) for i in images]
11066                )
11067                return self.rpn(images_m, features)
11068
11069        images = torch.rand(2, 3, 150, 150)
11070        features = self.get_features(images)
11071        images2 = torch.rand(2, 3, 80, 80)
11072        test_features = self.get_features(images2)
11073
11074        model = RPNModule()
11075        model.eval()
11076        model(images, features)
11077        self.run_test(
11078            model,
11079            (images, features),
11080            input_names=["input1", "input2", "input3", "input4", "input5", "input6"],
11081            dynamic_axes={
11082                "input1": [0, 1, 2, 3],
11083                "input2": [0, 1, 2, 3],
11084                "input3": [0, 1, 2, 3],
11085                "input4": [0, 1, 2, 3],
11086                "input5": [0, 1, 2, 3],
11087                "input6": [0, 1, 2, 3],
11088            },
11089            additional_test_inputs=[(images, features), (images2, test_features)],
11090            # dict_check=False,
11091        )
11092
11093    @skipIfUnsupportedMaxOpsetVersion(15)  # TODO: Opset 16 RoiAlign result mismatch
11094    @skipIfUnsupportedMinOpsetVersion(11)
11095    @skipScriptTest()
11096    def test_multi_scale_roi_align(self):
11097        class TransformModule(torch.nn.Module):
11098            def __init__(self) -> None:
11099                super().__init__()
11100                self.model = torchvision.ops.MultiScaleRoIAlign(
11101                    ["feat1", "feat2"], 3, 2
11102                )
11103                self.image_sizes = [(512, 512)]
11104
11105            def forward(self, input: Dict[str, Tensor], boxes: List[Tensor]) -> Tensor:
11106                return self.model(input, boxes, self.image_sizes)
11107
11108        i = OrderedDict()
11109        i["feat1"] = torch.rand(1, 5, 64, 64)
11110        i["feat2"] = torch.rand(1, 5, 16, 16)
11111        boxes = torch.rand(6, 4) * 256
11112        boxes[:, 2:] += boxes[:, :2]
11113
11114        i1 = OrderedDict()
11115        i1["feat1"] = torch.rand(1, 5, 64, 64)
11116        i1["feat2"] = torch.rand(1, 5, 16, 16)
11117        boxes1 = torch.rand(6, 4) * 256
11118        boxes1[:, 2:] += boxes1[:, :2]
11119
11120        self.run_test(
11121            TransformModule(),
11122            (
11123                i,
11124                [boxes],
11125            ),
11126            additional_test_inputs=[
11127                (
11128                    i,
11129                    [boxes],
11130                ),
11131                (
11132                    i1,
11133                    [boxes1],
11134                ),
11135            ],
11136        )
11137
11138    def test_set_(self):
11139        class M(torch.nn.Module):
11140            def forward(self, x, y):
11141                x.set_(y)
11142                return x
11143
11144        x = torch.ones(2, 3)
11145        y = torch.randn(4, 6)
11146        self.run_test(M(), (x, y), remained_onnx_input_idx=[1])
11147
11148        y2 = torch.randn(5, 2)
11149        self.run_test(
11150            M(),
11151            (x, y),
11152            remained_onnx_input_idx=[1],
11153            input_names=["x", "y"],
11154            dynamic_axes={"x": [0, 1], "y": [0, 1]},
11155            additional_test_inputs=[(y, y2)],
11156        )
11157
11158    @skipIfUnsupportedMinOpsetVersion(9)
11159    def test_set_attr_modules(self):
11160        class InnerModule2(torch.nn.Module):
11161            def __init__(self, embedding_dim):
11162                super().__init__()
11163                self.weights = InnerModule2.get_embedding(embedding_dim)
11164                self._float_tensor = torch.nn.Buffer(torch.FloatTensor(1))
11165                self.const = 2
11166
11167            @staticmethod
11168            def get_embedding(embedding_dim: int):
11169                emb = 4 / ((embedding_dim // 2) - 1)
11170                emb = torch.exp(
11171                    torch.arange((embedding_dim // 2), dtype=torch.float) * -emb
11172                )
11173                return emb
11174
11175            def forward(self, input, incremental_state: Optional[Tensor] = None):
11176                bsz, seq_len = input.shape[0], input.shape[1]
11177                self.const = 3
11178                if self.weights is None:
11179                    self.weights = InnerModule.get_embedding(self.embedding_dim)
11180                self.weights = self.weights.to(self._float_tensor)
11181                self.weights = self.weights * self.const
11182                if incremental_state is not None:
11183                    pos = seq_len
11184                    return self.weights[1 + pos, :].expand(bsz, 1, -1)
11185                return self.weights.index_select(
11186                    0, torch.ones((bsz * seq_len), dtype=torch.int64)
11187                ).view(bsz, seq_len, -1)
11188
11189        class InnerModule(torch.nn.Module):
11190            def __init__(self, embedding_dim):
11191                super().__init__()
11192                self.weights = InnerModule.get_embedding(embedding_dim)
11193                self.module = InnerModule2(embedding_dim=8)
11194
11195            @staticmethod
11196            def get_embedding(embedding_dim: int):
11197                emb = 4 / ((embedding_dim // 2) - 1)
11198                emb = torch.exp(
11199                    torch.arange((embedding_dim // 2), dtype=torch.float) * -emb
11200                )
11201                return emb
11202
11203            def forward(self, x):
11204                return self.module(x) + self.weights
11205
11206        class Module(torch.nn.Module):
11207            def __init__(self) -> None:
11208                super().__init__()
11209                self.module = InnerModule(embedding_dim=8)
11210
11211            def forward(self, x):
11212                return self.module(x)
11213
11214        x = torch.randn(3, 256)
11215        self.run_test(Module(), (x,), input_names=["x"], dynamic_axes={"x": [0, 1]})
11216        self.run_test(Module(), (x,), remained_onnx_input_idx=[])
11217
11218    @skipIfUnsupportedMinOpsetVersion(9)
11219    def test_set_attr_modules_2(self):
11220        class InnerModule(torch.nn.Module):
11221            def __init__(self, embedding_dim):
11222                super().__init__()
11223                self.embedding_dim = embedding_dim
11224                self.const = 2.5
11225                self.weights = InnerModule.get_embedding(self.embedding_dim)
11226                self._float_tensor = torch.nn.Buffer(torch.FloatTensor(1))
11227
11228            @staticmethod
11229            def get_embedding(embedding_dim: int):
11230                emb = 4 / ((embedding_dim // 2) - 1)
11231                emb = torch.exp(
11232                    torch.arange((embedding_dim // 2), dtype=torch.float) * -emb
11233                )
11234                return emb
11235
11236            def forward(self, input, incremental_state: Optional[Tensor] = None):
11237                bsz, seq_len = input.shape[0], input.shape[1]
11238                self.const = 1.5
11239                self.weights = InnerModule.get_embedding(self.embedding_dim)
11240                return (
11241                    self.weights.index_select(
11242                        0, torch.ones((bsz * seq_len), dtype=torch.int64)
11243                    ).view(bsz, seq_len, -1)
11244                ) * self.const
11245
11246        class Module(torch.nn.Module):
11247            def __init__(self) -> None:
11248                super().__init__()
11249                self.module = InnerModule(embedding_dim=8)
11250
11251            def forward(self, x):
11252                return self.module(x)
11253
11254        x = torch.randn(3, 256)
11255        self.run_test(Module(), (x,), input_names=["x"], dynamic_axes={"x": [0, 1]})
11256        self.run_test(Module(), (x,), remained_onnx_input_idx=[])
11257
11258    def test_set_attr(self):
11259        class MyModule(torch.nn.Module):
11260            def __init__(self) -> None:
11261                super().__init__()
11262                self.conv = torch.nn.Conv1d(3, 10, 2)
11263                self.b = False
11264
11265            def forward(self, box_regression, weight):
11266                self.b = True
11267                self.conv.weight = weight
11268                w = torch.softmax(self.conv.weight, dim=0)
11269                self.conv.weight = w + w
11270                if self.b:
11271                    return box_regression + self.conv.weight
11272                else:
11273                    return box_regression - self.conv.weight
11274
11275        model = torch.jit.script(MyModule())
11276        weight = torch.ones(3, 2)
11277        box_regression = torch.randn(3, 2)
11278        self.run_test(model, (box_regression, weight))
11279
11280    @skipIfUnsupportedMinOpsetVersion(11)
11281    def test_set_attr_2(self):
11282        class MyModule(torch.nn.Module):
11283            def __init__(self) -> None:
11284                super().__init__()
11285                self.conv = torch.nn.Conv1d(10, 3, 3)
11286                self.conv.bias = torch.nn.Parameter(torch.zeros(3, 10, 3))
11287
11288            def set_cell_anchors(self, anchors):
11289                if self.conv.bias is not None:
11290                    b = self.conv.bias
11291                    assert b is not None
11292                    self.conv.bias = anchors + b
11293                elif self.conv.weight is not None:
11294                    self.conv.weight = torch.randn(3, 10)
11295                    self.conv.bias = self.conv.weight[:]
11296
11297            def forward(self, anchors) -> Optional[Tensor]:
11298                self.set_cell_anchors(anchors)
11299                return self.conv.bias
11300
11301        model = torch.jit.script(MyModule())
11302        anchors = torch.ones(3, 10, 3)
11303        self.run_test(model, (anchors))
11304
11305    @skipIfUnsupportedMinOpsetVersion(11)
11306    def test_set_attr_3(self):
11307        class MyModule(torch.nn.Module):
11308            def __init__(self) -> None:
11309                super().__init__()
11310                self.conv = torch.nn.Conv1d(10, 3, 3)
11311                self.conv.weight = torch.nn.Parameter(torch.zeros(3, 10))
11312                self.conv.bias = torch.nn.Parameter(torch.zeros(3, 10, 3))
11313
11314            def set_cell_anchors(self, anchors, boxes):
11315                self.conv.weight = torch.ones(3, 10)
11316                if self.conv.bias is not None:
11317                    self.conv.bias = torch.randn(3, 10, 3)
11318                    self.conv.weight = anchors + self.conv.weight
11319                    boxes[:] = torch.zeros(2, 3)
11320
11321            def forward(self, anchors) -> Tuple[Tensor, Tensor]:
11322                boxes = torch.ones(2, 2, 3)
11323                self.set_cell_anchors(anchors, boxes)
11324                if self.conv.bias is not None:
11325                    return self.conv.weight, boxes
11326                return anchors, boxes
11327
11328        model = torch.jit.script(MyModule())
11329        anchors = torch.rand(3, 10)
11330        self.run_test(model, (anchors))
11331
11332    @skipIfUnsupportedMinOpsetVersion(11)
11333    def test_set_attr_4(self):
11334        class MyModule(torch.nn.Module):
11335            def __init__(self) -> None:
11336                super().__init__()
11337                self.conv = torch.nn.Conv1d(10, 3, 3)
11338                self.conv.bias = torch.nn.Parameter(torch.zeros(3, 10, 3))
11339
11340            def set_cell_anchors(self, anchors):
11341                self.conv.weight = torch.zeros(10, 3)
11342                if self.conv.bias is not None:
11343                    w = self.conv.bias
11344                    assert w is not None
11345                    self.conv.bias = anchors + w
11346                else:
11347                    self.conv.bias = torch.ones(3, 10, 3)
11348
11349            def forward(self, feature_maps, anchors) -> Tuple[Tensor, Tensor]:
11350                self.set_cell_anchors(anchors)
11351                result = []
11352                if self.conv.bias is not None:
11353                    a = self.conv.bias
11354                    assert a is not None
11355                    result += [a]
11356                result += [feature_maps]
11357                return result[0], result[1]
11358
11359        model = torch.jit.script(MyModule())
11360        x = torch.rand(5, 11, 30)
11361        anchors = torch.ones(3, 10, 3)
11362        self.run_test(model, (x, anchors))
11363
11364    @skipIfUnsupportedMinOpsetVersion(11)
11365    def test_set_attr_5(self):
11366        class MyModule(torch.nn.Module):
11367            def __init__(self) -> None:
11368                super().__init__()
11369                self.conv = torch.nn.Conv1d(10, 3, 3)
11370                self.conv.bias = torch.nn.Parameter(torch.zeros(3, 10, 3))
11371
11372            def set_cell_anchors(self, anchors):
11373                self.conv.weight = torch.arange(10)
11374                for i in range(10):
11375                    if i == 3:
11376                        for j in range(10):
11377                            w = self.conv.weight
11378                            self.conv.weight = torch.arange(10) + w
11379
11380                    self.conv.weight = self.conv.weight + torch.arange(10)
11381                    # NOTE: `is not None` and `assert` is for passing torchscript.
11382                    if self.conv.bias is not None:
11383                        a = self.conv.bias
11384                        assert a is not None
11385                        self.conv.bias = anchors + a
11386
11387            def forward(self, anchors):
11388                self.set_cell_anchors(anchors)
11389                return self.conv.weight, self.conv.bias
11390
11391        model = torch.jit.script(MyModule())
11392        anchors = torch.ones(3, 10, 3)
11393        self.run_test(model, (anchors))
11394
11395    @skipIfUnsupportedMinOpsetVersion(11)
11396    def test_set_attr_in_loop(self):
11397        class MyModule(torch.nn.Module):
11398            def __init__(self) -> None:
11399                super().__init__()
11400                self.conv = torch.nn.Conv1d(10, 3, 3)
11401                self.conv.weight = torch.nn.Parameter(torch.zeros(3, 10))
11402                self.conv.bias = torch.nn.Parameter(torch.zeros(3, 10, 3))
11403
11404            def set_cell_anchors(self, anchors, boxes):
11405                self.conv.weight = torch.randn(3, 10)
11406                for i in range(self.conv.weight.size(0)):
11407                    for j in range(10):
11408                        self.conv.bias = torch.randn(3, 10, 3)
11409                        self.conv.weight = anchors * i
11410                        boxes[j] += torch.ones(3, 3)
11411
11412            def forward(self, anchors) -> Tuple[Tensor, Tensor]:
11413                boxes = torch.ones(10, 3, 3)
11414                self.set_cell_anchors(anchors, boxes)
11415                if self.conv.bias is not None:
11416                    return self.conv.weight, boxes
11417                return anchors, boxes
11418
11419        model = torch.jit.script(MyModule())
11420        anchors = torch.rand(10)
11421        self.run_test(model, anchors)
11422
11423    @skipIfUnsupportedMinOpsetVersion(13)
11424    def test_set_attr_in_loop_with_list(self):
11425        class MyModule(torch.nn.Module):
11426            def __init__(self) -> None:
11427                super().__init__()
11428                self.conv = torch.nn.Conv1d(10, 3, 3)
11429                self.conv.weight = torch.nn.Parameter(torch.zeros(3, 10))
11430                self.conv.bias = torch.nn.Parameter(torch.zeros(3, 10, 3))
11431                self.boxes: List[Tensor] = [
11432                    torch.ones(1)
11433                ]  # Workaround placeholder for TorchScript
11434
11435            def set_cell_anchors(self, anchors):
11436                self.conv.weight = torch.randn(3, 10)
11437                for i in range(self.conv.weight.size(0)):
11438                    for j in range(10):
11439                        self.conv.bias = torch.randn(3, 10, 3)
11440                        self.conv.weight = anchors * i
11441                        self.boxes.append(torch.ones(3, 3))
11442
11443            def forward(self, anchors) -> Tuple[Tensor, List[Tensor]]:
11444                self.boxes = []
11445                self.set_cell_anchors(anchors)
11446                if self.conv.bias is not None:
11447                    return self.conv.weight, self.boxes
11448                return anchors, self.boxes
11449
11450        model = torch.jit.script(MyModule())
11451        anchors = torch.rand(10)
11452        self.run_test(model, anchors)
11453
11454    @skipIfUnsupportedMinOpsetVersion(11)
11455    def test_index_put_if(self):
11456        @torch.jit.script
11457        def check_init(
11458            input_data: Tensor, hidden_size: int, prev_state: Tensor
11459        ) -> Tuple[Tensor, Tensor]:
11460            batch_size = input_data.size(0)
11461            spatial_size_0 = input_data.size(2)
11462            spatial_size_1 = input_data.size(3)
11463            # generate empty prev_state, if None is provided
11464            state_size = (2, batch_size, hidden_size, spatial_size_0, spatial_size_1)
11465            state = torch.zeros(state_size, device=input_data.device)
11466            state_copy = torch.zeros(state_size, device=input_data.device)
11467            if prev_state.size(0) == 0:
11468                state[:] = (
11469                    torch.zeros(batch_size, hidden_size, spatial_size_0, spatial_size_1)
11470                    + state[:]
11471                )
11472                state_copy[:] = (
11473                    torch.ones(batch_size, hidden_size, spatial_size_0, spatial_size_1)
11474                    * 2
11475                )
11476                state_copy[:] = (
11477                    torch.zeros(batch_size, hidden_size, spatial_size_0, spatial_size_1)
11478                    * 2
11479                )
11480            else:
11481                state[:] = (
11482                    torch.ones(batch_size, hidden_size, spatial_size_0, spatial_size_1)
11483                    * 4
11484                )
11485            return state, state_copy
11486
11487        class Example(torch.nn.Module):
11488            def __init__(self, hidden_size):
11489                super().__init__()
11490                self.hidden_size = hidden_size
11491
11492            def forward(self, input_data, prev_state):
11493                prev_state = check_init(input_data, self.hidden_size, prev_state)
11494                return prev_state[0], prev_state[1]
11495
11496        model = Example(10)
11497        random_data = torch.rand((1, 5, 30, 30))
11498        empty_tensor = torch.tensor([], dtype=torch.float).view(0, 0, 0, 0, 0)
11499        self.run_test(
11500            model,
11501            (random_data, empty_tensor),
11502            input_names=["random_data", "empty_tensor"],
11503            dynamic_axes={"random_data": [0, 1, 2, 3], "empty_tensor": [0, 1, 2, 3, 4]},
11504        )
11505        self.run_test(model, (random_data, empty_tensor), remained_onnx_input_idx=[])
11506
11507    @skipIfUnsupportedMinOpsetVersion(11)
11508    def test_index_put_if_2(self):
11509        @torch.jit.script
11510        def check_init(
11511            input_data: Tensor, hidden_size: int, prev_state: Tensor
11512        ) -> Tuple[Tensor, Tensor]:
11513            batch_size = input_data.size(0)
11514            spatial_size_0 = input_data.size(2)
11515            spatial_size_1 = input_data.size(3)
11516            # generate empty prev_state, if None is provided
11517            state_size = (2, batch_size, hidden_size, spatial_size_0, spatial_size_1)
11518            state = torch.zeros(state_size, device=input_data.device)
11519            state_copy = torch.zeros(state_size, device=input_data.device)
11520            if prev_state.size(0) == 0:
11521                for i in range(2):
11522                    state[:] = (
11523                        torch.ones(
11524                            batch_size, hidden_size, spatial_size_0, spatial_size_1
11525                        )
11526                        * i
11527                    )
11528                    state_copy[:] = (
11529                        torch.ones(
11530                            batch_size, hidden_size, spatial_size_0, spatial_size_1
11531                        )
11532                        * i
11533                    )
11534            elif prev_state.size(0) == 1:
11535                s = state[:]
11536                state[:] = prev_state + s
11537            elif prev_state.size(0) == 2:
11538                state[:] = (
11539                    torch.ones(batch_size, hidden_size, spatial_size_0, spatial_size_1)
11540                    * 4
11541                )
11542            return state, state_copy
11543
11544        class Example(torch.nn.Module):
11545            def __init__(self, hidden_size):
11546                super().__init__()
11547                self.hidden_size = hidden_size
11548
11549            def forward(self, input_data, prev_state):
11550                prev_state = check_init(input_data, self.hidden_size, prev_state)
11551                return prev_state[0], prev_state[1]
11552
11553        model = Example(10)
11554        random_data = torch.rand((1, 5, 30, 30))
11555        empty_tensor = torch.tensor([], dtype=torch.float).view(0, 0, 0, 0, 0)
11556        random_state = torch.rand((1, 1, 10, 30, 30))
11557        self.run_test(
11558            model,
11559            (random_data, empty_tensor),
11560            input_names=["data", "state"],
11561            dynamic_axes={"data": [0, 1, 2], "state": [0, 1, 2, 3, 4]},
11562            additional_test_inputs=[(random_data, random_state)],
11563        )
11564        self.run_test(
11565            model,
11566            (random_data, empty_tensor),
11567            input_names=["data", "state"],
11568            dynamic_axes={"state": [0, 1, 2, 3, 4]},
11569            additional_test_inputs=[(random_data, random_state)],
11570            remained_onnx_input_idx=[1],
11571        )
11572        self.run_test(model, (random_data, empty_tensor), remained_onnx_input_idx=[])
11573
11574    @skipIfUnsupportedMinOpsetVersion(11)
11575    def test_index_put_if_3(self):
11576        @torch.jit.script
11577        def check_init(
11578            input_data: Tensor, hidden_size: int, prev_state: Tensor
11579        ) -> Tensor:
11580            batch_size = input_data.size(0)
11581            spatial_size_0 = input_data.size(2)
11582            spatial_size_1 = input_data.size(3)
11583            # generate empty prev_state, if None is provided
11584            state_size = (2, batch_size, hidden_size, spatial_size_0, spatial_size_1)
11585            state = torch.zeros(state_size, device=input_data.device)
11586            if prev_state.size(0) < 2:
11587                state = state * 3
11588                if prev_state.size(0) == 0:
11589                    state[:] = (
11590                        torch.ones(
11591                            batch_size, hidden_size, spatial_size_0, spatial_size_1
11592                        )
11593                        * 3
11594                    )
11595                else:
11596                    state = state + 2
11597
11598            return state
11599
11600        class Example(torch.nn.Module):
11601            def __init__(self, hidden_size):
11602                super().__init__()
11603                self.hidden_size = hidden_size
11604
11605            def forward(self, input_data, prev_state):
11606                prev_state = check_init(input_data, self.hidden_size, prev_state)
11607                return prev_state
11608
11609        model = Example(4)
11610        random_data = torch.rand((1, 5, 4, 4))
11611        empty_tensor = torch.tensor([], dtype=torch.float).view(0, 0, 0, 0, 0)
11612        self.run_test(
11613            model,
11614            (random_data, empty_tensor),
11615            input_names=["random_data", "empty_tensor"],
11616            dynamic_axes={"random_data": [0, 1, 2, 3], "empty_tensor": [0, 1, 2, 3, 4]},
11617        )
11618        self.run_test(model, (random_data, empty_tensor), remained_onnx_input_idx=[])
11619
11620    @skipIfUnsupportedMinOpsetVersion(11)
11621    def test_index_put_if_4(self):
11622        @torch.jit.script
11623        def check_init(
11624            input_data: Tensor, hidden_size: int, prev_state: Tensor
11625        ) -> Tensor:
11626            batch_size = input_data.size(0)
11627            spatial_size_0 = input_data.size(2)
11628            spatial_size_1 = input_data.size(3)
11629            # generate empty prev_state, if None is provided
11630            state_size = (2, batch_size, hidden_size, spatial_size_0, spatial_size_1)
11631            state = torch.zeros(state_size, device=input_data.device)
11632            if prev_state.size(0) == 0:
11633                state = state + 3
11634                state[:] = (
11635                    torch.ones(batch_size, hidden_size, spatial_size_0, spatial_size_1)
11636                    * 3
11637                )
11638                state = state + 3
11639                state[:] = (
11640                    torch.ones(batch_size, hidden_size, spatial_size_0, spatial_size_1)
11641                    * 4
11642                )
11643            else:
11644                state = state + 2
11645            return state
11646
11647        class Example(torch.nn.Module):
11648            def __init__(self, hidden_size):
11649                super().__init__()
11650                self.hidden_size = hidden_size
11651
11652            def forward(self, input_data, prev_state):
11653                prev_state = check_init(input_data, self.hidden_size, prev_state)
11654                return prev_state
11655
11656        model = Example(4)
11657        random_data = torch.rand((1, 5, 4, 4))
11658        empty_tensor = torch.tensor([], dtype=torch.float).view(0, 0, 0, 0, 0)
11659        self.run_test(
11660            model,
11661            (random_data, empty_tensor),
11662            input_names=["random_data", "empty_tensor"],
11663            dynamic_axes={"random_data": [0, 1, 2, 3], "empty_tensor": [0, 1, 2, 3, 4]},
11664        )
11665        self.run_test(model, (random_data, empty_tensor), remained_onnx_input_idx=[])
11666
11667    @skipIfUnsupportedMinOpsetVersion(11)
11668    def test_index_put_if_5(self):
11669        @torch.jit.script
11670        def check_init(
11671            input_data: Tensor, hidden_size: int, prev_state: Tensor
11672        ) -> Tuple[Tensor, Tensor]:
11673            batch_size = input_data.size(0)
11674            spatial_size_0 = input_data.size(2)
11675            spatial_size_1 = input_data.size(3)
11676            # generate empty prev_state, if None is provided
11677            state_size = (2, batch_size, hidden_size, spatial_size_0, spatial_size_1)
11678            state = torch.zeros(state_size, device=input_data.device)
11679            state_ref = state
11680            if prev_state.size(0) == 0:
11681                state[:] = (
11682                    torch.ones(batch_size, hidden_size, spatial_size_0, spatial_size_1)
11683                    * 3
11684                )
11685                state = state + 3
11686                state[:] = (
11687                    torch.ones(batch_size, hidden_size, spatial_size_0, spatial_size_1)
11688                    * 4
11689                )
11690            else:
11691                state = state + 2
11692            return state, state_ref
11693
11694        class Example(torch.nn.Module):
11695            def __init__(self, hidden_size):
11696                super().__init__()
11697                self.hidden_size = hidden_size
11698
11699            def forward(self, input_data, prev_state):
11700                prev_state, state_ref = check_init(
11701                    input_data, self.hidden_size, prev_state
11702                )
11703                return prev_state, state_ref
11704
11705        model = Example(4)
11706        random_data = torch.rand((1, 5, 4, 4))
11707        empty_tensor = torch.tensor([], dtype=torch.float).view(0, 0, 0, 0, 0)
11708        self.run_test(
11709            model,
11710            (random_data, empty_tensor),
11711            input_names=["random_data", "empty_tensor"],
11712            dynamic_axes={"random_data": [0, 1, 2, 3], "empty_tensor": [0, 1, 2, 3, 4]},
11713        )
11714        self.run_test(model, (random_data, empty_tensor), remained_onnx_input_idx=[])
11715
11716    @skipIfUnsupportedMinOpsetVersion(11)
11717    def test_list_append_in_block(self):
11718        class ListModel(torch.nn.Module):
11719            def forward(self, x, y):
11720                res = []
11721                for i in range(x.size(0)):
11722                    res.append(torch.matmul(x[i], y))
11723                return res
11724
11725        model = torch.jit.script(ListModel())
11726        x = torch.randn(16, 3, 4)
11727        y = torch.randn(4, 5)
11728        self.run_test(model, (x, y))
11729
11730    @skipIfUnsupportedMinOpsetVersion(13)
11731    def test_list_append_in_nested_block(self):
11732        class ListModel(torch.nn.Module):
11733            def forward(self, x, y):
11734                res = []
11735                for i in range(x.size(0)):
11736                    for j in range(x.size(1)):
11737                        res.append(torch.matmul(x[i][j], y))
11738                return res
11739
11740        model = torch.jit.script(ListModel())
11741        x = torch.randn(4, 4, 3, 4)
11742        y = torch.randn(4, 5)
11743        self.run_test(model, (x, y))
11744
11745    @skipIfUnsupportedMinOpsetVersion(13)
11746    def test_list_pop_in_block(self):
11747        class ListModel(torch.nn.Module):
11748            def forward(self, x, y):
11749                res = []
11750                elem = torch.matmul(x[0], y)
11751                for i in range(x.size(0)):
11752                    res.append(torch.matmul(x[i], y))
11753                for i in range(x.size(0)):
11754                    elem = res.pop()
11755                for i in range(x.size(0)):
11756                    res.append(torch.matmul(x[i], y))
11757                    elem = res.pop()
11758                return res.append(elem)
11759
11760        model = torch.jit.script(ListModel())
11761        x = torch.randn(16, 3, 4)
11762        y = torch.randn(4, 5)
11763        self.run_test(model, (x, y))
11764
11765    @skipIfUnsupportedMinOpsetVersion(13)
11766    def test_list_del_in_block(self):
11767        class ListModel(torch.nn.Module):
11768            def forward(self, x, y):
11769                res = []
11770                elem = torch.matmul(x[0], y)
11771                for i in range(x.size(0)):
11772                    res.append(torch.matmul(x[i], y))
11773                for i in range(x.size(0)):
11774                    del res[0]
11775                for i in range(x.size(0)):
11776                    res.append(torch.matmul(x[i], y))
11777                    del res[0]
11778                return res.append(elem)
11779
11780        model = torch.jit.script(ListModel())
11781        x = torch.randn(16, 3, 4)
11782        y = torch.randn(4, 5)
11783        self.run_test(model, (x, y))
11784
11785    @skipIfUnsupportedMinOpsetVersion(11)
11786    def test_list_unpack(self):
11787        class ListModel(torch.nn.Module):
11788            def forward(self, x, y):
11789                res = []
11790                elem = torch.matmul(x[0], y)
11791                for i in range(x.size(0)):
11792                    res.append(torch.matmul(x[i], y))
11793                a, b, c = res
11794                return a, b
11795
11796        model = torch.jit.script(ListModel())
11797        x = torch.randn(3, 3, 4)
11798        y = torch.randn(4, 5)
11799        self.run_test(model, (x, y))
11800
11801    @skipIfUnsupportedMinOpsetVersion(11)
11802    def test_index_put_inplace_ops(self):
11803        @torch.jit.script
11804        def check_init(input_data: Tensor, hidden_size: int) -> Tensor:
11805            batch_size = input_data.size(0)
11806            spatial_size_0 = input_data.size(2)
11807            spatial_size_1 = input_data.size(3)
11808            # generate empty prev_state, if None is provided
11809            state_size = (2, batch_size, hidden_size, spatial_size_0, spatial_size_1)
11810            state = torch.zeros(state_size, device=input_data.device)
11811            if input_data.size(0) == 1:
11812                state[1] += (
11813                    torch.ones(batch_size, hidden_size, spatial_size_0, spatial_size_1)
11814                    * 2
11815                )
11816                state[1] /= (
11817                    torch.ones(batch_size, hidden_size, spatial_size_0, spatial_size_1)
11818                    * 3
11819                )
11820            for i in range(input_data.size(0)):
11821                state[1] += torch.ones(
11822                    batch_size, hidden_size, spatial_size_0, spatial_size_1
11823                )
11824                state[1] /= (
11825                    torch.ones(batch_size, hidden_size, spatial_size_0, spatial_size_1)
11826                    * i
11827                )
11828            return state
11829
11830        class Example(torch.nn.Module):
11831            def __init__(self, hidden_size):
11832                super().__init__()
11833                self.hidden_size = hidden_size
11834
11835            def forward(self, input_data):
11836                state = check_init(input_data, self.hidden_size)
11837                return state
11838
11839        model = Example(10)
11840        random_data = torch.rand((1, 5, 30, 30))
11841        self.run_test(
11842            model,
11843            (random_data),
11844            input_names=["random_data"],
11845            dynamic_axes={"random_data": [0, 1, 2, 3]},
11846        )
11847        self.run_test(model, (random_data), remained_onnx_input_idx=[])
11848
11849    @skipIfUnsupportedMinOpsetVersion(11)
11850    def test_input_mask_model(self):
11851        class InputMaskModel(torch.nn.Module):
11852            def __init__(self, output_size):
11853                super().__init__()
11854                self.bias = torch.nn.Parameter(
11855                    torch.empty(output_size, dtype=torch.float)
11856                )
11857                with torch.no_grad():
11858                    self.bias.zero_()
11859
11860            def forward(self, model_input, y):
11861                input_mask = (model_input <= 0) | (model_input > 25)
11862                y[input_mask, :] = 0.0
11863                output = y + self.bias
11864                return output
11865
11866        output_size = 4
11867        m = InputMaskModel(output_size)
11868        x = torch.tensor([0, 4, 24, 25], dtype=torch.int64)
11869        y = torch.tensor(
11870            [
11871                [0.1, 0.2, 0.3, 0.4],
11872                [0.1, 0.2, 0.3, 0.4],
11873                [0.1, 0.2, 0.3, 0.4],
11874                [0.1, 0.2, 0.3, 0.4],
11875            ],
11876            dtype=torch.float,
11877        )
11878        self.run_test(m, (x, y))
11879
11880        class InputMaskModel(torch.nn.Module):
11881            def __init__(self, output_size):
11882                super().__init__()
11883
11884            def forward(self, model_input_1, model_input_2, y):
11885                input_mask_1 = (model_input_1 <= 0) | (model_input_1 > 25)
11886                input_mask_2 = (model_input_2 < 1) | (model_input_2 >= 12)
11887                y[input_mask_1, input_mask_2] = 0.0
11888                return y
11889
11890        output_size = 4
11891        m = InputMaskModel(output_size)
11892        x1 = torch.tensor([0, 4, 24, 25], dtype=torch.int64)
11893        x2 = torch.tensor([0, 3, 12, 15], dtype=torch.int64)
11894        y = torch.tensor(
11895            [
11896                [0.1, 0.2, 0.3, 0.4],
11897                [0.1, 0.2, 0.3, 0.4],
11898                [0.1, 0.2, 0.3, 0.4],
11899                [0.1, 0.2, 0.3, 0.4],
11900            ],
11901            dtype=torch.float,
11902        )
11903        self.run_test(m, (x1, x2, y))
11904
11905    @skipScriptTest()
11906    def test_unsafe_chunk(self):
11907        class ChunkModel(torch.nn.Module):
11908            def forward(self, x):
11909                return torch.unsafe_chunk(x, 3, dim=1)
11910
11911        model = ChunkModel()
11912        model.eval()
11913        x = torch.randn(1, 18)
11914        self.run_test(model, x, input_names=["x"])
11915
11916    def test_symbolic_shape_inference(self):
11917        # ConstantOfShape is tested in test_embedding_bag
11918        # Tile is tested in test_repeat
11919        # test Shape, Reshape, Transpose, Gather
11920        class ShapeModel(torch.nn.Module):
11921            def forward(self, x, y):
11922                shape = x.size()[:3] + (-1,)  # shape [4], ("batch", 3, 4, -1)
11923                y = y.reshape(shape)  # batch, 3, 4, 10/batch
11924                return y.transpose(1, 2)
11925
11926        model = ShapeModel()
11927        model.eval()
11928        x = torch.ones(2, 3, 4, 5)
11929        y = torch.ones(3, 4, 5, 2)
11930        self.run_test(
11931            model,
11932            (x, y),
11933            input_names=["x", "y"],
11934            dynamic_axes={"x": [0, 1, 2, 3], "y": [0, 1, 2, 3]},
11935        )
11936        self.run_test(model, (x, y), remained_onnx_input_idx=[1])
11937
11938        class ViewModel(torch.nn.Module):
11939            def forward(self, x):
11940                return x.view(-1)
11941
11942        model = ViewModel()
11943        model.eval()
11944        x = torch.tensor(2.0)
11945        self.run_test(model, (x,))
11946
11947        # test prim::ListConstruct for Reshape input 1
11948        class ViewModel_2(torch.nn.Module):
11949            def forward(self, x):
11950                N, C, H, W = x.shape[0], x.shape[2], x.shape[3], x.shape[4]
11951                x1 = x.view(N, -1, C, H, W)
11952                x2 = x1.permute(0, 3, 4, 1, 2)
11953                return x2.reshape(N, -1, C)
11954
11955        model = ViewModel_2()
11956        model.eval()
11957        x = torch.ones(2, 3, 4, 5, 6)
11958        self.run_test(model, x)
11959
11960    @skipIfUnsupportedMinOpsetVersion(9)
11961    def test_symbolic_shape_inference_arange(self):
11962        # test Range
11963        class ArangeModel(torch.nn.Module):
11964            def forward(self, signal):
11965                frame_step = 2
11966                outer_dimensions = signal.size()[:-2]
11967                frames, frame_length = signal.size()[-2:]
11968
11969                subframe_length = signal.size()[0]
11970                subframe_step = frame_step // subframe_length
11971                subframes_per_frame = frame_length // subframe_length
11972                output_size = frame_step * (frames - 1) + frame_length
11973                output_subframes = output_size // subframe_length
11974
11975                frame = torch.arange(0, output_subframes)
11976                return frame
11977
11978        model = ArangeModel()
11979        model.eval()
11980        M, C, K, N = 1, 2, 3, 4
11981        x = torch.randint(5, (M, C, K, N))
11982        y = torch.randint(5, (M, C + 1, K + 1, N + 1))
11983        self.run_test(model, x, input_names=["x"], dynamic_axes={"x": [0, 1, 2, 3]})
11984        self.run_test(model, x, remained_onnx_input_idx=[])
11985        self.run_test(
11986            model,
11987            x,
11988            input_names=["x"],
11989            dynamic_axes={"x": [0, 1, 2, 3]},
11990            additional_test_inputs=[(x,), (y,)],
11991        )
11992
11993    @skipIfUnsupportedMinOpsetVersion(11)
11994    def test_symbolic_shape_inference_box(self):
11995        # test NonZero
11996        class BoxModel(torch.nn.Module):
11997            def forward(self, boxes):
11998                min_size = 1e-2
11999                ws, hs = boxes[:, 2] - boxes[:, 0], boxes[:, 3] - boxes[:, 1]
12000                keep = (ws >= min_size) & (hs >= min_size)
12001                keep = torch.where(keep)[0]
12002                return keep
12003
12004        model = BoxModel()
12005        model.eval()
12006        x = torch.ones(2, 4)
12007        y = torch.ones(3, 5)
12008        self.run_test(model, x)
12009        self.run_test(
12010            model,
12011            x,
12012            input_names=["x"],
12013            dynamic_axes={"x": [0, 1]},
12014            additional_test_inputs=[(x,), (y,)],
12015        )
12016
12017    @skipIfUnsupportedMinOpsetVersion(11)
12018    def test_symbolic_shape_inference_box_if(self):
12019        # test If
12020        class BoxIfModel(torch.nn.Module):
12021            def forward(self, boxes, scores):
12022                score_thresh = 0.0
12023                inds = torch.where(scores > score_thresh)[0]
12024                boxes_1 = boxes[inds]
12025                if boxes_1.numel() > 3:
12026                    return boxes_1
12027                else:
12028                    return boxes_1 * 2
12029
12030        model = BoxIfModel()
12031        model.eval()
12032        boxes = torch.ones(2, 4)
12033        scores = torch.ones(1, 4)
12034        self.run_test(model, (boxes, scores))
12035
12036    @skipIfUnsupportedMinOpsetVersion(11)
12037    @skipDtypeChecking
12038    def test_symbolic_shape_inference_arange_2(self):
12039        # test Range
12040        class ArangeModel(torch.nn.Module):
12041            def forward(self, start):
12042                return torch.arange(start.size(0), 8.5, 1.5, dtype=torch.int64)
12043
12044        x = torch.randn(2, 3, 4)
12045        self.run_test(
12046            ArangeModel(), (x,), input_names=["x"], dynamic_axes={"x": [0, 1, 2]}
12047        )
12048        self.run_test(ArangeModel(), (x,), remained_onnx_input_idx=[])
12049
12050        class ArangeModel2(torch.nn.Module):
12051            def forward(self, start):
12052                return torch.arange(start.size(0), 8.5, 1.5, dtype=torch.double)
12053
12054        x = torch.randn(2, 3, 4)
12055        self.run_test(
12056            ArangeModel2(), (x,), input_names=["x"], dynamic_axes={"x": [0, 1, 2]}
12057        )
12058        self.run_test(ArangeModel2(), (x,), remained_onnx_input_idx=[])
12059
12060    @skipIfUnsupportedMinOpsetVersion(9)
12061    def test_symbolic_shape_inference_nonzero(self):
12062        class OneLikeModel(torch.nn.Module):
12063            def forward(self, x):
12064                ones = torch.ones_like(
12065                    x,
12066                    dtype=torch.float,
12067                    layout=torch.strided,
12068                    device=torch.device("cpu"),
12069                )
12070                return torch.nonzero(ones)
12071
12072        x = torch.randn(2)
12073        self.run_test(OneLikeModel(), x, input_names=["x"], dynamic_axes={"x": [0]})
12074        self.run_test(OneLikeModel(), x, remained_onnx_input_idx=[])
12075        x = torch.randn(2, 3, 4)
12076        self.run_test(
12077            OneLikeModel(), x, input_names=["x"], dynamic_axes={"x": [0, 1, 2]}
12078        )
12079        self.run_test(OneLikeModel(), x, remained_onnx_input_idx=[])
12080
12081        class ZeroLikeModel(torch.nn.Module):
12082            def forward(self, x):
12083                zeros = torch.zeros_like(
12084                    x,
12085                    dtype=torch.float,
12086                    layout=torch.strided,
12087                    device=torch.device("cpu"),
12088                )
12089                return torch.nonzero(zeros)
12090
12091        x = torch.randn(2)
12092        self.run_test(ZeroLikeModel(), x, input_names=["x"], dynamic_axes={"x": [0]})
12093        self.run_test(ZeroLikeModel(), x, remained_onnx_input_idx=[])
12094        x = torch.randn(2, 3, 4)
12095        self.run_test(
12096            ZeroLikeModel(), x, input_names=["x"], dynamic_axes={"x": [0, 1, 2]}
12097        )
12098        self.run_test(ZeroLikeModel(), x, remained_onnx_input_idx=[])
12099
12100    @skipIfUnsupportedMinOpsetVersion(9)
12101    def test_symbolic_shape_inference_expand_1(self):
12102        class ExpandModel(torch.nn.Module):
12103            def forward(self, x):
12104                return x.expand(4, 6, 2)
12105
12106        x = torch.randn(6, 1, requires_grad=True)
12107        self.run_test(ExpandModel(), (x,))
12108
12109    @skipIfUnsupportedMinOpsetVersion(9)
12110    def test_symbolic_shape_inference_expand_2(self):
12111        class M(torch.nn.Module):
12112            def forward(self, x):
12113                input_shape = x.size()
12114                batch_size, seq_length = input_shape
12115                seq_ids = torch.arange(seq_length)
12116                causal_mask = (
12117                    seq_ids[None, None, :].repeat(batch_size, seq_length, 1)
12118                    <= seq_ids[None, :, None]
12119                )
12120                return causal_mask.transpose(0, 1)
12121
12122        x = torch.randn(3, 16)
12123        self.run_test(M(), (x,), input_names=["x"], dynamic_axes={"x": [0, 1]})
12124        self.run_test(M(), (x,), remained_onnx_input_idx=[])
12125
12126    @skipIfUnsupportedMinOpsetVersion(10)
12127    def test_symbolic_shape_inference_slice(self):
12128        class M(torch.nn.Module):
12129            def forward(self, x, position_bias):
12130                input_shape = x.size()
12131                batch_size, seq_length = input_shape
12132                position_bias = position_bias[:, :, -seq_length:, :]
12133                return position_bias.transpose(0, 1)
12134
12135        x = torch.randn(3, 16)
12136        position_bias = torch.randn(1, 3, 20, 8)
12137        self.run_test(
12138            M(),
12139            (x, position_bias),
12140            input_names=["x", "position_bias"],
12141            dynamic_axes={"x": [0, 1], "position_bias": [0, 1, 2, 3]},
12142        )
12143        self.run_test(M(), (x, position_bias), remained_onnx_input_idx=[1])
12144
12145    def test_symbolic_shape_inference_slice_2(self):
12146        class M(torch.nn.Module):
12147            def forward(self, position_bias):
12148                position_bias = position_bias[:, :, -2:, :]
12149                return position_bias.transpose(0, 1)
12150
12151        position_bias = torch.randn(1, 3, 20, 8)
12152        self.run_test(M(), (position_bias,))
12153
12154    @skipIfUnsupportedMinOpsetVersion(9)
12155    @skipScriptTest()
12156    def test_symbolic_shape_inference_time(self):
12157        input = torch.randn(RNN_SEQUENCE_LENGTH, BATCH_SIZE, RNN_INPUT_SIZE)
12158        h0 = torch.randn(1, BATCH_SIZE, RNN_HIDDEN_SIZE)
12159        c0 = torch.randn(1, BATCH_SIZE, RNN_HIDDEN_SIZE)
12160        model_lstm = torch.nn.LSTM(
12161            RNN_INPUT_SIZE, RNN_HIDDEN_SIZE, 1, bidirectional=False
12162        )
12163        self.run_test(
12164            model_lstm,
12165            (input, (h0, c0)),
12166            input_names=["x", "y"],
12167            dynamic_axes={"x": [0, 1]},
12168        )
12169        model_gru = torch.nn.GRU(
12170            RNN_INPUT_SIZE, RNN_HIDDEN_SIZE, 1, bidirectional=False, bias=False
12171        )
12172        self.run_test(
12173            model_gru, (input, h0), input_names=["x", "y"], dynamic_axes={"x": [0, 1]}
12174        )
12175        model_rnn = torch.nn.RNN(
12176            RNN_INPUT_SIZE, RNN_HIDDEN_SIZE, 1, bidirectional=False, bias=False
12177        )
12178        self.run_test(
12179            model_rnn, (input, h0), input_names=["x", "y"], dynamic_axes={"x": [0, 1]}
12180        )
12181
12182    def test_symbolic_shape_inference_dynamic_axes(self):
12183        class M(torch.nn.Module):
12184            def forward(self, input_ids):
12185                input_shape = input_ids.size()
12186                input_ids = input_ids.view(-1, input_shape[-1])
12187                return input_ids.transpose(0, 1)
12188
12189        x = torch.randn(3, 16)
12190        self.run_test(
12191            M(),
12192            (x,),
12193            input_names=["input_ids"],
12194            dynamic_axes={"input_ids": {0: "batch", 1: "sequence"}},
12195        )
12196
12197    @skipIfUnsupportedMinOpsetVersion(9)
12198    def test_hann_window_periodic(self):
12199        class HannWindowModule_Periodic(torch.nn.Module):
12200            def __init__(self) -> None:
12201                super().__init__()
12202                self.window_length = 0
12203
12204            def forward(self, x, window_length: int):
12205                self.window_length = window_length
12206                return torch.add(
12207                    x,
12208                    torch.hann_window(
12209                        self.window_length, periodic=True, dtype=torch.float
12210                    ),
12211                )
12212
12213        win_length = 100
12214        x = torch.randn(win_length)
12215
12216        module = HannWindowModule_Periodic()
12217        self.run_test(module, (x, win_length))
12218
12219    @skipIfUnsupportedMinOpsetVersion(9)
12220    def test_hann_window_not_periodic(self):
12221        class HannWindowModule_NotPeriodic(torch.nn.Module):
12222            def __init__(self) -> None:
12223                super().__init__()
12224                self.window_length = 0
12225
12226            def forward(self, x, window_length: int):
12227                self.window_length = window_length
12228                return torch.add(
12229                    x,
12230                    torch.hann_window(
12231                        self.window_length, periodic=False, dtype=torch.float
12232                    ),
12233                )
12234
12235        win_length = 100
12236        x = torch.randn(win_length)
12237
12238        module = HannWindowModule_NotPeriodic()
12239        self.run_test(module, (x, win_length))
12240
12241    @skipIfUnsupportedMinOpsetVersion(9)
12242    @skipScriptTest()
12243    def test_hann_window_default_values(self):
12244        class HannWindowModule(torch.nn.Module):
12245            def __init__(self) -> None:
12246                super().__init__()
12247                self.window_length = 0
12248
12249            def forward(self, x, window_length: int):
12250                import torch.nn.functional as F
12251
12252                self.window_length = window_length
12253                return torch.add(x, F.relu(torch.hann_window(self.window_length)))
12254
12255        win_length = 100
12256        x = torch.randn(win_length, dtype=torch.float)
12257        module = HannWindowModule()
12258
12259        output = module(x, win_length)
12260        self.run_test(module, (x, win_length))
12261
12262    @skipIfUnsupportedMinOpsetVersion(12)
12263    def test_tensordot_dim_count(self):
12264        class M(torch.nn.Module):
12265            def forward(self, x, y):
12266                output = torch.tensordot(x, y, 2)
12267                return output
12268
12269        x = torch.randint(6, (7, 5, 3, 4))
12270        y = torch.randint(6, (3, 4, 9, 2))
12271
12272        self.run_test(M(), (x, y))
12273
12274    @skipIfUnsupportedMinOpsetVersion(12)
12275    def test_tensordot_dim_list(self):
12276        class M(torch.nn.Module):
12277            def forward(self, x, y):
12278                output = torch.tensordot(x, y, ([1, -2, -1], [1, 0, 3]))
12279                return output
12280
12281        x = torch.randint(6, (7, 4, 3, 5, 2))
12282        y = torch.randint(6, (5, 4, 4, 2, 6))
12283
12284        self.run_test(M(), (x, y))
12285
12286    @skipIfUnsupportedMinOpsetVersion(12)
12287    def test_tensordot_dynamic_dim(self):
12288        class M(torch.nn.Module):
12289            def forward(self, x, y):
12290                output = torch.tensordot(x, y, 2)
12291                return output
12292
12293        x = torch.randint(6, (7, 5, 3, 4))
12294        y = torch.randint(6, (3, 4, 9, 2))
12295
12296        new_x = torch.randint(6, (8, 6, 2, 5))
12297        new_y = torch.randint(6, (2, 5, 3, 4))
12298
12299        self.run_test(
12300            M(),
12301            (x, y),
12302            additional_test_inputs=[(new_x, new_y)],
12303            input_names=["input_x", "input_y"],
12304            dynamic_axes={"input_x": [0, 1, 2, 3], "input_y": [0, 1, 2, 3]},
12305        )
12306
12307    @skipIfUnsupportedMinOpsetVersion(9)
12308    def test_to_device(self):
12309        class M_ToDevice(torch.nn.Module):
12310            def forward(self, x, y):
12311                return x.to(y.device), y
12312
12313        class M_ToDeviceDtype(torch.nn.Module):
12314            def forward(self, x, y):
12315                return x.to(y.device, dtype=torch.long), y
12316
12317        x = torch.randn(6)
12318        y = torch.randn(6)
12319
12320        self.run_test(M_ToDevice(), (x, y))
12321        self.run_test(M_ToDeviceDtype(), (x, y))
12322
12323    @skipIfUnsupportedMinOpsetVersion(9)
12324    def test_fill(self):
12325        class FillModule(torch.nn.Module):
12326            def forward(self, x, filled_value: int):
12327                return x.fill_(filled_value)
12328
12329        x = torch.randn((4, 5, 6))
12330        filled_value = 7
12331        self.run_test(FillModule(), (x, filled_value))
12332
12333        class FillFloatModule(torch.nn.Module):
12334            def forward(self, x, filled_value: float):
12335                return x.fill_(filled_value)
12336
12337        x = torch.randn((4, 5, 6))
12338        filled_value = 7.5
12339        self.run_test(FillFloatModule(), (x, filled_value))
12340
12341        class FillScalarModule(torch.nn.Module):
12342            def forward(self, x):
12343                res = x + 2
12344                res.fill_(2.5)
12345                return res, x
12346
12347        x = torch.ones(2, 3, 4, dtype=torch.long)
12348        self.run_test(FillScalarModule(), x)
12349
12350    @skipIfUnsupportedMinOpsetVersion(9)
12351    def test_index_add_normal(self):
12352        class M(torch.nn.Module):
12353            def __init__(self, dim, index, updates):
12354                super().__init__()
12355                self.dim = dim
12356                self.index = index
12357                self.updates = updates
12358
12359            def forward(self, x):
12360                x.index_add_(self.dim, self.index, self.updates)
12361                return x
12362
12363        x = torch.ones(5, 1)
12364        updates = torch.tensor([[1], [4], [7], [3], [2]], dtype=torch.float)
12365        index = torch.tensor([0, 2, 3, 1, 4])
12366        self.run_test(M(0, index, updates), (x,))
12367
12368        x = torch.ones(1, 4, 3)
12369        updates = torch.tensor(
12370            [[[1, 5, 7], [2, 4, 5], [5, 5, 6], [2, 3, 4]]], dtype=torch.float
12371        )
12372        index = torch.tensor([0, 2, 3, 1])
12373        self.run_test(M(1, index, updates), (x,))
12374
12375        updates = torch.tensor(
12376            [[[1, 2, 3], [4, 5, 6], [7, 8, 9], [2, 3, 4]]], dtype=torch.float
12377        )
12378        index = torch.tensor([0, 2, 1])
12379        self.run_test(M(2, index, updates), (x,))
12380
12381    @skipIfUnsupportedMinOpsetVersion(9)
12382    def test_index_add_dim_size_differ(self):
12383        class M(torch.nn.Module):
12384            def __init__(self, dim, index, updates):
12385                super().__init__()
12386                self.dim = dim
12387                self.index = index
12388                self.updates = updates
12389
12390            def forward(self, x):
12391                x.index_add_(self.dim, self.index, self.updates)
12392                return x
12393
12394        x = torch.ones(1, 4, 3)
12395        updates = torch.tensor([[[1, 5, 7], [2, 4, 5], [5, 5, 6]]], dtype=torch.float)
12396        index = torch.tensor([0, 2, 1])
12397        self.run_test(M(1, index, updates), (x,))
12398
12399    @skipIfUnsupportedMinOpsetVersion(9)
12400    def test_index_add_in_loop(self):
12401        class M(torch.nn.Module):
12402            def __init__(self, dim, index, updates, loop_count):
12403                super().__init__()
12404                self.dim = dim
12405                self.index = index
12406                self.updates = updates
12407                self.loop_count = loop_count
12408
12409            def forward(self, x):
12410                for i in range(self.loop_count):
12411                    x.index_add_(self.dim, self.index, self.updates)
12412                return x
12413
12414        x = torch.ones(1, 4, 3)
12415        updates = torch.tensor(
12416            [[[1, 5, 7], [2, 4, 5], [5, 5, 6], [2, 3, 4]]], dtype=torch.float
12417        )
12418        index = torch.tensor([0, 2, 3, 1])
12419        loop_count = torch.randint(20, (1,))[0].item()
12420        self.run_test(M(1, index, updates, loop_count), (x,))
12421
12422    @skipIfUnsupportedMinOpsetVersion(9)
12423    def test_index_add_if(self):
12424        class M(torch.nn.Module):
12425            def __init__(self, dim, updates, index_true, index_false):
12426                super().__init__()
12427                self.dim = dim
12428                self.updates = updates
12429                self.index_true = index_true
12430                self.index_false = index_false
12431
12432            def forward(self, x, cond):
12433                if cond:
12434                    x.index_add_(self.dim, self.index_true, self.updates)
12435                else:
12436                    x.index_add_(self.dim, self.index_false, self.updates)
12437                return x
12438
12439        x = torch.ones(1, 4, 3)
12440        updates = torch.tensor(
12441            [[[1, 5, 7], [2, 4, 5], [5, 5, 6], [2, 3, 4]]], dtype=torch.float
12442        )
12443        index_true = torch.tensor([0, 2, 3, 1])
12444        index_false = torch.tensor([1, 0, 2, 3])
12445        cond = torch.tensor(1, dtype=torch.bool)
12446        self.run_test(
12447            torch.jit.script(M(1, updates, index_true, index_false)), (x, cond)
12448        )
12449
12450    @skipIfUnsupportedMinOpsetVersion(9)
12451    def test_index_add_dynamic_axes(self):
12452        class M(torch.nn.Module):
12453            def __init__(self, dim, index, updates):
12454                super().__init__()
12455                self.dim = dim
12456                self.index = index
12457                self.updates = updates
12458
12459            def forward(self, x):
12460                x.index_add_(self.dim, self.index, self.updates)
12461                return x
12462
12463        x = torch.ones(1, 4, 3)
12464        updates = torch.tensor(
12465            [[[1, 5, 7], [2, 4, 5], [5, 5, 6], [2, 3, 4]]], dtype=torch.float
12466        )
12467        index = torch.tensor([0, 2, 3, 1])
12468
12469        self.run_test(
12470            M(1, index, updates),
12471            (x,),
12472            input_names=["input_1"],
12473            dynamic_axes={"input_1": [0, 1]},
12474        )
12475
12476    def test_roll(self):
12477        class M(torch.nn.Module):
12478            def __init__(self, shifts, dims):
12479                super().__init__()
12480                self.shifts = shifts
12481                self.dims = dims
12482
12483            def forward(self, x):
12484                return torch.roll(x, self.shifts, self.dims)
12485
12486        x = torch.randn(2, 3, 4)
12487        self.run_test(M([1, 1], [1, 0]), (x,))
12488        self.run_test(M([0, 1, 2], [1, 0, 2]), (x,))
12489        self.run_test(M(2, 1), (x,))
12490        self.run_test(M([-1, 3], [-2, -1]), (x,))
12491
12492    def test_sum(self):
12493        class M(torch.nn.Module):
12494            def forward(self, x):
12495                return torch.sum(x)
12496
12497        x = torch.ones(12, 3)
12498        self.run_test(M(), (x,), input_names=["x"], dynamic_axes={"x": [0]})
12499
12500    @skipShapeChecking
12501    def test_sum_empty_tensor(self):
12502        class M(torch.nn.Module):
12503            def forward(self, x):
12504                return x[0:0].sum(), x.sum()
12505
12506        x = torch.ones(12)
12507        self.run_test(M(), (x,))
12508
12509        x = torch.ones(2, 0, 3)
12510        self.run_test(M(), (x,))
12511
12512        x = torch.ones(0)
12513        self.run_test(M(), (x,))
12514
12515    @skipIfUnsupportedMinOpsetVersion(11)
12516    def test_broad_cast_tensors(self):
12517        class M(torch.nn.Module):
12518            def forward(self, x, y):
12519                m = torch.broadcast_tensors(x, y)
12520                return m
12521
12522        x = torch.randint(5, (1,))
12523        y = torch.randint(5, (5,))
12524
12525        self.run_test(M(), (x, y))
12526
12527        x = torch.randint(5, (4, 2, 1, 4))
12528        y = torch.randint(5, (2, 3, 1))
12529
12530        self.run_test(M(), (x, y))
12531
12532        x = torch.randn(2, 1, 4)
12533        y = torch.randn(5, 2, 3, 1)
12534
12535        self.run_test(M(), (x, y))
12536
12537    @skipIfUnsupportedMinOpsetVersion(14)
12538    def test_scaled_dot_product_attention(self):
12539        class M(torch.nn.Module):
12540            def forward(self, q, k, v):
12541                return torch.nn.functional.scaled_dot_product_attention(
12542                    q, k, v, scale=1.0
12543                )
12544
12545        # Parameters
12546        batch_size = 2  # Number of samples in the batch
12547        num_heads = 4  # Number of attention heads
12548        seq_length = 5  # Sequence length
12549        head_dim = 8  # Dimensionality of each head
12550
12551        # Create random query, key, and value tensors
12552        q = torch.randn(batch_size, num_heads, seq_length, head_dim)
12553        k = torch.randn(batch_size, num_heads, seq_length, head_dim)
12554        v = torch.randn(batch_size, num_heads, seq_length, head_dim)
12555
12556        self.run_test(M(), (q, k, v))
12557
12558    @skipScriptTest()
12559    @skipIfUnsupportedMinOpsetVersion(11)
12560    def test_dist_normal(self):
12561        class M(torch.nn.Module):
12562            def forward(self, x, y):
12563                return torch.distributions.Normal(x, y).sample().size(0), x, y
12564
12565        self.run_test(M(), (torch.tensor([0.0]), torch.tensor([[1.0], [2.0]])))
12566        self.run_test(M(), (torch.tensor([0.0]), torch.tensor([1.0])))
12567
12568        self.run_test(
12569            M(),
12570            (
12571                torch.tensor([[[0.0], [10.0]], [[2.0], [8.0]], [[2.0], [8.0]]]),
12572                torch.tensor([[1.0], [3.0]]),
12573            ),
12574        )
12575
12576    @skipScriptTest()
12577    @skipIfUnsupportedMinOpsetVersion(11)
12578    def test_dist_normal_correctness(self):
12579        class M(torch.nn.Module):
12580            def forward(self, x, y):
12581                return torch.distributions.Normal(x, y).sample([20000])
12582
12583        expected_mean = 5.0
12584        expected_std = 10.0
12585
12586        model_export = M()
12587        dummy_input = (torch.tensor([expected_mean]), torch.tensor([expected_std]))
12588        model_onnx = io.BytesIO()
12589        torch.onnx.export(
12590            model_export, dummy_input, model_onnx, opset_version=self.opset_version
12591        )
12592        ort_sess = verification._ort_session(model_onnx)
12593        ort_out = verification._run_onnx(ort_sess, inputs=dummy_input)
12594
12595        actual_std = np.std(ort_out)
12596        actual_mean = np.mean(ort_out)
12597
12598        assert (
12599            abs(abs(actual_mean) - expected_mean) <= expected_mean * 0.1
12600        ), "the gap of mean between ort outputs and expected one is unacceptable."
12601        assert (
12602            abs(abs(actual_std) - expected_std) <= expected_std * 0.1
12603        ), "the gap of variance between ort outputs and expected one is unacceptable."
12604
12605    @skipScriptTest()
12606    @skipIfUnsupportedMinOpsetVersion(11)
12607    def test_nn_init_normal_correctness(self):
12608        expected_mean = 5.0
12609        expected_std = 10.0
12610
12611        class M(torch.nn.Module):
12612            def forward(self):
12613                x = torch.ones([]).new_empty(1, 400, 50)
12614                torch.nn.init.normal_(x, expected_mean, expected_std)
12615                return x
12616
12617        model_export = M()
12618        model_onnx = io.BytesIO()
12619        test_inputs = ()
12620        torch.onnx.export(
12621            model_export, test_inputs, model_onnx, opset_version=self.opset_version
12622        )
12623        ort_sess = verification._ort_session(model_onnx)
12624        ort_out = verification._run_onnx(ort_sess, inputs=test_inputs)
12625
12626        actual_std = np.std(ort_out)
12627        actual_mean = np.mean(ort_out)
12628
12629        assert (
12630            abs(abs(actual_mean) - expected_mean) <= expected_mean * 0.1
12631        ), "the gap of mean between ort outputs and expected one is unacceptable."
12632        assert (
12633            abs(abs(actual_std) - expected_std) <= expected_std * 0.1
12634        ), "the gap of variance between ort outputs and expected one is unacceptable."
12635
12636    @skipScriptTest()
12637    @skipIfUnsupportedMinOpsetVersion(11)
12638    def test_dist_uniform(self):
12639        class M(torch.nn.Module):
12640            def forward(self, x, y):
12641                return torch.distributions.Uniform(x, y).sample().size(0), x, y
12642
12643        self.run_test(M(), (torch.tensor([0.0]), torch.tensor([10.0])))
12644        self.run_test(M(), (torch.tensor([[0.0], [6.0]]), torch.tensor([[1.0], [7.0]])))
12645        self.run_test(
12646            M(), (torch.tensor([1.0]), torch.tensor([[10.0], [7.0], [9.0], [20.0]]))
12647        )
12648
12649    @skipScriptTest()
12650    @skipIfUnsupportedMinOpsetVersion(11)
12651    def test_dist_uniform_correctness(self):
12652        class M(torch.nn.Module):
12653            def forward(self, x, y):
12654                return torch.distributions.Uniform(x, y).sample([10000])
12655
12656        expected_min = 5.0
12657        expected_max = 10.0
12658        expected_mean = (expected_min + expected_max) / 2
12659
12660        model_export = M()
12661        dummy_input = (torch.tensor([expected_min]), torch.tensor([expected_max]))
12662        model_onnx = io.BytesIO()
12663        torch.onnx.export(
12664            model_export, dummy_input, model_onnx, opset_version=self.opset_version
12665        )
12666        ort_sess = verification._ort_session(model_onnx)
12667
12668        ort_out = verification._run_onnx(ort_sess, inputs=dummy_input)
12669        actual_min = np.min(ort_out)
12670        actual_max = np.max(ort_out)
12671        actual_mean = np.mean(ort_out)
12672
12673        assert (
12674            actual_min >= expected_min
12675        ), "the minimum value of ort outputs is out of scope."
12676        assert (
12677            actual_max <= expected_max
12678        ), "the maximum value of ort outputs is out of scope."
12679        assert (
12680            abs(actual_mean - expected_mean) <= expected_mean * 0.05
12681        ), "the mean value of ort outputs is out of scope."
12682
12683    @skipIfUnsupportedMinOpsetVersion(13)
12684    def test_sequence_to_int(self):
12685        class M(torch.nn.Module):
12686            def forward(self, x):
12687                result = torch.tensor([2 for i in range(x.size()[0])], dtype=torch.int)
12688                return x, result
12689
12690        x = torch.randn(10, 5)
12691        self.run_test(M(), (x,))
12692
12693    @skipIfUnsupportedMinOpsetVersion(13)
12694    def test_sequence_to_float(self):
12695        class M(torch.nn.Module):
12696            def forward(self, x):
12697                result = torch.tensor(
12698                    [1.1 for i in range(x.size()[0])], dtype=torch.float
12699                )
12700                return x, result
12701
12702        x = torch.randn(10, 5)
12703        self.run_test(M(), (x,))
12704
12705    @skipIfUnsupportedMinOpsetVersion(13)
12706    def test_sequence_to_bool(self):
12707        class M(torch.nn.Module):
12708            def forward(self, x):
12709                result = torch.tensor(
12710                    [False for i in range(x.size()[0])], dtype=torch.bool
12711                )
12712                return x, result
12713
12714        x = torch.randn(10, 5)
12715        self.run_test(M(), (x,))
12716
12717    def test_tuple_output_from_if_with_raised_exception(self):
12718        class M(torch.nn.Module):
12719            def forward(self, t: Tensor) -> Tuple[Tensor, Tensor]:
12720                if float(t) < 0:
12721                    raise Exception("Negative input")  # noqa: TRY002
12722                else:
12723                    return torch.zeros(5), torch.zeros(5)
12724
12725        x = torch.zeros(1)
12726        self.run_test(torch.jit.script(M()), (x,))
12727
12728    # NOTE: For quantization tests, choose scale and zero point carefully
12729    #       such that inputs and outputs do not always overflow/underflow.
12730    #       Otherwise test results could be inaccurate.
12731    @skipIfUnsupportedMinOpsetVersion(10)
12732    def test_quantized_linear(self):
12733        model = torch.ao.nn.quantized.Linear(4, 8)
12734        # Set fixed weight to avoid flaky test.
12735        weight = torch.quantize_per_tensor(
12736            torch.arange(32, dtype=torch.float).view(8, 4), 0.5, 0, torch.qint8
12737        )
12738        # Set non-zero bias.
12739        bias = torch.arange(8, dtype=torch.float)
12740        model.set_weight_bias(weight, bias)
12741        # Set fixed input to avoid flaky test.
12742        input = torch.randn(4, 4)
12743        input = torch.arange(16, dtype=torch.float).view(4, 4) - 8
12744        input_tensor = torch.quantize_per_tensor(input, 0.5, 128, torch.quint8)
12745        self.run_test(model, input_tensor)
12746
12747    @skipIfUnsupportedMinOpsetVersion(10)
12748    def test_quantized_conv1d(self):
12749        model = torch.ao.nn.quantized.Conv1d(16, 33, 3, stride=2)
12750        # Manually initialize model weight and bias to random numbers.
12751        # By default all zeros.
12752        q_weight = torch.quantize_per_tensor(
12753            torch.randn(33, 16, 3), 0.5, 0, torch.qint8
12754        )
12755        bias = torch.arange(33).to(torch.float) - 16
12756        model.set_weight_bias(q_weight, bias)
12757        input = torch.randn(3, 16, 32)
12758        q_input = torch.quantize_per_tensor(input, 0.5, 128, torch.quint8)
12759        self.run_test(model, q_input)
12760
12761    @skipIfUnsupportedMinOpsetVersion(10)
12762    def test_quantized_conv2d(self):
12763        model = torch.ao.nn.quantized.Conv2d(16, 33, 3, stride=2)
12764        # Manually initialize model weight and bias to random numbers.
12765        # By default all zeros.
12766        q_weight = torch.quantize_per_tensor(
12767            torch.randn(33, 16, 3, 3), 0.5, 0, torch.qint8
12768        )
12769        bias = torch.arange(33).to(torch.float) - 16
12770        model.set_weight_bias(q_weight, bias)
12771        input = torch.randn(3, 16, 32, 32)
12772        q_input = torch.quantize_per_tensor(input, 0.5, 128, torch.quint8)
12773        self.run_test(model, q_input)
12774
12775    @skipIfUnsupportedMinOpsetVersion(10)
12776    @skipIfQuantizationBackendQNNPack
12777    def test_quantized_conv3d(self):
12778        model = torch.ao.nn.quantized.Conv3d(16, 33, [2, 3, 4], stride=[3, 1, 2])
12779        # Manually initialize model weight and bias to random numbers.
12780        # By default all zeros.
12781        q_weight = torch.quantize_per_tensor(
12782            torch.randn(33, 16, 2, 3, 4), 0.5, 0, torch.qint8
12783        )
12784        bias = torch.arange(33).to(torch.float) - 16
12785        model.set_weight_bias(q_weight, bias)
12786        input = torch.randn(3, 16, 8, 8, 8)
12787        q_input = torch.quantize_per_tensor(input, 0.5, 128, torch.quint8)
12788        self.run_test(model, q_input)
12789
12790    @skipIfUnsupportedMinOpsetVersion(10)
12791    def test_quantized_adaptive_avg_pool2d(self):
12792        model = torch.nn.AdaptiveAvgPool2d((5, 7))
12793        input = torch.randn(4, 3, 10, 14)
12794        q_input = torch.quantize_per_tensor(input, 0.2, 128, torch.quint8)
12795        self.run_test(model, q_input)
12796
12797    @skipIfUnsupportedMinOpsetVersion(10)
12798    def test_quantized_conv1d_relu(self):
12799        model = torch.ao.nn.intrinsic.quantized.ConvReLU1d(16, 33, 3, stride=2)
12800        # Manually initialize model weight and bias to random numbers.
12801        # By default all zeros.
12802        q_weight = torch.quantize_per_tensor(
12803            torch.randn(33, 16, 3), 0.5, 0, torch.qint8
12804        )
12805        bias = torch.arange(33).to(torch.float) - 16
12806        model.set_weight_bias(q_weight, bias)
12807        input = torch.randn(3, 16, 32)
12808        q_input = torch.quantize_per_tensor(input, 0.5, 128, torch.quint8)
12809        self.run_test(model, q_input)
12810
12811    @skipIfUnsupportedMinOpsetVersion(10)
12812    def test_quantized_conv2d_relu(self):
12813        model = torch.ao.nn.intrinsic.quantized.ConvReLU2d(16, 33, 3, stride=2)
12814        # Manually initialize model weight and bias to random numbers.
12815        # By default all zeros.
12816        q_weight = torch.quantize_per_tensor(
12817            torch.randn(33, 16, 3, 3), 0.5, 0, torch.qint8
12818        )
12819        bias = torch.arange(33).to(torch.float) - 16
12820        model.set_weight_bias(q_weight, bias)
12821        input = torch.randn(3, 16, 32, 32)
12822        q_input = torch.quantize_per_tensor(input, 0.5, 128, torch.quint8)
12823        self.run_test(model, q_input)
12824
12825    @skipIfUnsupportedMinOpsetVersion(10)
12826    @skipIfQuantizationBackendQNNPack
12827    def test_quantized_conv3d_relu(self):
12828        model = torch.ao.nn.intrinsic.quantized.ConvReLU3d(
12829            16, 33, [2, 3, 4], stride=[3, 1, 2]
12830        )
12831        # Manually initialize model weight and bias to random numbers.
12832        # By default all zeros.
12833        q_weight = torch.quantize_per_tensor(
12834            torch.randn(33, 16, 2, 3, 4), 0.5, 0, torch.qint8
12835        )
12836        bias = torch.arange(33).to(torch.float) - 16
12837        model.set_weight_bias(q_weight, bias)
12838        input = torch.randn(3, 16, 8, 8, 8)
12839        q_input = torch.quantize_per_tensor(input, 0.5, 128, torch.quint8)
12840        self.run_test(model, q_input)
12841
12842    @skipIfUnsupportedMinOpsetVersion(10)
12843    def test_quantized_conv_transpose1d(self):
12844        model = torch.ao.nn.quantized.ConvTranspose1d(
12845            16, 33, 3, output_padding=1, stride=2
12846        )
12847        # Manually initialize model weight and bias to random numbers.
12848        # By default all zeros.
12849        q_weight = torch.quantize_per_tensor(
12850            torch.randn(16, 33, 3), 0.5, 0, torch.qint8
12851        )
12852        bias = torch.arange(33).to(torch.float) - 16
12853        model.set_weight_bias(q_weight, bias)
12854        input = torch.randn(3, 16, 32)
12855        q_input = torch.quantize_per_tensor(input, 0.5, 128, torch.quint8)
12856        self.run_test(model, q_input)
12857
12858    @skipIfUnsupportedMinOpsetVersion(10)
12859    def test_quantized_conv_transpose2d(self):
12860        model = torch.ao.nn.quantized.ConvTranspose2d(
12861            16, 33, 3, output_padding=(0, 1), stride=2
12862        )
12863        # Manually initialize model weight and bias to random numbers.
12864        # By default all zeros.
12865        q_weight = torch.quantize_per_tensor(
12866            torch.randn(16, 33, 3, 3), 0.5, 0, torch.qint8
12867        )
12868        bias = torch.arange(33).to(torch.float) - 16
12869        model.set_weight_bias(q_weight, bias)
12870        input = torch.randn(3, 16, 32, 32)
12871        q_input = torch.quantize_per_tensor(input, 0.5, 128, torch.quint8)
12872        self.run_test(model, q_input)
12873
12874    @skipIfUnsupportedMinOpsetVersion(10)
12875    @skipIfQuantizationBackendQNNPack
12876    def test_quantized_conv_transpose3d(self):
12877        model = torch.ao.nn.quantized.ConvTranspose3d(
12878            16, 33, [2, 3, 4], output_padding=(0, 1, 2), stride=[3, 1, 2]
12879        )
12880        # Manually initialize model weight and bias to random numbers.
12881        # By default all zeros.
12882        q_weight = torch.quantize_per_tensor(
12883            torch.randn(16, 33, 2, 3, 4), 0.5, 0, torch.qint8
12884        )
12885        bias = torch.arange(33).to(torch.float) - 16
12886        model.set_weight_bias(q_weight, bias)
12887        input = torch.randn(3, 16, 8, 8, 8)
12888        q_input = torch.quantize_per_tensor(input, 0.5, 128, torch.quint8)
12889        self.run_test(model, q_input)
12890
12891    @common_utils.parametrize(
12892        "function_or_module",
12893        [
12894            common_utils.subtest(
12895                torch.nn.ReLU(),
12896                name="relu",
12897            ),
12898            common_utils.subtest(
12899                torch.nn.LeakyReLU(),
12900                name="leaky_relu",
12901            ),
12902            common_utils.subtest(
12903                torch.ao.nn.quantized.LeakyReLU(2.0, 1),
12904                name="quantized_leaky_relu",
12905            ),
12906            common_utils.subtest(
12907                torch.ao.nn.quantized.Hardswish(2.0, 1),
12908                name="quantized_hardswish",
12909            ),
12910            common_utils.subtest(
12911                torch.nn.Sigmoid(),
12912                name="sigmoid",
12913            ),
12914            common_utils.subtest(
12915                torch.ao.nn.quantized.Sigmoid(2.0, 1),
12916                name="quantized_sigmoid",
12917            ),
12918            common_utils.subtest(
12919                torch.nn.Hardsigmoid(),
12920                name="hardsigmoid",
12921            ),
12922            common_utils.subtest(
12923                torch.nn.Tanh(),
12924                name="tanh",
12925            ),
12926            common_utils.subtest(
12927                torch.nn.Hardtanh(),
12928                name="hardtanh",
12929            ),
12930            common_utils.subtest(
12931                lambda x: torch.transpose(x, 0, 1),
12932                name="transpose",
12933            ),
12934            common_utils.subtest(
12935                lambda x: x.expand(2, 4, 2, 3),
12936                name="expand",
12937            ),
12938            common_utils.subtest(
12939                lambda x: x.view(1, 4, 6),
12940                name="view",
12941            ),
12942            common_utils.subtest(
12943                lambda x: x.select(1, 1),
12944                name="select",
12945            ),
12946            common_utils.subtest(
12947                torch.ao.nn.quantized.LayerNorm(
12948                    [4, 2, 3],
12949                    torch.nn.Parameter(torch.ones([4, 2, 3])),
12950                    torch.nn.Parameter(torch.zeros([4, 2, 3])),
12951                    2.0,
12952                    1,
12953                ),
12954                name="layer_norm",
12955            ),
12956            common_utils.subtest(
12957                torch.ao.nn.quantized.InstanceNorm1d(
12958                    2,
12959                    torch.nn.Parameter(torch.ones(4)),
12960                    torch.nn.Parameter(torch.zeros(4)),
12961                    2.0,
12962                    1,
12963                ),
12964                name="instance_norm",
12965            ),
12966            common_utils.subtest(
12967                torch.ao.nn.quantized.GroupNorm(
12968                    2,
12969                    4,
12970                    torch.nn.Parameter(torch.zeros(4)),
12971                    torch.nn.Parameter(torch.zeros(4)),
12972                    2.0,
12973                    1,
12974                ),
12975                name="group_norm",
12976            ),
12977            common_utils.subtest(
12978                lambda x: torch.as_strided(x, (2, 2), (1, 2)),
12979                name="as_strided",
12980            ),
12981        ],
12982    )
12983    @skipScriptTest()
12984    @skipIfUnsupportedMinOpsetVersion(10)
12985    def test_quantized_unary_ops(self, function_or_module):
12986        input = torch.randn(1, 4, 2, 3)
12987        q_input = torch.quantize_per_tensor(input, 0.26, 128, torch.quint8)
12988
12989        class Model(torch.nn.Module):
12990            def __init__(self, function_or_module):
12991                super().__init__()
12992                self.function_or_module = function_or_module
12993
12994            def forward(self, x):
12995                return self.function_or_module(x)
12996
12997        self.run_test(Model(function_or_module), q_input)
12998
12999    @skipIfUnsupportedMinOpsetVersion(10)
13000    def test_quantized_flatten(self):
13001        class FlattenModel(torch.nn.Module):
13002            def forward(self, input):
13003                return torch.flatten(input)
13004
13005        x = torch.quantize_per_tensor(torch.randn(1, 2, 3, 4), 1, 0, torch.quint8)
13006        self.run_test(FlattenModel(), x)
13007
13008    @skipIfUnsupportedMinOpsetVersion(10)
13009    @skipScriptTest()  # torch.jit.frontend.FrontendError: Cannot instantiate class 'QFunctional' in a script function:
13010    def test_quantized_cat_when_concatinating_the_same_tensor(self):
13011        class QuantizedSelfConcatenationModel(torch.nn.Module):
13012            def forward(self, x):
13013                return torch.ao.nn.quantized.QFunctional().cat((x, x), dim=1)
13014
13015        q_input = torch.quantize_per_tensor(torch.ones(2, 3), 0.26, 128, torch.quint8)
13016        self.run_test(QuantizedSelfConcatenationModel(), q_input)
13017
13018    @common_utils.parametrize(
13019        "x, y",
13020        [
13021            common_utils.subtest(
13022                [
13023                    torch.quantize_per_tensor(
13024                        torch.ones(2, 3), 0.26, 128, torch.quint8
13025                    ),
13026                    torch.quantize_per_tensor(
13027                        torch.zeros(1, 3), 0.26, 128, torch.quint8
13028                    ),
13029                ],
13030                name="different_shape",
13031            ),
13032            common_utils.subtest(
13033                [
13034                    torch.quantize_per_tensor(
13035                        torch.ones(2, 3), 0.26, 128, torch.quint8
13036                    ),
13037                    torch.quantize_per_tensor(torch.ones(2, 3), 42, 1, torch.quint8),
13038                ],
13039                name="different_scale",
13040            ),
13041            common_utils.subtest(
13042                [
13043                    torch.quantize_per_tensor(
13044                        torch.ones(2, 3), 0.26, 128, torch.quint8
13045                    ),
13046                    torch.quantize_per_tensor(torch.ones(2, 3), 0.26, 63, torch.quint8),
13047                ],
13048                name="different_zero_point",
13049            ),
13050            common_utils.subtest(
13051                [
13052                    torch.quantize_per_tensor(
13053                        torch.ones(2, 3), 0.26, 128, torch.quint8
13054                    ),
13055                    torch.quantize_per_tensor(torch.ones(2, 3), 0.1, 63, torch.quint8),
13056                ],
13057                name="different_zero_point_and_scale",
13058            ),
13059        ],
13060    )
13061    @skipIfUnsupportedMinOpsetVersion(10)
13062    @skipScriptTest()  # torch.jit.frontend.FrontendError: Cannot instantiate class 'QFunctional' in a script function:
13063    def test_quantized_cat(self, x: torch.Tensor, y: torch.Tensor):
13064        class QuantizedConcatenationModel(torch.nn.Module):
13065            def forward(self, x, y):
13066                return torch.ao.nn.quantized.QFunctional().cat((x, y), dim=0)
13067
13068        self.run_test(QuantizedConcatenationModel(), (x, y))
13069
13070    @skipIfUnsupportedMinOpsetVersion(10)
13071    # torch.jit.frontend.FrontendError:
13072    # Cannot instantiate class 'QFunctional' in a script function
13073    @skipScriptTest()
13074    def test_quantized_arithmetic_qfunctional(self):
13075        x = torch.quantize_per_tensor(torch.randn(3, 4), 0.2, 128, torch.quint8)
13076        y = torch.quantize_per_tensor(torch.randn(3, 4), 0.2, 128, torch.quint8)
13077
13078        class ArithmeticModel(torch.nn.Module):
13079            def forward(self, x, y):
13080                o = torch.ao.nn.quantized.QFunctional().add(x, y)
13081                o = torch.ao.nn.quantized.QFunctional().mul(o, x)
13082                return o
13083
13084        self.run_test(ArithmeticModel(), (x, y))
13085
13086    @skipIfUnsupportedMinOpsetVersion(10)
13087    def test_quantized_arithmetic(self):
13088        x = torch.quantize_per_tensor(torch.randn(3, 4), 0.2, 128, torch.quint8)
13089        y = torch.quantize_per_tensor(torch.randn(3, 4), 0.2, 128, torch.quint8)
13090
13091        class ArithmeticModel2(torch.nn.Module):
13092            def forward(self, x, y):
13093                o = torch.ops.quantized.add(x, y, 0.4, 100)
13094                o = torch.ops.quantized.mul(o, x, 0.4, 100)
13095                return o
13096
13097        self.run_test(ArithmeticModel2(), (x, y))
13098
13099    @skipIfUnsupportedMinOpsetVersion(10)
13100    def test_quantize_per_tensor(self):
13101        class Module(torch.nn.Module):
13102            def forward(self, x):
13103                return (
13104                    torch.quantize_per_tensor(x, 0.2, 0, torch.qint8),
13105                    torch.quantize_per_tensor(x, 0.2, 128, torch.quint8),
13106                )
13107
13108        x = torch.randn(4, 6)
13109        self.run_test(Module(), x)
13110
13111    @skipIfUnsupportedMinOpsetVersion(10)
13112    def test_dequantize(self):
13113        class Module(torch.nn.Module):
13114            def forward(self, x):
13115                return torch.dequantize(x)
13116
13117        x = torch.quantize_per_tensor(torch.randn(3, 4), 0.2, 0, torch.qint8)
13118        self.run_test(Module(), x)
13119
13120    @skipIfUnsupportedMinOpsetVersion(13)
13121    def test_qat_linear_per_channel(self):
13122        class M(torch.nn.Module):
13123            def __init__(self) -> None:
13124                super().__init__()
13125                self.quant = torch.ao.quantization.QuantStub()
13126                self.linear = torch.nn.Linear(4, 3)
13127                self.dequant = torch.ao.quantization.DeQuantStub()
13128
13129            def forward(self, x):
13130                x = self.quant(x)
13131                x = self.linear(x)
13132                x = self.dequant(x)
13133                return x
13134
13135        model = M()
13136        model.qconfig = torch.ao.quantization.get_default_qconfig("fbgemm")
13137        model = torch.ao.quantization.prepare_qat(model)
13138        # Set fixed weight and bias to avoid flaky test.
13139        model.linear.weight = torch.nn.Parameter(
13140            _construct_tensor_for_quantization_test((3, 4))
13141        )
13142        model.linear.bias = torch.nn.Parameter(torch.arange(3, dtype=torch.float))
13143        model = torch.ao.quantization.convert(model)
13144
13145        # Set fixed input to avoid flaky test.
13146        input = _construct_tensor_for_quantization_test((4, 4), offset=-8)
13147        self.run_test(model, input)
13148
13149    @unittest.skip(
13150        "ORT fails with Validating no unexpected access using an invalid node_index on torch converted model"
13151    )
13152    @skipIfUnsupportedMinOpsetVersion(13)
13153    def test_quantized_list_of_inputs_with_cat(self):
13154        class TestModel(torch.nn.Module):
13155            def __init__(self) -> None:
13156                super().__init__()
13157                self.quant = torch.ao.quantization.QuantStub()
13158                self.dequant = torch.ao.quantization.DeQuantStub()
13159
13160            def forward(self, x):
13161                x = self.quant(x)
13162                x = torch.cat([x, x], 1)
13163                x = self.dequant(x)
13164                return x
13165
13166        model = TestModel()
13167        model.qconfig = torch.ao.quantization.get_default_qconfig("fbgemm")
13168        model = torch.ao.quantization.prepare_qat(model)
13169        model = torch.ao.quantization.convert(model)
13170        x = torch.randn(2, 4, 6)
13171        self.run_test(model, x)
13172
13173    @skipIfUnsupportedMinOpsetVersion(13)
13174    def test_qat_relu(self):
13175        class M(torch.nn.Module):
13176            def __init__(self) -> None:
13177                super().__init__()
13178                self.quant = torch.ao.quantization.QuantStub()
13179                self.relu = torch.nn.ReLU()
13180                self.dequant = torch.ao.quantization.DeQuantStub()
13181
13182            def forward(self, x):
13183                x = self.quant(x)
13184                x = self.relu(x)
13185                x = self.dequant(x)
13186                return x
13187
13188        model = M()
13189        model.qconfig = torch.ao.quantization.get_default_qconfig("fbgemm")
13190        model = torch.ao.quantization.prepare_qat(model)
13191        model = torch.ao.quantization.convert(model)
13192        input = torch.randn(8, 4)
13193        self.run_test(model, input)
13194
13195    @skipIfUnsupportedMinOpsetVersion(13)
13196    def test_qat_conv2d(self):
13197        class M(torch.nn.Module):
13198            def __init__(self) -> None:
13199                super().__init__()
13200                self.quant = torch.ao.quantization.QuantStub()
13201                self.conv = torch.nn.Conv2d(4, 2, 3, stride=2)
13202                self.dequant = torch.ao.quantization.DeQuantStub()
13203
13204            def forward(self, x):
13205                x = self.quant(x)
13206                x = self.conv(x)
13207                x = self.dequant(x)
13208                return x
13209
13210        model = M()
13211        model.qconfig = torch.ao.quantization.get_default_qconfig("fbgemm")
13212        model = torch.ao.quantization.prepare_qat(model)
13213        # Set fixed weight and bias to avoid flaky test.
13214        model.conv.weight = torch.nn.Parameter(
13215            _construct_tensor_for_quantization_test((2, 4, 3, 3), max_val=2)
13216        )
13217        model.conv.bias = torch.nn.Parameter(torch.tensor([0.0, 1.0]))
13218        model = torch.ao.quantization.convert(model)
13219
13220        # Set fixed input to avoid flaky test.
13221        input = _construct_tensor_for_quantization_test(
13222            (3, 4, 8, 8), offset=-384, max_val=12
13223        )
13224        self.run_test(model, input)
13225
13226    @skipIfUnsupportedMinOpsetVersion(13)
13227    def test_qat_conv2d_relu(self):
13228        class M(torch.nn.Module):
13229            def __init__(self) -> None:
13230                super().__init__()
13231                self.quant = torch.ao.quantization.QuantStub()
13232                self.conv = torch.nn.Conv2d(4, 2, 3, stride=2)
13233                self.relu = torch.nn.ReLU()
13234                self.dequant = torch.ao.quantization.DeQuantStub()
13235
13236            def forward(self, x):
13237                x = self.quant(x)
13238                x = self.conv(x)
13239                x = self.relu(x)
13240                x = self.dequant(x)
13241                return x
13242
13243        model = M()
13244        model.qconfig = torch.ao.quantization.get_default_qconfig("fbgemm")
13245        model = torch.ao.quantization.prepare_qat(model)
13246        # Set fixed weight and bias to avoid flaky test.
13247        model.conv.weight = torch.nn.Parameter(
13248            _construct_tensor_for_quantization_test((2, 4, 3, 3), max_val=2)
13249        )
13250        model.conv.bias = torch.nn.Parameter(torch.tensor([0.0, 1.0]))
13251        model = torch.ao.quantization.convert(model)
13252
13253        # Set fixed input to avoid flaky test.
13254        input = _construct_tensor_for_quantization_test(
13255            (3, 4, 8, 8), offset=-384, max_val=12
13256        )
13257        self.run_test(model, input)
13258
13259    @skipIfUnsupportedMinOpsetVersion(13)
13260    def test_qat_conv2d_relu_fused(self):
13261        class M(torch.nn.Module):
13262            def __init__(self) -> None:
13263                super().__init__()
13264                self.quant = torch.ao.quantization.QuantStub()
13265                self.conv = torch.nn.Conv2d(4, 2, 3, stride=2)
13266                self.relu = torch.nn.ReLU()
13267                self.dequant = torch.ao.quantization.DeQuantStub()
13268
13269            def forward(self, x):
13270                x = self.quant(x)
13271                x = self.conv(x)
13272                x = self.relu(x)
13273                x = self.dequant(x)
13274                return x
13275
13276        model = M()
13277        model.qconfig = torch.ao.quantization.get_default_qconfig("fbgemm")
13278        model = torch.ao.quantization.fuse_modules(model.eval(), [["conv", "relu"]])
13279        model = torch.ao.quantization.prepare_qat(model.train())
13280        # Set fixed weight and bias to avoid flaky test.
13281        model.conv.weight = torch.nn.Parameter(
13282            _construct_tensor_for_quantization_test((2, 4, 3, 3), max_val=2)
13283        )
13284        model.conv.bias = torch.nn.Parameter(torch.tensor([0.0, 1.0]))
13285        model = torch.ao.quantization.convert(model)
13286
13287        # Set fixed input to avoid flaky test.
13288        input = _construct_tensor_for_quantization_test(
13289            (3, 4, 8, 8), offset=-384, max_val=12
13290        )
13291        self.run_test(model, input)
13292
13293    @skipIfUnsupportedMinOpsetVersion(13)
13294    def test_qat_linear_relu_fused(self):
13295        class M(torch.nn.Module):
13296            def __init__(self) -> None:
13297                super().__init__()
13298                self.quant = torch.ao.quantization.QuantStub()
13299                self.linear = torch.nn.Linear(4, 2)
13300                self.relu = torch.nn.ReLU()
13301                self.dequant = torch.ao.quantization.DeQuantStub()
13302
13303            def forward(self, x):
13304                x = self.quant(x)
13305                x = self.linear(x)
13306                x = self.relu(x)
13307                x = self.dequant(x)
13308                return x
13309
13310        model = M()
13311        model.qconfig = torch.ao.quantization.get_default_qconfig("fbgemm")
13312        model = torch.ao.quantization.fuse_modules(model.eval(), [["linear", "relu"]])
13313        model = torch.ao.quantization.prepare_qat(model.train())
13314        # Set fixed weight and bias to avoid flaky test.
13315        model.linear.weight = torch.nn.Parameter(
13316            _construct_tensor_for_quantization_test((2, 4), max_val=2)
13317        )
13318        model.linear.bias = torch.nn.Parameter(torch.tensor([0.0, 1.0]))
13319        model = torch.ao.quantization.convert(model)
13320
13321        # Set fixed input to avoid flaky test.
13322        input = _construct_tensor_for_quantization_test((3, 4), offset=-384, max_val=12)
13323        self.run_test(model, input)
13324
13325    @skipIfUnsupportedMinOpsetVersion(10)
13326    def test_qat_maxpool2d(self):
13327        class M(torch.nn.Module):
13328            def __init__(self) -> None:
13329                super().__init__()
13330                self.quant = torch.ao.quantization.QuantStub()
13331                self.pool = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
13332                self.dequant = torch.ao.quantization.DeQuantStub()
13333
13334            def forward(self, x):
13335                x = self.quant(x)
13336                x = self.pool(x)
13337                x = self.dequant(x)
13338                return x
13339
13340        model = M()
13341        model.qconfig = torch.ao.quantization.get_default_qconfig("fbgemm")
13342        model = torch.ao.quantization.prepare_qat(model.train())
13343        model = torch.ao.quantization.convert(model)
13344
13345        # Set fixed input to avoid flaky test.
13346        input = _construct_tensor_for_quantization_test((4, 4, 3, 2))
13347        self.run_test(model, input)
13348
13349    @skipIfUnsupportedMinOpsetVersion(10)
13350    @skipScriptTest()  # Scale and Zero-point must be a scalar in ORT:optimization
13351    def test_qat_avg_pool2d(self):
13352        model = torch.nn.Sequential(
13353            torch.ao.quantization.QuantStub(),
13354            torch.nn.AvgPool2d(kernel_size=3, stride=2, padding=1),
13355            torch.ao.quantization.DeQuantStub(),
13356        )
13357        model.qconfig = torch.ao.quantization.get_default_qconfig("fbgemm")
13358        model = torch.ao.quantization.prepare_qat(model.train())
13359        model = torch.ao.quantization.convert(model)
13360        input = _construct_tensor_for_quantization_test((4, 4, 3, 2))
13361        self.run_test(model, input)
13362
13363    @skipIfUnsupportedMinOpsetVersion(11)
13364    def test_qat_upsample_nearest2d(self):
13365        model = torch.nn.Sequential(
13366            torch.ao.quantization.QuantStub(),
13367            torch.nn.UpsamplingNearest2d(scale_factor=1.5),
13368            torch.ao.quantization.DeQuantStub(),
13369        )
13370        model.qconfig = torch.ao.quantization.get_default_qconfig("fbgemm")
13371        model = torch.ao.quantization.prepare_qat(model.train())
13372        model = torch.ao.quantization.convert(model)
13373        input = _construct_tensor_for_quantization_test((4, 3, 2, 2))
13374        self.run_test(model, input)
13375
13376    def test_0d_tensor_broadcast(self):
13377        class fn(torch.nn.Module):
13378            def forward(self, x, y):
13379                a = torch.add(x, y)
13380                b = torch.mul(y, y)
13381                return a + b
13382
13383        x = torch.ones(0)
13384        y = torch.ones(1)
13385        self.run_test(fn(), (x, y), input_names=["x", "y"], output_names=["output"])
13386
13387    @skipIfUnsupportedMinOpsetVersion(9)
13388    def test_convolution_allow_tf32(self):
13389        class Module(torch.nn.Module):
13390            def __init__(self, allow_tf32):
13391                super().__init__()
13392
13393                self.allow_tf32 = allow_tf32
13394                weight = torch.rand(32, 3, 3, 3)
13395                self.weight = torch.nn.Parameter(weight)
13396
13397            def forward(self, x):
13398                if self.allow_tf32:
13399                    return torch._convolution(
13400                        x,
13401                        self.weight,
13402                        None,
13403                        [2, 2],
13404                        [0, 0],
13405                        [1, 1],
13406                        False,
13407                        [0, 0],
13408                        1,
13409                        False,
13410                        False,
13411                        True,
13412                        True,
13413                    )
13414                else:
13415                    return torch._convolution(
13416                        x,
13417                        self.weight,
13418                        None,
13419                        [2, 2],
13420                        [0, 0],
13421                        [1, 1],
13422                        False,
13423                        [0, 0],
13424                        1,
13425                        False,
13426                        False,
13427                        True,
13428                    )
13429
13430        x = torch.randn(1, 3, 224, 224)
13431        self.run_test(Module(False), x, rtol=1e-3, atol=1e-6)
13432        self.run_test(Module(True), x, rtol=1e-3, atol=1e-6)
13433
13434    class AffineGridModule(torch.nn.Module):
13435        def __init__(self, align_corners) -> None:
13436            super().__init__()
13437            self.align_corners = align_corners
13438
13439        def forward(self, theta, size):
13440            return torch.nn.functional.affine_grid(theta, size, self.align_corners)
13441
13442    @skipIfUnsupportedMinOpsetVersion(20)
13443    @skipScriptTest()
13444    @common_utils.parametrize(
13445        "align_corners",
13446        (True, False),
13447    )
13448    @common_utils.parametrize(
13449        "theta_params",
13450        (
13451            (
13452                10,
13453                np.array([0.3, -0.5]),
13454                np.array([1.5, 0.5]),
13455            ),
13456            (
13457                60,
13458                np.array([-0.5, -0.5]),
13459                np.array([3.0, 5.5]),
13460            ),
13461        ),
13462    )
13463    @common_utils.parametrize(
13464        "size",
13465        ([1, 1, 3, 2], [2, 10, 2, 3]),
13466    )
13467    def test_affine_grid_2d(self, align_corners, theta_params, size):
13468        angle, translation, scale = theta_params
13469        theta = np.array([], dtype=np.float32)
13470        for _ in range(size[0]):
13471            angle_radian = (angle / 180.0) * np.pi
13472            theta = np.append(
13473                theta,
13474                [
13475                    np.cos(angle_radian) * scale[0],
13476                    -np.sin(angle_radian),
13477                    translation[0],
13478                    np.sin(angle_radian),
13479                    np.cos(angle_radian) * scale[1],
13480                    translation[1],
13481                ],
13482            )
13483        theta = theta.reshape(size[0], 2, 3)
13484        theta = torch.Tensor(theta)
13485        self.run_test(TestONNXRuntime.AffineGridModule(align_corners), (theta, size))
13486
13487    @skipIfUnsupportedMinOpsetVersion(20)
13488    @skipScriptTest()
13489    @common_utils.parametrize(
13490        "align_corners",
13491        (True, False),
13492    )
13493    @common_utils.parametrize(
13494        "theta_params",
13495        (
13496            (
13497                [10, 20],
13498                np.array([0.3, -0.5, 1.8]),
13499                np.array([1.5, 2.0, 0.5]),
13500            ),
13501            (
13502                [60, -30],
13503                np.array([-0.5, -0.5, 0.3]),
13504                np.array([0.3, 3.0, 5.5]),
13505            ),
13506        ),
13507    )
13508    @common_utils.parametrize(
13509        "size",
13510        ([1, 1, 3, 2, 2], [2, 10, 2, 2, 3]),
13511    )
13512    def test_affine_grid_3d(self, align_corners, theta_params, size):
13513        angle, translation, scale = theta_params
13514        theta = np.array([], dtype=np.float32)
13515        for _ in range(size[0]):
13516            angle_radian_x = (angle[0] / 180.0) * np.pi
13517            angle_radian_y = (angle[1] / 180.0) * np.pi
13518            rot_matrix_x = np.array(
13519                [
13520                    [1, 0, 0],
13521                    [0, np.cos(angle_radian_x), -np.sin(angle_radian_x)],
13522                    [0, np.sin(angle_radian_x), np.cos(angle_radian_x)],
13523                ]
13524            )
13525            rot_matrix_y = np.array(
13526                [
13527                    [np.cos(angle_radian_y), 0, np.sin(angle_radian_y)],
13528                    [0, 1, 0],
13529                    [-np.sin(angle_radian_y), 0, np.cos(angle_radian_y)],
13530                ]
13531            )
13532            rot_matrix = np.matmul(rot_matrix_x, rot_matrix_y)
13533            rot_matrix = rot_matrix * scale.reshape(3, 1)
13534            rot_matrix = np.append(rot_matrix, np.reshape(translation, (3, 1)), axis=1)
13535            theta = np.append(theta, rot_matrix.flatten())
13536
13537        theta = theta.reshape(size[0], 3, 4)
13538        theta = torch.Tensor(theta)
13539        self.run_test(TestONNXRuntime.AffineGridModule(align_corners), (theta, size))
13540
13541    @skipIfUnsupportedMinOpsetVersion(16)
13542    @common_utils.parametrize(
13543        "mode",
13544        ("bilinear", "nearest", "bicubic"),
13545    )
13546    @common_utils.parametrize(
13547        "padding_mode",
13548        ("zeros", "border", "reflection"),
13549    )
13550    @common_utils.parametrize(
13551        "align_corners",
13552        (True, False),
13553        name_fn=lambda align_corners: str(align_corners),
13554    )
13555    def test_grid_sample(self, mode, padding_mode, align_corners):
13556        n, c, d_in, h_in, w_in, d_out, h_out, w_out = 1, 1, 2, 3, 2, 3, 2, 4
13557
13558        atol_rtol = {}
13559        if (mode, padding_mode) == ("bicubic", "border"):
13560            if align_corners:
13561                atol_rtol.update({"atol": 0.3, "rtol": 0.4})
13562            else:
13563                atol_rtol.update({"atol": 0.02, "rtol": 0.02})
13564        input, grid = torch.randn(n, c, h_in, w_in), torch.randn(n, h_out, w_out, 2)
13565
13566        class GridSampleModule(torch.nn.Module):
13567            def __init__(self, mode, padding_mode, align_corners) -> None:
13568                super().__init__()
13569                self.mode, self.padding_mode, self.align_corners = (
13570                    mode,
13571                    padding_mode,
13572                    align_corners,
13573                )
13574
13575            def forward(self, input, grid):
13576                return torch.nn.functional.grid_sample(
13577                    input, grid, self.mode, self.padding_mode, self.align_corners
13578                )
13579
13580        self.run_test(
13581            GridSampleModule(mode, padding_mode, align_corners),
13582            (input, grid),
13583            **atol_rtol,
13584        )
13585
13586        # ONNX Opset 16 GridSample with 5D volumetric input is not supported.
13587        volumetric_input_tensor = torch.randn(n, c, d_in, h_in, w_in)
13588        volumetric_grid_tensor = torch.randn(n, d_out, h_out, w_out, 3)
13589        for mode, padding_mode, align_corners in itertools.product(
13590            (
13591                "bilinear",
13592                "nearest",
13593            ),  # PyTorch grid_sample "bicubic" mode does not support 5D volumetric input.
13594            (
13595                "zeros",
13596                "border",
13597                "reflection",
13598            ),
13599            (
13600                True,
13601                False,
13602            ),
13603        ):
13604            if self.opset_version < 20:
13605                with self.assertRaises(
13606                    torch.onnx.OnnxExporterError,
13607                ):
13608                    self.run_test(
13609                        GridSampleModule(mode, padding_mode, align_corners),
13610                        (volumetric_input_tensor, volumetric_grid_tensor),
13611                        **atol_rtol,
13612                    )
13613            else:
13614                self.run_test(
13615                    GridSampleModule(mode, padding_mode, align_corners),
13616                    (volumetric_input_tensor, volumetric_grid_tensor),
13617                    **atol_rtol,
13618                )
13619
13620    class IfNoneInput(torch.nn.Module):
13621        def forward(self, x) -> Optional[Tensor]:
13622            y: Optional[Tensor] = None
13623            if x.size(0) > 1:
13624                y = x
13625            return y
13626
13627    class IfNoneOutput(torch.nn.Module):
13628        def forward(self, x) -> Optional[Tensor]:
13629            y: Optional[Tensor] = x
13630            if x.size(0) > 1:
13631                y = None
13632            return y
13633
13634    class LoopNoneInput(torch.nn.Module):
13635        def forward(self, x) -> Optional[Tensor]:
13636            y: Optional[Tensor] = None
13637            for _ in range(x.size(0)):
13638                y = x
13639            return y
13640
13641    class LoopNoneOutput(torch.nn.Module):
13642        def forward(self, x) -> Optional[Tensor]:
13643            y: Optional[Tensor] = x
13644            for _ in range(x.size(0)):
13645                y = None
13646            return y
13647
13648    @common_utils.parametrize(
13649        "module_class",
13650        (IfNoneOutput, IfNoneInput, LoopNoneOutput, LoopNoneInput),
13651        name_fn=lambda module_class: module_class.__name__,
13652    )
13653    @common_utils.parametrize("x_size", (0, 1), name_fn=lambda x_size: str(x_size))
13654    @skipTraceTest()
13655    @skipIfUnsupportedMinOpsetVersion(16)
13656    def test_optional_output(self, module_class: Type[torch.nn.Module], x_size: int):
13657        # Need scripting to preserve control flow for this test to be
13658        # meaningful.
13659        model = torch.jit.script(module_class())
13660        f = io.BytesIO()
13661        x = torch.ones(x_size)
13662        dynamic_axis_name = "condition"
13663        torch.onnx.export(
13664            model,
13665            x,
13666            f,
13667            opset_version=self.opset_version,
13668            # Ensure condition is not constant
13669            dynamic_axes={"x": {0: dynamic_axis_name}},
13670            input_names=["x"],
13671        )
13672        exported = onnx.load_from_string(f.getvalue())
13673        expected_elem_type = torch.onnx.JitScalarType.from_value(x).onnx_type()
13674        expected_output_type = onnx.helper.make_optional_type_proto(
13675            onnx.helper.make_tensor_type_proto(expected_elem_type, (dynamic_axis_name,))
13676        )
13677        self.assertEqual(expected_output_type, exported.graph.output[0].type)
13678        for node in exported.graph.node:
13679            # Both branches output types should match.
13680            if node.op_type == "If":
13681                for attr in node.attribute:
13682                    if attr.name in ("then_branch", "else_branch"):
13683                        self.assertEqual(expected_output_type, attr.g.output[0].type)
13684
13685        self.run_test(
13686            module_class(),
13687            x,
13688            # Ensure condition is not constant
13689            dynamic_axes={"x": {0: dynamic_axis_name}},
13690            input_names=["x"],
13691        )
13692
13693    @skipTraceTest()
13694    @skipIfUnsupportedMinOpsetVersion(16)
13695    def test_uninitialized_optional(self):
13696        class Module(torch.nn.Module):
13697            def forward(self, y: Optional[Tensor]) -> Optional[Tensor]:
13698                if y is not None:
13699                    if y.shape[1] < 5:
13700                        if y.size(0) == 1:
13701                            y = y + 4
13702                        else:
13703                            return y
13704                return y
13705
13706        self.run_test(
13707            Module(),
13708            torch.ones((3, 4), dtype=torch.int),
13709            dynamic_axes={"y": {0: "y0", 1: "y1"}},
13710            input_names=["y"],
13711        )
13712
13713    @skipIfUnsupportedMinOpsetVersion(9)
13714    def test_device_eq(self):
13715        class M(torch.nn.Module):
13716            def forward(self, a):
13717                # exercise both Tensor.device (prim::device)
13718                # and torch.device (prim::Constant).
13719                if a.device != torch.device("cpu"):
13720                    return a
13721                return torch.zeros_like(a)
13722
13723        mod = torch.jit.script(M())  # preserve control flow
13724
13725        self.run_test(
13726            mod,
13727            # In order for the ONNX model behavior to match the torch model, we
13728            # need to construct input that has the same device that is checked for
13729            # in forward(). In ONNX there is no such thing as a device, so the if
13730            # condition is always false.
13731            torch.randn(3, 3, device="cpu"),
13732            # Force dynamic axes so that the output shape depends on the input.
13733            # Otherwise the entire model will just return a constant and not have
13734            # any inputs.
13735            input_names=["a"],
13736            dynamic_axes={"a": {0: "a0"}},
13737        )
13738
13739    @skipIfUnsupportedMinOpsetVersion(9)
13740    def test_lerp(self):
13741        class LerpModel(torch.nn.Module):
13742            def forward(self, x):
13743                return (
13744                    x.lerp(torch.full_like(x, 10), 0.4),
13745                    x.lerp(torch.full_like(x, 20), 0.7),
13746                    x.lerp(torch.full_like(x, 30), torch.tensor(0.4)),
13747                    x.lerp(torch.full_like(x, 40), x / 10.0),
13748                    x.lerp(torch.tensor(10.0), x / 10.0),
13749                    x.lerp(torch.tensor(10.0), 0.4),
13750                    x.lerp(torch.tensor(10.0), torch.tensor(0.4)),
13751                )
13752
13753        self.run_test(LerpModel(), torch.rand(5, 4, 3))
13754
13755    @common_utils.parametrize("input_dtype", [torch.cfloat, torch.float])
13756    @skipIfUnsupportedMinOpsetVersion(9)
13757    def test_print_tensor_within_torch_nn_module(self, input_dtype: torch.dtype):
13758        class PrintTensorOnMyModel(torch.nn.Module):
13759            def forward(self, x):
13760                # 'print' has side effect calling 'resolve_conj' and 'resolve_neg'.
13761                x_firsts = x[:, 0]
13762                print(f"x_firsts: {x_firsts}")
13763                # 'tolist' has side effect calling 'resolve_conj' and 'resolve_neg'.
13764                # Annotation added to pass torch script.
13765                _: List[float] = x.tolist()
13766                return x_firsts
13767
13768        m = PrintTensorOnMyModel()
13769        x = torch.randn(10, 5, dtype=input_dtype)
13770        if input_dtype == torch.cfloat:
13771            with self.assertRaises(RuntimeError):
13772                self.run_test(
13773                    m,
13774                    x,
13775                )
13776        else:
13777            self.run_test(
13778                m,
13779                x,
13780            )
13781
13782    @skipScriptTest()
13783    @skipIfUnsupportedMinOpsetVersion(16)
13784    @unittest.skipIf(
13785        not torch.hub._check_module_exists("torch_geometric"),
13786        "torch_geometric not installed.",
13787    )
13788    def test_sage_conv(self):
13789        from torch_geometric import nn as torch_geometric_nn
13790
13791        # Input
13792        coords0 = torch.randn(1, 6)
13793        coords1 = torch.randn(1, 6)
13794        coords = torch.transpose(torch.cat((coords0, coords1), dim=0), 0, 1)
13795        adj = torch_geometric_nn.knn_graph(coords, k=2, batch=None, loop=True)
13796        edge_from = adj[0:1, :]
13797        edge_to = adj[1:, :]
13798        inputs = (coords0, coords1, edge_from, edge_to)
13799
13800        class MySAGEConv(torch.nn.Module):
13801            def __init__(self) -> None:
13802                super().__init__()
13803                self.SAGEConvBlock1 = torch_geometric_nn.SAGEConv(
13804                    2, 512, normalize=True
13805                )
13806                self.bano1 = torch_geometric_nn.BatchNorm(512)
13807                self.relu = torch.nn.ReLU()
13808                self.dense1 = torch.nn.Seq(Lin(512, 1))  # noqa: F821
13809                self.sigmoid = torch.nn.Sigmoid()
13810
13811            def forward(self, coords0, coords1, edge_from, edge_to):
13812                adj = torch.cat((edge_from, edge_to), dim=0)
13813                gra = torch.transpose(torch.cat((coords0, coords1), dim=0), 0, 1)
13814                x1 = self.SAGEConvBlock1(gra, edge_index=adj)
13815                x = torch.unsqueeze(torch.sum(x1), dim=0)
13816                return x
13817
13818        input_names = ["coords0", "coords1", "edge_from", "edge_to"]
13819        output_names = ["outputs"]
13820        dynamic_axes = {
13821            "coords0": {0: "batch_size", 1: "features"},
13822            "coords1": {0: "batch_size", 1: "features"},
13823            "edge_from": {0: "batch_size", 1: "features"},
13824            "edge_to": {0: "batch_size", 1: "features"},
13825            "outputs": {0: "batch_size"},
13826        }
13827        self.run_test(
13828            MySAGEConv(),
13829            inputs,
13830            input_names=input_names,
13831            output_names=output_names,
13832            dynamic_axes=dynamic_axes,
13833        )
13834
13835    # Cannot export with older opsets because of "ConstantFill" op
13836    # ConstantFill was a temp op removed at opset 8. This is no longer supported by onnxruntime
13837    # There are still some issues prevent us from enabling script test for these scenarios:
13838    # test_gru_*:
13839    #   Operator aten::as_tensor is not supported by exporter yet.
13840    #       - https://msdata.visualstudio.com/Vienna/_workitems/edit/1055382
13841    #   Operator aten::_pack_padded_sequence is not supported by exporter yet.
13842    #       - https://msdata.visualstudio.com/Vienna/_workitems/edit/1055384
13843    # test_elman_*:
13844    # Compiling in script mode fails with errors like:
13845    #   torch.jit.frontend.UnsupportedNodeError: annotated assignments
13846    #   without assigned value aren't supported
13847    #       - https://msdata.visualstudio.com/Vienna/_workitems/edit/1160723
13848    # test_lstm_*:
13849    #   Compiling in script mode fails with errors like:
13850    #   RuntimeError: Arguments for call are not valid.
13851    #       - https://msdata.visualstudio.com/Vienna/_workitems/edit/1160723
13852    @skipScriptTest()
13853    @skipIfUnsupportedMinOpsetVersion(9)
13854    @common_utils.parametrize(
13855        "name, nonlinearity",
13856        [
13857            ("elman", "relu"),
13858            ("elman", "tanh"),
13859            ("lstm", None),
13860            ("gru", None),
13861        ],
13862    )
13863    @common_utils.parametrize(**_parametrize_rnn_args("layers"))
13864    @common_utils.parametrize(**_parametrize_rnn_args("bidirectional"))
13865    @common_utils.parametrize(**_parametrize_rnn_args("initial_state"))
13866    @common_utils.parametrize(**_parametrize_rnn_args("packed_sequence"))
13867    @common_utils.parametrize(**_parametrize_rnn_args("dropout"))
13868    def test_rnn(self, *args, **kwargs):
13869        self._dispatch_rnn_test(*args, **kwargs)
13870
13871
13872if __name__ == "__main__":
13873    common_utils.TestCase._default_dtype_check_enabled = True
13874    common_utils.run_tests()
13875