1# Owner(s): ["module: onnx"] 2 3from __future__ import annotations 4 5import functools 6import io 7import itertools 8import os 9import unittest 10from collections import OrderedDict 11from typing import Dict, List, Optional, Tuple, Type, Union 12 13import numpy as np 14 15import onnx 16import onnx_test_common 17import parameterized 18import torchvision 19from model_defs import ( 20 lstm_flattening_result, 21 rnn_model_with_packed_sequence, 22 word_language_model, 23) 24from pytorch_test_common import ( 25 BATCH_SIZE, 26 RNN_BATCH_SIZE, 27 RNN_HIDDEN_SIZE, 28 RNN_INPUT_SIZE, 29 RNN_SEQUENCE_LENGTH, 30 skipDtypeChecking, 31 skipIfQuantizationBackendQNNPack, 32 skipIfUnsupportedMaxOpsetVersion, 33 skipIfUnsupportedMinOpsetVersion, 34 skipIfUnsupportedOpsetVersion, 35 skipScriptTest, 36 skipShapeChecking, 37 skipTraceTest, 38) 39 40import torch 41from torch import Tensor 42from torch.nn.utils import rnn as rnn_utils 43from torch.onnx import errors, verification 44from torch.testing._internal import common_utils 45from torch.testing._internal.common_utils import skipIfNoLapack 46 47 48def _init_test_generalized_rcnn_transform(): 49 min_size = 100 50 max_size = 200 51 image_mean = [0.485, 0.456, 0.406] 52 image_std = [0.229, 0.224, 0.225] 53 transform = torchvision.models.detection.transform.GeneralizedRCNNTransform( 54 min_size, max_size, image_mean, image_std 55 ) 56 return transform 57 58 59def _init_test_rpn(): 60 anchor_sizes = ((32,), (64,), (128,), (256,), (512,)) 61 aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes) 62 rpn_anchor_generator = torchvision.models.detection.rpn.AnchorGenerator( 63 anchor_sizes, aspect_ratios 64 ) 65 out_channels = 256 66 rpn_head = torchvision.models.detection.rpn.RPNHead( 67 out_channels, rpn_anchor_generator.num_anchors_per_location()[0] 68 ) 69 rpn_fg_iou_thresh = 0.7 70 rpn_bg_iou_thresh = 0.3 71 rpn_batch_size_per_image = 256 72 rpn_positive_fraction = 0.5 73 rpn_pre_nms_top_n = dict(training=2000, testing=1000) 74 rpn_post_nms_top_n = dict(training=2000, testing=1000) 75 rpn_nms_thresh = 0.7 76 rpn_score_thresh = 0.0 77 78 rpn = torchvision.models.detection.rpn.RegionProposalNetwork( 79 rpn_anchor_generator, 80 rpn_head, 81 rpn_fg_iou_thresh, 82 rpn_bg_iou_thresh, 83 rpn_batch_size_per_image, 84 rpn_positive_fraction, 85 rpn_pre_nms_top_n, 86 rpn_post_nms_top_n, 87 rpn_nms_thresh, 88 score_thresh=rpn_score_thresh, 89 ) 90 return rpn 91 92 93def _construct_tensor_for_quantization_test( 94 shape: Tuple[int, ...], 95 offset: Optional[Union[int, float]] = None, 96 max_val: Optional[Union[int, float]] = None, 97) -> Tensor: 98 """Helper function to generate weights and test inputs in a deterministic way. 99 100 Due to difference in implementation details between PyTorch and ONNXRuntime, randomly generated 101 test data for quantization tests can be flaky. To help stablize the test, this helper function is 102 used to generate weights and test inputs in a deterministic way. 103 104 Args: 105 shape (Tuple[int]): Shape for tensor to construct. 106 offset (Optional[Union[int, float]]): Offset to be added to the generated tensor. 107 max_val (Optional[Union[int, float]]): If any element within tensor has a larger absolute value than 108 max_val, the tensor will be scaled by max_val / tensor.abs().max(). This step is done after 109 applying offset. 110 """ 111 tensor = torch.arange(np.prod(shape), dtype=torch.float).view(shape) 112 if offset is not None: 113 tensor = tensor + offset 114 if max_val is not None and tensor.abs().max() > max_val: 115 tensor = tensor * max_val / tensor.abs().max() 116 return tensor 117 118 119def _parameterized_class_attrs_and_values( 120 min_opset_version: int, max_opset_version: int 121): 122 attrs = ("opset_version", "is_script", "keep_initializers_as_inputs") 123 input_values = [] 124 input_values.extend(itertools.product((7, 8), (True, False), (True,))) 125 # Valid opset versions are defined in torch/onnx/_constants.py. 126 # Versions are intentionally set statically, to not be affected by changes elsewhere. 127 if min_opset_version < 9: 128 raise ValueError("min_opset_version must be >= 9") 129 input_values.extend( 130 itertools.product( 131 range(min_opset_version, max_opset_version + 1), 132 (True, False), 133 (True, False), 134 ) 135 ) 136 return {"attrs": attrs, "input_values": input_values} 137 138 139def _parametrize_rnn_args(arg_name): 140 options = { 141 "layers": {1: "unilayer", 3: "trilayer"}, 142 "bidirectional": {True: "bidirectional", False: "forward"}, 143 "initial_state": {True: "with_initial_state", False: "no_initial_state"}, 144 "packed_sequence": { 145 0: "without_sequence_lengths", 146 1: "with_variable_length_sequences", 147 2: "with_batch_first_sequence_lengths", 148 }, 149 "dropout": {0.2: "with_dropout", 0.0: "without_dropout"}, 150 } 151 152 return { 153 "arg_str": arg_name, 154 "arg_values": options[arg_name].keys(), 155 "name_fn": lambda val: options[arg_name][val], 156 } 157 158 159@parameterized.parameterized_class( 160 **_parameterized_class_attrs_and_values( 161 onnx_test_common.MIN_ONNX_OPSET_VERSION, onnx_test_common.MAX_ONNX_OPSET_VERSION 162 ), 163 class_name_func=onnx_test_common.parameterize_class_name, 164) 165@common_utils.instantiate_parametrized_tests 166class TestONNXRuntime(onnx_test_common._TestONNXRuntime): 167 def test_fuse_conv_bn1d(self): 168 class Fuse(torch.nn.Module): 169 def __init__(self) -> None: 170 super().__init__() 171 self.conv = torch.nn.Conv1d(16, 33, 3, stride=2) 172 self.bn = torch.nn.BatchNorm1d(33) 173 174 def forward(self, x): 175 out = self.conv(x) 176 return self.bn(out) 177 178 model = Fuse() 179 x = torch.randn(20, 16, 50, requires_grad=True) 180 self.run_test(model, (x,)) 181 182 def test_fuse_conv_bn2d(self): 183 class Fuse(torch.nn.Module): 184 def __init__(self) -> None: 185 super().__init__() 186 self.conv = torch.nn.Conv2d( 187 3, 2, kernel_size=1, stride=2, padding=3, bias=False 188 ) 189 self.bn = torch.nn.BatchNorm2d(2) 190 191 def forward(self, x): 192 out = self.conv(x) 193 return self.bn(out) 194 195 model = Fuse() 196 x = torch.randn(2, 3, 2, 2, requires_grad=True) 197 self.run_test(model, (x,)) 198 199 def test_fuse_conv_bn3d(self): 200 class Fuse(torch.nn.Module): 201 def __init__(self) -> None: 202 super().__init__() 203 self.conv = torch.nn.Conv3d( 204 3, 2, (3, 5, 2), stride=(2, 1, 1), padding=(3, 2, 0), bias=False 205 ) 206 self.bn = torch.nn.BatchNorm3d(2) 207 208 def forward(self, x): 209 out = self.conv(x) 210 return self.bn(out) 211 212 model = Fuse() 213 x = torch.randn(2, 3, 10, 50, 100, requires_grad=True) 214 self.run_test(model, (x,), rtol=1e-3, atol=1e-6) 215 216 def test_fuse_conv_in_block(self): 217 class Fuse(torch.nn.Module): 218 def __init__(self) -> None: 219 super().__init__() 220 self.conv = torch.nn.Conv1d( 221 in_channels=5, 222 out_channels=5, 223 kernel_size=3, 224 stride=1, 225 padding=2, 226 dilation=1, 227 ) 228 self.bn = torch.nn.BatchNorm1d(5) 229 230 def forward(self, x): 231 results_available = True 232 233 if x.sum() > -1: 234 results_available = False 235 236 if results_available: 237 x = self.conv(x) 238 x = self.bn(x) 239 240 return x 241 242 model = Fuse() 243 x = torch.randn(2, 5, 9, requires_grad=True) 244 self.run_test( 245 torch.jit.script(model), 246 (x,), 247 input_names=["x"], 248 dynamic_axes={"x": [0, 2]}, 249 rtol=1e-3, 250 atol=1e-6, 251 ) 252 253 def test_conv_tbc(self): 254 from torch.nn.modules.utils import _single 255 256 class ConvTBC(torch.nn.Module): 257 def __init__(self, in_channels, out_channels, kernel_size, padding=0): 258 super().__init__() 259 self.in_channels = in_channels 260 self.out_channels = out_channels 261 self.kernel_size = _single(kernel_size) 262 self.padding = _single(padding) 263 264 self.weight = torch.nn.Parameter( 265 Tensor(self.kernel_size[0], in_channels, out_channels) 266 ) 267 self.bias = torch.nn.Parameter(Tensor(out_channels)) 268 self.reset_parameters() 269 270 def reset_parameters(self): 271 torch.nn.init.xavier_normal_(self.weight) 272 torch.nn.init.zeros_(self.bias) 273 274 def conv_tbc(self, input): 275 return torch.conv_tbc( 276 input.contiguous(), self.weight, self.bias, self.padding[0] 277 ) 278 279 def forward(self, input): 280 return self.conv_tbc(input) 281 282 in_channels = 3 283 out_channels = 5 284 kernel_size = 5 285 model = ConvTBC(in_channels, out_channels, kernel_size, padding=0) 286 x = torch.randn(10, 7, in_channels, requires_grad=True) 287 self.run_test(model, (x,), atol=1e-5) 288 289 def test_reshape_constant_fold(self): 290 class Reshape(torch.nn.Module): 291 def __init__( 292 self, 293 ): 294 super().__init__() 295 self.weight = torch.nn.Buffer(torch.ones(5)) 296 297 def forward(self, x): 298 scale_1 = self.weight.reshape(1, -1, 1, 1) 299 return x * scale_1 300 301 x = torch.randn(4, 5) 302 self.run_test(Reshape(), (x,), rtol=1e-3, atol=1e-5) 303 304 def run_word_language_model(self, model_name): 305 ntokens = 50 306 emsize = 5 307 nhid = 5 308 nlayers = 5 309 dropout = 0.2 310 tied = False 311 batchsize = 5 312 if model_name == "GRU": 313 model = word_language_model.RNNModelWithTensorHidden( 314 model_name, ntokens, emsize, nhid, nlayers, dropout, tied, batchsize 315 ) 316 elif model_name == "LSTM": 317 model = word_language_model.RNNModelWithTupleHidden( 318 model_name, ntokens, emsize, nhid, nlayers, dropout, tied, batchsize 319 ) 320 else: 321 model = word_language_model.RNNModel( 322 model_name, ntokens, emsize, nhid, nlayers, dropout, tied, batchsize 323 ) 324 x = torch.arange(0, ntokens).long().view(-1, batchsize) 325 # Only support CPU version, since tracer is not working in GPU RNN. 326 self.run_test(model, (x, model.hidden)) 327 328 def get_image(self, rel_path: str, size: Tuple[int, int]) -> Tensor: 329 from PIL import Image 330 from torchvision import transforms 331 332 data_dir = os.path.join(os.path.dirname(__file__), "assets") 333 path = os.path.join(data_dir, *rel_path.split("/")) 334 image = Image.open(path).convert("RGB").resize(size, Image.BILINEAR) 335 336 return transforms.ToTensor()(image) 337 338 def get_test_images(self) -> Tuple[List[Tensor], List[Tensor]]: 339 return ( 340 [self.get_image("grace_hopper_517x606.jpg", (100, 320))], 341 [self.get_image("rgb_pytorch.png", (250, 380))], 342 ) 343 344 def test_paste_mask_in_image(self): 345 masks = torch.rand(10, 1, 26, 26) 346 boxes = torch.rand(10, 4) 347 boxes[:, 2:] += torch.rand(10, 2) 348 boxes *= 50 349 o_im_s = (100, 100) 350 from torchvision.models.detection.roi_heads import paste_masks_in_image 351 352 out = paste_masks_in_image(masks, boxes, o_im_s) 353 jit_trace = torch.jit.trace( 354 paste_masks_in_image, 355 (masks, boxes, [torch.tensor(o_im_s[0]), torch.tensor(o_im_s[1])]), 356 ) 357 out_trace = jit_trace( 358 masks, boxes, [torch.tensor(o_im_s[0]), torch.tensor(o_im_s[1])] 359 ) 360 361 assert torch.all(out.eq(out_trace)) 362 363 masks2 = torch.rand(20, 1, 26, 26) 364 boxes2 = torch.rand(20, 4) 365 boxes2[:, 2:] += torch.rand(20, 2) 366 boxes2 *= 100 367 o_im_s2 = (200, 200) 368 from torchvision.models.detection.roi_heads import paste_masks_in_image 369 370 out2 = paste_masks_in_image(masks2, boxes2, o_im_s2) 371 out_trace2 = jit_trace( 372 masks2, boxes2, [torch.tensor(o_im_s2[0]), torch.tensor(o_im_s2[1])] 373 ) 374 375 assert torch.all(out2.eq(out_trace2)) 376 377 def test_heatmaps_to_keypoints(self): 378 maps = torch.rand(10, 1, 26, 26) 379 rois = torch.rand(10, 4) 380 from torchvision.models.detection.roi_heads import heatmaps_to_keypoints 381 382 out = heatmaps_to_keypoints(maps, rois) 383 jit_trace = torch.jit.trace(heatmaps_to_keypoints, (maps, rois)) 384 out_trace = jit_trace(maps, rois) 385 386 assert torch.all(out[0].eq(out_trace[0])) 387 assert torch.all(out[1].eq(out_trace[1])) 388 389 maps2 = torch.rand(20, 2, 21, 21) 390 rois2 = torch.rand(20, 4) 391 from torchvision.models.detection.roi_heads import heatmaps_to_keypoints 392 393 out2 = heatmaps_to_keypoints(maps2, rois2) 394 out_trace2 = jit_trace(maps2, rois2) 395 396 assert torch.all(out2[0].eq(out_trace2[0])) 397 assert torch.all(out2[1].eq(out_trace2[1])) 398 399 def test_word_language_model_RNN_TANH(self): 400 self.run_word_language_model("RNN_TANH") 401 402 def test_word_language_model_RNN_RELU(self): 403 self.run_word_language_model("RNN_RELU") 404 405 @skipScriptTest() # scripting prim::unchecked_cast prim::setattr 406 def test_word_language_model_LSTM(self): 407 self.run_word_language_model("LSTM") 408 409 def test_word_language_model_GRU(self): 410 self.run_word_language_model("GRU") 411 412 def test_index_1d(self): 413 class MyModel(torch.nn.Module): 414 def forward(self, input): 415 return input[0] 416 417 m1 = torch.randn(3, 4, 5, 6, 7) 418 self.run_test(MyModel(), m1) 419 420 def test_index_2d_1dimslice(self): 421 class MyModel(torch.nn.Module): 422 def forward(self, input): 423 return input[0:1, :] 424 425 m1 = torch.randn(3, 4, 5, 6, 7) 426 self.run_test(MyModel(), m1) 427 428 def test_index_2d_sliceint(self): 429 class MyModel(torch.nn.Module): 430 def forward(self, input): 431 return input[1, :] 432 433 m1 = torch.randn(3, 4, 5, 6, 7) 434 self.run_test(MyModel(), m1) 435 436 def test_index_2d_neg_slice(self): 437 class MyModel(torch.nn.Module): 438 def forward(self, input): 439 return input[0:-1, :] 440 441 m1 = torch.randn(3, 4, 5, 6, 7) 442 self.run_test(MyModel(), m1) 443 444 @skipIfUnsupportedMinOpsetVersion(9) 445 def test_index_mask(self): 446 class MyModel(torch.nn.Module): 447 def forward(self, input): 448 return input[torch.tensor([0, 1, 0], dtype=torch.uint8)] 449 450 m1 = torch.randn(3, 4, 5, 6, 7) 451 self.run_test(MyModel(), m1) 452 453 class MyModel(torch.nn.Module): 454 def forward(self, input): 455 return input[torch.tensor([0, 1, 0], dtype=torch.bool)] 456 457 m1 = torch.randn(3, 4, 5, 6, 7) 458 self.run_test(MyModel(), m1) 459 460 @skipIfUnsupportedMinOpsetVersion(9) 461 def test_data(self): 462 class Data(torch.jit.ScriptModule): 463 @torch.jit.script_method 464 def forward(self, x): 465 return x.new_zeros(x.data.size()) 466 467 x = torch.randn(3, 4) 468 self.run_test(Data(), x, input_names=["x"], dynamic_axes={"x": [0, 1]}) 469 self.run_test(Data(), x, remained_onnx_input_idx=[]) 470 471 @skipIfUnsupportedMinOpsetVersion(11) 472 def test_index_mask_nd(self): 473 class MyModel(torch.nn.Module): 474 def forward(self, input): 475 return input[input > 0] 476 477 m1 = torch.randn(3, 4, 5, 6, 7) 478 self.run_test(MyModel(), m1) 479 480 @skipScriptTest() 481 def test_dict(self): 482 class MyModel(torch.nn.Module): 483 def forward(self, x_in): 484 x_out = {} 485 x_out["test_key_out"] = torch.add( 486 x_in[list(x_in.keys())[0]], # noqa: RUF015 487 list(x_in.keys())[0], # noqa: RUF015 488 ) 489 return x_out 490 491 x = {torch.tensor(1.0): torch.randn(1, 2, 3)} 492 self.run_test(MyModel(), (x,)) 493 494 @skipScriptTest() 495 def test_dict_str(self): 496 class MyModel(torch.nn.Module): 497 def forward(self, x_in): 498 x_out = {} 499 x_out["test_key_out"] = torch.add(x_in["test_key_in"], 2.0) 500 return x_out 501 502 x = {"test_key_in": torch.randn(1, 2, 3)} 503 self.run_test(MyModel(), (x,)) 504 505 @skipScriptTest() # User-defined class not supported 506 def test_dict_output(self): 507 class DictModelOutput(OrderedDict): 508 tensor_out: Tensor 509 tuple_out: Optional[Tuple[Tensor]] = None 510 list_out: Optional[List[Tensor]] = None 511 512 class MyModel(torch.nn.Module): 513 def forward(self, a, b, c, d): 514 return DictModelOutput( 515 tensor_out=a, 516 tuple_out=(b, c), 517 list_out=[d], 518 ) 519 520 a = torch.randn(2, 3) 521 b = torch.randn(2, 3) 522 c = torch.randn(2, 3) 523 d = torch.randn(2, 3) 524 self.run_test(MyModel(), (a, b, c, d)) 525 526 def test_tuple_output(self): 527 class MyModel(torch.nn.Module): 528 def forward(self, a, b, c, d): 529 return a, (b, c), d 530 531 a = torch.randn(2, 3) 532 b = torch.randn(2, 3) 533 c = torch.randn(2, 3) 534 d = torch.randn(2, 3) 535 self.run_test(MyModel(), (a, b, c, d)) 536 537 def test_nested_tuple_output(self): 538 class MyModel(torch.nn.Module): 539 def forward(self, a, b, c, d): 540 return a, ((b,), (c, d)) 541 542 a = torch.randn(2, 3) 543 b = torch.randn(2, 3) 544 c = torch.randn(2, 3) 545 d = torch.randn(2, 3) 546 self.run_test(MyModel(), (a, b, c, d)) 547 548 def test_tuple_input(self): 549 class TupleModel(torch.nn.Module): 550 def forward(self, a: Tuple[Tensor, Tensor]): 551 return a 552 553 x = (torch.randn(3, 4), torch.randn(4, 3)) 554 self.run_test(TupleModel(), input_args=(x,)) 555 556 def test_tuple_primitive_input(self): 557 class TupleModel(torch.nn.Module): 558 def forward(self, a: Tuple[int, Tensor], b): 559 return a[0], a[1] + b 560 561 x = (3, torch.randn(4, 3)) 562 y = torch.randn(4, 3) 563 self.run_test(TupleModel(), input_args=(x, y)) 564 565 def test_nested_tuple_input(self): 566 class NestedTupleModel(torch.nn.Module): 567 def forward(self, a, b: Tuple[Tensor, Tuple[Tensor, Tensor]]): 568 return a + b[0] + b[1][0] + b[1][1] 569 570 x = torch.randn(4, 5) 571 y = (torch.randn(4, 5), (torch.randn(1, 5), torch.randn(4, 1))) 572 self.run_test(NestedTupleModel(), input_args=(x, y)) 573 574 @skipScriptTest() # Needs https://github.com/pytorch/rfcs/pull/21 575 @skipIfUnsupportedMinOpsetVersion(15) 576 def test_mixed_optional_default_none(self): 577 class Model(torch.nn.Module): 578 def forward( 579 self, 580 x, 581 y: Optional[Tensor] = None, 582 z: Optional[Tensor] = None, 583 ): 584 if y is not None: 585 return x + y 586 if z is not None: 587 return x + z 588 return x 589 590 x = torch.randn(2, 3) 591 y = torch.randn(2, 3) 592 z = torch.randn(2, 3) 593 model = Model() 594 # Without kwargs dict. 595 self.run_test(model, (x, y, None)) 596 self.run_test(model, (x, None, z)) 597 # With kwargs dict. 598 self.run_test(model, (x,), {"y": y, "z": None}) 599 self.run_test(model, (x,), {"y": None, "z": z}) 600 self.run_test(model, (x,), {"z": z}) 601 self.run_test(model, (x,), {"y": y}) 602 603 @skipScriptTest() # tracing eliminates None inputs so it works differently. See _script version below. 604 @skipIfUnsupportedMinOpsetVersion(15) 605 def test_mixed_optional_default_tensor(self): 606 class Model(torch.nn.Module): 607 def forward( 608 self, 609 x, 610 y: Optional[Tensor] = torch.ones(2, 3), 611 z: Optional[Tensor] = torch.zeros(2, 3), 612 ): 613 if y is not None: 614 return x + y 615 if z is not None: 616 return x + z 617 return x 618 619 x = torch.randn(2, 3) 620 y = torch.randn(2, 3) 621 z = torch.randn(2, 3) 622 model = Model() 623 624 self.run_test(model, (x, y, None)) 625 self.run_test(model, (x, None, z)) 626 627 @skipTraceTest() # tracing is verified with different set of inputs. See above. 628 @skipIfUnsupportedMinOpsetVersion(15) 629 def test_mixed_optional_default_tensor_script(self): 630 class Model(torch.nn.Module): 631 def forward( 632 self, 633 x, 634 y: Optional[Tensor] = torch.ones(2, 3), 635 z: Optional[Tensor] = torch.zeros(2, 3), 636 ): 637 if y is not None: 638 return x + y 639 if z is not None: 640 return x + z 641 return x 642 643 x = torch.randn(2, 3) 644 y = torch.randn(2, 3) 645 z = torch.randn(2, 3) 646 model = torch.jit.script(Model()) 647 648 self.run_test(model, (x, y, z), input_names=("x", "y", "z")) 649 self.run_test(model, (x,), {"y": y, "z": z}, input_names=("x", "y", "z")) 650 self.run_test(model, (x,), {"y": y}, input_names=("x", "y")) 651 652 for example_inputs, example_kwargs in ( 653 ((x, y, None), {}), 654 ((x, None, z), {}), 655 ((x,), {"y": y, "z": None}), 656 ((x,), {"y": None, "z": z}), 657 ): 658 with self.assertRaisesRegex( 659 ValueError, "args contained 1 None's after flattening." 660 ): 661 self.run_test( 662 model, example_inputs, example_kwargs, input_names=("x", "y", "z") 663 ) 664 665 @skipScriptTest() # Needs https://github.com/pytorch/rfcs/pull/21 666 @skipIfUnsupportedMinOpsetVersion(15) 667 def test_all_optional_default_none(self): 668 class Model(torch.nn.Module): 669 def forward(self, x: Optional[Tensor] = None, y: Optional[Tensor] = None): 670 if x is not None: 671 return x 672 if y is not None: 673 return y 674 else: 675 return torch.tensor(-1.0) 676 677 x = torch.randn(2, 3) 678 model = Model() 679 self.run_test(model, (x, None)) 680 self.run_test( 681 model, 682 (), 683 {"x": x, "y": None}, 684 # y disappears in tracing. 685 input_names=("x",), 686 ) 687 688 @skipScriptTest() # tracing eliminates None inputs so it works differently. See _script version below. 689 @skipIfUnsupportedMinOpsetVersion(15) 690 def test_all_optional_default_tensor(self): 691 class Model(torch.nn.Module): 692 def forward( 693 self, 694 x: Optional[Tensor] = torch.ones(2, 3), 695 y: Optional[Tensor] = torch.zeros(2, 3), 696 ): 697 if x is not None: 698 return x 699 elif y is not None: 700 return y 701 else: 702 return torch.tensor(-1.0) 703 704 x = torch.randn(2, 3) 705 y = torch.randn(2, 3) 706 model = Model() 707 self.run_test(model, (x, None)) 708 self.run_test(model, (None, y)) 709 # tracing means y is never used so it's removed from the exported model inputs, 710 # and we fail when trying to run ORT. 711 with self.assertRaisesRegex(ValueError, "got too many positional inputs"): 712 self.run_test(model, (x, y)) 713 714 @skipTraceTest() # tracing is verified with different set of inputs. See above. 715 @skipIfUnsupportedMinOpsetVersion(15) 716 def test_all_optional_default_tensor_script(self): 717 class Model(torch.nn.Module): 718 def forward( 719 self, 720 x: Optional[Tensor] = torch.ones(2, 3), 721 y: Optional[Tensor] = torch.zeros(2, 3), 722 ): 723 if x is not None: 724 return x 725 elif y is not None: 726 return y 727 else: 728 return torch.tensor(-1.0) 729 730 x = torch.randn(2, 3) 731 y = torch.randn(2, 3) 732 model = torch.jit.script(Model()) 733 734 # Optional supports None inputs 735 self.run_test(model, (x,)) 736 # NOTE: default value is not supported on ONNX, so torch and ONNX has 737 # different behavior 738 with self.assertRaisesRegex(AssertionError, "Tensor-likes are not close!"): 739 self.run_test(model, (), {"y": y}, input_names=["y"]) 740 741 self.run_test(model, (x, y)) 742 self.run_test(model, (), {"x": x, "y": y}, input_names=("x", "y")) 743 744 @skipIfUnsupportedMinOpsetVersion(9) 745 def test_logit(self): 746 class Logit(torch.nn.Module): 747 def __init__(self, eps): 748 super().__init__() 749 self.eps = eps 750 751 def forward(self, x): 752 return x.logit(self.eps) 753 754 model = Logit(eps=1e-6) 755 self.run_test(model, torch.randn(1, 3, 640, 640)) 756 757 class Atleast1d(torch.nn.Module): 758 def forward(self, t, w, x, y, z): 759 return torch.atleast_1d((t, w, x, y, z)) 760 761 class Atleast2d(torch.nn.Module): 762 def forward(self, t, w, x, y, z): 763 return torch.atleast_2d((t, w, x, y, z)) 764 765 class Atleast3d(torch.nn.Module): 766 def forward(self, t, w, x, y, z): 767 return torch.atleast_3d((t, w, x, y, z)) 768 769 class Atleast1dTensor(torch.nn.Module): 770 def forward(self, x): 771 return torch.atleast_1d(x) 772 773 class Atleast2dTensor(torch.nn.Module): 774 def forward(self, x): 775 return torch.atleast_2d(x) 776 777 class Atleast3dTensor(torch.nn.Module): 778 def forward(self, x): 779 return torch.atleast_3d(x) 780 781 @skipScriptTest() # tracing uses prim::ListUnpack to avoid onnx::SequenceConstruct 782 @skipIfUnsupportedMinOpsetVersion(11) 783 @common_utils.parametrize("module_class", (Atleast1d, Atleast2d, Atleast3d)) 784 def test_atleast_nd_list_input(self, module_class: torch.nn.Module): 785 inputs = ( 786 torch.tensor(1.0), 787 torch.randn(2), 788 torch.randn(2, 3), 789 torch.randn(2, 3, 4), 790 torch.randn(2, 3, 4, 5), 791 ) 792 self.run_test(module_class(), inputs) 793 794 @skipScriptTest() # tracing uses prim::ListUnpack to avoid onnx::SequenceConstruct 795 @skipIfUnsupportedMinOpsetVersion(11) 796 @common_utils.parametrize( 797 "module_class", (Atleast1dTensor, Atleast2dTensor, Atleast3dTensor) 798 ) 799 @common_utils.parametrize( 800 "inputs", 801 [ 802 torch.tensor(1.0), 803 torch.randn(2), 804 torch.randn(2, 3), 805 torch.randn(2, 3, 4), 806 torch.randn(2, 3, 4, 5), 807 ], 808 ) 809 def test_atleast_nd_single_tensor_input( 810 self, module_class: torch.nn.Module, inputs: torch.Tensor 811 ): 812 self.run_test(module_class(), inputs) 813 814 @skipScriptTest() # Needs https://github.com/pytorch/rfcs/pull/21 815 @skipIfUnsupportedMinOpsetVersion(15) 816 def test_mixed_optional(self): 817 class Model(torch.nn.Module): 818 def forward(self, x, y: Optional[Tensor]): 819 if y is not None: 820 return x + y 821 return x 822 823 x = torch.randn(2, 3) 824 model = Model() 825 self.run_test(model, (x, None)) 826 self.run_test(model, (x, x)) 827 828 @skipScriptTest() # Needs https://github.com/pytorch/rfcs/pull/21 829 @skipIfUnsupportedMinOpsetVersion(15) 830 def test_tuple_of_optional(self): 831 class Model(torch.nn.Module): 832 def forward(self, x, y: Tuple[Optional[Tensor], Optional[Tensor]]): 833 if y[0] is not None: 834 return x + y[0] 835 if y[1] is not None: 836 return x + y[1] 837 return x 838 839 x = torch.randn(2, 3) 840 y1 = torch.randn(2, 3) 841 self.run_test(Model(), (x, (None, y1))) 842 843 @skipScriptTest() # tracing eliminates None inputs so it works differently. See _script version below. 844 @skipIfUnsupportedMinOpsetVersion(15) 845 def test_tuple_of_optional_default_tensor(self): 846 class Model(torch.nn.Module): 847 def forward( 848 self, 849 x, 850 y: Tuple[Optional[Tensor], Optional[Tensor]] = ( 851 torch.zeros(2, 3), 852 torch.zeros(2, 3), 853 ), 854 ): 855 y0, y1 = y 856 if y0 is not None: 857 return x + y0 858 if y1 is not None: 859 return x + y1 860 return x 861 862 x = torch.randn(2, 3) 863 y1 = torch.randn(2, 3) 864 self.run_test(Model(), (x, (None, y1))) 865 866 @skipTraceTest() # tracing is verified with different set of inputs. See above. 867 @skipIfUnsupportedMinOpsetVersion(15) 868 def test_tuple_of_optional_default_tensor_script(self): 869 class Model(torch.nn.Module): 870 def forward( 871 self, 872 x, 873 y: Tuple[Optional[Tensor], Optional[Tensor]] = ( 874 torch.zeros(2, 3), 875 torch.zeros(2, 3), 876 ), 877 ): 878 y0, y1 = y 879 if y0 is not None: 880 return x + y0 881 if y1 is not None: 882 return x + y1 883 return x 884 885 x = torch.randn(2, 3) 886 y0 = torch.randn(2, 3) 887 y1 = torch.randn(2, 3) 888 model = torch.jit.script(Model()) 889 with self.assertRaisesRegex( 890 ValueError, "args contained 1 None's after flattening." 891 ): 892 self.run_test(model, (x, (None, y1))) 893 self.run_test(model, (x, (y0, y1))) 894 # export succeeds, but running ORT through run_test would fail because the exported model 895 # has the inputs flattened into 3 inputs. 896 torch.onnx.export( 897 model, (x, {"y": (y0, y1)}), io.BytesIO(), opset_version=self.opset_version 898 ) 899 900 def test_primitive_input_integer(self): 901 class Model(torch.nn.Module): 902 def forward(self, x: int, y): 903 return x + y 904 905 x = 3 906 y = torch.randint(10, (2, 3, 4)) 907 self.run_test(Model(), (x, y)) 908 909 @skipDtypeChecking 910 def test_primitive_input_floating(self): 911 class Model(torch.nn.Module): 912 def forward(self, x: float, y): 913 return x + y 914 915 x = 3.0 916 y = torch.randn(2, 3, 4) 917 self.run_test(Model(), (x, y)) 918 919 def test_primitive_input_bool(self): 920 class Model(torch.nn.Module): 921 def forward(self, flag: bool, x, y): 922 if flag: 923 return x 924 else: 925 return y 926 927 flag = True 928 x = torch.randn(2, 3, 4) 929 y = torch.randn(2, 3, 4) 930 self.run_test(torch.jit.script(Model()), (flag, x, y)) 931 932 @skipIfUnsupportedMinOpsetVersion(9) 933 def test_cste_script(self): 934 class MyModel(torch.jit.ScriptModule): 935 @torch.jit.script_method 936 def forward(self, x): 937 return torch.zeros(x.size(0)), torch.ones( 938 (x.size(1), x.size(0)), dtype=torch.int64 939 ) 940 941 x = torch.randn(3, 4) 942 self.run_test(MyModel(), x, input_names=["x"], dynamic_axes={"x": [0, 1]}) 943 self.run_test(MyModel(), x, remained_onnx_input_idx=[]) 944 945 def test_scalar_tensor(self): 946 class test(torch.nn.Module): 947 def forward(self, input): 948 return torch.scalar_tensor(input.size(0)), torch.scalar_tensor( 949 input.size(1), dtype=torch.int64 950 ) 951 952 x = torch.randn(2, 3, 4) 953 y = torch.randn(7, 8, 9) 954 model = test() 955 self.run_test( 956 model, 957 x, 958 additional_test_inputs=[y], 959 input_names=["input_1"], 960 dynamic_axes={"input_1": [0, 1, 2]}, 961 ) 962 963 def test_tensor(self): 964 class ScalarInputModel(torch.jit.ScriptModule): 965 @torch.jit.script_method 966 def forward(self, input): 967 return torch.tensor(input.shape[1]) 968 969 x = torch.randn(3, 4) 970 self.run_test( 971 ScalarInputModel(), x, input_names=["x"], dynamic_axes={"x": [0, 1]} 972 ) 973 self.run_test(ScalarInputModel(), x, remained_onnx_input_idx=[]) 974 975 class TensorInputModel(torch.jit.ScriptModule): 976 @torch.jit.script_method 977 def forward(self, input): 978 return torch.tensor([input.shape[0], input.shape[1]]) 979 980 x = torch.randn(3, 4) 981 self.run_test( 982 TensorInputModel(), x, input_names=["x"], dynamic_axes={"x": [0, 1]} 983 ) 984 self.run_test(TensorInputModel(), x, remained_onnx_input_idx=[]) 985 986 class FloatInputModel(torch.jit.ScriptModule): 987 @torch.jit.script_method 988 def forward(self, input): 989 return torch.tensor([float(input)]) 990 991 x = torch.randn(1) 992 self.run_test(FloatInputModel(), x) 993 994 class InputWithDtypeModel(torch.jit.ScriptModule): 995 @torch.jit.script_method 996 def forward(self, input): 997 return torch.tensor(input.shape[1], dtype=torch.long) 998 999 x = torch.randn(3, 4) 1000 self.run_test( 1001 InputWithDtypeModel(), x, input_names=["x"], dynamic_axes={"x": [0, 1]} 1002 ) 1003 self.run_test(InputWithDtypeModel(), x, remained_onnx_input_idx=[]) 1004 1005 class MixedInputModel(torch.jit.ScriptModule): 1006 @torch.jit.script_method 1007 def forward(self, input): 1008 return torch.tensor([input.shape[0], int(input)]) 1009 1010 x = torch.randn(1) 1011 self.run_test(MixedInputModel(), x) 1012 1013 def test_hardtanh(self): 1014 model = torch.nn.Hardtanh(-1.5, 2.5) 1015 x = torch.arange(-5, 5).to(dtype=torch.float32) 1016 self.run_test(model, x) 1017 1018 def test_hardtanh_script_with_default_values(self): 1019 class MyModel(torch.jit.ScriptModule): 1020 @torch.jit.script_method 1021 def forward(self, x): 1022 return torch.nn.functional.hardtanh(x) 1023 1024 x = torch.arange(-5, 5).to(dtype=torch.float32) 1025 self.run_test(MyModel(), x) 1026 1027 def test_hardswish(self): 1028 model = torch.nn.Hardswish() 1029 1030 x = torch.rand(3, 3).to(dtype=torch.float32) 1031 self.run_test(model, x) 1032 1033 # Testing edge cases 1034 x = torch.tensor(3).to(dtype=torch.float32) 1035 self.run_test(model, x) 1036 x = torch.tensor(-3).to(dtype=torch.float32) 1037 self.run_test(model, x) 1038 1039 def test_hardswish_script(self): 1040 class MyModel(torch.jit.ScriptModule): 1041 @torch.jit.script_method 1042 def forward(self, x): 1043 return torch.nn.functional.hardswish(x) 1044 1045 x = torch.rand(3, 3).to(dtype=torch.float32) 1046 self.run_test(MyModel(), x) 1047 1048 def test_hardsigmoid(self): 1049 model = torch.nn.Hardsigmoid() 1050 1051 x = torch.rand(3, 3).to(dtype=torch.float32) 1052 self.run_test(model, x) 1053 1054 # corner cases 1055 x = torch.tensor(3).to(dtype=torch.float32) 1056 self.run_test(model, x) 1057 x = torch.tensor(-3).to(dtype=torch.float32) 1058 self.run_test(model, x) 1059 1060 def test_tanhshrink(self): 1061 model = torch.nn.Tanhshrink() 1062 1063 x = torch.rand(3, 3).to(dtype=torch.float32) 1064 self.run_test(model, x) 1065 1066 @skipIfUnsupportedMinOpsetVersion(9) 1067 def test_hardshrink(self): 1068 model = torch.nn.Hardshrink() 1069 1070 x = torch.rand(3, 3).to(dtype=torch.float32) 1071 self.run_test(model, x) 1072 1073 # Testing edge cases 1074 x = torch.tensor(0.5).to(dtype=torch.float32) 1075 self.run_test(model, x) 1076 x = torch.tensor(-0.5).to(dtype=torch.float32) 1077 self.run_test(model, x) 1078 1079 @skipIfUnsupportedMinOpsetVersion(9) 1080 def test_hardshrink_dtype(self): 1081 x = torch.rand(3, 3).to(dtype=torch.float64) 1082 self.run_test(torch.nn.Hardshrink(), x) 1083 1084 @skipIfUnsupportedMinOpsetVersion(9) 1085 def test_softshrink(self): 1086 model = torch.nn.Softshrink() 1087 1088 x = torch.rand(3, 3).to(dtype=torch.float32) 1089 self.run_test(model, x) 1090 1091 # Testing edge cases 1092 x = torch.tensor(0.5).to(dtype=torch.float32) 1093 self.run_test(model, x) 1094 x = torch.tensor(-0.5).to(dtype=torch.float32) 1095 self.run_test(model, x) 1096 1097 @skipIfUnsupportedMinOpsetVersion(9) 1098 def test_softshrink_dtype(self): 1099 x = torch.rand(3, 3).to(dtype=torch.float64) 1100 self.run_test(torch.nn.Softshrink(), x) 1101 1102 def test_clamp(self): 1103 class ClampModel(torch.nn.Module): 1104 def forward(self, x): 1105 return x.clamp(-0.5, 0.5) 1106 1107 x = torch.randn(3, 4) 1108 self.run_test(ClampModel(), x) 1109 1110 class ClampMinModel(torch.nn.Module): 1111 def forward(self, x): 1112 return x.clamp(min=-0.5) 1113 1114 x = torch.randn(3, 4) 1115 self.run_test(ClampMinModel(), x) 1116 1117 class ClampMaxModel(torch.nn.Module): 1118 def forward(self, x): 1119 return x.clamp(max=0.5) 1120 1121 x = torch.randn(3, 4) 1122 self.run_test(ClampMaxModel(), x) 1123 1124 @skipIfUnsupportedMinOpsetVersion(8) 1125 def test_clamp_dyn(self): 1126 class ClampMaxModel(torch.jit.ScriptModule): 1127 @torch.jit.script_method 1128 def forward(self, x): 1129 return x.clamp(None, x.size(0)) 1130 1131 x = torch.arange(16).view(4, 4).float() 1132 self.run_test(ClampMaxModel(), x) 1133 1134 class ClampMinModel(torch.jit.ScriptModule): 1135 @torch.jit.script_method 1136 def forward(self, x): 1137 return x.clamp(x.size(0), None) 1138 1139 x = torch.arange(16).view(4, 4).float() 1140 self.run_test(ClampMinModel(), x) 1141 1142 class ClampMinMaxModel(torch.jit.ScriptModule): 1143 @torch.jit.script_method 1144 def forward(self, x): 1145 return x.clamp(x.size(0), x.size(1)) 1146 1147 x = torch.arange(16).view(2, 8).float() 1148 self.run_test(ClampMinMaxModel(), x) 1149 1150 class ClampTensorModel(torch.nn.Module): 1151 def forward(self, x, min, max): 1152 return x.clamp(min, max) 1153 1154 x = torch.randn(3, 4) 1155 y = torch.randn(3, 4) 1156 z = torch.randn(3, 4) 1157 self.run_test(ClampTensorModel(), (x, y, z)) 1158 1159 class ClampTensorMinModel(torch.nn.Module): 1160 def forward(self, x, min): 1161 return x.clamp(min=min) 1162 1163 self.run_test(ClampTensorMinModel(), (x, y)) 1164 1165 class ClampTensorMaxModel(torch.nn.Module): 1166 def forward(self, x, max): 1167 return x.clamp(max=max) 1168 1169 self.run_test(ClampTensorMaxModel(), (x, z)) 1170 1171 @skipIfUnsupportedMinOpsetVersion(9) 1172 def test_full_trace(self): 1173 class FullModel(torch.nn.Module): 1174 def forward(self, x): 1175 return torch.full((3, 4), x, dtype=torch.long) 1176 1177 x = torch.tensor(12) 1178 self.run_test(FullModel(), x) 1179 1180 @skipIfUnsupportedMinOpsetVersion(9) 1181 def test_full_script(self): 1182 class FullModelScripting(torch.jit.ScriptModule): 1183 @torch.jit.script_method 1184 def forward(self, x): 1185 return torch.full((3, 4), x, dtype=torch.long) 1186 1187 x = torch.tensor(12) 1188 self.run_test(FullModelScripting(), x) 1189 1190 def test_fuse_addmm(self): 1191 class AddmmModel(torch.nn.Module): 1192 def forward(self, x): 1193 return torch.mm(x, x) + x 1194 1195 x = torch.ones(3, 3) 1196 self.run_test(AddmmModel(), x) 1197 1198 def test_maxpool(self): 1199 model = torch.nn.MaxPool1d(2, stride=1) 1200 x = torch.randn(20, 16, 50) 1201 self.run_test(model, x) 1202 1203 def test_conv(self): 1204 class TraceModel(torch.nn.Module): 1205 def __init__(self) -> None: 1206 super().__init__() 1207 self.conv1 = torch.nn.Conv1d(16, 33, 3, stride=2) 1208 self.conv2 = torch.nn.Conv2d( 1209 16, 33, (3, 5), stride=(2, 1), padding=(4, 2), dilation=(3, 1) 1210 ) 1211 self.conv3 = torch.nn.Conv3d( 1212 16, 33, (3, 5, 2), stride=(2, 1, 1), padding=(4, 2, 0) 1213 ) 1214 1215 def forward(self, input1, input2, input3): 1216 return self.conv1(input1), self.conv2(input2), self.conv3(input3) 1217 1218 x1 = torch.randn(20, 16, 50) 1219 x2 = torch.randn(20, 16, 50, 50) 1220 x3 = torch.randn(20, 16, 10, 50, 50) 1221 1222 self.run_test(TraceModel(), (x1, x2, x3), atol=10e-5) 1223 1224 def test_conv_str_padding(self): 1225 class TraceModel(torch.nn.Module): 1226 def __init__(self) -> None: 1227 super().__init__() 1228 self.conv1 = torch.nn.Conv1d(16, 33, 3, padding="valid") 1229 self.conv2 = torch.nn.Conv2d( 1230 16, 33, (3, 5), stride=1, padding="valid", dilation=(3, 1) 1231 ) 1232 self.conv3 = torch.nn.Conv3d( 1233 16, 33, (3, 5, 2), stride=1, padding="same" 1234 ) 1235 1236 def forward(self, input1, input2, input3): 1237 return self.conv1(input1), self.conv2(input2), self.conv3(input3) 1238 1239 x1 = torch.randn(20, 16, 50) 1240 x2 = torch.randn(20, 16, 50, 50) 1241 x3 = torch.randn(20, 16, 10, 50, 50) 1242 1243 self.run_test(TraceModel(), (x1, x2, x3), atol=10e-5) 1244 1245 def test_conv_shape_inference(self): 1246 class Model(torch.nn.Module): 1247 def __init__(self) -> None: 1248 super().__init__() 1249 self.conv2 = torch.nn.Conv2d( 1250 16, 33, (3, 5), stride=(2, 1), padding=(4, 2), dilation=(3, 1) 1251 ) 1252 1253 def forward(self, input): 1254 return self.conv2(input) + 2 1255 1256 x = torch.randn(20, 16, 50, 100) 1257 self.run_test( 1258 Model(), x, atol=10e-5, input_names=["x"], dynamic_axes={"x": [0]} 1259 ) 1260 1261 def test_conv_transpose(self): 1262 class TraceModel(torch.nn.Module): 1263 def __init__(self) -> None: 1264 super().__init__() 1265 self.conv1 = torch.nn.ConvTranspose1d(16, 33, 3, stride=2) 1266 self.conv2 = torch.nn.ConvTranspose2d( 1267 16, 33, (3, 5), stride=(2, 1), padding=(4, 2), dilation=(3, 1) 1268 ) 1269 self.conv3 = torch.nn.ConvTranspose3d( 1270 16, 33, (3, 5, 2), stride=(2, 1, 1), padding=(4, 2, 0) 1271 ) 1272 1273 def forward(self, input1, input2, input3): 1274 return self.conv1(input1), self.conv2(input2), self.conv3(input3) 1275 1276 x1 = torch.randn(20, 16, 10) 1277 x2 = torch.randn(20, 16, 10, 10) 1278 x3 = torch.randn(20, 16, 10, 10, 10) 1279 1280 self.run_test(TraceModel(), (x1, x2, x3), atol=10e-5) 1281 1282 def test_numpy_T(self): 1283 class NumpyTranspose(torch.nn.Module): 1284 def forward(self, x): 1285 return x.T 1286 1287 self.run_test(NumpyTranspose(), torch.randn(4, 7)) 1288 1289 # Conversion of Transpose depends on input shape to be known. 1290 # The following test only works when onnx shape inference is enabled. 1291 def test_transpose_infer_shape(self): 1292 class TransposeModule(torch.jit.ScriptModule): 1293 def __init__(self) -> None: 1294 super().__init__() 1295 self.conv = torch.nn.Conv2d(3, 1, 3, stride=2) 1296 1297 @torch.jit.script_method 1298 def forward(self, x): 1299 x = self.conv(x) 1300 return x.transpose(0, 1) 1301 1302 x = torch.randn(32, 3, 64, 64) 1303 y = torch.randn(16, 3, 8, 64) 1304 self.run_test( 1305 TransposeModule(), 1306 x, 1307 input_names=["x"], 1308 dynamic_axes={"x": [0, 2]}, 1309 additional_test_inputs=[y], 1310 ) 1311 1312 def squeeze_model_tests(self, d, x1, x2): 1313 class Squeeze(torch.nn.Module): 1314 def __init__(self, d): 1315 super().__init__() 1316 self.d = d 1317 1318 def forward(self, x): 1319 if self.d is not None: 1320 return torch.squeeze(x, dim=self.d) 1321 else: 1322 return torch.squeeze(x) 1323 1324 x2 = [] if x2 is None else [x2] 1325 if len(x2) > 0: 1326 self.run_test( 1327 Squeeze(d), 1328 x1, 1329 input_names=["input"], 1330 dynamic_axes={"input": {0: "0", 1: "1", 2: "2"}}, 1331 additional_test_inputs=x2, 1332 ) 1333 else: 1334 self.run_test(Squeeze(d), x1) 1335 1336 def test_squeeze_without_no_op(self): 1337 x = torch.randn(2, 1, 4) 1338 self.squeeze_model_tests(1, x, None) 1339 1340 @skipIfUnsupportedMinOpsetVersion(11) 1341 def test_squeeze_dynamic(self): 1342 x_squeeze = torch.randn(2, 1, 4) 1343 x_noop = torch.randn(2, 2, 3) 1344 self.squeeze_model_tests(1, x_squeeze, x_noop) 1345 1346 def test_squeeze_neg_without_no_op(self): 1347 x = torch.randn(2, 1, 4) 1348 self.squeeze_model_tests(-2, x, None) 1349 1350 @skipIfUnsupportedMinOpsetVersion(11) 1351 def test_squeeze_neg(self): 1352 x_squeeze = torch.randn(2, 1, 4) 1353 x_noop = torch.randn(2, 2, 3) 1354 self.squeeze_model_tests(-2, x_squeeze, x_noop) 1355 1356 def test_squeeze_all_dims(self): 1357 x_squeeze = torch.randn(2, 1, 4) 1358 x_noop = torch.randn(2, 2, 3) 1359 self.squeeze_model_tests(None, x_squeeze, x_noop) 1360 1361 @skipIfUnsupportedMinOpsetVersion(11) 1362 def test_squeeze_no_op(self): 1363 x_noop = torch.randn(2, 1, 4) 1364 x_squeeze = torch.randn(2, 2, 1) 1365 self.squeeze_model_tests(2, x_noop, x_squeeze) 1366 1367 @skipIfUnsupportedMinOpsetVersion(11) 1368 def test_squeeze_runtime_dim(self): 1369 class Squeeze(torch.nn.Module): 1370 def forward(self, d1, d2): 1371 t = torch.zeros(d1[0], d2[0]) 1372 return t.squeeze(0) 1373 1374 d1 = torch.tensor([1]) 1375 d3 = torch.tensor([3]) 1376 d4 = torch.tensor([4]) 1377 self.run_test(Squeeze(), (d1, d4), additional_test_inputs=[(d3, d4)]) 1378 self.run_test(Squeeze(), (d3, d4), additional_test_inputs=[(d1, d3)]) 1379 1380 def test_squeeze(self): 1381 class Squeeze(torch.nn.Module): 1382 def forward(self, x): 1383 return torch.squeeze(x, dim=-2) 1384 1385 x = torch.randn(2, 1, 4) 1386 self.run_test(Squeeze(), x) 1387 1388 @skipIfUnsupportedMinOpsetVersion(13) 1389 def test_squeeze_dynamic_dim(self): 1390 class Squeeze(torch.nn.Module): 1391 def forward(self, x, dim: int): 1392 return torch.squeeze(x, dim) 1393 1394 x = torch.randn(2, 1, 4) 1395 dim = 1 1396 self.run_test(Squeeze(), (x, dim)) 1397 1398 def test_unsqueeze(self): 1399 class Unsqueeze(torch.nn.Module): 1400 def forward(self, x): 1401 return torch.unsqueeze(x, dim=-2) 1402 1403 x = torch.randn(2, 3, 4) 1404 self.run_test(Unsqueeze(), x) 1405 1406 @skipIfUnsupportedMinOpsetVersion(13) 1407 def test_unsqueeze_dynamic_dim(self): 1408 class Unsqueeze(torch.nn.Module): 1409 def forward(self, x, dim: int): 1410 return torch.unsqueeze(x, dim) 1411 1412 x = torch.randn(2, 1, 4) 1413 dim = -1 1414 self.run_test(Unsqueeze(), (x, dim)) 1415 1416 def test_maxpool_default_stride(self): 1417 class MaxPoolModel(torch.nn.Module): 1418 def forward(self, x): 1419 return torch.nn.functional.max_pool2d(x, 2) 1420 1421 model = MaxPoolModel() 1422 x = torch.randn(10, 20, 16, 50) 1423 self.run_test(model, x) 1424 1425 @skipIfUnsupportedMinOpsetVersion(8) 1426 def test_maxpool_adaptive(self): 1427 model = torch.nn.AdaptiveMaxPool1d((5), return_indices=False) 1428 x = torch.randn(20, 16, 50, requires_grad=True) 1429 y = torch.randn(32, 16, 50, requires_grad=True) 1430 self.run_test( 1431 model, 1432 x, 1433 input_names=["x"], 1434 dynamic_axes={"x": [0]}, 1435 additional_test_inputs=[y], 1436 ) 1437 1438 def test_maxpool_2d(self): 1439 model = torch.nn.MaxPool2d(5, padding=(1, 2)) 1440 x = torch.randn(1, 20, 16, 50, requires_grad=True) 1441 self.run_test(model, x) 1442 1443 def test_maxpool_1d_ceil(self): 1444 model = torch.nn.MaxPool1d(3, 2, ceil_mode=True) 1445 x = torch.randn(20, 16, 50) 1446 self.run_test(model, x) 1447 1448 def test_maxpool_2d_ceil(self): 1449 model = torch.nn.MaxPool2d(3, 2, ceil_mode=True) 1450 x = torch.randn(20, 16, 50, 32) 1451 self.run_test(model, x) 1452 1453 def test_maxpool_3d_ceil(self): 1454 model = torch.nn.MaxPool3d(3, 2, ceil_mode=True) 1455 x = torch.randn(20, 16, 50, 44, 31) 1456 self.run_test(model, x) 1457 1458 @skipIfUnsupportedMinOpsetVersion(10) 1459 def test_maxpool_dynamic(self): 1460 class test(torch.nn.Module): 1461 def __init__(self, in_channels, out_channels): 1462 super().__init__() 1463 norm_layer = functools.partial(torch.nn.BatchNorm2d, eps=0.0009) 1464 self.avgpool = torch.nn.MaxPool2d((2, 2), stride=2, ceil_mode=True) 1465 self.conv = torch.nn.Conv2d( 1466 in_channels, out_channels, kernel_size=1, stride=1, bias=False 1467 ) 1468 self.norm = norm_layer(out_channels) 1469 1470 def forward(self, x): 1471 return self.norm(self.conv(self.avgpool(x))) 1472 1473 model = test(8, 16) 1474 inputs = torch.randn(2, 8, 64, 64) 1475 self.run_test( 1476 model, 1477 inputs, 1478 input_names=["input_0"], 1479 dynamic_axes={"input_0": {3: "x", 2: "y"}, "output_0": {3: "x", 2: "y"}}, 1480 output_names=["output_0"], 1481 ) 1482 1483 # TODO: Enable maxpool-ceil family after ONNX 1.15.1+ is bumped 1484 @skipIfUnsupportedMaxOpsetVersion(9) 1485 def test_maxpool_1d_ceil_corner(self): 1486 model = torch.nn.MaxPool1d( 1487 kernel_size=1, dilation=1, stride=2, ceil_mode=True, return_indices=False 1488 ) 1489 x = torch.randn(1, 3, 32) 1490 self.run_test(model, x) 1491 1492 @skipIfUnsupportedMaxOpsetVersion(9) 1493 def test_maxpool_2d_ceil_corner(self): 1494 model = torch.nn.MaxPool2d( 1495 kernel_size=[1, 1], 1496 dilation=[1, 1], 1497 stride=[2, 2], 1498 ceil_mode=True, 1499 return_indices=False, 1500 ) 1501 x = torch.randn(1, 3, 32, 32) 1502 self.run_test(model, x) 1503 1504 @skipIfUnsupportedMaxOpsetVersion(9) 1505 def test_maxpool_3d_ceil_corner(self): 1506 model = torch.nn.MaxPool3d( 1507 kernel_size=[7, 8, 4], 1508 dilation=[1, 1, 1], 1509 stride=[10, 11, 3], 1510 padding=[2, 2, 2], 1511 ceil_mode=True, 1512 return_indices=False, 1513 ) 1514 x = torch.randn(1, 3, 51, 52, 45) 1515 self.run_test(model, x) 1516 1517 @skipIfUnsupportedMaxOpsetVersion(9) 1518 @skipIfUnsupportedMinOpsetVersion(8) 1519 def test_maxpool_1d_ceil_corner_with_indices(self): 1520 model = torch.nn.MaxPool1d( 1521 kernel_size=1, dilation=1, stride=2, ceil_mode=True, return_indices=True 1522 ) 1523 x = torch.randn(1, 3, 32) 1524 self.run_test(model, x) 1525 1526 @skipIfUnsupportedMaxOpsetVersion(9) 1527 @skipIfUnsupportedMinOpsetVersion(8) 1528 def test_maxpool_2d_ceil_corner_with_indices(self): 1529 model = torch.nn.MaxPool2d( 1530 kernel_size=[1, 1], 1531 dilation=[1, 1], 1532 stride=[2, 2], 1533 ceil_mode=True, 1534 return_indices=True, 1535 ) 1536 x = torch.randn(1, 3, 32, 32) 1537 self.run_test(model, x) 1538 1539 @skipIfUnsupportedMaxOpsetVersion(9) 1540 @skipIfUnsupportedMinOpsetVersion(8) 1541 def test_maxpool_3d_ceil_corner_with_indices(self): 1542 model = torch.nn.MaxPool3d( 1543 kernel_size=[7, 8, 4], 1544 dilation=[1, 1, 1], 1545 stride=[10, 11, 3], 1546 padding=[2, 2, 2], 1547 ceil_mode=True, 1548 return_indices=True, 1549 ) 1550 x = torch.randn(1, 3, 51, 52, 45) 1551 self.run_test(model, x) 1552 1553 @skipIfUnsupportedMinOpsetVersion(8) 1554 def test_maxpool_with_indices(self): 1555 model = torch.nn.MaxPool1d(2, stride=1, return_indices=True) 1556 x = torch.randn(20, 16, 50) 1557 self.run_test(model, x) 1558 1559 @skipIfUnsupportedMinOpsetVersion(10) 1560 def test_maxpool_dilation(self): 1561 model = torch.nn.MaxPool1d(2, stride=1, dilation=2) 1562 x = torch.randn(20, 16, 50) 1563 self.run_test(model, x) 1564 1565 def test_avgpool_default_stride(self): 1566 class AvgPoolModel(torch.nn.Module): 1567 def forward(self, x): 1568 return torch.nn.functional.avg_pool2d(x, 2) 1569 1570 model = AvgPoolModel() 1571 x = torch.randn(10, 20, 16, 50) 1572 self.run_test(model, x) 1573 1574 def test_avgpool(self): 1575 model = torch.nn.AvgPool1d(2, stride=1) 1576 x = torch.randn(20, 16, 50) 1577 self.run_test(model, x) 1578 1579 def test_avgpool_1d_ceil(self): 1580 model = torch.nn.AvgPool1d(3, 2, ceil_mode=True) 1581 x = torch.randn(1, 1, 7) 1582 self.run_test(model, x) 1583 1584 # TODO: ceil_mode is not included in the test, because of 1585 # https://github.com/microsoft/onnxruntime/issues/16203 1586 # The ORT and PyTorch has different calculation for ceil_mode (the last value). 1587 @common_utils.parametrize( 1588 "padding", 1589 (0, 1), 1590 ) 1591 @common_utils.parametrize( 1592 "count_include_pad", 1593 (True, False), 1594 ) 1595 def test_avgpool_2d(self, padding, count_include_pad): 1596 model = torch.nn.AvgPool2d( 1597 3, 1598 3, 1599 padding=padding, 1600 count_include_pad=count_include_pad, 1601 ) 1602 x = torch.randn(20, 16, 50, 32) 1603 self.run_test(model, x) 1604 1605 # TODO: ceil_mode is not included in the test, because of 1606 # https://github.com/microsoft/onnxruntime/issues/16203 1607 # The ORT and PyTorch has different calculation for ceil_mode (the last value). 1608 # the issue requires fix in onnx(21) (https://github.com/onnx/onnx/issues/5711) 1609 # a fix in ORT is planned. After the fixes in place, we can add ceil_mode to the test. 1610 @skipIfUnsupportedMinOpsetVersion(21) 1611 def test_avgpool_3d_ceil(self): 1612 model = torch.nn.AvgPool3d(3, 2, ceil_mode=True) 1613 x = torch.randn(20, 16, 50, 44, 31) 1614 y = torch.randn(32, 8, 50, 44, 31) 1615 self.run_test( 1616 model, 1617 x, 1618 input_names=["x"], 1619 dynamic_axes={"x": [0, 1]}, 1620 additional_test_inputs=[y], 1621 ) 1622 1623 @skipIfUnsupportedMinOpsetVersion(10) 1624 def test_avgpool_dynamic(self): 1625 class test(torch.nn.Module): 1626 def __init__(self, in_channels, out_channels): 1627 super().__init__() 1628 norm_layer = functools.partial(torch.nn.BatchNorm2d, eps=0.0009) 1629 self.avgpool = torch.nn.AvgPool2d( 1630 (2, 2), stride=2, ceil_mode=True, count_include_pad=False 1631 ) 1632 self.conv = torch.nn.Conv2d( 1633 in_channels, out_channels, kernel_size=1, stride=1, bias=False 1634 ) 1635 self.norm = norm_layer(out_channels) 1636 1637 def forward(self, x): 1638 return self.norm(self.conv(self.avgpool(x))) 1639 1640 model = test(8, 16) 1641 inputs = torch.randn(2, 8, 64, 64) 1642 self.run_test( 1643 model, 1644 inputs, 1645 input_names=["input_0"], 1646 dynamic_axes={"input_0": {3: "x", 2: "y"}, "output_0": {3: "x", 2: "y"}}, 1647 output_names=["output_0"], 1648 ) 1649 1650 @skipIfUnsupportedMinOpsetVersion(9) 1651 def test_floating_point(self): 1652 class FloatingPoint(torch.jit.ScriptModule): 1653 @torch.jit.script_method 1654 def forward(self, x): 1655 if x.is_floating_point(): 1656 return x.new_zeros(x.shape) 1657 return x.new_zeros(x.shape) 1658 1659 x = torch.randn(2, 3, 4) 1660 self.run_test( 1661 FloatingPoint(), x, input_names=["x"], dynamic_axes={"x": [0, 1, 2]} 1662 ) 1663 self.run_test(FloatingPoint(), x, remained_onnx_input_idx=[]) 1664 1665 class FloatingPoint(torch.jit.ScriptModule): 1666 @torch.jit.script_method 1667 def forward(self, x): 1668 if x.size(0) > 1: 1669 a = x + 2 1670 if a.is_floating_point(): 1671 return x + 1 1672 return x + 1 1673 return x 1674 1675 x = torch.randn(2, 3, 4) 1676 self.run_test(FloatingPoint(), x) 1677 1678 # Operator rank mismatch between outputs of two branches for opsets below 11. 1679 @skipIfUnsupportedMinOpsetVersion(11) 1680 def test_floating_point_infer_dtype(self): 1681 class FloatingPoint(torch.jit.ScriptModule): 1682 @torch.jit.script_method 1683 def forward(self, x): 1684 if x.size(0) > 1: 1685 a = x + 2 1686 if a.is_floating_point(): 1687 return x.new_zeros(x.shape[1:]) 1688 return x.new_zeros(x.shape) 1689 return x 1690 1691 x = torch.randn(2, 3, 4) 1692 self.run_test( 1693 FloatingPoint(), x, input_names=["x"], dynamic_axes={"x": [0, 1, 2]} 1694 ) 1695 self.run_test(FloatingPoint(), x, remained_onnx_input_idx=[]) 1696 1697 class FloatingPoint(torch.jit.ScriptModule): 1698 @torch.jit.script_method 1699 def forward(self, x): 1700 if x.size(0) > 1: 1701 a = x + 2 1702 if a.is_floating_point(): 1703 return x + 1 1704 return x 1705 return x 1706 1707 x = torch.randn(2, 3, 4).to(torch.int32) 1708 self.run_test(FloatingPoint(), x) 1709 1710 @skipIfUnsupportedMinOpsetVersion(12) 1711 def test_prim_min(self): 1712 @torch.jit.script 1713 def list_append(boxes: List[Tensor]): 1714 temp = [] 1715 for i, b in enumerate( 1716 boxes 1717 ): # enumerate is creating a prim::min op in torch graph 1718 temp.append(torch.full_like(b[:, 1], i)) 1719 return temp[0] 1720 1721 class Min(torch.nn.Module): 1722 def forward(self, x): 1723 boxes = [x for _ in range(3)] 1724 return list_append(boxes) 1725 1726 x = torch.rand(5, 5) 1727 self.run_test(Min(), (x,)) 1728 1729 class M(torch.jit.ScriptModule): 1730 @torch.jit.script_method 1731 def forward(self, x): 1732 i = 3 1733 return min(x[i], i) 1734 1735 x = torch.arange(6, dtype=torch.int64) 1736 self.run_test(M(), (x,)) 1737 1738 def test_arithmetic(self): 1739 class ArithmeticModule(torch.nn.Module): 1740 def forward(self, x): 1741 x = x + 2 1742 x = x - 4 1743 x = x * 6 1744 x = x / 8 1745 return x 1746 1747 x = torch.randn(2, 3, 4) 1748 self.run_test(ArithmeticModule(), x) 1749 1750 def test_arithmetic_prim_long(self): 1751 class ArithmeticModule(torch.nn.Module): 1752 def forward(self, x, y: int): 1753 x = x + y 1754 x = x - y 1755 x = x * (y * 3) 1756 x = x / (y * 4) 1757 return x 1758 1759 x = torch.randn(2, 3, 4) 1760 y = 2 1761 self.run_test(ArithmeticModule(), (x, y)) 1762 1763 class ArithmeticModule(torch.nn.Module): 1764 def forward(self, x): 1765 x = x + 2 1766 x = x - 3 1767 return x.shape[0] 1768 1769 x = torch.randn(2, 3, 4) 1770 self.run_test(ArithmeticModule(), x, remained_onnx_input_idx=[]) 1771 1772 @skipDtypeChecking 1773 def test_arithmetic_prim_float(self): 1774 class ArithmeticModule(torch.nn.Module): 1775 def forward(self, x, y: float): 1776 x = x + y 1777 x = x - y 1778 x = x * (y * 3) 1779 x = x / (y * 4) 1780 return x 1781 1782 x = torch.randn(2, 3, 4) 1783 y = 2.5 1784 self.run_test(ArithmeticModule(), (x, y)) 1785 1786 class ArithmeticModule(torch.nn.Module): 1787 def forward(self, x): 1788 x = x + 2 1789 x = x - 3 1790 return x.shape[1] / 2 1791 1792 x = torch.randn(2, 3, 4) 1793 self.run_test(ArithmeticModule(), x, remained_onnx_input_idx=[]) 1794 1795 @skipDtypeChecking 1796 def test_arithmetic_prim_bool(self): 1797 class ArithmeticModule(torch.nn.Module): 1798 def forward(self, x, y: int, z: bool, t: float): 1799 x = x + y 1800 x = x - y 1801 if z: 1802 x = x * (y * 3) 1803 x = x / (y * 4) 1804 return x / t, z 1805 1806 x = torch.randn(2, 3, 4) 1807 y = 2 1808 z = False 1809 t = 2.5 1810 self.run_test(ArithmeticModule(), (x, y, z, t)) 1811 1812 class ArithmeticModule(torch.nn.Module): 1813 def forward(self, x: int, y: int): 1814 return x == y 1815 1816 x = 3 1817 y = 2 1818 self.run_test(ArithmeticModule(), (x, y)) 1819 1820 @skipScriptTest( 1821 15, 1822 reason="In trace: Outputs that are always None are removed. \ 1823 In script: Outputs that are always None are removed before opset 15. \ 1824 After opset 15, we replace the None in output with Optional node.", 1825 ) 1826 def test_tuple_with_none_outputs(self): 1827 class TupleModel(torch.nn.Module): 1828 def forward(self, x): 1829 return (x, (x, None, (x, None))) 1830 1831 x = torch.randn(3, 4) 1832 self.run_test(TupleModel(), (x,)) 1833 1834 # In scripting the first transpose node do not carry shape and dtype info. 1835 # The following test only works when onnx shape inference is enabled. 1836 def test_arithmetic_infer_dtype(self): 1837 class ArithmeticModule(torch.jit.ScriptModule): 1838 @torch.jit.script_method 1839 def forward(self, x): 1840 x = x.t() 1841 x = x + 2 1842 x = x - 4 1843 x = x * 6 1844 x = x / 8 1845 return x 1846 1847 x = torch.randn(2, 3) 1848 self.run_test(ArithmeticModule(), x) 1849 1850 @unittest.skip("Floor division on ONNX is inconsistent with eager (see #78411)") 1851 def test_floor_div(self): 1852 class FloorDivModule(torch.nn.Module): 1853 def forward(self, x, y): 1854 return ( 1855 x // 3, 1856 x // 2.0, 1857 x.to(dtype=torch.float64) // 3, 1858 x.to(dtype=torch.float64) // 2.0, 1859 x.to(dtype=torch.int64) // 3, 1860 x.to(dtype=torch.int64) // 2.0, 1861 x // (y + 1.0).to(dtype=torch.int64), 1862 x // y, 1863 x.to(dtype=torch.float64) // y.to(dtype=torch.int64), 1864 x.to(dtype=torch.float64) // y.to(dtype=torch.float64), 1865 x.to(dtype=torch.int64) // y.to(dtype=torch.int64), 1866 x.to(dtype=torch.int64) // y, 1867 ) 1868 1869 x = torch.arange(-2, 4).reshape(2, 3, 1) 1870 y = torch.arange(1, 2 * 3 * 4 + 1).reshape(2, 3, 4) 1871 self.run_test(FloorDivModule(), (x, y)) 1872 1873 @unittest.skip("Floor division on ONNX is inconsistent with eager (see #78411)") 1874 def test_floor_div_script(self): 1875 class FloorDivModule(torch.jit.ScriptModule): 1876 @torch.jit.script_method 1877 def forward(self, x, y): 1878 return x // 3, x // 2.0, x // y 1879 1880 x = torch.arange(-2, 4).reshape(2, 3, 1) 1881 y = torch.randn(2, 3, 4) 1882 self.run_test(FloorDivModule(), (x, y)) 1883 1884 @unittest.skip("Floor division on ONNX is inconsistent with eager (see #78411)") 1885 @skipIfUnsupportedMinOpsetVersion(9) 1886 def test_floordiv(self): 1887 class FloordivModule(torch.nn.Module): 1888 def forward(self, x): 1889 return x.new_zeros(x.size(2) // x.size(1)) 1890 1891 x = torch.randn(2, 3, 4) 1892 self.run_test( 1893 FloordivModule(), x, input_names=["x"], dynamic_axes={"x": [0, 1, 2]} 1894 ) 1895 self.run_test(FloordivModule(), (x,), remained_onnx_input_idx=[]) 1896 1897 def test_div(self): 1898 class DivModule(torch.nn.Module): 1899 def forward(self, x, y): 1900 return x / y, torch.true_divide(x, y) 1901 1902 x = torch.randn(2, 3, 4).to(torch.int) 1903 y = torch.arange(1, 2 * 3 * 4 + 1).reshape(2, 3, 4).to(torch.int) 1904 self.run_test(DivModule(), (x, y)) 1905 self.run_test(DivModule(), (x.float(), y.float())) 1906 1907 # Note: div cannot (generally) be exported via scripting 1908 # since its type promotion logic is dependent on knowing the scalar types 1909 # of the input tensors. That is, the ONNX graph is dependent on the 1910 # data type of the inputs. This makes it appropriate for tracing only. 1911 def test_div_promotion_trace(self): 1912 class DivModule(torch.nn.Module): 1913 def forward(self, x, y): 1914 return x / y, torch.true_divide(x, y) 1915 1916 x = torch.randn(2, 3, 4).to(torch.int) 1917 y = torch.arange(1, 2 * 3 * 4 + 1).reshape(2, 3, 4).to(torch.int) 1918 1919 with common_utils.set_default_dtype(torch.float): 1920 self.run_test(torch.jit.trace(DivModule(), (x, y)), (x, y)) 1921 1922 with common_utils.set_default_dtype(torch.double): 1923 self.run_test(torch.jit.trace(DivModule(), (x, y)), (x, y)) 1924 1925 # In scripting x, y do not carry shape and dtype info. 1926 # The following test only works when onnx shape inference is enabled. 1927 def test_div_promotion_script(self): 1928 class DivModule(torch.nn.Module): 1929 def forward(self, x, y): 1930 # Add transpose to hide shape/type information 1931 # Otherwise shape and type are still avaiable from input. 1932 x = x.transpose(1, 2) 1933 y = y.transpose(1, 2) 1934 return x / y, torch.true_divide(x, y) 1935 1936 x = torch.randn(2, 3, 4).to(torch.int) 1937 y = torch.arange(1, 2 * 3 * 4 + 1).reshape(2, 3, 4).to(torch.int) 1938 1939 # 1. x,y are int, and output is float. 1940 # This can be handled by the default case, where both are cast to float. 1941 # It works even if type of x, y are unknown. 1942 with common_utils.set_default_dtype(torch.float): 1943 self.run_test(torch.jit.script(DivModule()), (x, y)) 1944 1945 # 2. x,y are int, and output is double. 1946 # This can be handled by the default case, where both are cast to double. 1947 # It works even if type of x, y are unknown. 1948 with common_utils.set_default_dtype(torch.double): 1949 self.run_test(torch.jit.script(DivModule()), (x, y)) 1950 1951 # 3. x is int, y is double, and output is double. 1952 # This can only be handled when both type of x and y are known. 1953 x = torch.randn(2, 3, 4).to(torch.int) 1954 y = torch.arange(1, 2 * 3 * 4 + 1).reshape(2, 3, 4).to(torch.double) 1955 self.run_test(torch.jit.script(DivModule()), (x, y)) 1956 1957 @skipDtypeChecking 1958 def test_div_rounding_mode(self): 1959 class TrueDivModule(torch.nn.Module): 1960 def forward(self, x, y): 1961 return ( 1962 x.div(y, rounding_mode=None), 1963 torch.div(x, y, rounding_mode=None), 1964 ) 1965 1966 class TruncDivModule(torch.nn.Module): 1967 def forward(self, x, y): 1968 return ( 1969 x.div(y, rounding_mode="trunc"), 1970 torch.div(x, y, rounding_mode="trunc"), 1971 ) 1972 1973 class FloorDivModule(torch.nn.Module): 1974 def forward(self, x, y): 1975 return ( 1976 x.div(y, rounding_mode="floor"), 1977 torch.div(x, y, rounding_mode="floor"), 1978 ) 1979 1980 modules = [TrueDivModule(), TruncDivModule(), FloorDivModule()] 1981 1982 x = (torch.randn(2, 3, 4) * 100).to(torch.int) 1983 y = torch.arange(1, 2 * 3 * 4 + 1).reshape(2, 3, 4).to(torch.int) 1984 1985 for module in modules: 1986 self.run_test(module, (x, y)) 1987 self.run_test(torch.jit.trace(module, (x, y)), (x, y)) 1988 self.run_test(torch.jit.script(module), (x, y)) 1989 1990 x = torch.randn(2, 3, 4) 1991 y = torch.rand(2, 3, 4) * 10.0 + 0.1 1992 1993 for module in modules: 1994 self.run_test(module, (x, y)) 1995 self.run_test(torch.jit.trace(module, (x, y)), (x, y)) 1996 self.run_test(torch.jit.script(module), (x, y)) 1997 1998 def test_slice_trace(self): 1999 class MyModule(torch.nn.Module): 2000 def forward(self, x): 2001 return x[0:1] 2002 2003 x = torch.randn(3) 2004 self.run_test(MyModule(), x) 2005 2006 def test_slice_neg(self): 2007 class NegSlice(torch.nn.Module): 2008 def forward(self, x): 2009 return x[-1:] 2010 2011 x = torch.randn(3, 4, 5) 2012 self.run_test(NegSlice(), x) 2013 2014 def test_slice_neg_large(self): 2015 class NegSlice(torch.nn.Module): 2016 def forward(self, x): 2017 return x[:, :, -3:-1, :, -1] 2018 2019 x = torch.randn(3, 4, 5, 6, 7) 2020 self.run_test(NegSlice(), x) 2021 2022 def test_slice_neg_large_negone(self): 2023 class NegSlice(torch.nn.Module): 2024 def forward(self, x): 2025 return x[:, :, :, :, -1] 2026 2027 x = torch.randn(3, 4, 5, 6, 7) 2028 self.run_test(NegSlice(), x) 2029 2030 @skipIfUnsupportedMinOpsetVersion(11) 2031 def test_slice_with_input_index(self): 2032 class InputIndexSlice(torch.nn.Module): 2033 def forward(self, x, y): 2034 x[: y.size(0), 0, :] = y 2035 return x 2036 2037 x = torch.zeros((56, 6, 256)) 2038 y = torch.rand((22, 256)) 2039 self.run_test(InputIndexSlice(), (x, y)) 2040 2041 @skipIfUnsupportedMinOpsetVersion(11) 2042 @skipScriptTest() # Torchscript doesn't support 1d index. 2043 def test_slice_with_1d_input_index(self): 2044 class InputIndexSlice(torch.nn.Module): 2045 def forward(self, x, y): 2046 x[:y, 0, :] = y 2047 return x 2048 2049 x = torch.zeros((56, 6, 256)) 2050 y = torch.tensor([5], dtype=torch.int64) 2051 self.run_test(InputIndexSlice(), (x, y)) 2052 2053 @skipIfUnsupportedMinOpsetVersion(11) 2054 def test_slice_with_input_step_size(self): 2055 class InputIndexSlice(torch.nn.Module): 2056 def forward(self, x, y, z): 2057 x[:y:z, 0::z, :] = 1 2058 return x 2059 2060 x = torch.zeros((56, 6, 256)) 2061 y = torch.tensor(5, dtype=torch.int64) 2062 z = torch.tensor(2, dtype=torch.int64) 2063 self.run_test(InputIndexSlice(), (x, y, z)) 2064 2065 @skipIfUnsupportedMinOpsetVersion(10) 2066 @skipScriptTest() # scripting tuple/list append 2067 def test_slice_dynamic(self): 2068 class DynamicSliceExportMod(torch.nn.Module): 2069 def forward(self, x): 2070 results = [] 2071 for i in range(4): 2072 results.append(x[: x.size(0) - i, i : x.size(2), i:3]) 2073 return tuple(results) 2074 2075 x = torch.rand(5, 5, 5) 2076 y = torch.randn(6, 7, 8) 2077 self.run_test( 2078 DynamicSliceExportMod(), 2079 x, 2080 additional_test_inputs=[y], 2081 input_names=["input_1"], 2082 output_names=["output_1"], 2083 dynamic_axes={"input_1": [0, 1, 2], "output_1": [0, 1, 2]}, 2084 ) 2085 2086 @skipIfUnsupportedMinOpsetVersion(10) 2087 def test_slice_dynamic_script(self): 2088 class DynamicSliceModel(torch.jit.ScriptModule): 2089 @torch.jit.script_method 2090 def forward(self, x): 2091 return x[1 : x.size(1)] 2092 2093 x = torch.rand(1, 2) 2094 self.run_test(DynamicSliceModel(), x) 2095 2096 @skipIfUnsupportedMinOpsetVersion(10) 2097 def test_slice_dynamic_shape_script(self): 2098 class DynamicSliceModel(torch.nn.Module): 2099 def forward(self, x): 2100 return x.new_zeros(x.shape[1 : x.size(2)]) 2101 2102 x = torch.rand(1, 2, 3, 4) 2103 self.run_test( 2104 DynamicSliceModel(), x, input_names=["x"], dynamic_axes={"x": [0, 1, 2, 3]} 2105 ) 2106 self.run_test(DynamicSliceModel(), x, remained_onnx_input_idx=[]) 2107 2108 @skipIfUnsupportedMinOpsetVersion(10) 2109 @skipScriptTest() # scripting tuple/list append 2110 def test_slice_dynamic_to_end(self): 2111 class DynamicSliceExportMod(torch.nn.Module): 2112 def forward(self, x): 2113 results = [] 2114 for i in range(4): 2115 results.append(x[:, i:, x.size(2) - 5]) 2116 return tuple(results) 2117 2118 x = torch.rand(5, 5, 5) 2119 self.run_test( 2120 DynamicSliceExportMod(), 2121 x, 2122 dynamic_axes={"input_1": [0, 1, 2], "output_1": [0, 1, 2]}, 2123 ) 2124 2125 def test_square(self): 2126 class Square(torch.nn.Module): 2127 def forward(self, x): 2128 return torch.square(x) 2129 2130 x = torch.randn(2, 3, 4) 2131 self.run_test(Square(), x) 2132 2133 @skipIfUnsupportedMinOpsetVersion(9) 2134 def test_arange_dynamic(self): 2135 class ArangeModel(torch.nn.Module): 2136 def forward(self, input): 2137 return ( 2138 torch.arange(input.shape[0]), 2139 torch.arange(12), 2140 torch.arange(start=input.shape[0], end=input.shape[0] + 5), 2141 ) 2142 2143 x = torch.randn(5, 3, 2) 2144 y = torch.randn(8, 3, 2) 2145 self.run_test( 2146 ArangeModel(), 2147 x, 2148 additional_test_inputs=[y], 2149 input_names=["input_1"], 2150 output_names=["output_1", "output_2", "output_3"], 2151 dynamic_axes={"input_1": [0], "output_1": [0]}, 2152 ) 2153 self.run_test( 2154 torch.jit.script(ArangeModel()), 2155 x, 2156 additional_test_inputs=[y], 2157 input_names=["input_1"], 2158 output_names=["output_1", "output_2", "output_3"], 2159 dynamic_axes={"input_1": [0], "output_1": [0]}, 2160 ) 2161 2162 @skipIfUnsupportedMinOpsetVersion(9) 2163 def test_dynamic_arange_out(self): 2164 class ArangeOutModel(torch.nn.Module): 2165 def forward(self, end): 2166 out_t = torch.tensor([1], dtype=torch.int64) 2167 return torch.arange(end, out=out_t) 2168 2169 x = torch.tensor(8) 2170 self.run_test(ArangeOutModel(), (x)) 2171 2172 @skipIfUnsupportedMinOpsetVersion(9) 2173 def test_dynamic_arange_start_out(self): 2174 class ArangeStartOutModel(torch.nn.Module): 2175 def forward(self, start, end): 2176 out_t = torch.tensor([1], dtype=torch.int64) 2177 return torch.arange(start.size(0), end, out=out_t) 2178 2179 x = torch.randn(2, 3, 4) 2180 y = torch.tensor(8) 2181 self.run_test( 2182 ArangeStartOutModel(), 2183 (x, y), 2184 input_names=["x", "y"], 2185 dynamic_axes={"x": [0, 1, 2]}, 2186 ) 2187 self.run_test(ArangeStartOutModel(), (x, y), remained_onnx_input_idx=[1]) 2188 2189 @skipIfUnsupportedMinOpsetVersion(9) 2190 def test_linspace(self): 2191 class LinspaceModel(torch.nn.Module): 2192 def forward(self, start, end, steps): 2193 return torch.linspace(start, end, steps) 2194 2195 x = torch.tensor(3, dtype=torch.float) 2196 y = torch.tensor(10, dtype=torch.float) 2197 z = torch.tensor(5, dtype=torch.int) 2198 self.run_test(LinspaceModel(), (x, y, z)) 2199 2200 @skipIfUnsupportedMinOpsetVersion(9) 2201 def test_linspace_negative_start(self): 2202 class LinspaceModel(torch.nn.Module): 2203 def forward(self, start, end, steps): 2204 return torch.linspace(start, end, steps) 2205 2206 x = torch.tensor(-1, dtype=torch.float) 2207 y = torch.tensor(1, dtype=torch.float) 2208 z = torch.tensor(6, dtype=torch.int) 2209 self.run_test(LinspaceModel(), (x, y, z)) 2210 2211 @skipIfUnsupportedMinOpsetVersion(9) 2212 def test_arange_with_floats_out(self): 2213 class ArangeModelEnd(torch.nn.Module): 2214 def forward(self, end): 2215 out_t = torch.tensor([1], dtype=torch.float) 2216 return torch.arange(end, out=out_t) 2217 2218 y = torch.tensor(8.5, dtype=torch.float) 2219 self.run_test(ArangeModelEnd(), (y)) 2220 2221 class ArangeModelStep(torch.nn.Module): 2222 def forward(self, start, end): 2223 out_t = torch.tensor([1], dtype=torch.float) 2224 return torch.arange(start.size(0), end, 1.5, out=out_t) 2225 2226 x = torch.randn(2, 3, 4) 2227 y = torch.tensor(8.5, dtype=torch.float) 2228 self.run_test( 2229 ArangeModelStep(), 2230 (x, y), 2231 input_names=["x", "y"], 2232 dynamic_axes={"x": [0, 1, 2]}, 2233 ) 2234 self.run_test(ArangeModelStep(), (x, y), remained_onnx_input_idx=[1]) 2235 2236 @skipIfUnsupportedMinOpsetVersion(9) 2237 def test_arange_with_floats(self): 2238 class ArangeModelEnd(torch.nn.Module): 2239 def forward(self, end): 2240 return torch.arange(end) 2241 2242 y = torch.tensor(8.5, dtype=torch.float) 2243 self.run_test(ArangeModelEnd(), (y)) 2244 2245 class ArangeModelStep(torch.nn.Module): 2246 def forward(self, start, end): 2247 return torch.arange(start.size(0), end, 1.5) 2248 2249 x = torch.randn(2, 3, 4) 2250 y = torch.tensor(8.5, dtype=torch.float) 2251 self.run_test( 2252 ArangeModelStep(), 2253 (x, y), 2254 input_names=["x", "y"], 2255 dynamic_axes={"x": [0, 1, 2]}, 2256 ) 2257 self.run_test(ArangeModelStep(), (x, y), remained_onnx_input_idx=[1]) 2258 2259 class ArangeModelStepNeg(torch.nn.Module): 2260 def forward(self, start, end): 2261 return torch.arange(end, start.size(0), -1.5) 2262 2263 x = torch.randn(2, 3, 4) 2264 y = torch.tensor(8.5, dtype=torch.float) 2265 self.run_test( 2266 ArangeModelStepNeg(), 2267 (x, y), 2268 input_names=["x", "y"], 2269 dynamic_axes={"x": [0, 1, 2]}, 2270 ) 2271 self.run_test(ArangeModelStepNeg(), (x, y), remained_onnx_input_idx=[1]) 2272 2273 class ArangeModelStart(torch.nn.Module): 2274 def forward(self, start, end): 2275 return torch.arange(start.size(0), end) 2276 2277 x = torch.randn(2, 3, 4) 2278 y = torch.tensor(8.5, dtype=torch.float) 2279 self.run_test( 2280 ArangeModelStart(), 2281 (x, y), 2282 input_names=["x", "y"], 2283 dynamic_axes={"x": [0, 1, 2]}, 2284 ) 2285 self.run_test(ArangeModelStart(), (x, y), remained_onnx_input_idx=[1]) 2286 2287 @skipIfUnsupportedMinOpsetVersion(9) 2288 def test_arange_with_floats_override(self): 2289 class ArangeModelEnd(torch.nn.Module): 2290 def forward(self, end): 2291 return torch.arange(end, dtype=torch.int64) 2292 2293 y = torch.tensor(8.5, dtype=torch.float) 2294 self.run_test(ArangeModelEnd(), (y)) 2295 2296 class ArangeModelStep(torch.nn.Module): 2297 def forward(self, start, end): 2298 return torch.arange(start.size(0), end, 1.5, dtype=torch.int64) 2299 2300 x = torch.randn(2, 3, 4) 2301 y = torch.tensor(8.5, dtype=torch.float) 2302 self.run_test( 2303 ArangeModelStep(), 2304 (x, y), 2305 input_names=["x", "y"], 2306 dynamic_axes={"x": [0, 1, 2]}, 2307 ) 2308 self.run_test(ArangeModelStep(), (x, y), remained_onnx_input_idx=[1]) 2309 2310 @skipIfUnsupportedMinOpsetVersion(11) 2311 def test_arange_out(self): 2312 class ArangeOutModel(torch.nn.Module): 2313 def forward(self, end): 2314 out_t = torch.tensor([1], dtype=torch.float) 2315 return torch.arange(end, out=out_t) 2316 2317 x = torch.tensor(8.5, dtype=torch.float) 2318 self.run_test(ArangeOutModel(), (x)) 2319 2320 @skipIfUnsupportedMinOpsetVersion(11) 2321 def test_arange_start_out(self): 2322 class ArangeStartOutModel(torch.nn.Module): 2323 def forward(self, start, end): 2324 out_t = torch.tensor([1], dtype=torch.float) 2325 return torch.arange(start.size(0), end, out=out_t) 2326 2327 x = torch.randn(2, 3, 4) 2328 y = torch.tensor(8.5, dtype=torch.float) 2329 self.run_test( 2330 ArangeStartOutModel(), 2331 (x, y), 2332 input_names=["x", "y"], 2333 dynamic_axes={"x": [0, 1, 2]}, 2334 ) 2335 self.run_test(ArangeStartOutModel(), (x, y), remained_onnx_input_idx=[1]) 2336 2337 @skipIfUnsupportedMinOpsetVersion(11) 2338 def test_arange_no_type(self): 2339 class ArangeModel(torch.nn.Module): 2340 def forward(self, end): 2341 return torch.arange(end), torch.arange(0, end) 2342 2343 x = torch.tensor(6.2, dtype=torch.float) 2344 self.run_test(ArangeModel(), x) 2345 2346 @skipIfUnsupportedMinOpsetVersion(9) 2347 def test_size(self): 2348 class SizeModel(torch.nn.Module): 2349 def forward(self, input): 2350 return ( 2351 torch.arange(input.size(0)), 2352 torch.arange(input.size(-1)), 2353 torch.ones(input.shape), 2354 ) 2355 2356 x = torch.randn(5, 3, 2) 2357 self.run_test(SizeModel(), x, input_names=["x"], dynamic_axes={"x": [0, 1, 2]}) 2358 self.run_test(SizeModel(), x, remained_onnx_input_idx=[]) 2359 2360 @skipIfUnsupportedMinOpsetVersion(9) 2361 @skipScriptTest() # x.stride() not scriptable 2362 def test_as_strided(self): 2363 class Model(torch.nn.Module): 2364 def forward(self, x): 2365 chunk_size = list(x.size()) 2366 chunk_size[1] = chunk_size[1] * 2 - 1 2367 chunk_stride = list(x.stride()) 2368 chunk_stride[1] = chunk_stride[1] // 2 2369 return x.as_strided( 2370 (3, 3, 3), (1, 4, 2), storage_offset=2 2371 ), x.as_strided(chunk_size, chunk_stride) 2372 2373 x = torch.randn(5, 8, 7) 2374 self.run_test(Model(), x) 2375 2376 @skipScriptTest() # Ellipses followed by tensor indexing not scriptable 2377 def test_tensor_index_advanced_indexing_ellipsis(self): 2378 class MyModel(torch.nn.Module): 2379 def forward(self, input): 2380 return input[..., torch.tensor([2, 1]), torch.tensor([0, 3])] 2381 2382 m1 = torch.randn(3, 4, 5, 6, 7) 2383 self.run_test(MyModel(), (m1,)) 2384 2385 def test_tensor_index_advanced_indexing(self): 2386 class MyModel(torch.nn.Module): 2387 def forward(self, input): 2388 return input[ 2389 :, 2390 torch.tensor([[0, 2], [1, 1]]), 2391 :, 2392 torch.tensor([2, 1]), 2393 torch.tensor([0, 3]), 2394 ] 2395 2396 m1 = torch.randn(3, 4, 5, 6, 7) 2397 self.run_test(MyModel(), (m1,)) 2398 2399 class MyModel(torch.nn.Module): 2400 def forward(self, input): 2401 return input[ 2402 :, torch.tensor([0, 2]), None, 2:4, torch.tensor([[1, 3], [4, 0]]) 2403 ] 2404 2405 self.run_test(MyModel(), (m1,)) 2406 2407 class MyModel(torch.nn.Module): 2408 def forward(self, input): 2409 return input[ 2410 :, 2411 torch.tensor([0, 2]), 2412 torch.tensor([1]), 2413 2:4, 2414 torch.tensor([[1], [4]]), 2415 ] 2416 2417 self.run_test(MyModel(), (m1,)) 2418 2419 def test_tensor_index_advanced_indexing_consecutive(self): 2420 class MyModel(torch.nn.Module): 2421 def forward(self, input): 2422 return input[ 2423 :, torch.tensor([0, 2]), torch.tensor([[1, 3], [4, 0]]), None 2424 ] 2425 2426 m1 = torch.randn(3, 4, 5, 6, 7) 2427 self.run_test(MyModel(), (m1,)) 2428 2429 @skipIfUnsupportedMinOpsetVersion(11) 2430 def test_index_put(self): 2431 class IndexPutModel(torch.nn.Module): 2432 def forward(self, x, ind, update): 2433 x[ind] = update 2434 return x 2435 2436 x = torch.randn(3, 4) 2437 ind = torch.tensor([1], dtype=torch.long) 2438 update = torch.ones(4) 2439 self.run_test(IndexPutModel(), (x, ind, update)) 2440 2441 @skipIfUnsupportedMinOpsetVersion(11) 2442 def test_index_put_singular(self): 2443 class IndexPutBoolModel(torch.nn.Module): 2444 def forward(self, mask, indices): 2445 mask[indices] = True 2446 return mask 2447 2448 mask = torch.zeros(100, dtype=torch.bool) 2449 indices = (torch.rand(25) * mask.shape[0]).to(torch.int64) 2450 self.run_test(IndexPutBoolModel(), (mask, indices)) 2451 2452 class IndexPutFloatModel(torch.nn.Module): 2453 def forward(self, mask, indices): 2454 mask[indices] = torch.tensor(5.5) 2455 return mask 2456 2457 mask = torch.rand(100, dtype=torch.float) 2458 indices = (torch.rand(50) * mask.shape[0]).to(torch.int64) 2459 self.run_test(IndexPutFloatModel(), (mask, indices)) 2460 2461 @skipIfUnsupportedMinOpsetVersion(11) 2462 def test_index_put_accumulate(self): 2463 class IndexPutModel(torch.nn.Module): 2464 def forward(self, x, ind, update): 2465 return x.index_put((ind,), update, accumulate=True) 2466 2467 x = torch.randn(3, 4) 2468 ind = torch.tensor([2], dtype=torch.long) 2469 update = torch.ones(4) 2470 self.run_test(IndexPutModel(), (x, ind, update)) 2471 2472 @skipIfUnsupportedMinOpsetVersion(11) 2473 def test_index_put_slice_index(self): 2474 class IndexPutModel(torch.nn.Module): 2475 def forward(self, x, update): 2476 x[1:2, 1:3, torch.tensor([1])] += update 2477 return x 2478 2479 x = torch.randn(3, 4, 5) 2480 update = torch.tensor([10, 15]).view(1, 2, 1) 2481 self.run_test(IndexPutModel(), (x, update)) 2482 2483 class IndexPutModel2(torch.nn.Module): 2484 def forward(self, x, update): 2485 x[torch.tensor([0, 2]), torch.tensor([1, 2])] += update 2486 return x 2487 2488 x = torch.randn(3, 4, 5) 2489 update = torch.randn(2, 5) 2490 self.run_test(IndexPutModel2(), (x, update)) 2491 2492 class IndexPutModel3(torch.nn.Module): 2493 def forward(self, x, update): 2494 x[torch.tensor([0, 2]), 1:2] += update 2495 return x 2496 2497 x = torch.randn(3, 4, 5) 2498 update = torch.tensor([10, 15]).view(2, 1, 1) 2499 self.run_test(IndexPutModel3(), (x, update)) 2500 2501 class IndexPutModel4(torch.nn.Module): 2502 def forward(self, x, update): 2503 x[torch.tensor([0, 2]), 2] += update 2504 return x 2505 2506 x = torch.randn(3, 4, 5) 2507 update = torch.tensor([10, 15]).view(2, 1) 2508 self.run_test(IndexPutModel4(), (x, update)) 2509 2510 class IndexPutModel5(torch.nn.Module): 2511 def forward(self, x, update): 2512 x[1:3, torch.tensor([0, 2]), 2] += update 2513 return x 2514 2515 x = torch.randn(3, 4, 5) 2516 update = torch.tensor([10, 15]).view(2, 1) 2517 self.run_test(IndexPutModel5(), (x, update)) 2518 2519 class IndexPutModel6(torch.nn.Module): 2520 def forward(self, x, update): 2521 x[1:3, 0] = update 2522 return x 2523 2524 x = torch.randn(3, 4, 5) 2525 update = torch.arange(2 * 5).to(torch.float).view(2, 5) 2526 self.run_test(IndexPutModel6(), (x, update)) 2527 2528 class IndexPutModel7(torch.nn.Module): 2529 def forward(self, x, update): 2530 x[1:, 0] = update 2531 return x 2532 2533 x = torch.randn(3, 4, 5) 2534 update = torch.arange(2 * 5).to(torch.float).view(2, 5) 2535 self.run_test(IndexPutModel7(), (x, update)) 2536 2537 class IndexPutModel8(torch.nn.Module): 2538 def forward(self, x, update): 2539 x[:3, 0] = update 2540 return x 2541 2542 x = torch.randn(3, 4, 5) 2543 update = torch.arange(3 * 5).to(torch.float).view(3, 5) 2544 self.run_test(IndexPutModel8(), (x, update)) 2545 2546 class IndexPutModel9(torch.nn.Module): 2547 def forward(self, poses): 2548 w = 32 2549 x = poses[:, :, 0] - (w - 1) // 2 2550 boxes = torch.zeros([poses.shape[0], 17, 4]) 2551 boxes[:, :, 0] = x 2552 return boxes 2553 2554 x = torch.zeros([2, 17, 3], dtype=torch.int64) 2555 self.run_test(IndexPutModel9(), (x,)) 2556 2557 class IndexPutModel10(torch.nn.Module): 2558 def forward(self, x, ind, update): 2559 x[ind, 1:3] = update.view(1, 1, 1, 5).expand(2, 2, 2, 5) 2560 return x 2561 2562 x = torch.randn(3, 4, 5) 2563 ind = torch.tensor([[0, 2], [1, 1]]) 2564 update = torch.randn(5) 2565 self.run_test(IndexPutModel10(), (x, ind, update)) 2566 2567 @skipIfUnsupportedMinOpsetVersion(11) 2568 @skipScriptTest() # Ellipses followed by tensor indexing not scriptable 2569 def test_index_put_ellipsis(self): 2570 class IndexPutModel(torch.nn.Module): 2571 def forward(self, x, update): 2572 x[..., torch.tensor([2, 1, 3]), 2:4] += update 2573 return x 2574 2575 x = torch.randn(3, 4, 5, 6, 7) 2576 update = torch.randn(3, 1, 1, 3, 2) 2577 self.run_test(IndexPutModel(), (x, update)) 2578 2579 class IndexPutModel2(torch.nn.Module): 2580 def forward(self, x, update): 2581 x[2, ..., torch.tensor([2, 1, 3]), 2:4] += update 2582 return x 2583 2584 x = torch.randn(3, 4, 5, 6, 7) 2585 update = torch.randn(4, 1, 3, 2) 2586 self.run_test(IndexPutModel2(), (x, update)) 2587 2588 @unittest.skip( 2589 "regression in 1.18: https://github.com/microsoft/onnxruntime/issues/20855" 2590 ) 2591 @skipIfUnsupportedMinOpsetVersion(11) 2592 def test_index_put_loop(self): 2593 @torch.jit.script 2594 def ngram_attention_bias( 2595 sequence_length: int, ngram: int, device: torch.device, dtype: torch.dtype 2596 ): 2597 bias = torch.ones( 2598 (ngram, sequence_length), device=device, dtype=dtype 2599 ) * float("-inf") 2600 for stream_idx in range(ngram): 2601 for i in range(sequence_length): 2602 bias = bias * 2 2603 bias[stream_idx, i] = 5 2604 bias = bias * 5 2605 bias[0, 0] = 5 2606 2607 for stream_idx in range(ngram): 2608 for i in range(sequence_length): 2609 bias[stream_idx, i] = 5 2610 bias[0, i] = 5 2611 return bias 2612 2613 class ScriptModel(torch.nn.Module): 2614 def __init__(self) -> None: 2615 super().__init__() 2616 self.ngram = 2 2617 self.max_target_positions = 512 2618 2619 def forward(self, hidden_states): 2620 seq_length, batch_size = hidden_states.shape[:2] 2621 predict_causal_mask = ngram_attention_bias( 2622 self.max_target_positions, 2623 self.ngram, 2624 hidden_states.device, 2625 hidden_states.dtype, 2626 ) 2627 predict_causal_mask = predict_causal_mask[:, :seq_length] 2628 return predict_causal_mask 2629 2630 x = torch.randn(6, 2) 2631 y = torch.randn(4, 1) 2632 self.run_test( 2633 ScriptModel(), 2634 x, 2635 input_names=["x"], 2636 dynamic_axes={"x": {0: "seq_length", 1: "batch_size"}}, 2637 additional_test_inputs=[y], 2638 ) 2639 2640 @skipIfUnsupportedMinOpsetVersion(11) 2641 def test_copy_(self): 2642 class CopyModel(torch.nn.Module): 2643 def forward(self, x, data): 2644 x[1:3] = data 2645 return x 2646 2647 x = torch.randn(3, 4) 2648 update = torch.randn(2, 4) 2649 self.run_test(CopyModel(), (x, update)) 2650 2651 # mixed slice and select 2652 class CopyModel2(torch.nn.Module): 2653 def forward(self, x, data): 2654 x[1:3, 0] = data 2655 return x 2656 2657 x = torch.randn(3, 4) 2658 update = torch.tensor([0], dtype=torch.float32) 2659 self.run_test(CopyModel2(), (x, update)) 2660 2661 update = torch.tensor([2, 3], dtype=torch.float32) 2662 self.run_test(CopyModel2(), (x, update)) 2663 2664 update = torch.randn(2) 2665 self.run_test(CopyModel2(), (x, update)) 2666 2667 class CopyModel3(torch.nn.Module): 2668 def forward(self, x, data): 2669 x[1, 1:3] = data 2670 return x 2671 2672 x = torch.randn(3, 4) 2673 update = torch.tensor([0], dtype=torch.float32) 2674 self.run_test(CopyModel3(), (x, update)) 2675 2676 update = torch.tensor([2, 3], dtype=torch.float32) 2677 self.run_test(CopyModel3(), (x, update)) 2678 2679 update = torch.randn(2) 2680 self.run_test(CopyModel3(), (x, update)) 2681 2682 class CopyModel4(torch.nn.Module): 2683 def forward(self, x, ind, data): 2684 x[ind] = data 2685 return x 2686 2687 x = torch.randn(3, 4) 2688 ind = torch.tensor(2) 2689 data = torch.randn(4) 2690 self.run_test(CopyModel4(), (x, ind, data)) 2691 2692 class CopyModel5(torch.nn.Module): 2693 def forward(self, x, mask): 2694 if mask is not None: 2695 x.copy_(mask) 2696 return x 2697 2698 x = torch.randn(3, 4) 2699 mask = torch.randn(3, 1) 2700 self.run_test(CopyModel5(), (x, mask)) 2701 2702 @skipIfUnsupportedMinOpsetVersion(11) 2703 @skipScriptTest() # Model not scriptable (output with shape doesn't match the broadcast shape) 2704 def test_copy_tracing(self): 2705 class CopyModel(torch.nn.Module): 2706 def forward(self, x, data): 2707 x[1, 1:3] = data 2708 return x 2709 2710 x = torch.randn(3, 4) 2711 update = torch.randn(1, 2) 2712 self.run_test(CopyModel(), (x, update)) 2713 2714 @skipIfUnsupportedMinOpsetVersion(11) 2715 def test_copy_ellipsis(self): 2716 class CopyModel(torch.nn.Module): 2717 def forward(self, x, update): 2718 x[..., 1] = update 2719 return x 2720 2721 x = torch.randn(2, 3, 4) 2722 update = torch.ones(1) 2723 self.run_test(CopyModel(), (x, update)) 2724 2725 x = torch.randn(2, 3, 4, 5, 6) 2726 update = torch.ones(1) 2727 self.run_test(CopyModel(), (x, update)) 2728 2729 @skipIfUnsupportedMinOpsetVersion(11) 2730 def test_copy_ellipsis_script(self): 2731 class CopyModel(torch.nn.Module): 2732 def forward(self, x, update): 2733 # Insert reshape node to ensure no shape/type info for 2734 # x in scripting, without onnx shape inference. 2735 x = x.reshape(4, 3, 5, 6) 2736 x[2, ..., 1:3] = update 2737 return x 2738 2739 x = torch.randn(3, 4, 5, 6) 2740 2741 update = torch.ones(1) 2742 self.run_test(CopyModel(), (x, update)) 2743 2744 @skipIfUnsupportedMinOpsetVersion(10) 2745 def test_flip(self): 2746 class MyModule(torch.nn.Module): 2747 def forward(self, x): 2748 return torch.flip(x, dims=[0]) 2749 2750 x = torch.tensor(np.arange(6.0).reshape(2, 3)) 2751 self.run_test(MyModule(), x) 2752 2753 @skipIfUnsupportedMinOpsetVersion(9) 2754 def test_randint(self): 2755 class RandInt(torch.nn.Module): 2756 def forward(self, x): 2757 randint = torch.randint(1, 10, x.shape) 2758 x = 0 * randint + x 2759 return x 2760 2761 x = torch.randn(2, 3, 4) 2762 self.run_test(RandInt(), x) 2763 2764 @skipIfUnsupportedMinOpsetVersion(9) 2765 def test_randint_value(self): 2766 class RandInt(torch.nn.Module): 2767 def forward(self, x): 2768 # This randint call always returns 3 2769 return torch.randint(3, 4, x.shape) + x 2770 2771 x = torch.randn(2, 3, 4) 2772 self.run_test(RandInt(), x) 2773 2774 @skipIfUnsupportedMinOpsetVersion(9) 2775 def test_randint_like(self): 2776 class RandInt(torch.nn.Module): 2777 def forward(self, x): 2778 # This randint call always returns 3 2779 return torch.randint_like(x, 3, 4) + x 2780 2781 x = torch.randn(2, 3, 4) 2782 self.run_test(RandInt(), x) 2783 2784 def test_randn(self): 2785 class RandN(torch.nn.Module): 2786 def forward(self, x): 2787 return torch.mul(x, (torch.randn(2, 3, 4) + x).size(0)) 2788 2789 x = torch.randn(2, 3, 4) 2790 self.run_test(RandN(), x) 2791 2792 def test_rand(self): 2793 class Rand(torch.nn.Module): 2794 def forward(self, x): 2795 return torch.mul(x, (torch.rand(2, 3, 4) + x).size(0)) 2796 2797 x = torch.randn(2, 3, 4) 2798 self.run_test(Rand(), x) 2799 2800 def test_randn_dtype(self): 2801 class RandN(torch.nn.Module): 2802 def forward(self, x): 2803 # The resulting node's dtype should be double. 2804 return ( 2805 x.to(torch.float32) 2806 * torch.randn(2, 3, 4, dtype=torch.double) 2807 * torch.tensor(0, dtype=torch.float32) 2808 ) 2809 2810 x = torch.randn(2, 3, 4) 2811 self.run_test(RandN(), x) 2812 2813 def test_rand_dtype(self): 2814 class Rand(torch.nn.Module): 2815 def forward(self, x): 2816 # The resulting node's dtype should be double. 2817 return ( 2818 x.to(torch.float32) 2819 * torch.rand(2, 3, 4, dtype=torch.double) 2820 * torch.tensor(0, dtype=torch.float32) 2821 ) 2822 2823 x = torch.randn(2, 3, 4) 2824 self.run_test(Rand(), x) 2825 2826 @skipIfUnsupportedMinOpsetVersion(9) 2827 def test_randn_dynamic_size(self): 2828 class RandN(torch.nn.Module): 2829 def forward(self, x): 2830 return torch.mul(x, torch.randn(x.size()).size(1)) 2831 2832 x = torch.randn(2, 3, 4) 2833 self.run_test(RandN(), x) 2834 2835 @skipIfUnsupportedMinOpsetVersion(9) 2836 def test_rand_dynamic_size(self): 2837 class Rand(torch.nn.Module): 2838 def forward(self, x): 2839 return torch.mul(x, torch.rand(x.size()).size(1)) 2840 2841 x = torch.randn(2, 3, 4) 2842 self.run_test(Rand(), x) 2843 2844 def test_randn_like(self): 2845 class RandNLike(torch.nn.Module): 2846 def forward(self, x): 2847 return torch.mul(x, torch.randn_like(x).size(0)) 2848 2849 x = torch.randn(2, 3, 4) 2850 self.run_test(RandNLike(), x) 2851 self.run_test(torch.jit.script(RandNLike()), x) 2852 2853 def test_rand_like(self): 2854 class RandLike(torch.nn.Module): 2855 def forward(self, x): 2856 return torch.mul(x, torch.rand_like(x).size(0)) 2857 2858 x = torch.randn(2, 3, 4) 2859 self.run_test(RandLike(), x) 2860 self.run_test(torch.jit.script(RandLike()), x) 2861 2862 def test_randn_like_dtype(self): 2863 class RandNLike(torch.nn.Module): 2864 def forward(self, x): 2865 # The resulting node's dtype should be double. 2866 return ( 2867 x.to(torch.float32) 2868 * torch.randn_like(x, dtype=torch.double) 2869 * torch.tensor(0, dtype=torch.float32) 2870 ) 2871 2872 x = torch.randn(2, 3, 4) 2873 self.run_test(RandNLike(), x) 2874 2875 def test_rand_like_dtype(self): 2876 class RandLike(torch.nn.Module): 2877 def forward(self, x): 2878 # The resulting node's dtype should be double. 2879 return ( 2880 x.to(torch.float32) 2881 * torch.rand_like(x, dtype=torch.double) 2882 * torch.tensor(0, dtype=torch.float32) 2883 ) 2884 2885 x = torch.randn(2, 3, 4) 2886 self.run_test(RandLike(), x) 2887 2888 def test_bernoulli(self): 2889 class Bernoulli(torch.nn.Module): 2890 def forward(self, x): 2891 return torch.mul(x, torch.bernoulli(x).size(0)) 2892 2893 x = torch.empty(3, 3).uniform_(0, 1) 2894 self.run_test(Bernoulli(), x) 2895 2896 x = torch.empty(2, 3, 3, dtype=torch.double).uniform_(0, 1) 2897 self.run_test(Bernoulli(), x) 2898 2899 def test_bernoulli_p(self): 2900 class Bernoulli_float(torch.nn.Module): 2901 def forward(self, x): 2902 return torch.mul(x, torch.bernoulli(x, 0.2).size(0)) 2903 2904 class Bernoulli_tensor(torch.nn.Module): 2905 def forward(self, x): 2906 return torch.mul(x, torch.rand_like(x).bernoulli_(x).size(0)) 2907 2908 x = torch.rand(3, 3) 2909 self.run_test(Bernoulli_float(), x) 2910 self.run_test(Bernoulli_tensor(), x) 2911 2912 x = torch.rand(2, 3, 3, dtype=torch.double) 2913 self.run_test(Bernoulli_float(), x) 2914 self.run_test(Bernoulli_tensor(), x) 2915 2916 @unittest.skip("Bug in ORT, skip test until rel-1.11.") 2917 @skipIfUnsupportedMinOpsetVersion(14) 2918 def test_reshape_allowzero(self): 2919 class ReshapeModel(torch.nn.Module): 2920 def forward(self, x): 2921 x = x.reshape(3, 4, 0) 2922 return x 2923 2924 x = torch.randn(0, 3, 4) 2925 self.run_test(ReshapeModel(), x) 2926 2927 def test_reshape_different_rank(self): 2928 class ReshapeModel(torch.nn.Module): 2929 def forward(self, x): 2930 x = x.reshape(-1, 2, 4, 4, 5, 5) 2931 return x 2932 2933 x = torch.randn(1, 32, 5, 5) 2934 self.run_test(ReshapeModel(), x) 2935 2936 def _interpolate(self, x, mode, use_size, is_upsample, align_corners=False): 2937 class MyModel(torch.nn.Module): 2938 __constants__ = [ 2939 "mode", 2940 "use_size", 2941 "is_upsample", 2942 "size", 2943 "scale", 2944 "size_array", 2945 "scale_array", 2946 "align_corners", 2947 ] 2948 2949 def __init__(self, mode, use_size, is_upsample, align_corners): 2950 super().__init__() 2951 self.mode = mode 2952 self.use_size = use_size 2953 self.is_upsample = is_upsample 2954 self.align_corners = align_corners 2955 self.scale = 2.0 if self.is_upsample else 0.5 2956 self.size = 24 if self.is_upsample else 2 2957 if x.dim() == 3: 2958 self.scale_array = [2.3] 2959 self.size_array = [16] 2960 elif x.dim() == 4: 2961 self.scale_array = [2.3, 3.1] 2962 self.size_array = [16, 32] 2963 else: 2964 self.scale_array = [2.3, 3.1, 4.6] 2965 self.size_array = [16, 32, 64] 2966 2967 def forward(self, x): 2968 if self.use_size: 2969 if self.align_corners: 2970 return torch.nn.functional.interpolate( 2971 x, mode=self.mode, size=self.size, align_corners=True 2972 ), torch.nn.functional.interpolate( 2973 x, mode=self.mode, size=self.size_array, align_corners=True 2974 ) 2975 return torch.nn.functional.interpolate( 2976 x, mode=self.mode, size=self.size 2977 ), torch.nn.functional.interpolate( 2978 x, mode=self.mode, size=self.size_array 2979 ) 2980 if self.align_corners: 2981 return torch.nn.functional.interpolate( 2982 x, 2983 mode=self.mode, 2984 scale_factor=self.scale, 2985 recompute_scale_factor=False, 2986 ), torch.nn.functional.interpolate( 2987 x, 2988 mode=self.mode, 2989 scale_factor=self.scale_array, 2990 recompute_scale_factor=False, 2991 ) 2992 return torch.nn.functional.interpolate( 2993 x, 2994 mode=self.mode, 2995 scale_factor=self.scale, 2996 recompute_scale_factor=False, 2997 ), torch.nn.functional.interpolate( 2998 x, 2999 mode=self.mode, 3000 scale_factor=self.scale_array, 3001 recompute_scale_factor=False, 3002 ) 3003 3004 model = MyModel(mode, use_size, is_upsample, align_corners) 3005 self.run_test(model, x, atol=1e-6) 3006 3007 def _interpolate_tests(self, is_upsample): 3008 # - cubic mode is not supported for opsets below 11; 3009 # - linear mode does not match for opsets below 11; 3010 modes = ["nearest", "linear", "bicubic"] 3011 if self.opset_version < 11: 3012 modes = ["nearest"] 3013 x = [ 3014 torch.randn(1, 2, 6, requires_grad=True), 3015 torch.randn(1, 2, 4, 6, requires_grad=True), 3016 torch.randn(1, 2, 4, 4, 6, requires_grad=True), 3017 ] 3018 3019 for mode in modes: 3020 for xi in x: 3021 mode_i = mode 3022 # TODO: enable bicubic downsample when ORT precision loss fixed 3023 if mode == "bicubic" and xi.dim() != 4: 3024 continue 3025 elif mode == "linear": 3026 if xi.dim() == 3: 3027 # TODO : enable when linear mode is implemented for 1d inputs in ORT 3028 continue 3029 elif xi.dim() == 4: 3030 mode_i = "bilinear" 3031 elif xi.dim() == 5: 3032 # TODO : enable when linear mode is implemented for 3d inputs in ORT 3033 mode_i = "trilinear" 3034 continue 3035 self._interpolate(xi, mode_i, True, is_upsample) 3036 # test with align_corners if supported 3037 if mode != "nearest": 3038 self._interpolate(xi, mode_i, True, is_upsample, True) 3039 # the following cases, require dynamic sizes/scales, 3040 # which which is not supported for opset_version < 9 3041 if self.opset_version >= 9: 3042 self._interpolate(xi, mode_i, True, is_upsample) 3043 # test with align_corners if supported 3044 if mode != "nearest": 3045 self._interpolate(xi, mode_i, False, is_upsample, True) 3046 self._interpolate(xi, mode_i, False, is_upsample) 3047 3048 # ONNX export failed on interpolate scripting because dynamic size not supported for opsets below 9. 3049 @skipIfUnsupportedMinOpsetVersion(9) 3050 def test_interpolate_upsample(self): 3051 self._interpolate_tests(True) 3052 3053 @skipIfUnsupportedMaxOpsetVersion(8) 3054 @skipScriptTest() # Scripting supported for opsets > 8. See test_interpolate_upsample 3055 def test_interpolate_upsample_trace(self): 3056 self._interpolate_tests(True) 3057 3058 @skipIfUnsupportedMinOpsetVersion(9) 3059 def test_interpolate_function_substitution(self): 3060 class ScriptModel(torch.jit.ScriptModule): 3061 @torch.jit.script_method 3062 def forward(self, x): 3063 return torch.nn.functional.interpolate( 3064 x, mode="nearest", scale_factor=2.0 3065 ) 3066 3067 class ScriptModule(torch.jit.ScriptModule): 3068 def __init__(self) -> None: 3069 super().__init__() 3070 self.submodule = ScriptModel() 3071 3072 @torch.jit.script_method 3073 def forward(self, input): 3074 return self.submodule(input) 3075 3076 x = torch.randn(1, 2, 4, 4, 6) 3077 self.run_test(ScriptModule(), (x,)) 3078 3079 @torch.jit.script 3080 def script_method(x): 3081 return torch.nn.functional.interpolate(x, mode="nearest", scale_factor=2.0) 3082 3083 class TracingModule(torch.nn.Module): 3084 def forward(self, x): 3085 return script_method(x) 3086 3087 self.run_test(TracingModule(), (x,)) 3088 3089 @skipIfUnsupportedMinOpsetVersion(10) 3090 def test_interpolate_downsample(self): 3091 self._interpolate_tests(False) 3092 3093 @skipIfUnsupportedMinOpsetVersion(11) 3094 def test_interpolate_half_pixel(self): 3095 # testing whether it uses "half_pixel" or "pytorch_half_pixel" 3096 # see https://github.com/onnx/onnx/blob/main/docs/Operators.md#Resize 3097 3098 class MyModel(torch.nn.Module): 3099 def __init__(self, mode, size): 3100 super().__init__() 3101 self.mode = mode 3102 self.size = size 3103 3104 def forward(self, x): 3105 return torch.nn.functional.interpolate( 3106 x, mode=self.mode, size=self.size 3107 ) 3108 3109 modes = ["linear", "bicubic"] 3110 x = [ 3111 torch.randn(1, 2, 6, requires_grad=True), 3112 torch.randn(1, 2, 4, 6, requires_grad=True), 3113 torch.randn(1, 2, 4, 4, 6, requires_grad=True), 3114 ] 3115 for mode in modes: 3116 for xi in x: 3117 mode_i = mode 3118 if mode == "bicubic" and xi.dim() != 4: 3119 continue 3120 elif mode == "linear": 3121 if xi.dim() == 4: 3122 mode_i = "bilinear" 3123 elif xi.dim() == 5: 3124 mode_i = "trilinear" 3125 for i in range(xi.dim() - 2): 3126 size = list(xi.shape[2:]) 3127 size[i] = 1 3128 self.run_test(MyModel(mode_i, size), xi) 3129 3130 @skipIfUnsupportedMinOpsetVersion(11) 3131 def test_interpolate_no_shape(self): 3132 class MyModel(torch.jit.ScriptModule): 3133 @torch.jit.script_method 3134 def forward(self, x, y): 3135 x = torch.add(x, x) 3136 out1 = torch.nn.functional.interpolate( 3137 x, mode="bilinear", size=(16, 16), align_corners=False 3138 ) 3139 out2 = torch.nn.functional.interpolate( 3140 x, mode="nearest", size=(int(y.size(0)), int(y.size(1))) 3141 ) 3142 return out1, out2 3143 3144 x = torch.randn(1, 2, 4, 4, requires_grad=True) 3145 y = torch.randn(16, 16, requires_grad=True) 3146 self.run_test( 3147 MyModel(), 3148 (x, y), 3149 input_names=["x", "y"], 3150 dynamic_axes={"x": [0, 1, 2, 3], "y": [0, 1]}, 3151 ) 3152 self.run_test(MyModel(), (x, y), remained_onnx_input_idx=[0]) 3153 3154 @skipScriptTest() # scripting raises OnnxRuntimeError 3155 def test_interpolate_adaptive_pooling_error(self): 3156 x = torch.randn(1, 2, 6, requires_grad=True) 3157 with self.assertRaises(RuntimeError) as cm: 3158 self._interpolate(x, "area", True, True) 3159 3160 with self.assertRaises(RuntimeError) as cm: 3161 self._interpolate(x, "area", False, True) 3162 3163 def test_groupnorm(self): 3164 model = torch.nn.GroupNorm(3, 6, 0.002) 3165 x = torch.randn(4, 6, 36, 36, 18) 3166 self.run_test(model, x) 3167 3168 model = torch.nn.GroupNorm(1, 6, 0.002) 3169 x = torch.randn(4, 6, 180, 180) 3170 self.run_test(model, x) 3171 3172 model = torch.nn.GroupNorm(6, 6, 0.002) 3173 x = torch.randn(4, 6, 180, 180) 3174 self.run_test(model, x) 3175 3176 def test_groupnorm_noaffine(self): 3177 model = torch.nn.GroupNorm(4, 8, 0.002, affine=False) 3178 x = torch.randn(3, 8, 224, 224) 3179 self.run_test(model, x) 3180 3181 model = torch.nn.GroupNorm(1, 6, 0.002, affine=False) 3182 x = torch.randn(4, 6, 180, 180) 3183 self.run_test(model, x) 3184 3185 model = torch.nn.GroupNorm(6, 6, 0.002, affine=False) 3186 x = torch.randn(4, 6, 180, 180) 3187 self.run_test(model, x) 3188 3189 @skipIfUnsupportedMinOpsetVersion(9) 3190 def test_list_unpack_scripted(self): 3191 class ListUnpack(torch.nn.Module): 3192 def forward(self, x): 3193 a, b = x.shape 3194 return x.new_zeros((a, b)) 3195 3196 x = torch.randn(2, 3) 3197 self.run_test( 3198 torch.jit.script(ListUnpack()), 3199 x, 3200 input_names=["x"], 3201 dynamic_axes={"x": [0, 1]}, 3202 ) 3203 self.run_test(torch.jit.script(ListUnpack()), x, remained_onnx_input_idx=[]) 3204 3205 @skipIfUnsupportedMinOpsetVersion(9) 3206 def test_list_unpack_scripted_runs_without_error_with_constructed_list_as_input( 3207 self, 3208 ): 3209 class PackUnpack(torch.nn.Module): 3210 """Create and unpack a list of tensors. 3211 3212 When scripted, it should produce a graph similar to 3213 3214 ``` 3215 graph(%self : __torch__.PackUnpack, 3216 %a.1 : Tensor, 3217 %b.1 : Tensor): 3218 %packed.1 : Tensor[] = prim::ListConstruct(%a.1, %b.1) 3219 %c.1 : Tensor, %8 : Tensor = prim::ListUnpack(%packed.1) 3220 return (%c.1) 3221 ``` 3222 """ 3223 3224 def forward(self, a, b): 3225 packed = [a, b] 3226 c, _ = packed 3227 return c 3228 3229 self.run_test( 3230 torch.jit.script(PackUnpack()), 3231 (torch.tensor(0), torch.tensor([42])), 3232 remained_onnx_input_idx=[0], 3233 ) 3234 3235 @skipIfUnsupportedMinOpsetVersion(9) 3236 def test_list_unpack_slice_scripted(self): 3237 class ListUnpackSlice(torch.nn.Module): 3238 def forward(self, x): 3239 a, b = x.shape[2:] 3240 return x.new_zeros((a, b)) 3241 3242 x = torch.randn(2, 3, 4, 5) 3243 self.run_test( 3244 torch.jit.script(ListUnpackSlice()), 3245 x, 3246 input_names=["x"], 3247 dynamic_axes={"x": [0, 1, 2, 3]}, 3248 ) 3249 self.run_test( 3250 torch.jit.script(ListUnpackSlice()), x, remained_onnx_input_idx=[] 3251 ) 3252 3253 @skipDtypeChecking 3254 def test_pow(self): 3255 class PowModule(torch.nn.Module): 3256 def forward(self, x, y): 3257 return x.pow(y) 3258 3259 x = torch.randn(2, 3, 4) 3260 y = torch.randn(2, 3, 4) 3261 self.run_test(PowModule(), (x, y)) 3262 3263 x = torch.randint(10, (2, 3, 4)) 3264 y = torch.randint(10, (2, 3, 4)).to(dtype=torch.int32) 3265 self.run_test(PowModule(), (x, y)) 3266 3267 x = torch.randint(10, (2, 3, 4)) 3268 y = torch.randint(10, (2, 3, 4)) 3269 self.run_test(PowModule(), (x, y)) 3270 3271 x = torch.randn(2, 3, 4).to(dtype=torch.float64) 3272 y = torch.randint(10, (2, 3, 4)) 3273 self.run_test(PowModule(), (x, y)) 3274 3275 class PowModule2(torch.nn.Module): 3276 def forward(self, x): 3277 return torch.pow(2, x) 3278 3279 x = torch.randn(1, 10) 3280 self.run_test(PowModule2(), (x,)) 3281 3282 x = torch.randint(10, (2, 3, 4)) 3283 self.run_test(PowModule2(), (x,)) 3284 3285 x = torch.randn(1, 10).to(dtype=torch.float64) 3286 self.run_test(PowModule2(), (x,)) 3287 3288 class PowModule3(torch.nn.Module): 3289 def forward(self, x, y): 3290 return y[torch.pow(2, x)] 3291 3292 x = torch.randint(5, (2, 3, 4)) 3293 y = torch.rand(100) 3294 self.run_test(PowModule3(), (x, y)) 3295 3296 # the arithmeticOps(Add\Sub\Mul\Div\Gemm\Pow\Mod) with low precision include unit8 will be failed in ORT 3297 # add to(dtype=torch.long) to avoid ORT output type does not match expected type. 3298 # will be fixed in ONNX version 14. 3299 @skipIfUnsupportedMaxOpsetVersion(13) 3300 @skipDtypeChecking 3301 def test_arithmeticOps_with_low_precision(self): 3302 class AddModule(torch.nn.Module): 3303 def forward(self, x, y): 3304 return x + y 3305 3306 class SubModule(torch.nn.Module): 3307 def forward(self, x, y): 3308 return x - y 3309 3310 class MulModule(torch.nn.Module): 3311 def forward(self, x, y): 3312 return x * y 3313 3314 class DivModule(torch.nn.Module): 3315 def forward(self, x, y): 3316 return x / y 3317 3318 class PowModule(torch.nn.Module): 3319 def forward(self, x, y): 3320 return x.pow(y) 3321 3322 x = torch.tensor([2, 3, 5], dtype=torch.uint8) 3323 y = torch.tensor([2, 3, 5], dtype=torch.uint8) 3324 z = torch.tensor([1], dtype=torch.uint8) 3325 self.run_test(AddModule(), (x, y)) 3326 self.run_test(SubModule(), (x, y)) 3327 self.run_test(MulModule(), (x, y)) 3328 self.run_test(DivModule(), (x, y)) 3329 self.run_test(PowModule(), (x, z)) 3330 3331 x = torch.tensor([2, 3, 5], dtype=torch.int8) 3332 y = torch.tensor([2, 3, 5], dtype=torch.int8) 3333 z = torch.tensor([1], dtype=torch.int8) 3334 self.run_test(AddModule(), (x, y)) 3335 self.run_test(SubModule(), (x, y)) 3336 self.run_test(MulModule(), (x, y)) 3337 self.run_test(DivModule(), (x, y)) 3338 self.run_test(PowModule(), (x, z)) 3339 3340 x = torch.tensor([2, 3, 5], dtype=torch.int16) 3341 y = torch.tensor([2, 3, 5], dtype=torch.int16) 3342 z = torch.tensor([1], dtype=torch.int16) 3343 self.run_test(AddModule(), (x, y)) 3344 self.run_test(SubModule(), (x, y)) 3345 self.run_test(MulModule(), (x, y)) 3346 self.run_test(DivModule(), (x, y)) 3347 self.run_test(PowModule(), (x, z)) 3348 3349 x = torch.tensor([2, 3, 5], dtype=torch.uint8) 3350 y = torch.tensor([2, 3, 5], dtype=torch.float32) 3351 z = torch.tensor([1], dtype=torch.float64) 3352 self.run_test(AddModule(), (x, y)) 3353 self.run_test(SubModule(), (x, y)) 3354 self.run_test(MulModule(), (x, y)) 3355 self.run_test(DivModule(), (x, y)) 3356 self.run_test(PowModule(), (x, z)) 3357 3358 x = torch.tensor([2, 3, 5], dtype=torch.uint8) 3359 y = torch.tensor([2, 3, 5], dtype=torch.int64) 3360 z = torch.tensor([1], dtype=torch.int32) 3361 self.run_test(AddModule(), (x, y)) 3362 self.run_test(SubModule(), (x, y)) 3363 self.run_test(MulModule(), (x, y)) 3364 self.run_test(DivModule(), (x, y)) 3365 self.run_test(PowModule(), (x, z)) 3366 3367 def test_mul_bool(self): 3368 class MyModel(torch.nn.Module): 3369 def forward(self, x, y): 3370 return torch.mul(x, y) 3371 3372 x_t = torch.tensor([True, False, True, False]) 3373 y_t = torch.tensor([True, True, False, False]) 3374 z_t = torch.tensor([1.0, 2.0, 3.0, 0.0]) 3375 self.run_test(MyModel(), (x_t, y_t)) 3376 self.run_test(MyModel(), (x_t, z_t)) 3377 self.run_test(MyModel(), (z_t, y_t)) 3378 3379 # fmod was added in version 10 3380 @skipIfUnsupportedMinOpsetVersion(10) 3381 @skipIfUnsupportedMaxOpsetVersion(13) 3382 def test_mod_with_low_precision(self): 3383 class ModModule(torch.nn.Module): 3384 def forward(self, x, y): 3385 return torch.fmod(x, y).to(dtype=torch.long) 3386 3387 x = torch.tensor([2, 3, 5], dtype=torch.uint8) 3388 y = torch.tensor([2, 3, 5], dtype=torch.uint8) 3389 self.run_test(ModModule(), (x, y)) 3390 3391 x = torch.tensor([2, 3, 5], dtype=torch.int8) 3392 y = torch.tensor([2, 3, 5], dtype=torch.int8) 3393 self.run_test(ModModule(), (x, y)) 3394 3395 x = torch.tensor([2, 3, 5], dtype=torch.int16) 3396 y = torch.tensor([2, 3, 5], dtype=torch.int16) 3397 self.run_test(ModModule(), (x, y)) 3398 3399 x = torch.tensor([2, 3, 5], dtype=torch.uint8) 3400 y = torch.tensor([2, 3, 5], dtype=torch.int32) 3401 self.run_test(ModModule(), (x, y)) 3402 3403 x = torch.tensor([2, 3, 5], dtype=torch.uint8) 3404 y = torch.tensor([2, 3, 5], dtype=torch.float64) 3405 self.run_test(ModModule(), (x, y)) 3406 3407 @skipIfUnsupportedMinOpsetVersion(9) 3408 def test_empty_constant_shape(self): 3409 class Zeros(torch.nn.Module): 3410 def forward(self, x): 3411 y = torch.zeros(()) 3412 y += x 3413 return y 3414 3415 x = torch.tensor(42.0) 3416 self.run_test(Zeros(), x) 3417 3418 class Ones(torch.nn.Module): 3419 def forward(self, x): 3420 y = torch.ones(()) 3421 y += x 3422 return y 3423 3424 x = torch.tensor(42.0) 3425 self.run_test(Ones(), x) 3426 3427 class Full(torch.nn.Module): 3428 def forward(self, x): 3429 y = torch.full((), 1.0) 3430 y += x 3431 return y 3432 3433 x = torch.tensor(42.0) 3434 self.run_test(Full(), x) 3435 3436 class Empty(torch.nn.Module): 3437 def forward(self, x): 3438 y = torch.empty(()).fill_(0) 3439 y += x 3440 return y 3441 3442 x = torch.tensor(42.0) 3443 self.run_test(Empty(), x) 3444 3445 def test_std(self): 3446 class StandardDeviation(torch.nn.Module): 3447 def forward(self, input): 3448 return torch.std(input, unbiased=False) 3449 3450 x = torch.randn(2, 3, 4) 3451 model = StandardDeviation() 3452 self.run_test(model, x) 3453 3454 class StandardDeviationUnbiased(torch.nn.Module): 3455 def forward(self, input): 3456 return torch.std(input, unbiased=True) 3457 3458 model = StandardDeviationUnbiased() 3459 self.run_test(model, x) 3460 3461 def test_std_along_dims(self): 3462 class StandardDeviation(torch.nn.Module): 3463 def forward(self, input): 3464 return torch.std(input, dim=(0, 1), unbiased=False) 3465 3466 x = torch.randn(2, 3, 4) 3467 model = StandardDeviation() 3468 self.run_test(model, x) 3469 3470 class StandardDeviationUnbiased(torch.nn.Module): 3471 def forward(self, input): 3472 return torch.std(input, dim=(0, 1), unbiased=True) 3473 3474 x = torch.randn(2, 3, 4) 3475 model = StandardDeviationUnbiased() 3476 self.run_test(model, x) 3477 3478 def test_std_keepdim(self): 3479 class StandardDeviation(torch.nn.Module): 3480 def forward(self, input): 3481 return torch.std(input, dim=(0, 1), unbiased=False, keepdim=True) 3482 3483 x = torch.randn(2, 3, 4) 3484 model = StandardDeviation() 3485 self.run_test(model, x) 3486 3487 class StandardDeviationUnbiased(torch.nn.Module): 3488 def forward(self, input): 3489 return torch.std(input, dim=(0, 1), unbiased=True, keepdim=True) 3490 3491 x = torch.randn(2, 3, 4) 3492 model = StandardDeviationUnbiased() 3493 self.run_test(model, x) 3494 3495 def test_std_correction(self): 3496 class StandardDeviation(torch.nn.Module): 3497 def forward(self, input): 3498 return torch.std(input, dim=(0, 1), correction=3, keepdim=True) 3499 3500 x = torch.randn(2, 3, 4) 3501 model = StandardDeviation() 3502 self.run_test(model, x) 3503 3504 def test_var(self): 3505 class Variance(torch.nn.Module): 3506 def forward(self, input): 3507 return torch.var(input, unbiased=False) 3508 3509 x = torch.randn(2, 3, 4) 3510 model = Variance() 3511 self.run_test(model, x) 3512 3513 class VarianceUnbiased(torch.nn.Module): 3514 def forward(self, input): 3515 return torch.var(input, unbiased=True) 3516 3517 model = VarianceUnbiased() 3518 self.run_test(model, x) 3519 3520 class VarianceSqrt(torch.nn.Module): 3521 def forward(self, input): 3522 y = torch.var(input, 1) 3523 return torch.sqrt(y + 1e-8) 3524 3525 x = torch.randn(1, 2, 3, 300, 300) 3526 model = VarianceSqrt() 3527 self.run_test(model, x) 3528 3529 def test_var_along_dims(self): 3530 class Variance(torch.nn.Module): 3531 def forward(self, input): 3532 return torch.var(input, dim=(0, 1), unbiased=False) 3533 3534 x = torch.randn(2, 3, 4) 3535 model = Variance() 3536 self.run_test(model, x) 3537 3538 class VarianceUnbiased(torch.nn.Module): 3539 def forward(self, input): 3540 return torch.var(input, dim=(0, 1), unbiased=True) 3541 3542 x = torch.randn(2, 3, 4) 3543 model = VarianceUnbiased() 3544 self.run_test(model, x) 3545 3546 def test_var_keepdim(self): 3547 class Variance(torch.nn.Module): 3548 def forward(self, input): 3549 return torch.var(input, dim=(0, 1), unbiased=False, keepdim=True) 3550 3551 x = torch.randn(2, 3, 4) 3552 model = Variance() 3553 self.run_test(model, x) 3554 3555 class VarianceUnbiased(torch.nn.Module): 3556 def forward(self, input): 3557 return torch.var(input, dim=(0, 1), unbiased=True, keepdim=True) 3558 3559 x = torch.randn(2, 3, 4) 3560 model = VarianceUnbiased() 3561 self.run_test(model, x) 3562 3563 def test_var_correction(self): 3564 class Variance(torch.nn.Module): 3565 def forward(self, input): 3566 return torch.var(input, dim=(0, 1), correction=3, keepdim=True) 3567 3568 x = torch.randn(2, 3, 4) 3569 model = Variance() 3570 self.run_test(model, x) 3571 3572 def test_var_mean(self): 3573 class Variance(torch.nn.Module): 3574 def forward(self, input): 3575 return torch.var_mean(input, unbiased=False) 3576 3577 x = torch.randn(2, 3, 4) 3578 model = Variance() 3579 self.run_test(model, x) 3580 3581 class VarianceUnbiased(torch.nn.Module): 3582 def forward(self, input): 3583 return torch.var_mean(input, unbiased=True) 3584 3585 model = VarianceUnbiased() 3586 self.run_test(model, x) 3587 3588 def test_var_mean_along_dims(self): 3589 class Variance(torch.nn.Module): 3590 def forward(self, input): 3591 return torch.var_mean(input, dim=(0, 1), unbiased=False) 3592 3593 x = torch.randn(2, 3, 4) 3594 model = Variance() 3595 self.run_test(model, x) 3596 3597 class VarianceUnbiased(torch.nn.Module): 3598 def forward(self, input): 3599 return torch.var_mean(input, dim=(0, 1), unbiased=True) 3600 3601 x = torch.randn(2, 3, 4) 3602 model = VarianceUnbiased() 3603 self.run_test(model, x) 3604 3605 def test_var_mean_mixed_dims(self): 3606 class ReverseDims(torch.nn.Module): 3607 def forward(self, input): 3608 return torch.var_mean(input, dim=(2, 1), unbiased=False) 3609 3610 x = torch.randn(2, 3, 4) 3611 model = ReverseDims() 3612 self.run_test(model, x) 3613 3614 class SkipDims(torch.nn.Module): 3615 def forward(self, input): 3616 return torch.var_mean(input, dim=(0, 2), unbiased=False) 3617 3618 x = torch.randn(2, 3, 4) 3619 model = SkipDims() 3620 self.run_test(model, x) 3621 3622 class NonZeroDims(torch.nn.Module): 3623 def forward(self, input): 3624 return torch.var_mean(input, dim=(1, 2), unbiased=False) 3625 3626 x = torch.randn(2, 3, 4) 3627 model = NonZeroDims() 3628 self.run_test(model, x) 3629 3630 def test_var_mean_keepdim(self): 3631 class Variance(torch.nn.Module): 3632 def forward(self, input): 3633 return torch.var_mean(input, dim=(0, 1), unbiased=False, keepdim=True) 3634 3635 x = torch.randn(2, 3, 4) 3636 model = Variance() 3637 self.run_test(model, x) 3638 3639 class VarianceUnbiased(torch.nn.Module): 3640 def forward(self, input): 3641 return torch.var_mean(input, dim=(0, 1), unbiased=True, keepdim=True) 3642 3643 x = torch.randn(2, 3, 4) 3644 model = VarianceUnbiased() 3645 self.run_test(model, x) 3646 3647 def test_var_mean_correction(self): 3648 class Variance(torch.nn.Module): 3649 def forward(self, input): 3650 return torch.var_mean(input, dim=(0, 1), correction=3, keepdim=True) 3651 3652 x = torch.randn(2, 3, 4) 3653 model = Variance() 3654 self.run_test(model, x) 3655 3656 def test_std_mean(self): 3657 class StandardDeviation(torch.nn.Module): 3658 def forward(self, input): 3659 return torch.std_mean(input, unbiased=False) 3660 3661 x = torch.randn(2, 3, 4) 3662 model = StandardDeviation() 3663 self.run_test(model, x) 3664 3665 class StandardDeviationUnbiased(torch.nn.Module): 3666 def forward(self, input): 3667 return torch.std_mean(input, unbiased=True) 3668 3669 model = StandardDeviationUnbiased() 3670 self.run_test(model, x) 3671 3672 def test_std_mean_along_dims(self): 3673 class StandardDeviation(torch.nn.Module): 3674 def forward(self, input): 3675 return torch.std_mean(input, dim=(0, 1), unbiased=False) 3676 3677 x = torch.randn(2, 3, 4) 3678 model = StandardDeviation() 3679 self.run_test(model, x) 3680 3681 class VarianceUnbiased(torch.nn.Module): 3682 def forward(self, input): 3683 return torch.std_mean(input, dim=(0, 1), unbiased=True) 3684 3685 x = torch.randn(2, 3, 4) 3686 model = VarianceUnbiased() 3687 self.run_test(model, x) 3688 3689 def test_std_mean_keepdim(self): 3690 class StandardDeviation(torch.nn.Module): 3691 def forward(self, input): 3692 return torch.std_mean(input, dim=(0, 1), unbiased=False, keepdim=True) 3693 3694 x = torch.randn(2, 3, 4) 3695 model = StandardDeviation() 3696 self.run_test(model, x) 3697 3698 class StandardDeviationUnbiased(torch.nn.Module): 3699 def forward(self, input): 3700 return torch.std_mean(input, dim=(0, 1), unbiased=True, keepdim=True) 3701 3702 x = torch.randn(2, 3, 4) 3703 model = StandardDeviationUnbiased() 3704 self.run_test(model, x) 3705 3706 def test_std_mean_correction(self): 3707 class StandardDeviation(torch.nn.Module): 3708 def forward(self, input): 3709 return torch.var_mean(input, dim=(0, 1), correction=3, keepdim=True) 3710 3711 x = torch.randn(2, 3, 4) 3712 model = StandardDeviation() 3713 self.run_test(model, x) 3714 3715 def test_bitshift(self): 3716 class BitshiftModel(torch.nn.Module): 3717 def forward(self, input): 3718 return ( 3719 input >> 1, 3720 input << 3, 3721 input >> torch.tensor([1, 2]), 3722 input << 4, 3723 ) 3724 3725 input = torch.arange(24, dtype=torch.int64).reshape(3, 4, 2) 3726 self.run_test(BitshiftModel(), input) 3727 3728 @skipIfUnsupportedMinOpsetVersion(18) 3729 def test_bitwise_and(self): 3730 class BitwiseAndModel(torch.nn.Module): 3731 def forward(self, input, other): 3732 return ( 3733 input & 20, 3734 torch.bitwise_and(input, other), 3735 other & torch.tensor([1, 2], dtype=torch.int32), 3736 ) 3737 3738 input = torch.randint(0, 255, (3, 4, 2), dtype=torch.uint8) 3739 other = torch.randint(-128, 127, (3, 4, 2), dtype=torch.int8) 3740 self.run_test(BitwiseAndModel(), (input, other)) 3741 3742 # uint8 not implemented in ORT for Mul used in 3743 # exporting bitshift for opset_version < 10 3744 @skipIfUnsupportedMinOpsetVersion(11) 3745 def test_bitshift_uint8(self): 3746 class BitshiftModel(torch.nn.Module): 3747 def forward(self, input, input2): 3748 return ( 3749 input >> 1, 3750 input << 3, 3751 input2 >> torch.tensor([1, 2], dtype=torch.uint8), 3752 input2 << 4, 3753 ) 3754 3755 input = torch.arange(24, dtype=torch.uint8).reshape(3, 4, 2) 3756 input2 = torch.arange(24, dtype=torch.uint8).reshape(3, 4, 2) 3757 self.run_test(BitshiftModel(), (input, input2)) 3758 3759 def test_narrow(self): 3760 class NarrowModel(torch.nn.Module): 3761 def forward(self, input): 3762 return torch.narrow(input, 0, 0, 2) 3763 3764 x = torch.randn(3, 3, requires_grad=True) 3765 self.run_test(NarrowModel(), x) 3766 3767 @skipIfUnsupportedMinOpsetVersion(11) 3768 def test_narrow_dynamic(self): 3769 class NarrowModel(torch.nn.Module): 3770 def forward(self, input): 3771 return torch.narrow(input, 0, 0, input.shape[0] - 1) 3772 3773 x = torch.randn(3, 3, requires_grad=True) 3774 self.run_test(NarrowModel(), x) 3775 3776 @skipIfUnsupportedMinOpsetVersion(9) 3777 def test_index_fill(self): 3778 class IndexFillModel(torch.nn.Module): 3779 def forward(self, input): 3780 index = torch.tensor([2, 0]) 3781 return input.index_fill(2, index, -1) 3782 3783 x = torch.randn(3, 4, 5, requires_grad=True) 3784 self.run_test(IndexFillModel(), x) 3785 3786 @skipIfUnsupportedMinOpsetVersion(9) 3787 def test_index_copy(self): 3788 class IndexCopyModel(torch.nn.Module): 3789 def __init__(self, dim): 3790 super().__init__() 3791 self.dim = dim 3792 3793 def forward(self, input): 3794 index = torch.tensor([2, 0]) 3795 source = torch.ones(3, 2, 5) 3796 return input.index_copy(self.dim, index, source) 3797 3798 x = torch.randn(3, 4, 5, requires_grad=True) 3799 for dim in (1, -2): 3800 self.run_test(IndexCopyModel(dim), x) 3801 3802 def test_select(self): 3803 class Select(torch.nn.Module): 3804 def forward(self, x): 3805 return x[:, 1] 3806 3807 x = torch.randn(3, 4) 3808 self.run_test(Select(), x) 3809 3810 def test_select_negative_index(self): 3811 class Select(torch.nn.Module): 3812 def forward(self, x): 3813 return x[:, -1] 3814 3815 x = torch.randn(3, 4) 3816 self.run_test(Select(), x) 3817 3818 def test_index_select_constant_scaler_index(self): 3819 class IndexSelectScalerIndexModel(torch.nn.Module): 3820 def forward(self, x): 3821 index = 2 3822 return torch.index_select(x, 1, torch.tensor(index)) 3823 3824 x = torch.randn(3, 4) 3825 self.run_test(IndexSelectScalerIndexModel(), x) 3826 3827 def test_index_select_scaler_index(self): 3828 class IndexSelectScalerIndexModel(torch.nn.Module): 3829 def __init__(self, index_base): 3830 super().__init__() 3831 self.index_base = torch.tensor(index_base) 3832 3833 def forward(self, x, index_offset): 3834 index = self.index_base + index_offset 3835 return torch.index_select(x, 1, index) 3836 3837 x = torch.randn(3, 4) 3838 offset = 2 3839 index_offset = torch.tensor(offset) 3840 base = 1 3841 self.run_test(IndexSelectScalerIndexModel(base), (x, index_offset)) 3842 3843 def test_take(self): 3844 class TakeModel(torch.nn.Module): 3845 def forward(self, x, y): 3846 return torch.take(x, y) 3847 3848 x = torch.randn(6, 4, 3, 3) 3849 y = torch.tensor([4, 1, 7, 15, 63]) 3850 self.run_test(TakeModel(), (x, y)) 3851 3852 def test_topk(self): 3853 class MyModule(torch.nn.Module): 3854 def forward(self, x): 3855 return torch.topk(x, 3) 3856 3857 x = torch.arange(1.0, 6.0, requires_grad=True) 3858 self.run_test(MyModule(), x) 3859 3860 @skipIfUnsupportedMinOpsetVersion(10) 3861 def test_topk_int32_k(self): 3862 class Model(torch.nn.Module): 3863 def forward(self, x, k): 3864 return torch.topk(x, k) 3865 3866 x = torch.arange(1.0, 6.0) 3867 k = torch.tensor(3, dtype=torch.int32) 3868 self.run_test(Model(), (x, k)) 3869 3870 @skipIfUnsupportedMinOpsetVersion(11) 3871 def test_topk_smallest_unsorted(self): 3872 class MyModule(torch.nn.Module): 3873 def forward(self, x, k): 3874 # When sorted=False, order of elements in the outout tensors 3875 # are not expected to match between PyTorch and ORT 3876 topk_unsorted = torch.topk(x, k, largest=False, sorted=False) 3877 topk_sorted = torch.topk(x, k, largest=False, sorted=True) 3878 return topk_sorted, torch.sort(topk_unsorted.values).values 3879 3880 x = torch.arange(1.0, 6.0, requires_grad=True) 3881 k = torch.tensor(3) 3882 self.run_test(MyModule(), (x, k)) 3883 3884 @skipIfUnsupportedMinOpsetVersion(10) 3885 def test_topk_script(self): 3886 class MyModuleDynamic(torch.jit.ScriptModule): 3887 @torch.jit.script_method 3888 def forward(self, x, k): 3889 return torch.topk(x, k) 3890 3891 x = torch.arange(1.0, 6.0, requires_grad=True) 3892 k = torch.tensor(3) 3893 self.run_test(MyModuleDynamic(), (x, k)) 3894 3895 @skipScriptTest() # Python builtin apply of FunctionMeta object is currently not supported in Torchscript. 3896 @skipIfUnsupportedMinOpsetVersion(11) # Clip op min is an input since opset 11. 3897 def test_auto_grad(self): 3898 class MyClip(torch.autograd.Function): 3899 @staticmethod 3900 def forward(ctx, input, scalar): 3901 ctx.save_for_backward(input) 3902 return input.clamp(min=scalar) 3903 3904 class MyRelu(torch.autograd.Function): 3905 @staticmethod 3906 def forward(ctx, input): 3907 ctx.save_for_backward(input) 3908 return input.clamp(min=0) 3909 3910 def symbolic_python_op(g, *args, **kwargs): 3911 name = kwargs["name"] 3912 if name == "MyClip": 3913 return g.op("Clip", args[0], args[1]) 3914 elif name == "MyRelu": 3915 return g.op("Relu", args[0]) 3916 else: 3917 # TODO(justinchuby): Remove reference to internal names in symbolic_helper 3918 return torch.onnx.symbolic_helper._unimplemented( 3919 "prim::PythonOp", "unknown node kind: " + name 3920 ) 3921 3922 torch.onnx.register_custom_op_symbolic("prim::PythonOp", symbolic_python_op, 1) 3923 self.addCleanup(torch.onnx.unregister_custom_op_symbolic, "prim::PythonOp", 1) 3924 3925 class MyClipModule(torch.nn.Module): 3926 def forward(self, x, min): 3927 return MyClip.apply(x, min) 3928 3929 x = torch.randn(3, 3) 3930 min = torch.tensor([0.0]) 3931 self.run_test(MyClipModule(), (x, min)) 3932 3933 class MyReluModule(torch.nn.Module): 3934 def forward(self, x): 3935 return MyRelu.apply(x) 3936 3937 x = torch.randn(3, 3) 3938 self.run_test(MyReluModule(), x) 3939 3940 def test_clip_int(self): 3941 class MyClipInt(torch.nn.Module): 3942 def forward(self, x): 3943 return torch.clamp(x, 0, 1) 3944 3945 self.run_test(MyClipInt(), torch.randn(3, 3).to(torch.int64)) 3946 3947 def test_relu_int(self): 3948 self.run_test(torch.nn.ReLU(), torch.randn(3, 3).to(torch.int32)) 3949 3950 def test_pad_int(self): 3951 class MyPadInt(torch.nn.Module): 3952 def forward(self, x): 3953 return torch.nn.functional.pad(x, (1, 1)) 3954 3955 self.run_test(MyPadInt(), torch.randn(3, 3).to(torch.int32)) 3956 3957 def test_min_int(self): 3958 class MyMinInt(torch.nn.Module): 3959 def forward(self, x): 3960 return torch.min(x, x + 1) 3961 3962 self.run_test(MyMinInt(), torch.randn(3, 3).to(torch.int32)) 3963 3964 def test_max_int(self): 3965 class MyMaxnInt(torch.nn.Module): 3966 def forward(self, x): 3967 return torch.max(x, x + 1) 3968 3969 self.run_test(MyMaxnInt(), torch.randn(3, 3).to(torch.int32)) 3970 3971 @skipIfUnsupportedOpsetVersion([7]) 3972 def test_normalize(self): 3973 class Model(torch.nn.Module): 3974 def forward(self, x): 3975 return torch.nn.functional.normalize(x) 3976 3977 x = torch.randn(3, 3) 3978 self.run_test(Model(), x) 3979 3980 def test_norm_with_dtype(self): 3981 class Model(torch.nn.Module): 3982 def forward(self, x): 3983 # TODO(bowbao): There is a slight gap in today's test infrastructure 3984 # to directly test aten ops. OpInfo `torch.norm`` in `common_methods_invocations.py` 3985 # will not decompose to below aten op. 3986 return torch.ops.aten.norm( 3987 x, p=2, dim=[1], keepdim=True, dtype=torch.float64 3988 ) 3989 3990 x = torch.randn(3, 3) 3991 self.run_test(Model(), x) 3992 3993 def test_layer_norm(self): 3994 # As layer_norm works on the last D dimension, please keep 3995 # this test case at least three dimension to prevent the 3996 # situation of axis=2 mapping to the same axis as axis=-2 3997 for elementwise_affine in (True, False): 3998 for bias in (True, False): 3999 model = torch.nn.LayerNorm( 4000 [10, 10, 10], elementwise_affine=elementwise_affine, bias=bias 4001 ) 4002 x = torch.randn(20, 5, 10, 10, 10) 4003 self.run_test(model, x) 4004 4005 def test_batchnorm1d(self): 4006 x = torch.randn(10, 10) 4007 model = torch.nn.BatchNorm1d(10, affine=True) 4008 self.run_test(model, x) 4009 4010 x = torch.randn(10, 10, 128) 4011 self.run_test(model, x) 4012 4013 def test_batchnorm1d_noaffine(self): 4014 x = torch.randn(10, 10) 4015 model = torch.nn.BatchNorm1d(10, affine=False) 4016 self.run_test(model, x) 4017 4018 x = torch.randn(10, 10, 128) 4019 self.run_test(model, x) 4020 4021 def test_batchnorm1d_norunningstats(self): 4022 x = torch.randn(10, 10) 4023 model = torch.nn.BatchNorm1d(10, track_running_stats=False) 4024 self.run_test(model, x) 4025 4026 x = torch.randn(10, 10, 128) 4027 self.run_test(model, x) 4028 4029 def test_batchnorm2d(self): 4030 x = torch.randn(10, 3, 128, 128) 4031 model = torch.nn.BatchNorm2d(3, affine=True) 4032 self.run_test(model, x) 4033 4034 def test_batchnorm2d_noaffine(self): 4035 x = torch.randn(10, 3, 128, 128) 4036 model = torch.nn.BatchNorm2d(3, affine=False) 4037 self.run_test(model, x) 4038 4039 def test_batchnorm2d_norunningstats(self): 4040 x = torch.randn(10, 3, 128, 128) 4041 model = torch.nn.BatchNorm2d(3, track_running_stats=False) 4042 self.run_test(model, x) 4043 4044 def test_batchnorm3d(self): 4045 x = torch.randn(10, 3, 64, 64, 64) 4046 model = torch.nn.BatchNorm3d(3, affine=True) 4047 self.run_test(model, x) 4048 4049 def test_batchnorm3d_noaffine(self): 4050 x = torch.randn(10, 3, 64, 64, 64) 4051 model = torch.nn.BatchNorm3d(3, affine=False) 4052 self.run_test(model, x) 4053 4054 @skipIfUnsupportedMinOpsetVersion( 4055 9 4056 ) # Because ConstantOfShape op is not supported for opset < 9 4057 def test_instancenorm1d_runningstats(self): 4058 x = torch.randn(10, 5, 128) 4059 model = torch.nn.InstanceNorm1d(5, affine=True, track_running_stats=True) 4060 self.run_test(model, x) 4061 4062 model = torch.nn.InstanceNorm1d(5, affine=False, track_running_stats=True) 4063 self.run_test(model, x) 4064 4065 def test_instancenorm1d_norunningstats(self): 4066 x = torch.randn(10, 5, 128) 4067 model = torch.nn.InstanceNorm1d(5, affine=True, track_running_stats=False) 4068 self.run_test(model, x) 4069 4070 model = torch.nn.InstanceNorm1d(5, affine=False, track_running_stats=False) 4071 self.run_test(model, x) 4072 4073 @skipIfUnsupportedMinOpsetVersion( 4074 9 4075 ) # Because ConstantOfShape op is not supported for opset < 9 4076 def test_instancenorm2d_runningstats(self): 4077 x = torch.randn(10, 3, 128, 128) 4078 model = torch.nn.InstanceNorm2d(3, affine=True, track_running_stats=True) 4079 self.run_test(model, x) 4080 4081 model = torch.nn.InstanceNorm2d(3, affine=False, track_running_stats=True) 4082 self.run_test(model, x) 4083 4084 def test_instancenorm2d_norunningstats(self): 4085 x = torch.randn(10, 3, 128, 128) 4086 model = torch.nn.InstanceNorm2d(3, affine=True, track_running_stats=False) 4087 self.run_test(model, x) 4088 4089 model = torch.nn.InstanceNorm2d(3, affine=False, track_running_stats=False) 4090 self.run_test(model, x) 4091 4092 @skipIfUnsupportedMinOpsetVersion( 4093 9 4094 ) # Because ConstantOfShape op is not supported for opset < 9 4095 def test_instancenorm3d_runningstats(self): 4096 x = torch.randn(10, 3, 64, 64, 64) 4097 model = torch.nn.InstanceNorm3d(3, affine=True, track_running_stats=True) 4098 self.run_test(model, x) 4099 4100 model = torch.nn.InstanceNorm3d(3, affine=False, track_running_stats=True) 4101 self.run_test(model, x) 4102 4103 def test_instancenorm3d_norunningstats(self): 4104 x = torch.randn(10, 3, 64, 64, 64) 4105 model = torch.nn.InstanceNorm3d(3, affine=True, track_running_stats=False) 4106 self.run_test(model, x) 4107 4108 model = torch.nn.InstanceNorm3d(3, affine=False, track_running_stats=False) 4109 self.run_test(model, x) 4110 4111 @skipIfUnsupportedMinOpsetVersion(9) 4112 def test_scatter_with_scalar(self): 4113 class ScatterModel(torch.nn.Module): 4114 def forward(self, input, indices): 4115 values = 1.0 4116 return input.scatter(1, indices, values) 4117 4118 input = torch.tensor( 4119 [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], dtype=torch.float64 4120 ) 4121 indices = torch.tensor([[1, 0], [0, 1], [0, 1]], dtype=torch.int64) 4122 self.run_test(ScatterModel(), input_args=(input, indices)) 4123 4124 @skipIfUnsupportedMinOpsetVersion(9) 4125 def test_scatter_with_scalar_different_types(self): 4126 # Tests the case when scalar src (updates values) type is different 4127 # from self type. Happens only with scalar src - PyTorch does not 4128 # allow this when src is a tensor. 4129 class ScatterModel(torch.nn.Module): 4130 def forward(self, input, indices): 4131 values = 1.0 4132 return input.scatter(1, indices, values) 4133 4134 input = torch.tensor( 4135 [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], dtype=torch.float32 4136 ) 4137 indices = torch.tensor([[1, 0], [0, 1], [0, 1]], dtype=torch.int64) 4138 self.run_test(ScatterModel(), input_args=(input, indices)) 4139 4140 @skipIfUnsupportedMinOpsetVersion(9) 4141 def test_scatter(self): 4142 class ScatterModel(torch.nn.Module): 4143 def forward(self, input, indices, values): 4144 return input.scatter(1, indices, values) 4145 4146 input = torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]) 4147 indices = torch.tensor([[1, 0], [0, 1], [0, 1]], dtype=torch.int64) 4148 values = torch.tensor([[1.0, 1.1], [2.0, 2.1], [3.0, 3.1]]) 4149 self.run_test(ScatterModel(), input_args=(input, indices, values)) 4150 4151 input = torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]) 4152 indices = torch.tensor([[1, 0], [0, 2], [0, 1]], dtype=torch.int64) 4153 values = torch.tensor([[1.0, 1.1], [2.0, 2.1], [3.0, 3.1]]) 4154 self.run_test(ScatterModel(), (input, indices, values)) 4155 4156 input = torch.zeros(3, 4, 5, 6) 4157 indices = torch.tensor([[1, 0], [0, 2], [0, 1]], dtype=torch.int64) 4158 indices = indices.view(3, 2, 1, 1).expand(3, 2, 5, 6) 4159 values = torch.arange(3 * 2 * 5 * 6, dtype=torch.float32).view(3, 2, 5, 6) 4160 self.run_test(ScatterModel(), (input, indices, values)) 4161 4162 input = torch.zeros(3, 4, 2) 4163 indices = torch.tensor([[[1, 0], [0, 2]], [[1, 1], [0, 1]], [[2, 1], [2, 2]]]) 4164 values = torch.arange(3 * 2 * 2, dtype=torch.float32).view(3, 2, 2) 4165 self.run_test(ScatterModel(), (input, indices, values)) 4166 4167 @skipIfUnsupportedMinOpsetVersion(9) 4168 def test_scatter_add(self): 4169 class ScatterModel(torch.nn.Module): 4170 def forward(self, input, indices, values): 4171 return input.scatter_add(1, indices, values) 4172 4173 input = torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]) 4174 indices = torch.tensor([[1, 0], [0, 1], [0, 1]], dtype=torch.int64) 4175 values = torch.tensor([[1.0, 1.1], [2.0, 2.1], [3.0, 3.1]]) 4176 self.run_test(ScatterModel(), input_args=(input, indices, values)) 4177 4178 @torch.jit.script 4179 def scatter_sum(src: Tensor, index: Tensor): 4180 size = src.size() 4181 out = torch.zeros(size, dtype=src.dtype) 4182 return out.scatter_add_(1, index, src) 4183 4184 class ScatterModel(torch.nn.Module): 4185 def forward(self, src, index): 4186 return scatter_sum(src, index) 4187 4188 src = torch.rand(3, 2) 4189 index = torch.tensor([[0, 1], [0, 1], [0, 1]], dtype=torch.int64) 4190 self.run_test(ScatterModel(), (src, index)) 4191 4192 @skipIfUnsupportedMinOpsetVersion(16) 4193 def test_scatter_add_index_not_unique(self): 4194 class ScatterModel(torch.nn.Module): 4195 def forward(self, input, indices, values): 4196 return input.scatter_add(1, indices, values) 4197 4198 input = torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]) 4199 indices = torch.tensor([[0, 0], [1, 1], [2, 2]], dtype=torch.int64) 4200 values = torch.tensor([[1.0, 1.1], [2.0, 2.1], [3.0, 3.1]]) 4201 self.run_test(ScatterModel(), input_args=(input, indices, values)) 4202 4203 @torch.jit.script 4204 def scatter_sum(src: Tensor, index: Tensor): 4205 size = src.size() 4206 out = torch.zeros(size, dtype=src.dtype) 4207 return out.scatter_add_(1, index, src) 4208 4209 class ScatterModel(torch.nn.Module): 4210 def forward(self, src, index): 4211 return scatter_sum(src, index) 4212 4213 src = torch.rand(3, 2) 4214 index = torch.tensor([[0, 0], [1, 1], [0, 1]], dtype=torch.int64) 4215 self.run_test(ScatterModel(), (src, index)) 4216 4217 @skipIfUnsupportedMinOpsetVersion(16) 4218 def test_scatter_add_different_size_index_src(self): 4219 class ScatterModel(torch.nn.Module): 4220 def forward(self, input, indices, src): 4221 return input.scatter_add(0, indices, src) 4222 4223 src = torch.ones((2, 5)) 4224 input = torch.zeros(3, 5, dtype=src.dtype) 4225 indices = torch.tensor([[0, 1, 2, 0, 0]]) 4226 self.run_test(ScatterModel(), input_args=(input, indices, src)) 4227 4228 @common_utils.parametrize( 4229 "src, indices", 4230 [ 4231 common_utils.subtest( 4232 [torch.ones((1, 5)), torch.tensor([[0, 1, 2, 0, 0]])], 4233 name="src_indices_dynamic_combination1", 4234 ), 4235 common_utils.subtest( 4236 [torch.ones((2, 5)), torch.tensor([[0, 1, 2, 0, 0], [1, 0, 2, 1, 2]])], 4237 name="src_indices_dynamic_combination2", 4238 ), 4239 common_utils.subtest( 4240 [torch.ones((3, 5)), torch.tensor([[0, 1, 2, 0, 0], [1, 0, 2, 1, 2]])], 4241 name="src_indices_dynamic_combination3", 4242 ), 4243 common_utils.subtest( 4244 [torch.ones((3, 5)), torch.tensor([[0, 1, 2, 0], [1, 0, 2, 1]])], 4245 name="src_indices_dynamic_combination4", 4246 ), 4247 ], 4248 ) 4249 @skipIfUnsupportedMinOpsetVersion(16) 4250 def test_scatter_add_dynamic_index(self, src, indices): 4251 class ScatterModel(torch.nn.Module): 4252 def forward(self, input, indices, src): 4253 return input.scatter_add(0, indices, src) 4254 4255 input = torch.zeros(3, 5, dtype=src.dtype) 4256 self.run_test( 4257 ScatterModel(), 4258 input_args=(input, indices, src), 4259 input_names=["input", "indices", "src"], 4260 dynamic_axes={"indices": {0: "a", 1: "b"}, "src": {0: "c", 1: "d"}}, 4261 ) 4262 4263 @skipIfUnsupportedMinOpsetVersion(16) 4264 def test_scatter_reduce(self): 4265 class Model(torch.nn.Module): 4266 def __init__(self) -> None: 4267 super().__init__() 4268 4269 def forward(self, x, index, input): 4270 y_max = input.scatter_reduce(0, index, x, reduce="amax") 4271 y_sum = input.scatter_reduce(0, index, x, reduce="sum") 4272 y_min = input.scatter_reduce(0, index, x, reduce="amin") 4273 y_mul = input.scatter_reduce(0, index, x, reduce="prod") 4274 return y_max, y_sum, y_min, y_mul 4275 4276 model = Model() 4277 model.eval() 4278 4279 src = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) 4280 index = torch.tensor([0, 1, 0, 1, 2, 1]) 4281 input = torch.tensor([1.0, 2.0, 3.0, 8.0]) 4282 4283 self.run_test(model, (src, index, input)) 4284 4285 @skipIfUnsupportedMinOpsetVersion(16) 4286 def test_scatter_reduce_self_rank_zero(self): 4287 class Model(torch.nn.Module): 4288 def __init__(self) -> None: 4289 super().__init__() 4290 4291 def forward(self, x, index, input): 4292 y_max = input.scatter_reduce(0, index, x, reduce="amax") 4293 y_sum = input.scatter_reduce(0, index, x, reduce="sum") 4294 y_min = input.scatter_reduce(0, index, x, reduce="amin") 4295 y_mul = input.scatter_reduce(0, index, x, reduce="prod") 4296 return y_max, y_sum, y_min, y_mul 4297 4298 model = Model() 4299 model.eval() 4300 4301 empty_tensor = torch.tensor([]) 4302 empty_idx = torch.tensor([], dtype=torch.int64) 4303 4304 self.run_test(model, (empty_tensor, empty_idx, empty_tensor)) 4305 4306 @skipIfUnsupportedMinOpsetVersion(9) 4307 def test_bucketize(self): 4308 class BucketModel(torch.nn.Module): 4309 def forward(self, input, boundaries): 4310 return torch.bucketize(input, boundaries), torch.bucketize( 4311 input, boundaries, right=True 4312 ) 4313 4314 input = torch.tensor([[2, 5, 10], [6, 8, 3]]) 4315 boundaries = torch.tensor([1, 5, 7, 8, 10]) 4316 self.run_test(BucketModel(), (input, boundaries)) 4317 4318 @skipIfUnsupportedMinOpsetVersion(9) 4319 def test_one_hot(self): 4320 class OneHot(torch.nn.Module): 4321 def __init__(self, num_classes): 4322 super().__init__() 4323 self.num_classes = num_classes 4324 4325 def forward(self, x): 4326 return torch.nn.functional.one_hot(x, self.num_classes) 4327 4328 x = torch.arange(10) 4329 self.run_test(OneHot(15), (x)) 4330 4331 class OneHot(torch.nn.Module): 4332 def forward(self, x, num_classes): 4333 num_classes = num_classes.to(torch.int32) 4334 return torch.nn.functional.one_hot(x, num_classes[0]) 4335 4336 x = torch.arange(10) 4337 num_classes = 15 * torch.ones(1) 4338 self.run_test(OneHot(), (x, num_classes)) 4339 4340 @skipIfUnsupportedMinOpsetVersion(9) 4341 def test_gather(self): 4342 class GatherModel(torch.nn.Module): 4343 def forward(self, input, indices): 4344 return input.gather(1, indices) 4345 4346 input = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]) 4347 indices = torch.tensor([[1, 0], [0, 1], [0, 1]], dtype=torch.int64) 4348 self.run_test(GatherModel(), input_args=(input, indices)) 4349 4350 @skipScriptTest() # Scripting error: Cannot instantiate nn module 4351 def test_gather_constant_fold(self): 4352 class GatherModule(torch.nn.Module): 4353 def __init__(self) -> None: 4354 super().__init__() 4355 self.weight = torch.nn.Buffer(torch.ones(5)) 4356 # torch.nn.Embedding is converted to ONNX::Gather. 4357 # Constant folding will be triggerred for constant inputs. 4358 # This pattern is common for constant mask inputs in transformer models. 4359 self.embed = torch.nn.Embedding(8, 3) 4360 4361 def forward(self, x): 4362 # shape is of rank 0 4363 shape = self.weight.shape[0] 4364 m = 5 - shape 4365 y = torch.ones(1, 4, dtype=torch.long) 4366 return x.clamp(min=m), self.embed(y) 4367 4368 x = torch.randn(1) 4369 self.run_test(GatherModule(), (x,)) 4370 4371 class GatherModule(torch.nn.Module): 4372 def __init__(self) -> None: 4373 super().__init__() 4374 self.weight = torch.nn.Buffer(torch.ones(2)) 4375 4376 def forward(self, x): 4377 # shape is of rank 0 4378 shape = self.weight.shape[0] 4379 pad = [1, shape, shape, shape] 4380 zero_pad = torch.nn.ZeroPad2d(pad) 4381 return zero_pad(x) 4382 4383 x = torch.randn(1, 3, 2) 4384 self.run_test(GatherModule(), (x,)) 4385 4386 class GatherModule(torch.nn.Module): 4387 def __init__(self) -> None: 4388 super().__init__() 4389 self.rb = torch.nn.Buffer(torch.randn(1, 1, 3, 1, 1)) 4390 4391 def forward(self, x): 4392 x += self.rb[0] 4393 return x 4394 4395 x = torch.randn(1, 3, 224, 224) 4396 self.run_test( 4397 GatherModule(), 4398 (x,), 4399 dynamic_axes={ 4400 "input": {0: "batch", 2: "height", 3: "width"}, 4401 "output": {0: "batch", 1: "class", 2: "height", 3: "width"}, 4402 }, 4403 input_names=["input"], 4404 output_names=["output"], 4405 ) 4406 4407 @skipIfUnsupportedOpsetVersion([13]) 4408 @skipIfUnsupportedMinOpsetVersion(9) 4409 def test_expand(self): 4410 class ExpandModel(torch.nn.Module): 4411 def forward(self, input): 4412 return input.expand(2, 3, -1) 4413 4414 input = torch.randn(2, 1, 4) 4415 self.run_test(ExpandModel(), input_args=(input)) 4416 4417 class ExpandInferDimModel(torch.nn.Module): 4418 def forward(self, input): 4419 return input.expand(-1, input.size(0)) 4420 4421 input = torch.randn(3, 1) 4422 self.run_test(ExpandInferDimModel(), input_args=(input)) 4423 4424 class ExpandTensorSizeModel(torch.nn.Module): 4425 def forward(self, input, size): 4426 return input.expand(size) 4427 4428 input = torch.randn( 4429 3, 4430 ) 4431 size = torch.tensor(-1) 4432 self.run_test(ExpandTensorSizeModel(), input_args=(input, size)) 4433 4434 @skipIfUnsupportedMinOpsetVersion(11) # index_put is supported in opsets >= 11 4435 def test_dynamic_expand_as(self): 4436 class Model(torch.nn.Module): 4437 def forward(self, x): 4438 x[:, x.size(0) :] = 0 4439 return x 4440 4441 x = torch.ones(2, 5) 4442 x2 = torch.randn(3, 4) 4443 self.run_test( 4444 Model(), 4445 (x,), 4446 input_names=["x"], 4447 dynamic_axes={"x": [0, 1]}, 4448 additional_test_inputs=[x2], 4449 ) 4450 4451 class Model(torch.nn.Module): 4452 def forward(self, x): 4453 x[:, x.size(0) :] = torch.tensor([1, 2, 3]) 4454 return x 4455 4456 x = torch.ones(2, 5, 3) 4457 x2 = torch.randn(3, 4, 3) 4458 self.run_test( 4459 Model(), 4460 (x,), 4461 input_names=["x"], 4462 dynamic_axes={"x": [0, 1, 2]}, 4463 additional_test_inputs=[x2], 4464 ) 4465 4466 class Model(torch.nn.Module): 4467 def forward(self, x): 4468 aa = torch.tensor([[0], [1], [2]]) 4469 return aa.expand_as(x) 4470 4471 x = torch.ones(3, 2) 4472 x2 = torch.randn(3, 5) 4473 self.run_test( 4474 Model(), 4475 (x,), 4476 input_names=["x"], 4477 dynamic_axes={"x": [0, 1]}, 4478 additional_test_inputs=[x2], 4479 ) 4480 4481 def test_multinomial(self): 4482 class Multinomial(torch.nn.Module): 4483 def forward(self, weight): 4484 return torch.multinomial(weight, 3, replacement=True) 4485 4486 class MultinomialNoReplacement(torch.nn.Module): 4487 def forward(self, weight): 4488 return torch.multinomial(weight, 1) 4489 4490 weight = torch.tensor([[0, 10, 0, 0], [0, 0, 100, 0]], dtype=torch.float) 4491 self.run_test(Multinomial(), (weight,)) 4492 self.run_test(MultinomialNoReplacement(), (weight,)) 4493 4494 def _test_reduced_ops(self, op): 4495 class ReducedOpModule(torch.nn.Module): 4496 def forward(self, input): 4497 return op(input, dim=-1) 4498 4499 if op != torch.mean: # torch.mean only supports float types 4500 x = torch.randint(10, (4, 4), dtype=torch.uint8) 4501 self.run_test(ReducedOpModule(), x) 4502 4503 x = torch.randint(10, (4, 4), dtype=torch.int8) 4504 self.run_test(ReducedOpModule(), x) 4505 4506 x = torch.randint(10, (4, 4), dtype=torch.int16) 4507 self.run_test(ReducedOpModule(), x) 4508 4509 x = torch.randint(10, (4, 4), dtype=torch.int32) 4510 self.run_test(ReducedOpModule(), x) 4511 4512 x = torch.randint(10, (4, 4), dtype=torch.int64) 4513 self.run_test(ReducedOpModule(), x) 4514 4515 # torch.mean only supports float types 4516 # ORT does not support double ReduceProd for double 4517 if op != torch.prod and op != torch.mean: 4518 x = torch.randn(4, 5, dtype=torch.double) 4519 self.run_test(ReducedOpModule(), x) 4520 4521 if op != torch.prod: # torch.prod not implemented for Half 4522 x = torch.randn(4, 4, dtype=torch.half) 4523 self.run_test(ReducedOpModule(), x) 4524 4525 x = torch.randn(4, 5, dtype=torch.float) 4526 self.run_test(ReducedOpModule(), x) 4527 4528 def test_reduced_sum(self): 4529 return self._test_reduced_ops(op=torch.sum) 4530 4531 def test_reduced_mean(self): 4532 return self._test_reduced_ops(op=torch.mean) 4533 4534 def test_reduced_prod(self): 4535 return self._test_reduced_ops(op=torch.prod) 4536 4537 def test_reduced_sum_dtypes(self): 4538 class NoDimModel(torch.nn.Module): 4539 def forward(self, input): 4540 return input.sum(dtype=torch.float) 4541 4542 class DimModel(torch.nn.Module): 4543 def forward(self, input): 4544 return input.sum(dim=-1, dtype=torch.float) 4545 4546 input = torch.randn((4, 4), dtype=torch.half) 4547 self.run_test(NoDimModel(), input) 4548 self.run_test(DimModel(), input) 4549 4550 def test_reduced_min_max(self): 4551 class ReducedMinMaxModule(torch.nn.Module): 4552 def forward(self, input): 4553 return torch.min(input, dim=-1)[0], torch.max(input, dim=0)[0] 4554 4555 x = torch.randint(10, (4, 4), dtype=torch.int32) 4556 self.run_test(ReducedMinMaxModule(), x) 4557 4558 x = torch.randint(10, (4, 4), dtype=torch.int64) 4559 self.run_test(ReducedMinMaxModule(), x) 4560 4561 x = torch.randn(4, 5, dtype=torch.float) 4562 self.run_test(ReducedMinMaxModule(), x) 4563 4564 def test_reduce_log_sum_exp(self): 4565 class ReduceLogSumExpModel(torch.nn.Module): 4566 def forward(self, input): 4567 a = torch.logsumexp(input, dim=0) 4568 b = torch.logsumexp(input, dim=(0, 1)) 4569 return a + b 4570 4571 x = torch.randn(4, 4, requires_grad=True) 4572 self.run_test(ReduceLogSumExpModel(), x) 4573 4574 def test_softmax(self): 4575 for i in range(-4, 3): 4576 model = torch.nn.Softmax(dim=i) 4577 input = torch.randn(3, 4, 5, 6) 4578 self.run_test(model, input) 4579 4580 class SoftmaxUnknownRank(torch.nn.Module): 4581 def __init__(self, i): 4582 super().__init__() 4583 self.softmax = torch.nn.Softmax(dim=i) 4584 4585 def forward(self, x): 4586 return self.softmax(x.reshape(3, 4, 5, 6)) 4587 4588 model = torch.jit.script(SoftmaxUnknownRank(i)) 4589 self.run_test(model, input) 4590 4591 def test_softmax_large_values(self): 4592 input = torch.tensor( 4593 [[-1e12, -1e12, -1e12], [1e12, 0.0, -5.0], [3.0, 4.0, 5.0]] 4594 ) 4595 for i in range(-2, 1): 4596 model = torch.nn.Softmax(dim=i) 4597 self.run_test(model, input) 4598 4599 class SoftmaxUnknownRank(torch.nn.Module): 4600 def __init__(self, i): 4601 super().__init__() 4602 self.softmax = torch.nn.Softmax(dim=i) 4603 4604 def forward(self, x): 4605 return self.softmax(x.reshape(3, 3)) 4606 4607 model = torch.jit.script(SoftmaxUnknownRank(i)) 4608 self.run_test(model, input) 4609 4610 def test_logsoftmax(self): 4611 for i in range(7)[2:]: 4612 model = torch.nn.LogSoftmax(dim=i - 1) 4613 dims = [2] * (i - 2) + [3, 4] 4614 input = torch.ones(*dims, requires_grad=True) 4615 self.run_test(model, input) 4616 4617 def test_logsoftmax_dim(self): 4618 for i in range(-4, 3): 4619 model = torch.nn.LogSoftmax(dim=i) 4620 input = torch.randn(3, 4, 5, 6) 4621 self.run_test(model, input) 4622 4623 def test_logsoftmax_dtype(self): 4624 class Model(torch.nn.Module): 4625 def forward(self, x): 4626 return torch.nn.functional.log_softmax(x, dim=1, dtype=torch.float64) 4627 4628 x = torch.randn(3, 4, 5, requires_grad=True) 4629 self.run_test(Model(), x) 4630 4631 def test_softplus(self): 4632 class BetaOneModel(torch.nn.Module): 4633 def forward(self, x): 4634 return torch.nn.functional.softplus(x) 4635 4636 x = torch.randn(3, 4, 5, requires_grad=True) 4637 self.run_test(BetaOneModel(), x) 4638 4639 class BetaModel(torch.nn.Module): 4640 def forward(self, x): 4641 return torch.nn.functional.softplus(x, beta=2) 4642 4643 x = torch.randn(3, 4, 5, requires_grad=True) 4644 self.run_test(BetaModel(), x) 4645 4646 class BetaFloatModel(torch.nn.Module): 4647 def forward(self, x): 4648 return torch.nn.functional.softplus(x, beta=1.7) 4649 4650 x = torch.randn(3, 4, 5, requires_grad=True) 4651 self.run_test(BetaFloatModel(), x) 4652 4653 @skipIfUnsupportedMinOpsetVersion(9) 4654 def test_lstm_no_hidden(self): 4655 class LSTMModel(torch.nn.Module): 4656 def __init__(self) -> None: 4657 super().__init__() 4658 self.rnn = torch.nn.LSTM(input_size=16, hidden_size=16) 4659 4660 def forward(self, x): 4661 return self.rnn(x) 4662 4663 input = torch.randn((10, 16, 16)) 4664 self.run_test(LSTMModel(), (input,)) 4665 4666 @skipIfUnsupportedMinOpsetVersion(9) 4667 def test_lstm_proj_no_hidden(self): 4668 class LSTMModel(torch.nn.Module): 4669 def __init__(self) -> None: 4670 super().__init__() 4671 self.rnn = torch.nn.LSTM(input_size=16, hidden_size=16, proj_size=8) 4672 4673 def forward(self, x): 4674 return self.rnn(x) 4675 4676 input = torch.randn((10, 16, 16)) 4677 with self.assertRaises(RuntimeError): 4678 self.run_test(LSTMModel(), (input,)) 4679 4680 @skipIfUnsupportedMinOpsetVersion(9) 4681 def test_lstm(self): 4682 class LSTMModel(torch.nn.Module): 4683 def __init__(self) -> None: 4684 super().__init__() 4685 self.rnn = torch.nn.LSTM( 4686 RNN_INPUT_SIZE, RNN_HIDDEN_SIZE, 1, bidirectional=False 4687 ) 4688 4689 def forward(self, x, h0, c0): 4690 return self.rnn(x, (h0, c0)) 4691 4692 input = torch.randn(RNN_SEQUENCE_LENGTH, BATCH_SIZE, RNN_INPUT_SIZE) 4693 h0 = torch.randn(1, BATCH_SIZE, RNN_HIDDEN_SIZE) 4694 c0 = torch.randn(1, BATCH_SIZE, RNN_HIDDEN_SIZE) 4695 self.run_test(LSTMModel(), (input, h0, c0)) 4696 4697 @skipIfUnsupportedMinOpsetVersion(9) 4698 def test_lstm_cell(self): 4699 class LSTMCellModel(torch.nn.Module): 4700 def __init__(self, bias): 4701 super().__init__() 4702 self.lstm_cell = torch.nn.LSTMCell( 4703 RNN_INPUT_SIZE, RNN_HIDDEN_SIZE, bias=bias 4704 ) 4705 4706 def forward(self, x, h0, c0): 4707 return self.lstm_cell(x, (h0, c0)) 4708 4709 input = torch.randn(BATCH_SIZE, RNN_INPUT_SIZE) 4710 h0 = torch.randn(BATCH_SIZE, RNN_HIDDEN_SIZE) 4711 c0 = torch.randn(BATCH_SIZE, RNN_HIDDEN_SIZE) 4712 for bias in [True, False]: 4713 self.run_test(LSTMCellModel(bias), (input, h0, c0)) 4714 4715 @skipIfUnsupportedMinOpsetVersion(9) 4716 def test_lstm_default_init_state(self): 4717 class LSTMModel(torch.nn.Module): 4718 def __init__(self) -> None: 4719 super().__init__() 4720 self.rnn = torch.nn.LSTM( 4721 RNN_INPUT_SIZE, RNN_HIDDEN_SIZE, 1, bidirectional=False 4722 ) 4723 4724 def forward(self, x): 4725 return self.rnn(x) 4726 4727 input = torch.randn(RNN_SEQUENCE_LENGTH, BATCH_SIZE, RNN_INPUT_SIZE) 4728 self.run_test(LSTMModel(), input) 4729 4730 @skipIfUnsupportedMinOpsetVersion(9) 4731 def test_lstm_fixed_batch_size(self): 4732 class LSTMModel(torch.nn.Module): 4733 def __init__(self) -> None: 4734 super().__init__() 4735 self.lstm = torch.nn.LSTM( 4736 RNN_INPUT_SIZE, RNN_HIDDEN_SIZE, 1, bidirectional=False 4737 ) 4738 self.RNN_HIDDEN_SIZE = RNN_HIDDEN_SIZE 4739 4740 def forward(self, input): 4741 batch_size = input.size()[1] 4742 h0 = torch.ones([1, batch_size, self.RNN_HIDDEN_SIZE]) 4743 c0 = torch.ones([1, batch_size, self.RNN_HIDDEN_SIZE]) 4744 return self.lstm(input, (h0, c0)) 4745 4746 input = torch.randn(RNN_SEQUENCE_LENGTH, BATCH_SIZE, RNN_INPUT_SIZE) 4747 # verify with different input of same batch size 4748 input2 = torch.randn(RNN_SEQUENCE_LENGTH, BATCH_SIZE, RNN_INPUT_SIZE) 4749 self.run_test( 4750 LSTMModel(), input, fixed_batch_size=True, additional_test_inputs=[input2] 4751 ) 4752 4753 @skipIfUnsupportedMinOpsetVersion(9) 4754 def test_lstm_post_fix_init_state(self): 4755 class LSTMModel(torch.nn.Module): 4756 def __init__(self) -> None: 4757 super().__init__() 4758 self.lstm = torch.nn.LSTM( 4759 RNN_INPUT_SIZE, RNN_HIDDEN_SIZE, 1, bidirectional=False 4760 ) 4761 self.RNN_HIDDEN_SIZE = RNN_HIDDEN_SIZE 4762 4763 def forward(self, input): 4764 batch_size = input.size()[1] 4765 h0 = torch.ones([1, batch_size, self.RNN_HIDDEN_SIZE]) 4766 c0 = torch.ones([1, batch_size, self.RNN_HIDDEN_SIZE]) 4767 return self.lstm(input, (h0, c0)) 4768 4769 model = LSTMModel() 4770 input = torch.randn(RNN_SEQUENCE_LENGTH, 1, RNN_INPUT_SIZE) 4771 # verify with different input of different batch size 4772 input2 = torch.randn(RNN_SEQUENCE_LENGTH, BATCH_SIZE, RNN_INPUT_SIZE) 4773 self.run_test( 4774 model, 4775 input, 4776 input_names=["input.1"], 4777 dynamic_axes={"input.1": {0: "seq", 1: "batch"}}, 4778 additional_test_inputs=[input2], 4779 ) 4780 4781 def test_lstm_constant_folding(self): 4782 class LstmNet(torch.nn.Module): 4783 def __init__(self, input_size, hidden_size, num_layers, bidirectional): 4784 super().__init__() 4785 self.lstm = torch.nn.LSTM( 4786 input_size, hidden_size, num_layers, bidirectional=bidirectional 4787 ) 4788 4789 def forward(self, input, initial_state: Tuple[Tensor, Tensor]): 4790 return self.lstm(input, initial_state) 4791 4792 def get_LstmNet_model_and_inputs( 4793 input_size, hidden_size, num_layers, batch_size, seq_len, bidirectional 4794 ): 4795 num_directions = 2 if bidirectional else 1 4796 model = LstmNet(input_size, hidden_size, num_layers, bidirectional) 4797 input = torch.randn(seq_len, batch_size, input_size) 4798 h0 = torch.randn(num_layers * num_directions, batch_size, hidden_size) 4799 c0 = torch.randn(num_layers * num_directions, batch_size, hidden_size) 4800 return model, (input, (h0, c0)) 4801 4802 batch_size1 = 3 4803 model1, input1 = get_LstmNet_model_and_inputs(7, 3, 2, batch_size1, 5, True) 4804 self.run_test(model1, input1, do_constant_folding=True) 4805 4806 batch_size2 = 4 4807 model2, input2 = get_LstmNet_model_and_inputs(5, 4, 3, batch_size2, 7, False) 4808 self.run_test(model2, input2, do_constant_folding=True) 4809 4810 @skipIfUnsupportedMinOpsetVersion(9) 4811 def test_lstm_no_bias(self): 4812 class LstmNet(torch.nn.Module): 4813 def __init__(self, num_layers, bidirectional): 4814 super().__init__() 4815 self.lstm = torch.nn.LSTM( 4816 RNN_INPUT_SIZE, 4817 RNN_HIDDEN_SIZE, 4818 num_layers, 4819 bias=False, 4820 bidirectional=bidirectional, 4821 ) 4822 4823 def forward(self, input, initial_state: Tuple[Tensor, Tensor]): 4824 return self.lstm(input, initial_state) 4825 4826 def get_LstmNet_model_and_inputs(num_layers, bidirectional): 4827 input = torch.randn(RNN_SEQUENCE_LENGTH, BATCH_SIZE, RNN_INPUT_SIZE) 4828 num_directions = 2 if bidirectional else 1 4829 model = LstmNet(num_layers, bidirectional) 4830 h0 = torch.randn(num_layers * num_directions, BATCH_SIZE, RNN_HIDDEN_SIZE) 4831 c0 = torch.randn(num_layers * num_directions, BATCH_SIZE, RNN_HIDDEN_SIZE) 4832 return model, (input, (h0, c0)) 4833 4834 num_layers = [1, 1, 2, 3] 4835 bidirectional = [True, False, True, False] 4836 models_and_inputs = [ 4837 get_LstmNet_model_and_inputs(n, b) 4838 for n, b in zip(num_layers, bidirectional) 4839 ] 4840 for model, input in models_and_inputs: 4841 self.run_test(model, input) 4842 4843 @skipIfUnsupportedMinOpsetVersion(9) 4844 def test_lstm_sequence(self): 4845 class LstmNet(torch.nn.Module): 4846 def __init__(self) -> None: 4847 super().__init__() 4848 self.rnn1 = torch.nn.LSTM(8, 8, bidirectional=True, batch_first=True) 4849 self.linear1 = torch.nn.Linear(8 * 2, 8) 4850 self.rnn2 = torch.nn.LSTM(8, 8, bidirectional=True, batch_first=True) 4851 self.linear2 = torch.nn.Linear(8 * 2, 8) 4852 4853 def forward(self, input): 4854 rnn_output1, _ = self.rnn1(input) 4855 linear_output1 = self.linear1(rnn_output1) 4856 rnn_output2, _ = self.rnn2(linear_output1) 4857 linear_output2 = self.linear2(rnn_output2) 4858 return linear_output2 4859 4860 input = torch.zeros((1, 100, 8), dtype=torch.float32) 4861 self.run_test( 4862 LstmNet(), 4863 input, 4864 input_names=["input"], 4865 output_names=["output"], 4866 dynamic_axes={ 4867 "input": {0: "batch_size", 1: "w", 2: "h"}, 4868 "output": {0: "batch_size", 1: "w", 2: "h"}, 4869 }, 4870 ) 4871 4872 @skipScriptTest() 4873 def test_rnn_no_bias(self): 4874 def make_model(layers, packed_sequence): 4875 batch_first = True if packed_sequence == 2 else False 4876 model = torch.nn.RNN( 4877 RNN_INPUT_SIZE, 4878 RNN_HIDDEN_SIZE, 4879 layers, 4880 bidirectional=False, 4881 batch_first=batch_first, 4882 bias=False, 4883 ) 4884 4885 if packed_sequence == 1: 4886 model = rnn_model_with_packed_sequence.RnnModelWithPackedSequence( 4887 model, False 4888 ) 4889 if packed_sequence == 2: 4890 model = rnn_model_with_packed_sequence.RnnModelWithPackedSequence( 4891 model, True 4892 ) 4893 return model 4894 4895 def make_input(batch_size, layers, packed_sequence): 4896 batch_first = True if packed_sequence == 2 else False 4897 seq_lengths = np.random.randint(1, RNN_SEQUENCE_LENGTH + 1, size=batch_size) 4898 seq_lengths = sorted(map(int, seq_lengths), reverse=True) 4899 inputs = [torch.randn(l, RNN_INPUT_SIZE) for l in seq_lengths] 4900 inputs = rnn_utils.pad_sequence(inputs, batch_first=batch_first) 4901 inputs = [inputs] 4902 4903 h0 = torch.randn(layers, batch_size, RNN_HIDDEN_SIZE) 4904 inputs.append(h0) 4905 if packed_sequence != 0: 4906 inputs.append(torch.IntTensor(seq_lengths)) 4907 if len(inputs) == 1: 4908 input = inputs[0] 4909 else: 4910 input = tuple(inputs) 4911 return input 4912 4913 layers = [1, 3, 1, 3, 1, 3] 4914 packed_sequence = [0, 0, 1, 1, 2, 2] 4915 models = [make_model(l, p) for l, p in zip(layers, packed_sequence)] 4916 inputs = [ 4917 make_input(RNN_BATCH_SIZE, l, p) for l, p in zip(layers, packed_sequence) 4918 ] 4919 4920 for model, input in zip(models, inputs): 4921 self.run_test(model, input) 4922 4923 def test_gru_no_bias(self): 4924 class GruNet(torch.nn.Module): 4925 def __init__(self, input_size, hidden_size, num_layers, bidirectional): 4926 super().__init__() 4927 self.mygru = torch.nn.GRU( 4928 input_size, 4929 hidden_size, 4930 num_layers, 4931 bidirectional=bidirectional, 4932 bias=False, 4933 ) 4934 4935 def forward(self, input, initial_state): 4936 out = self.mygru(input, initial_state) 4937 return out 4938 4939 def get_GruNet_model_and_inputs( 4940 input_size, hidden_size, num_layers, batch_size, seq_len, bidirectional 4941 ): 4942 num_directions = 2 if bidirectional else 1 4943 model = GruNet(input_size, hidden_size, num_layers, bidirectional) 4944 input = torch.randn(seq_len, batch_size, input_size) 4945 h0 = torch.randn(num_layers * num_directions, batch_size, hidden_size) 4946 return model, (input, h0) 4947 4948 input_size = [7, 5] 4949 hidden_size = [3, 4] 4950 num_layers = [2, 3] 4951 batch_size = [3, 4] 4952 seq_len = [5, 7] 4953 bidirectional = [True, False] 4954 models_and_inputs = [ 4955 get_GruNet_model_and_inputs(i, h, n, b, s, bi) 4956 for i, h, n, b, s, bi in zip( 4957 input_size, hidden_size, num_layers, batch_size, seq_len, bidirectional 4958 ) 4959 ] 4960 for model, input in models_and_inputs: 4961 self.run_test(model, input, do_constant_folding=True) 4962 4963 def test_gru_constant_folding(self): 4964 class GruNet(torch.nn.Module): 4965 def __init__(self, input_size, hidden_size, num_layers, bidirectional): 4966 super().__init__() 4967 self.mygru = torch.nn.GRU( 4968 input_size, hidden_size, num_layers, bidirectional=bidirectional 4969 ) 4970 4971 def forward(self, input, initial_state): 4972 out = self.mygru(input, initial_state) 4973 return out 4974 4975 def get_GruNet_model_and_inputs( 4976 input_size, hidden_size, num_layers, batch_size, seq_len, bidirectional 4977 ): 4978 num_directions = 2 if bidirectional else 1 4979 model = GruNet(input_size, hidden_size, num_layers, bidirectional) 4980 input = torch.randn(seq_len, batch_size, input_size) 4981 h0 = torch.randn(num_layers * num_directions, batch_size, hidden_size) 4982 return model, (input, h0) 4983 4984 batch_size1 = 3 4985 model1, input1 = get_GruNet_model_and_inputs(7, 3, 2, batch_size1, 5, True) 4986 self.run_test(model1, input1, do_constant_folding=True) 4987 4988 batch_size2 = 4 4989 model2, input2 = get_GruNet_model_and_inputs(5, 4, 3, batch_size2, 7, False) 4990 self.run_test(model2, input2, do_constant_folding=True) 4991 4992 @skipIfUnsupportedMinOpsetVersion(8) 4993 def test_max_tensors(self): 4994 class MaxModel(torch.nn.Module): 4995 def forward(self, input, other): 4996 return torch.max(input, other) 4997 4998 model = MaxModel() 4999 x = torch.randn(4, 4, requires_grad=True) 5000 y = torch.randn(4, 1, requires_grad=True) 5001 self.run_test(model, (x, y)) 5002 5003 def test_amax_amin(self): 5004 class Model(torch.nn.Module): 5005 def forward(self, x): 5006 return torch.amax(x, dim=0, keepdim=True), torch.amin( 5007 x, dim=[0, 1], keepdim=False 5008 ) 5009 5010 model = Model() 5011 x = torch.randn(4, 4) 5012 self.run_test(model, x) 5013 5014 def test_aminmax(self): 5015 class Model(torch.nn.Module): 5016 def forward(self, x): 5017 return torch.aminmax(x, dim=1, keepdim=True), torch.aminmax( 5018 x, keepdim=False 5019 ) 5020 5021 model = Model() 5022 x = torch.randn(3, 4) 5023 self.run_test(model, x) 5024 5025 @skipIfUnsupportedMinOpsetVersion(9) 5026 def test_arange_end(self): 5027 class ArangeScript(torch.jit.ScriptModule): 5028 @torch.jit.script_method 5029 def forward(self, a): 5030 return torch.arange(a.size(0), dtype=torch.float).view(-1, 1) + a 5031 5032 x = torch.randn(3, 4, requires_grad=True) 5033 outputs = ArangeScript()(x) 5034 self.run_test(ArangeScript(), x) 5035 5036 class ArangeModel(torch.nn.Module): 5037 def forward(self, a): 5038 return torch.arange(a.size(0), dtype=torch.float).view(-1, 1) + a 5039 5040 self.run_test(ArangeModel(), x) 5041 5042 @skipIfUnsupportedMinOpsetVersion(11) 5043 def test_arange_end_notype(self): 5044 class ArangeScript(torch.jit.ScriptModule): 5045 @torch.jit.script_method 5046 def forward(self, a): 5047 return torch.arange(a.size(0)) 5048 5049 x = torch.randn(3, 4, requires_grad=True) 5050 outputs = ArangeScript()(x) 5051 self.run_test(ArangeScript(), x, input_names=["x"], dynamic_axes={"x": [0, 1]}) 5052 self.run_test(ArangeScript(), x, remained_onnx_input_idx=[]) 5053 5054 class ArangeModel(torch.nn.Module): 5055 def forward(self, a): 5056 return torch.arange(a.size(0)) 5057 5058 self.run_test(ArangeModel(), x, input_names=["x"], dynamic_axes={"x": [0, 1]}) 5059 self.run_test(ArangeModel(), x, remained_onnx_input_idx=[]) 5060 5061 @skipIfUnsupportedMinOpsetVersion(9) 5062 def test_arange_start_end(self): 5063 class ArangeScript(torch.jit.ScriptModule): 5064 @torch.jit.script_method 5065 def forward(self, a): 5066 return torch.arange(2, a.size(0) + 2, dtype=torch.float).view(-1, 1) + a 5067 5068 x = torch.randn(3, 4, requires_grad=True) 5069 self.run_test(ArangeScript(), x) 5070 5071 class ArangeModel(torch.nn.Module): 5072 def forward(self, a): 5073 return torch.arange(2, a.size(0) + 2, dtype=torch.float).view(-1, 1) + a 5074 5075 self.run_test(ArangeModel(), x) 5076 5077 @skipIfUnsupportedMinOpsetVersion(11) 5078 def test_arange_start_end_notype(self): 5079 class ArangeScript(torch.jit.ScriptModule): 5080 @torch.jit.script_method 5081 def forward(self, a): 5082 return torch.arange(2.7, a.size(0) + 2).view(-1, 1) + a 5083 5084 x = torch.randn(3, 4, requires_grad=True) 5085 self.run_test(ArangeScript(), x) 5086 5087 class ArangeModel(torch.nn.Module): 5088 def forward(self, a): 5089 return torch.arange(2.7, a.size(0) + 2).view(-1, 1) + a 5090 5091 self.run_test(ArangeModel(), x) 5092 5093 @skipIfUnsupportedMinOpsetVersion(9) 5094 def test_arange_start_end_step(self): 5095 class ArangeScript(torch.jit.ScriptModule): 5096 @torch.jit.script_method 5097 def forward(self, a): 5098 return ( 5099 torch.arange( 5100 2, a.size(0) * a.size(1) + 2, a.size(1), dtype=torch.float 5101 ).view(-1, 1) 5102 + a 5103 ) 5104 5105 x = torch.randn(3, 4, requires_grad=True) 5106 self.run_test(ArangeScript(), x) 5107 5108 class ArangeModel(torch.nn.Module): 5109 def forward(self, a): 5110 return ( 5111 torch.arange( 5112 2, a.size(0) * a.size(1) + 2, a.size(1), dtype=torch.float 5113 ).view(-1, 1) 5114 + a 5115 ) 5116 5117 self.run_test(ArangeModel(), x) 5118 5119 @skipIfUnsupportedMinOpsetVersion(11) 5120 def test_arange_start_end_step_notype(self): 5121 class ArangeScript(torch.jit.ScriptModule): 5122 @torch.jit.script_method 5123 def forward(self, a): 5124 return ( 5125 torch.arange(2.7, a.size(0) * a.size(1) + 2, a.size(1)).view(-1, 1) 5126 + a 5127 ) 5128 5129 x = torch.randn(3, 4, requires_grad=True) 5130 self.run_test(ArangeScript(), x) 5131 5132 class ArangeModel(torch.nn.Module): 5133 def forward(self, a): 5134 return ( 5135 torch.arange(2.7, a.size(0) * a.size(1) + 2, a.size(1)).view(-1, 1) 5136 + a 5137 ) 5138 5139 self.run_test(ArangeModel(), x) 5140 5141 @skipIfUnsupportedMinOpsetVersion(9) 5142 def test__dim_arange(self): 5143 class DimArange(torch.nn.Module): 5144 def forward(self, input): 5145 return torch._dim_arange(input, 1) 5146 5147 x = torch.ones(5, 6) 5148 self.run_test(DimArange(), x, input_names=["x"], dynamic_axes={"x": [0, 1]}) 5149 remained_onnx_input_idx = None if self.opset_version < 11 else [] 5150 self.run_test(DimArange(), x, remained_onnx_input_idx=remained_onnx_input_idx) 5151 5152 def _test_compare_ops(self, model, num_inputs): 5153 x_float = torch.randn(1, 2, 3, 4, requires_grad=True) 5154 x_int = torch.randint(10, (3, 4), dtype=torch.int32) 5155 if num_inputs > 1: 5156 y_float = torch.randn(1, 2, 3, 4, requires_grad=True) 5157 y_int = torch.randint(10, (3, 4), dtype=torch.int32) 5158 self.run_test(model, (x_float, y_float)) 5159 self.run_test(model, (x_float, y_int)) 5160 self.run_test(model, (x_int, y_float)) 5161 self.run_test(model, (x_int, y_int)) 5162 else: 5163 self.run_test(model, x_float) 5164 self.run_test(model, x_int) 5165 5166 @skipIfUnsupportedMinOpsetVersion(9) 5167 def test_and_or_xor(self): 5168 class MyModel(torch.nn.Module): 5169 def forward(self, x, y): 5170 return x ^ y, x | y, x & y, ~x 5171 5172 x = torch.randint(0, 2, (5, 5), dtype=torch.bool) 5173 y = torch.randint(0, 2, (5, 5), dtype=torch.bool) 5174 self.run_test(MyModel(), input_args=(x, y)) 5175 5176 @skipIfUnsupportedMinOpsetVersion(9) 5177 def test_logical_and(self): 5178 class AndModel(torch.nn.Module): 5179 def forward(self, x, y): 5180 return torch.logical_and(x, y) 5181 5182 x = torch.randint(0, 2, (5, 5), dtype=torch.bool) 5183 y = torch.randint(0, 2, (5, 5), dtype=torch.bool) 5184 self.run_test(AndModel(), input_args=(x, y)) 5185 5186 x = torch.randint(10, (5, 5), dtype=torch.int32) 5187 y = torch.randint(10, (5, 5), dtype=torch.int32) 5188 self.run_test(AndModel(), input_args=(x, y)) 5189 5190 x = torch.randint(10, (5, 5), dtype=torch.double) 5191 y = torch.randint(10, (5, 5), dtype=torch.double) 5192 self.run_test(AndModel(), input_args=(x, y)) 5193 5194 x = torch.randint(10, (2, 3, 5), dtype=torch.float32) 5195 y = torch.randint(10, (2, 3, 5), dtype=torch.long) 5196 self.run_test(AndModel(), input_args=(x, y)) 5197 5198 @skipIfUnsupportedMinOpsetVersion(9) 5199 def test_logical_or(self): 5200 class OrModel(torch.nn.Module): 5201 def forward(self, x, y): 5202 return torch.logical_or(x, y) 5203 5204 x = torch.randint(0, 2, (5, 5), dtype=torch.bool) 5205 y = torch.randint(0, 2, (5, 5), dtype=torch.bool) 5206 self.run_test(OrModel(), input_args=(x, y)) 5207 5208 x = torch.randint(10, (5, 5), dtype=torch.int32) 5209 y = torch.randint(10, (5, 5), dtype=torch.int32) 5210 self.run_test(OrModel(), input_args=(x, y)) 5211 5212 x = torch.randint(10, (5, 5), dtype=torch.double) 5213 y = torch.randint(10, (5, 5), dtype=torch.double) 5214 self.run_test(OrModel(), input_args=(x, y)) 5215 5216 x = torch.randint(10, (2, 3, 5), dtype=torch.float32) 5217 y = torch.randint(10, (2, 3, 5), dtype=torch.long) 5218 self.run_test(OrModel(), input_args=(x, y)) 5219 5220 @skipIfUnsupportedMinOpsetVersion(9) 5221 def test_logical_xor(self): 5222 class XorModel(torch.nn.Module): 5223 def forward(self, x, y): 5224 return torch.logical_xor(x, y) 5225 5226 x = torch.randint(0, 2, (5, 5), dtype=torch.bool) 5227 y = torch.randint(0, 2, (5, 5), dtype=torch.bool) 5228 self.run_test(XorModel(), input_args=(x, y)) 5229 5230 x = torch.randint(10, (5, 5), dtype=torch.int32) 5231 y = torch.randint(10, (5, 5), dtype=torch.int32) 5232 self.run_test(XorModel(), input_args=(x, y)) 5233 5234 x = torch.randint(10, (5, 5), dtype=torch.double) 5235 y = torch.randint(10, (5, 5), dtype=torch.double) 5236 self.run_test(XorModel(), input_args=(x, y)) 5237 5238 x = torch.randint(10, (2, 3, 5), dtype=torch.float32) 5239 y = torch.randint(10, (2, 3, 5), dtype=torch.long) 5240 self.run_test(XorModel(), input_args=(x, y)) 5241 5242 @skipIfUnsupportedMinOpsetVersion(9) 5243 def test_logical_not(self): 5244 class NotModel(torch.nn.Module): 5245 def forward(self, x): 5246 return torch.logical_not(x) 5247 5248 x = torch.randint(0, 2, (5, 5), dtype=torch.bool) 5249 self.run_test(NotModel(), input_args=(x,)) 5250 5251 x = torch.randint(10, (5, 5), dtype=torch.int32) 5252 self.run_test(NotModel(), input_args=(x,)) 5253 5254 x = torch.randint(10, (5, 5), dtype=torch.double) 5255 self.run_test(NotModel(), input_args=(x,)) 5256 5257 x = torch.randint(10, (2, 3, 5), dtype=torch.float32) 5258 self.run_test(NotModel(), input_args=(x,)) 5259 5260 @skipIfUnsupportedMinOpsetVersion(11) # float equal added after opset 11 5261 def test_eq(self): 5262 class EqualModel(torch.nn.Module): 5263 def forward(self, input, other): 5264 return input == other 5265 5266 self._test_compare_ops(EqualModel(), 2) 5267 5268 def test_gt(self): 5269 class GreaterModel(torch.nn.Module): 5270 def forward(self, input, other): 5271 return input > other 5272 5273 self._test_compare_ops(GreaterModel(), 2) 5274 5275 @skipIfUnsupportedMinOpsetVersion(9) 5276 def test_ge(self): 5277 class GreaterOrEqualModel(torch.nn.Module): 5278 def forward(self, input, other): 5279 return input >= other 5280 5281 self._test_compare_ops(GreaterOrEqualModel(), 2) 5282 5283 def test_gt_scalar(self): 5284 class GreaterModel(torch.nn.Module): 5285 def forward(self, input): 5286 return input > 1 5287 5288 self._test_compare_ops(GreaterModel(), 1) 5289 5290 def test_gt_primitive(self): 5291 class GreaterModel(torch.nn.Module): 5292 def __init__(self) -> None: 5293 super().__init__() 5294 self.y: int = 2 5295 5296 def forward(self, x: int): 5297 return self.y > x 5298 5299 x = 3 5300 self.run_test(GreaterModel(), (x,)) 5301 5302 @skipIfUnsupportedMinOpsetVersion(9) 5303 def test_ge_scalar(self): 5304 class GreaterOrEqualModel(torch.nn.Module): 5305 def forward(self, input): 5306 return input >= 1 5307 5308 self._test_compare_ops(GreaterOrEqualModel(), 1) 5309 5310 def test_lt(self): 5311 class LessModel(torch.nn.Module): 5312 def forward(self, input, other): 5313 return input > other 5314 5315 self._test_compare_ops(LessModel(), 2) 5316 5317 @skipIfUnsupportedMinOpsetVersion(9) 5318 def test_le(self): 5319 class LessOrEqualModel(torch.nn.Module): 5320 def forward(self, input, other): 5321 return input <= other 5322 5323 self._test_compare_ops(LessOrEqualModel(), 2) 5324 5325 def test_lt_scalar(self): 5326 class LessModel(torch.nn.Module): 5327 def forward(self, input): 5328 return input < 1 5329 5330 self._test_compare_ops(LessModel(), 1) 5331 5332 @skipIfUnsupportedMinOpsetVersion(9) 5333 def test_le_scalar(self): 5334 class LessOrEqualModel(torch.nn.Module): 5335 def forward(self, input): 5336 return input <= 1 5337 5338 self._test_compare_ops(LessOrEqualModel(), 1) 5339 5340 def test_matmul(self): 5341 class MatmulModel(torch.nn.Module): 5342 def forward(self, input, other): 5343 return torch.matmul(input, other) 5344 5345 x = torch.randn(3, 4, requires_grad=True) 5346 y = torch.randn(4, 5, requires_grad=True) 5347 self.run_test(MatmulModel(), (x, y)) 5348 5349 x = torch.randint(10, (3, 4)) 5350 y = torch.randint(10, (4, 5)) 5351 self.run_test(MatmulModel(), (x, y)) 5352 5353 def test_matmul_batch(self): 5354 class MatmulModel(torch.nn.Module): 5355 def forward(self, input, other): 5356 return torch.matmul(input, other) 5357 5358 x = torch.randn(2, 3, 4, requires_grad=True) 5359 y = torch.randn(2, 4, 5, requires_grad=True) 5360 self.run_test(MatmulModel(), (x, y)) 5361 5362 x = torch.randint(10, (2, 3, 4)) 5363 y = torch.randint(10, (2, 4, 5)) 5364 self.run_test(MatmulModel(), (x, y)) 5365 5366 def _argmin_argmax_model(self, input): 5367 class ArgminArgmaxModel(torch.nn.Module): 5368 def forward(self, input): 5369 return ( 5370 torch.argmin(input), 5371 torch.argmax(input), 5372 torch.argmin(input, keepdim=True), 5373 torch.argmax(input, keepdim=True), 5374 torch.argmin(input, dim=0, keepdim=True), 5375 torch.argmax(input, dim=1, keepdim=True), 5376 ) 5377 5378 self.run_test(ArgminArgmaxModel(), input) 5379 5380 @skipIfUnsupportedMinOpsetVersion(9) 5381 def test_argmin_argmax(self): 5382 input = torch.randn(7, 3, 5) 5383 self._argmin_argmax_model(input) 5384 5385 # Argmin and Argmax with "select_last_index" is not supprted before opset 12 5386 # "select_last_index" was added in opset 12 to deal with corner case where the 5387 # same value appears multiple times in the tensor 5388 @skipIfUnsupportedMinOpsetVersion(12) 5389 def test_argmin_argmax_select_last_index(self): 5390 input = torch.tensor([[1.0, 2.0, 3.0], [1.0, 1.0, 2.0]]) 5391 self._argmin_argmax_model(input) 5392 5393 input = torch.ones(7, 3, 5) 5394 self._argmin_argmax_model(input) 5395 5396 def test_repeat(self): 5397 class RepeatModel(torch.nn.Module): 5398 def forward(self, x, y): 5399 x2 = x.repeat(y.shape[0], 1) 5400 y1 = y.view(-1, 1) 5401 return x2 + y1 5402 5403 x = torch.tensor([1, 2, 3]) 5404 y = torch.tensor([4, 5, 8, 9]) 5405 self.run_test(RepeatModel(), (x, y)) 5406 5407 @skipIfUnsupportedMinOpsetVersion(9) 5408 def test_repeat_interleave(self): 5409 class FlattenModel(torch.nn.Module): 5410 def forward(self, x): 5411 return x.repeat_interleave(2) 5412 5413 for shape in ([3], [3, 4], [2, 3, 4]): 5414 x = torch.randn(shape) 5415 self.run_test(FlattenModel(), (x,)) 5416 5417 class DimsModel(torch.nn.Module): 5418 def forward(self, x): 5419 return x.repeat_interleave(4, dim=1) 5420 5421 x = torch.tensor([[1, 2], [3, 4]]) 5422 self.run_test(DimsModel(), (x,)) 5423 5424 class DimsModel2(torch.nn.Module): 5425 def forward(self, x): 5426 repeats = torch.tensor([4]) 5427 return torch.repeat_interleave(x, repeats, dim=1) 5428 5429 x = torch.tensor([[1, 2], [3, 4]]) 5430 self.run_test(DimsModel2(), (x,)) 5431 5432 class RepeatsDimsModel(torch.nn.Module): 5433 def forward(self, x): 5434 repeats = torch.tensor([1, 2]) 5435 return torch.repeat_interleave(x, repeats, dim=0) 5436 5437 x = torch.tensor([[1, 2], [3, 4]]) 5438 self.run_test(RepeatsDimsModel(), (x,)) 5439 5440 class RepeatsDimsModel2(torch.nn.Module): 5441 def forward(self, x): 5442 repeats = torch.tensor([1, 2]) 5443 return torch.repeat_interleave(x, repeats, dim=1) 5444 5445 x = torch.tensor([[1, 2], [3, 4]]) 5446 self.run_test(RepeatsDimsModel2(), (x,)) 5447 5448 @skipIfUnsupportedMinOpsetVersion(9) 5449 def test_repeat_interleave_noop(self): 5450 class Model(torch.nn.Module): 5451 def forward(self, x): 5452 return x.repeat_interleave(1, dim=1) 5453 5454 x = torch.randn(4, 1, 8) 5455 self.run_test(Model(), (x,)) 5456 5457 @skipIfUnsupportedMinOpsetVersion(13) 5458 def test_dynamic_repeat_interleave(self): 5459 class SingleDynamicModel(torch.nn.Module): 5460 def forward(self, x): 5461 repeats = torch.tensor(4) 5462 return torch.repeat_interleave(x, repeats, dim=1) 5463 5464 x = torch.tensor([[1, 2, 4], [3, 4, 7]]) 5465 another_x = torch.tensor([[7, 8], [5, 6]]) 5466 self.run_test( 5467 SingleDynamicModel(), 5468 x, 5469 additional_test_inputs=[another_x], 5470 input_names=["input_1"], 5471 dynamic_axes={"input_1": {1: "w"}}, 5472 ) 5473 5474 class NegDynamicModel(torch.nn.Module): 5475 def forward(self, x): 5476 repeats = torch.tensor(4) 5477 return torch.repeat_interleave(x, repeats, dim=-1) 5478 5479 x = torch.tensor([[1, 2, 4], [3, 4, 7]]) 5480 another_x = torch.tensor([[7, 8], [5, 6]]) 5481 self.run_test( 5482 NegDynamicModel(), 5483 x, 5484 additional_test_inputs=[another_x], 5485 input_names=["input_1"], 5486 dynamic_axes={"input_1": {1: "w"}}, 5487 ) 5488 5489 class SingleDynamicModelFloat(torch.nn.Module): 5490 def forward(self, x): 5491 repeats = torch.tensor([4]) 5492 return torch.repeat_interleave(x, repeats, dim=0) 5493 5494 x = torch.tensor([[1.1, 2.1], [3.1, 4.1]]) 5495 another_x = torch.tensor([[7.1, 8.1], [5.1, 6.1]]) 5496 self.run_test( 5497 SingleDynamicModelFloat(), 5498 x, 5499 additional_test_inputs=[another_x], 5500 input_names=["input_1"], 5501 dynamic_axes={"input_1": {0: "h"}}, 5502 ) 5503 5504 class DynamicRepeatsModel(torch.nn.Module): 5505 def forward(self, x, repeats): 5506 return torch.repeat_interleave(x, repeats, dim=1) 5507 5508 x = torch.tensor([[1, 2, 4], [3, 4, 7]]) 5509 another_x = torch.tensor([[7, 8], [5, 6]]) 5510 repeats = torch.tensor([2]) 5511 another_repeats = torch.tensor([4]) 5512 self.run_test( 5513 DynamicRepeatsModel(), 5514 (x, repeats), 5515 additional_test_inputs=[(another_x, another_repeats)], 5516 input_names=["input_1", "repeats_1"], 5517 dynamic_axes={"input_1": {1: "w"}, "repeats_1": {0: "r"}}, 5518 ) 5519 5520 class DynamicRepeatsModel2(torch.nn.Module): 5521 def forward(self, x, repeats): 5522 return torch.repeat_interleave(x, repeats, dim=1) 5523 5524 x = torch.tensor([[1, 2, 4], [3, 4, 7]]) 5525 repeats = torch.tensor([2]) 5526 another_repeats = torch.tensor([4]) 5527 self.run_test( 5528 DynamicRepeatsModel2(), 5529 (x, repeats), 5530 additional_test_inputs=[(x, another_repeats)], 5531 input_names=["input_1", "repeats_1"], 5532 dynamic_axes={"repeats_1": {0: "r"}}, 5533 ) 5534 5535 class DynamicFlattenModel(torch.nn.Module): 5536 def forward(self, x): 5537 return x.repeat_interleave(2) 5538 5539 x = torch.tensor([1, 2, 3]) 5540 self.run_test( 5541 DynamicFlattenModel(), 5542 x, 5543 input_names=["input_1"], 5544 dynamic_axes={"input_1": {0: "w"}}, 5545 ) 5546 5547 @skipIfUnsupportedMinOpsetVersion(13) 5548 def test_multiple_dynamic_repeat_interleave(self): 5549 class DynamicRepeatsModel(torch.nn.Module): 5550 def forward(self, x, repeats): 5551 return torch.repeat_interleave(x, repeats, dim=1) 5552 5553 x = torch.tensor([[1, 2, 4], [3, 4, 7]]) 5554 repeats = torch.tensor([2, 3, 4]) 5555 another_repeats = torch.tensor([4, 3, 2]) 5556 self.run_test( 5557 DynamicRepeatsModel(), 5558 (x, repeats), 5559 additional_test_inputs=[(x, another_repeats)], 5560 input_names=["input_1", "repeats_1"], 5561 dynamic_axes={"repeats_1": {0: "r"}}, 5562 ) 5563 5564 class DynamicRepeatsModel2(torch.nn.Module): 5565 def forward(self, x, repeats): 5566 return torch.repeat_interleave(x, repeats, dim=0) 5567 5568 x = torch.tensor([[1, 2, 4], [3, 4, 7]]) 5569 repeats = torch.tensor([2, 3]) 5570 another_repeats = torch.tensor([4, 3]) 5571 self.run_test( 5572 DynamicRepeatsModel2(), 5573 (x, repeats), 5574 additional_test_inputs=[(x, another_repeats)], 5575 input_names=["input_1", "repeats_1"], 5576 dynamic_axes={"repeats_1": {0: "r"}}, 5577 ) 5578 5579 def test_view(self): 5580 class ViewModel(torch.nn.Module): 5581 def forward(self, input): 5582 return input.view(4, 24) 5583 5584 x = torch.randint(10, (4, 2, 3, 4), dtype=torch.int32) 5585 self.run_test(ViewModel(), x) 5586 5587 def test_view_dynamic(self): 5588 class ViewModel(torch.nn.Module): 5589 def forward(self, input, other): 5590 return input.view(other.shape) 5591 5592 x = torch.randn(2, 3, 4) 5593 shape = torch.randn(6, 4) 5594 self.run_test( 5595 ViewModel(), 5596 (x, shape), 5597 input_names=["x", "shape"], 5598 dynamic_axes={"x": [0, 1, 2], "shape": [0, 1]}, 5599 ) 5600 self.run_test(ViewModel(), (x, shape), remained_onnx_input_idx=[0]) 5601 5602 def test_view_dynamic_zero_dim(self): 5603 class ViewModel(torch.nn.Module): 5604 def forward(self, input): 5605 input = input.view(-1, 2) 5606 return input.view(1, -1) 5607 5608 x = torch.ones(2) 5609 another_x = torch.empty((0,)) 5610 self.run_test( 5611 ViewModel(), 5612 x, 5613 additional_test_inputs=[another_x], 5614 input_names=["input_1"], 5615 dynamic_axes={ 5616 "input_1": [ 5617 0, 5618 ] 5619 }, 5620 ) 5621 5622 def test_view_as(self): 5623 class ViewModel(torch.nn.Module): 5624 def forward(self, input, other): 5625 return input.view_as(other) 5626 5627 x = torch.randn(2, 3, 4) 5628 y = torch.randn(6, 4) 5629 self.run_test(ViewModel(), (x, y)) 5630 5631 def test_linear(self): 5632 class LinearModel(torch.nn.Module): 5633 def __init__(self) -> None: 5634 super().__init__() 5635 self.fc = torch.nn.Linear(16, 16) 5636 5637 def forward(self, x): 5638 out = self.fc(x) 5639 out = self.fc(out) 5640 return out 5641 5642 x = torch.randn(3, 16) 5643 self.run_test(LinearModel(), (x,)) 5644 5645 class LinearModel(torch.nn.Module): 5646 def forward(self, input, weight, bias): 5647 return torch.nn.functional.linear(input, weight, bias) 5648 5649 # input of rank 2 5650 x = torch.randn(2, 2) 5651 y = torch.randn(2, 2) 5652 z = torch.randn(1) 5653 self.run_test(LinearModel(), (x, y, z)) 5654 5655 # input of rank 3 5656 x = torch.randn(3, 3, 3) 5657 y = torch.randn(3, 3) 5658 z = torch.randn(1) 5659 self.run_test(LinearModel(), (x, y, z)) 5660 5661 @skipScriptTest() 5662 def test_weight_norm(self): 5663 # addmm for 3-d inputs converts to onnx::MatMul 5664 model = torch.nn.utils.weight_norm(torch.nn.Linear(5, 10), dim=1) 5665 x = torch.randn(3, 4, 5, requires_grad=True) 5666 self.run_test(model, x) 5667 5668 # addmm for 2-d inputs converts to onnx::Gemm 5669 model = torch.nn.utils.weight_norm(torch.nn.Linear(5, 10), dim=1) 5670 x = torch.randn(4, 5, requires_grad=True) 5671 self.run_test(model, x) 5672 5673 model = torch.nn.utils.weight_norm(torch.nn.Conv1d(1, 1, 3)) 5674 x = torch.randn(1, 1, 5, requires_grad=True) 5675 self.run_test(model, x) 5676 5677 model = torch.nn.utils.weight_norm(torch.nn.Conv1d(1, 1, 3), dim=-2) 5678 x = torch.randn(1, 1, 5, requires_grad=True) 5679 self.run_test(model, x) 5680 5681 model = torch.nn.utils.weight_norm(torch.nn.Conv1d(3, 6, 3), name="weight") 5682 x = torch.randn(3, 3, 5, requires_grad=True) 5683 self.run_test(model, x) 5684 5685 @skipScriptTest() 5686 def test_weight_norm_nodim(self): 5687 # addmm for 3-d inputs converts to onnx::MatMul 5688 model = torch.nn.utils.weight_norm(torch.nn.Linear(5, 10), dim=None) 5689 x = torch.randn(3, 4, 5, requires_grad=True) 5690 self.run_test(model, x) 5691 5692 # addmm for 2-d inputs converts to onnx::Gemm 5693 model = torch.nn.utils.weight_norm(torch.nn.Linear(5, 10), dim=None) 5694 x = torch.randn(4, 5, requires_grad=True) 5695 self.run_test(model, x) 5696 5697 def test_flatten(self): 5698 class FlattenModel(torch.nn.Module): 5699 def forward(self, input): 5700 return torch.flatten(input) 5701 5702 model = FlattenModel() 5703 5704 # flatten with 4d input 5705 x = torch.randint(10, (1, 2, 3, 4)) 5706 self.run_test(model, x) 5707 5708 # flatten with 0d input 5709 x = torch.randn([]) 5710 self.run_test(model, x) 5711 5712 # flatten with 1d input 5713 x = torch.randn(4) 5714 self.run_test(model, x) 5715 5716 def test_flatten2d(self): 5717 class FlattenModel(torch.nn.Module): 5718 def forward(self, input): 5719 return torch.flatten(input, 1) 5720 5721 x = torch.randint(10, (1, 2, 3, 4)) 5722 self.run_test(FlattenModel(), x) 5723 5724 def test_flatten2d_neg(self): 5725 class FlattenModel(torch.nn.Module): 5726 def forward(self, x): 5727 return ( 5728 torch.flatten(x, 1, -1), 5729 torch.flatten(x, 0, -2), 5730 torch.flatten(x, 1, -2), 5731 ) 5732 5733 x = torch.randint(10, (1, 2, 3, 4)) 5734 self.run_test(FlattenModel(), x) 5735 5736 @skipIfUnsupportedMinOpsetVersion(9) 5737 def test_flatten_dynamic_axes(self): 5738 class MyModule(torch.nn.Module): 5739 def forward(self, x): 5740 return torch.flatten(x, start_dim=2, end_dim=3) 5741 5742 batch_size = 3 5743 x = torch.randn(batch_size, 5, 4, 5) 5744 y = torch.randn(5, 5, 4, 5) 5745 model = MyModule() 5746 self.run_test( 5747 model, 5748 x, 5749 additional_test_inputs=[y], 5750 input_names=["input"], 5751 output_names=["output"], 5752 dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}}, 5753 ) 5754 5755 @skipIfUnsupportedMinOpsetVersion(11) 5756 def test_getitem(self): 5757 class GetItemModel(torch.jit.ScriptModule): 5758 @torch.jit.script_method 5759 def forward(self, x, y, z, ind): 5760 # this will create prim::ListConstruct(x, y, z) + aten::__getitem__ 5761 arr = [x, y, z] 5762 return arr[ind] 5763 5764 x = torch.randn(3, 4, 5) 5765 y = torch.randn(1, 4, 5) 5766 z = torch.randn(2, 4, 5) 5767 ind = torch.tensor(1, dtype=torch.long) 5768 self.run_test(GetItemModel(), (x, y, z, ind)) 5769 5770 ind = torch.tensor(-2, dtype=torch.long) 5771 self.run_test(GetItemModel(), (x, y, z, ind)) 5772 5773 @skipDtypeChecking 5774 def test_item(self): 5775 class M(torch.nn.Module): 5776 def forward(self, x, y, i: int): 5777 return int(x[y[i]].item()) 5778 5779 x = torch.arange(6, dtype=torch.float) 5780 y = torch.tensor([0, 1, 2, 3, 4], dtype=torch.long) 5781 i = 3 5782 self.run_test(torch.jit.script(M()), (x, y, i)) 5783 5784 @skipScriptTest() # torch.nonzero(x, as_tuple=True) is not scriptable. 5785 @skipIfUnsupportedMinOpsetVersion(9) 5786 def test_nonzero(self): 5787 class NonzeroModel(torch.nn.Module): 5788 def forward(self, x): 5789 return x.nonzero(), x.nonzero(as_tuple=True) 5790 5791 x = torch.randn(60).index_fill_(0, torch.randint(0, 60, (20,)), 0).view(3, 4, 5) 5792 self.run_test(NonzeroModel(), (x,)) 5793 5794 def test_unbind(self): 5795 class UnbindModel(torch.nn.Module): 5796 def forward(self, input): 5797 _, out, _ = input.unbind() 5798 return out 5799 5800 x = torch.randn(3, 4, 5) 5801 self.run_test(UnbindModel(), x) 5802 5803 class UnbindModel2(torch.nn.Module): 5804 def forward(self, input): 5805 _, out, _, _ = input.unbind(1) 5806 return out 5807 5808 x = torch.randn(3, 4, 5) 5809 self.run_test(UnbindModel2(), x) 5810 5811 class UnbindModel3(torch.nn.Module): 5812 def forward(self, input): 5813 _, out, _, _ = input.unbind(-2) 5814 return out 5815 5816 x = torch.randn(3, 4, 5) 5817 self.run_test(UnbindModel3(), x) 5818 5819 @skipIfUnsupportedMinOpsetVersion(11) 5820 def test_len(self): 5821 class LenModel(torch.jit.ScriptModule): 5822 @torch.jit.script_method 5823 def forward(self, input): 5824 return len(input.unbind()) + input 5825 5826 x = torch.randn(4, 5) 5827 self.run_test( 5828 LenModel(), 5829 x, 5830 input_names=["input"], 5831 dynamic_axes={"input": {0: "seq"}}, 5832 additional_test_inputs=(torch.randn(5, 5),), 5833 ) 5834 5835 @skipIfUnsupportedMinOpsetVersion(9) 5836 def test_len_list(self): 5837 class LenListModel(torch.jit.ScriptModule): 5838 @torch.jit.script_method 5839 def forward(self, input): 5840 return torch.ones(len(input.shape)) 5841 5842 x = torch.randn(4, 5) 5843 self.run_test(LenListModel(), x, remained_onnx_input_idx=[]) 5844 5845 @skipIfUnsupportedMinOpsetVersion(11) 5846 def test_unbind_dynamic(self): 5847 class UnbindModel(torch.jit.ScriptModule): 5848 @torch.jit.script_method 5849 def forward(self, input): 5850 return input.unbind()[1] 5851 5852 x = torch.randn(3, 4, 5) 5853 self.run_test(UnbindModel(), x) 5854 5855 class UnbindModel2(torch.jit.ScriptModule): 5856 @torch.jit.script_method 5857 def forward(self, input): 5858 return input.unbind(-1)[1] 5859 5860 x = torch.randn(3, 4, 5) 5861 self.run_test(UnbindModel2(), x) 5862 5863 @skipScriptTest() # scripting tests run for opsets > 11. See: test_split_script 5864 def test_split(self): 5865 class SplitModel(torch.nn.Module): 5866 def forward(self, input): 5867 return input.split([2, 1, 2]), input.split([3, 2])[0] 5868 5869 x = torch.randn(5, 4, 3) 5870 self.run_test(SplitModel(), x) 5871 5872 class SplitModel2(torch.nn.Module): 5873 def forward(self, input): 5874 return input.split([2, 1, 1], -2), input.split([2, 2], -2)[-1] 5875 5876 x = torch.randn(5, 4, 3) 5877 self.run_test(SplitModel2(), x) 5878 5879 class SplitModel3(torch.nn.Module): 5880 def forward(self, input): 5881 return input.split([2, 1, 2]) 5882 5883 x = torch.randn(5, 4, 3) 5884 self.run_test(SplitModel3(), x) 5885 5886 @skipIfUnsupportedMinOpsetVersion(11) 5887 def test_split_script(self): 5888 class SplitModel(torch.nn.Module): 5889 def forward(self, input): 5890 return input.split([2, 1, 2]), input.split([3, 2])[0] 5891 5892 x = torch.randn(5, 4, 3) 5893 self.run_test(SplitModel(), x) 5894 5895 class SplitModel2(torch.nn.Module): 5896 def forward(self, input): 5897 return input.split([2, 1, 1], -2), input.split([2, 2], -2)[-1] 5898 5899 x = torch.randn(5, 4, 3) 5900 self.run_test(SplitModel2(), x) 5901 5902 class SplitModel3(torch.nn.Module): 5903 def forward(self, input): 5904 return input.split([2, 1, 2]) 5905 5906 x = torch.randn(5, 4, 3) 5907 self.run_test(SplitModel3(), x) 5908 5909 @skipIfUnsupportedMinOpsetVersion(11) 5910 @skipScriptTest() 5911 def test_split_size_as_list(self): 5912 class SplitModel(torch.nn.Module): 5913 def forward(self, input, split_sizes: List[int]): 5914 out = [] 5915 split_list: List[Tensor] = input.split(split_sizes) 5916 5917 for ob in split_list: 5918 out.append(ob) # noqa: PERF402 5919 return torch.cat(out, dim=0) 5920 5921 x = torch.randn(6, 4, 3) 5922 split_sizes = [torch.tensor(2), torch.tensor(4)] 5923 self.run_test(SplitModel(), (x, split_sizes)) 5924 5925 @skipIfUnsupportedMinOpsetVersion(11) 5926 def test_split_size_with_slice(self): 5927 class SplitModule(torch.nn.Module): 5928 def forward(self, x, y, t): 5929 splits = (x.size(1), y.size(1)) 5930 out, out2 = torch.split(t, splits, dim=1) 5931 return out, out2 5932 5933 x = torch.randn(2, 3) 5934 y = torch.randn(2, 4) 5935 t = torch.randn(2, 7) 5936 self.run_test( 5937 SplitModule(), 5938 (x, y, t), 5939 input_names=["x", "y", "t"], 5940 dynamic_axes={"x": [0, 1], "y": [0, 1], "t": [0, 1]}, 5941 ) 5942 self.run_test(SplitModule(), (x, y, t), remained_onnx_input_idx=[2]) 5943 5944 @skipIfUnsupportedMinOpsetVersion(11) 5945 def test_split_dynamic(self): 5946 class SplitModel(torch.jit.ScriptModule): 5947 @torch.jit.script_method 5948 def forward(self, input): 5949 return input.split(2)[1] 5950 5951 x = torch.randn(5, 4, 3) 5952 self.run_test(SplitModel(), x) 5953 5954 class SplitModel2(torch.jit.ScriptModule): 5955 @torch.jit.script_method 5956 def forward(self, input): 5957 return input.split(2, -3)[1] 5958 5959 x = torch.randn(5, 4, 3) 5960 self.run_test(SplitModel2(), x) 5961 5962 @skipIfUnsupportedMinOpsetVersion(11) 5963 def test_split_dynamic_axes(self): 5964 class Split(torch.nn.Module): 5965 def forward(self, x): 5966 return x.split(1, dim=-1) 5967 5968 x = torch.randn(4, 384, 2) 5969 input_names = ["logits"] 5970 self.run_test( 5971 Split(), 5972 x, 5973 input_names=input_names, 5974 dynamic_axes={input_names[0]: {0: "batch"}}, 5975 ) 5976 5977 @skipIfUnsupportedMinOpsetVersion(11) 5978 def test_chunk(self): 5979 class ChunkModel(torch.nn.Module): 5980 def __init__(self, dim=1): 5981 super().__init__() 5982 self.dim = dim 5983 5984 def forward(self, x): 5985 return torch.chunk(x, 3, dim=self.dim) 5986 5987 model = ChunkModel() 5988 model.eval() 5989 model_neg_dim = ChunkModel(-1) 5990 model_neg_dim.eval() 5991 x = torch.randn(1, 18) 5992 5993 for dim_size_ in range(13, 16): 5994 y = torch.randn(1, dim_size_) 5995 self.run_test( 5996 model, 5997 x, 5998 additional_test_inputs=[y], 5999 input_names=["x"], 6000 dynamic_axes={"x": {0: "batch_size", 1: "dims"}}, 6001 ) 6002 6003 self.run_test( 6004 model_neg_dim, 6005 x, 6006 additional_test_inputs=[y], 6007 input_names=["x"], 6008 dynamic_axes={"x": {0: "batch_size", 1: "dims"}}, 6009 ) 6010 6011 @skipIfUnsupportedMinOpsetVersion(11) 6012 def test_dynamic_chunk(self): 6013 class ChunkModel(torch.nn.Module): 6014 def __init__(self, dim=1): 6015 super().__init__() 6016 self.dim = dim 6017 6018 def forward(self, x): 6019 return torch.chunk(x, x.size(0), dim=self.dim) 6020 6021 model = ChunkModel() 6022 model.eval() 6023 model_neg_dim = ChunkModel(-1) 6024 model_neg_dim.eval() 6025 x = torch.randn(3, 18) 6026 6027 for dim_size_ in range(13, 16): 6028 y = torch.randn(3, dim_size_) 6029 self.run_test( 6030 model, 6031 x, 6032 additional_test_inputs=[y], 6033 input_names=["x"], 6034 dynamic_axes={"x": {0: "batch_size", 1: "dims"}}, 6035 ) 6036 6037 self.run_test( 6038 model_neg_dim, 6039 x, 6040 additional_test_inputs=[y], 6041 input_names=["x"], 6042 dynamic_axes={"x": {0: "batch_size", 1: "dims"}}, 6043 ) 6044 6045 def test_concat(self): 6046 class ConcatModel(torch.nn.Module): 6047 def forward(self, x, y, z): 6048 return torch.cat((x, y, z)) 6049 6050 x = torch.randn(3, 4, 5) 6051 y = torch.randn(1, 4, 5) 6052 z = torch.randn(2, 4, 5) 6053 self.run_test(ConcatModel(), (x, y, z)) 6054 6055 @skipIfUnsupportedMinOpsetVersion(11) 6056 def test_concat_dynamic(self): 6057 class ConcatDynamicModel(torch.jit.ScriptModule): 6058 @torch.jit.script_method 6059 def forward(self, x): 6060 return torch.cat(x.unbind()) 6061 6062 x = torch.randn(4, 5, 6) 6063 self.run_test(ConcatDynamicModel(), x) 6064 6065 def test_stack(self): 6066 class StackModel(torch.nn.Module): 6067 def forward(self, x, y, z): 6068 return torch.stack((x, y, z), 1) 6069 6070 x = torch.randn(3, 4, 5) 6071 y = torch.randn(3, 4, 5) 6072 z = torch.randn(3, 4, 5) 6073 self.run_test(StackModel(), (x, y, z)) 6074 6075 @skipIfUnsupportedMinOpsetVersion(11) 6076 def test_stack_dynamic(self): 6077 class StackDynamicModel(torch.jit.ScriptModule): 6078 @torch.jit.script_method 6079 def forward(self, x): 6080 return torch.stack(x.unbind(), 1) 6081 6082 x = torch.randn(4, 5, 6) 6083 self.run_test(StackDynamicModel(), x) 6084 6085 def test_loop_dynamic(self): 6086 class LoopModel(torch.jit.ScriptModule): 6087 @torch.jit.script_method 6088 def forward(self, x): 6089 for i in range(x.size(2)): 6090 x = x + i 6091 return x 6092 6093 model = LoopModel() 6094 inputs = torch.zeros(1, 2, 3, dtype=torch.long) 6095 self.run_test(model, inputs) 6096 6097 @skipIfUnsupportedMinOpsetVersion(9) 6098 def test_loop_nested(self): 6099 class NestedLoopsModel(torch.jit.ScriptModule): 6100 @torch.jit.script_method 6101 def forward(self, x): 6102 for i in range(5): 6103 a = 0 6104 while a < 4: 6105 a += 1 6106 x = x + a 6107 return x 6108 6109 model = NestedLoopsModel() 6110 inputs = torch.zeros(1, 2, 3, dtype=torch.long) 6111 self.run_test(model, inputs) 6112 6113 @skipIfUnsupportedMinOpsetVersion(11) 6114 def test_loop_with_list(self): 6115 class ListLoopModel(torch.jit.ScriptModule): 6116 @torch.jit.script_method 6117 def forward(self, x): 6118 res = [] 6119 res1 = [] 6120 arr = x.split([3, 4, 1, 1, 2, 3, 2], 0) 6121 res2 = torch.zeros(3, 4, dtype=torch.long) 6122 res3 = [] 6123 res4 = [] 6124 for i in range(len(arr)): 6125 res.append(arr[i].sum(0, False)) 6126 res1.append(arr[-1 - i].sum(0, False)) 6127 res2 += 1 6128 res3 = res3 + [arr[i].sum(0, False)] 6129 res4 += [arr[-1 - i].sum(0, False)] 6130 return res, res1, res2, torch.stack(res3), torch.stack(res4) 6131 6132 model = ListLoopModel() 6133 inputs = torch.randn(16) 6134 self.run_test(model, inputs) 6135 6136 @skipIfUnsupportedMinOpsetVersion(11) 6137 def test_loop_transpose(self): 6138 class LoopModel(torch.nn.Module): 6139 def forward(self, x): 6140 res = torch.zeros_like(x[0]) 6141 for i in range(x.size(0)): 6142 res += x[0].transpose(0, 1) 6143 return res 6144 6145 model = torch.jit.script(LoopModel()) 6146 x = torch.randn(5, 3, 3) 6147 self.run_test(model, x) 6148 6149 @skipIfUnsupportedMinOpsetVersion(11) 6150 def test_loop_multi_dim(self): 6151 class LoopMultiDimModel(torch.jit.ScriptModule): 6152 @torch.jit.script_method 6153 def forward(self, x, y): 6154 for x_ in torch.flip(x.narrow(0, 0, 7), [0]): 6155 y = x_[0][y] 6156 return y 6157 6158 model = LoopMultiDimModel() 6159 x = torch.randint(0, 5, (8, 1, 17), dtype=torch.long) 6160 y = torch.ones(1, dtype=torch.long) 6161 self.run_test(model, (x, y)) 6162 6163 @skipIfUnsupportedMinOpsetVersion(11) 6164 def test_list(self): 6165 class ListModel(torch.jit.ScriptModule): 6166 @torch.jit.script_method 6167 def forward(self, x): 6168 tensors = x.unbind() 6169 res = [] 6170 res.append(tensors[0]) 6171 res.append(tensors[1]) 6172 res.pop(1) 6173 6174 res.insert(0, tensors[1]) 6175 res.append(tensors[2]) 6176 res += [tensors[3], tensors[4]] 6177 res = res + [tensors[5]] 6178 return torch.ones(len(res)) 6179 6180 model = ListModel() 6181 inputs = torch.randn(16, 1) 6182 self.run_test(model, inputs) 6183 6184 @skipIfUnsupportedMinOpsetVersion(11) 6185 def test_list_append(self): 6186 class ListModel(torch.nn.Module): 6187 def forward(self, x, y): 6188 res = [] 6189 for i in range(x.size(0)): 6190 res += [torch.matmul(x[i], y)] 6191 return res 6192 6193 model = torch.jit.script(ListModel()) 6194 x = torch.randn(16, 3, 4) 6195 y = torch.randn(4, 5) 6196 self.run_test(model, (x, y)) 6197 6198 @skipIfUnsupportedMinOpsetVersion(13) 6199 def test_list_append_nested(self): 6200 class ListModel(torch.nn.Module): 6201 def forward(self, x, y): 6202 res = [] 6203 for i in range(x.size(0)): 6204 for j in range(x.size(1)): 6205 res += [torch.matmul(x[i][j], y)] 6206 return res 6207 6208 model = torch.jit.script(ListModel()) 6209 x = torch.randn(4, 4, 3, 4) 6210 y = torch.randn(4, 5) 6211 self.run_test(model, (x, y)) 6212 6213 @skipIfUnsupportedMinOpsetVersion(14) # Need onnx::Identity of sequence in opset 14 6214 def test_list_append_nested_2(self): 6215 class ListModel(torch.nn.Module): 6216 def forward(self, x): 6217 res = [] 6218 res_replicate = [] 6219 for i in range(x.size(0)): 6220 if len(res) > 2: 6221 for j in range(x.size(1)): 6222 res.append(x[i][j]) 6223 res_replicate.append(res[-1]) 6224 res.append(res_replicate[-1]) 6225 return res, res_replicate 6226 6227 model = torch.jit.script(ListModel()) 6228 x = torch.randn(4, 4, 3, 4) 6229 self.run_test(model, (x,)) 6230 6231 @skipIfUnsupportedMinOpsetVersion(13) 6232 def test_list_append_nested_mixed_dtype(self): 6233 class ListModel(torch.nn.Module): 6234 def forward(self, x, y): 6235 res = [] 6236 for i in range(x.size(0)): 6237 for j in range(x.size(1)): 6238 if i == j: 6239 res.append(x == y) 6240 else: 6241 res.append(x != y) 6242 return res 6243 6244 model = torch.jit.script(ListModel()) 6245 x = torch.randn(4, 4, 3, 4) 6246 y = torch.randn(3, 4) 6247 self.run_test(model, (x, y)) 6248 6249 @skipIfUnsupportedMinOpsetVersion(11) 6250 def test_list_pop(self): 6251 class ListModel(torch.nn.Module): 6252 def forward(self, x, y): 6253 res = [] 6254 for i in range(x.size(0)): 6255 res += [torch.matmul(x[i], y)] 6256 res.pop() 6257 return res 6258 6259 model = torch.jit.script(ListModel()) 6260 x = torch.randn(16, 3, 4) 6261 y = torch.randn(4, 5) 6262 self.run_test(model, (x, y)) 6263 6264 @skipIfUnsupportedMinOpsetVersion(13) 6265 def test_list_pop_nested(self): 6266 class ListModel(torch.nn.Module): 6267 def forward(self, x, y): 6268 res = [] 6269 for i in range(x.size(0)): 6270 for j in range(x.size(1)): 6271 res += [torch.matmul(x[i][j], y)] 6272 res.pop() 6273 res += [torch.matmul(x[i][0], y)] 6274 return res 6275 6276 model = torch.jit.script(ListModel()) 6277 x = torch.randn(4, 4, 3, 4) 6278 y = torch.randn(4, 5) 6279 self.run_test(model, (x, y)) 6280 6281 @skipIfUnsupportedMinOpsetVersion(11) 6282 def test_list_del(self): 6283 class ListModel(torch.nn.Module): 6284 def forward(self, x, y): 6285 res = [] 6286 for i in range(x.size(0)): 6287 res += [torch.matmul(x[i], y)] 6288 del res[2] 6289 return res 6290 6291 model = torch.jit.script(ListModel()) 6292 x = torch.randn(16, 3, 4) 6293 y = torch.randn(4, 5) 6294 self.run_test(model, (x, y)) 6295 6296 @skipIfUnsupportedMinOpsetVersion(13) 6297 def test_list_del_nested(self): 6298 class ListModel(torch.nn.Module): 6299 def forward(self, x, y): 6300 res = [] 6301 for i in range(x.size(0)): 6302 for j in range(x.size(1)): 6303 res += [torch.matmul(x[i][j], y)] 6304 del res[i] 6305 res += [torch.matmul(x[i][0], y)] 6306 return res 6307 6308 model = torch.jit.script(ListModel()) 6309 x = torch.randn(4, 4, 3, 4) 6310 y = torch.randn(4, 5) 6311 self.run_test(model, (x, y)) 6312 6313 @skipIfUnsupportedMinOpsetVersion(11) 6314 def test_list_set(self): 6315 class ListModel(torch.nn.Module): 6316 def forward(self, x, y): 6317 res = [] 6318 for i in range(x.size(0)): 6319 res.append(x[i]) 6320 res[y] = x[y] 6321 return res 6322 6323 model = torch.jit.script(ListModel()) 6324 x = torch.randn(12, 4) 6325 y = torch.tensor(2, dtype=torch.long) 6326 self.run_test(model, (x, y)) 6327 6328 @skipIfUnsupportedMinOpsetVersion(13) 6329 def test_list_idx_sum(self): 6330 class ListModel(torch.nn.Module): 6331 def forward(self, x, y): 6332 indices = torch.arange(x.size(0)) 6333 res = [] 6334 for i in range(x.size(0)): 6335 res.append(x[i]) 6336 return res[torch.sum(indices[:y])] 6337 6338 model = torch.jit.script(ListModel()) 6339 x = torch.randn(12, 4) 6340 y = torch.tensor(2, dtype=torch.long) 6341 self.run_test(model, (x, y)) 6342 6343 @skipIfUnsupportedMinOpsetVersion(9) 6344 def test_tensor_factories(self): 6345 class TensorFactory(torch.nn.Module): 6346 def forward(self, x): 6347 return torch.zeros(x.size()) + torch.ones(x.size()) 6348 6349 x = torch.randn(2, 3, 4) 6350 self.run_test( 6351 TensorFactory(), x, input_names=["x"], dynamic_axes={"x": [0, 1, 2]} 6352 ) 6353 self.run_test(TensorFactory(), x, remained_onnx_input_idx=[]) 6354 6355 @skipIfUnsupportedMinOpsetVersion(9) 6356 def test_tensor_factories_script(self): 6357 class TensorFactory(torch.jit.ScriptModule): 6358 @torch.jit.script_method 6359 def forward(self, x): 6360 return torch.zeros(x.shape, dtype=torch.float) + torch.ones( 6361 x.shape, dtype=torch.float 6362 ) 6363 6364 x = torch.randn(2, 3, 4) 6365 self.run_test( 6366 TensorFactory(), x, input_names=["x"], dynamic_axes={"x": [0, 1, 2]} 6367 ) 6368 self.run_test(TensorFactory(), x, remained_onnx_input_idx=[]) 6369 6370 @skipIfUnsupportedMinOpsetVersion(9) 6371 def test_tensor_like_factories_script(self): 6372 class TensorFactory(torch.jit.ScriptModule): 6373 @torch.jit.script_method 6374 def forward(self, x): 6375 zeros = torch.zeros_like( 6376 x, 6377 dtype=torch.float, 6378 layout=torch.strided, 6379 device=torch.device("cpu"), 6380 ) 6381 ones = torch.ones_like( 6382 x, 6383 dtype=torch.float, 6384 layout=torch.strided, 6385 device=torch.device("cpu"), 6386 ) 6387 return zeros + ones 6388 6389 x = torch.randn(2, 3, 4) 6390 self.run_test( 6391 TensorFactory(), x, input_names=["x"], dynamic_axes={"x": [0, 1, 2]} 6392 ) 6393 self.run_test(TensorFactory(), x, remained_onnx_input_idx=[]) 6394 6395 @skipIfUnsupportedMinOpsetVersion(13) 6396 def test_tensor_split(self): 6397 class TensorSplitModel(torch.nn.Module): 6398 def forward(self, input): 6399 return ( 6400 input.tensor_split([1, 3]), 6401 # test with output indexing. 6402 input.tensor_split([2, 4])[0], 6403 # test split on specific dim. 6404 input.tensor_split([1, 3, 4], dim=-2), 6405 # test split on specific dim and output indexing. 6406 input.tensor_split([0, 2], dim=-2)[-1], 6407 # test with out of bound end index (5). 6408 input.tensor_split([2, 3, 5]), 6409 ) 6410 6411 self.run_test(TensorSplitModel(), torch.randn(5, 4, 3)) 6412 6413 @skipIfUnsupportedMinOpsetVersion(13) 6414 def test_tensor_split_scalar(self): 6415 class TensorSplitModel(torch.nn.Module): 6416 def forward(self, x): 6417 return torch.tensor_split(x, x.size(1)) 6418 6419 self.run_test(TensorSplitModel(), torch.randn(1, 2, 3)) 6420 6421 @skipIfUnsupportedMinOpsetVersion(13) 6422 def test_tensor_split_dynamic_axes(self): 6423 class TensorSplitModel(torch.nn.Module): 6424 def forward(self, x): 6425 return x.tensor_split(1, dim=-1) 6426 6427 x = torch.randn(4, 384, 2) 6428 input_names = ["logits"] 6429 self.run_test( 6430 TensorSplitModel(), 6431 x, 6432 input_names=input_names, 6433 dynamic_axes={input_names[0]: {0: "batch"}}, 6434 ) 6435 6436 @skipIfUnsupportedMinOpsetVersion(9) 6437 def test_eye(self): 6438 class TensorFactory(torch.nn.Module): 6439 def forward(self, x): 6440 return ( 6441 torch.eye(x.size()[1], 3), 6442 torch.eye(4, 4, dtype=torch.long), 6443 torch.eye(x.size()[1], 2, dtype=torch.long), 6444 torch.eye(x.shape[0]), 6445 torch.eye(x.shape[0], dtype=torch.float64), 6446 ) 6447 6448 x = torch.randn(2, 3, 4) 6449 another_x = torch.randn(5, 6, 7) 6450 self.run_test( 6451 TensorFactory(), 6452 x, 6453 additional_test_inputs=[another_x], 6454 input_names=["input_1"], 6455 dynamic_axes={"input_1": [0, 1, 2]}, 6456 ) 6457 6458 @skipIfUnsupportedMinOpsetVersion(13) 6459 def test_diagonal(self): 6460 class DiagonalModel(torch.nn.Module): 6461 def forward(self, x): 6462 return torch.diagonal(x) 6463 6464 x = torch.randn(2, 4, 5, 2) 6465 # Other test inputs to test dynamic behavior 6466 another_x = torch.randn(5, 6, 7, 8) 6467 self.run_test( 6468 DiagonalModel(), 6469 x, 6470 additional_test_inputs=[another_x], 6471 input_names=["input_1"], 6472 dynamic_axes={"input_1": [0, 1, 2, 3]}, 6473 ) 6474 6475 class DiagonalModelNegOffset(torch.nn.Module): 6476 def forward(self, x): 6477 return torch.diagonal(x, offset=-1) 6478 6479 x = torch.randn(2, 4, 5, 2) 6480 # Other test inputs to test dynamic behavior 6481 another_x = torch.randn(5, 6, 7, 8) 6482 self.run_test( 6483 DiagonalModelNegOffset(), 6484 x, 6485 additional_test_inputs=[another_x], 6486 input_names=["input_1"], 6487 dynamic_axes={"input_1": [0, 1, 2, 3]}, 6488 ) 6489 6490 class DiagonalModelPosOffset(torch.nn.Module): 6491 def forward(self, x): 6492 return torch.diagonal(x, offset=1) 6493 6494 x = torch.randn(2, 4, 5, 2) 6495 # Other test inputs to test dynamic behavior 6496 another_x = torch.randn(5, 6, 7, 8) 6497 self.run_test( 6498 DiagonalModelPosOffset(), 6499 x, 6500 additional_test_inputs=[another_x], 6501 input_names=["input_1"], 6502 dynamic_axes={"input_1": [0, 1, 2, 3]}, 6503 ) 6504 6505 class DiagonalModelWithDims(torch.nn.Module): 6506 def forward(self, x): 6507 return torch.diagonal(x, offset=-1, dim1=1, dim2=2) 6508 6509 x = torch.randn(2, 4, 5, 2) 6510 # Other test inputs to test dynamic behavior 6511 another_x = torch.randn(5, 6, 7, 8) 6512 self.run_test( 6513 DiagonalModelWithDims(), 6514 x, 6515 additional_test_inputs=[another_x], 6516 input_names=["input_1"], 6517 dynamic_axes={"input_1": [0, 1, 2, 3]}, 6518 ) 6519 6520 class DiagonalModelWithNegativeDims(torch.nn.Module): 6521 def forward(self, x): 6522 return torch.diagonal(x, offset=0, dim1=-2, dim2=-1) 6523 6524 x = torch.randn(2, 4, 5, 2) 6525 # Other test inputs to test dynamic behavior 6526 another_x = torch.randn(5, 6, 7, 8) 6527 self.run_test( 6528 DiagonalModelWithNegativeDims(), 6529 x, 6530 additional_test_inputs=[another_x], 6531 input_names=["input_1"], 6532 dynamic_axes={"input_1": [0, 1, 2, 3]}, 6533 ) 6534 6535 class DiagonalModelOffsetOverrun(torch.nn.Module): 6536 def forward(self, x): 6537 return torch.diagonal(x, offset=-2), torch.diagonal(x, offset=5) 6538 6539 x = torch.randn(2, 4, 5, 2) 6540 # Other test inputs to test dynamic behavior 6541 another_x = torch.randn(5, 6, 7, 8) 6542 self.run_test( 6543 DiagonalModelOffsetOverrun(), 6544 x, 6545 additional_test_inputs=[another_x], 6546 input_names=["input_1"], 6547 dynamic_axes={"input_1": [0, 1, 2, 3]}, 6548 ) 6549 6550 @skipIfUnsupportedMinOpsetVersion(9) 6551 def test_inplace_zero(self): 6552 class Zero_(torch.nn.Module): 6553 def forward(self, x): 6554 return x.zero_(), x 6555 6556 x = torch.randn(2, 3, 4) 6557 self.run_test(Zero_(), x, input_names=["x"], dynamic_axes={"x": [0, 1, 2]}) 6558 self.run_test(Zero_(), x, remained_onnx_input_idx=[]) 6559 6560 @skipIfUnsupportedMinOpsetVersion(11) 6561 def test_inplace_zero_qkv(self): 6562 class Zero_(torch.nn.Module): 6563 def forward(self, x): 6564 return x[2:4].zero_() 6565 6566 x = torch.randn(24, 3, 4) 6567 self.run_test(Zero_(), x, input_names=["x"], dynamic_axes={"x": [0, 1, 2]}) 6568 6569 @skipIfUnsupportedMinOpsetVersion(9) 6570 def test_new_zeros(self): 6571 class Zero_(torch.nn.Module): 6572 def forward(self, x): 6573 return x.new_zeros(x.shape[1:2]), x.new_zeros( 6574 x.shape[2:], dtype=torch.long 6575 ) 6576 6577 x = torch.randn(2, 3, 4) 6578 self.run_test(Zero_(), x, input_names=["x"], dynamic_axes={"x": [0, 1, 2]}) 6579 self.run_test(Zero_(), x, remained_onnx_input_idx=[]) 6580 6581 @skipIfUnsupportedMinOpsetVersion(9) 6582 def test_new_zeros_with_dtype(self): 6583 class MyModel(torch.nn.Module): 6584 def __init__(self) -> None: 6585 super().__init__() 6586 self.emb = torch.nn.Embedding(50, 64) 6587 6588 def forward(self, x): 6589 inp = x.new_zeros(x.shape) 6590 return self.emb(inp) 6591 6592 model = MyModel() 6593 x = torch.Tensor([[2, 5, 6], [3, 2, 5]]).to(torch.int64) 6594 self.run_test(model, x, input_names=["x"], dynamic_axes={"x": [0, 1]}) 6595 6596 @skipIfUnsupportedMinOpsetVersion(9) 6597 def test_new_ones(self): 6598 class OnesModel(torch.nn.Module): 6599 def forward(self, x): 6600 return x.new_ones(x.shape[1:2]), x.new_ones( 6601 x.shape[2:], dtype=torch.long 6602 ) 6603 6604 x = torch.randn(2, 3, 4) 6605 self.run_test(OnesModel(), x, input_names=["x"], dynamic_axes={"x": [0, 1, 2]}) 6606 self.run_test(OnesModel(), x, remained_onnx_input_idx=[]) 6607 6608 @skipIfUnsupportedMinOpsetVersion(9) 6609 @skipScriptTest() # torch.zeros/torch.ones with size tensor of dim != 0 not scriptable. 6610 def test_zeros_ones_with_tensor_input(self): 6611 class ZeroAndOnes(torch.nn.Module): 6612 def forward(self, x): 6613 return torch.zeros(x, 1), torch.ones(x, 1) 6614 6615 x = torch.tensor([2]) 6616 self.run_test(ZeroAndOnes(), (x,)) 6617 6618 @skipIfUnsupportedMinOpsetVersion(9) 6619 @skipShapeChecking 6620 def test_tolist(self): 6621 class List(torch.jit.ScriptModule): 6622 @torch.jit.script_method 6623 def forward(self, input): 6624 res: List[int] = input.tolist() 6625 return res 6626 6627 self.run_test(List(), (torch.randint(100, (1,)),)) 6628 6629 @skipIfUnsupportedMinOpsetVersion(9) 6630 def test_list_pass(self): 6631 class Slice(torch.nn.Module): 6632 def forward(self, x, y): 6633 return x.new_zeros(x.shape[2:] + y.shape[1:]) 6634 6635 x = torch.randn(2, 3, 4, 5) 6636 y = torch.randn(1, 2, 3, 4) 6637 self.run_test( 6638 Slice(), 6639 (x, y), 6640 input_names=["x", "y"], 6641 dynamic_axes={"x": [0, 1, 2, 3], "y": [0, 1, 2, 3]}, 6642 ) 6643 self.run_test(Slice(), (x, y), remained_onnx_input_idx=[]) 6644 6645 class Size(torch.nn.Module): 6646 def forward(self, x, y): 6647 return x.new_zeros(x.shape + y.shape) 6648 6649 x = torch.randn(2, 3, 4) 6650 y = torch.randn(1, 2, 3) 6651 self.run_test( 6652 Size(), 6653 (x, y), 6654 input_names=["x", "y"], 6655 dynamic_axes={"x": [0, 1, 2], "y": [0, 1, 2]}, 6656 ) 6657 self.run_test(Size(), (x, y), remained_onnx_input_idx=[]) 6658 6659 class Array(torch.nn.Module): 6660 def forward(self, x, y): 6661 arr1 = [x.shape[0], x.shape[1], 2] 6662 arr2 = [y.shape[0], y.shape[1]] 6663 return x.new_zeros(arr1 + arr2) 6664 6665 x = torch.randn(2, 3, 4) 6666 y = torch.randn(1, 2, 3) 6667 self.run_test( 6668 Array(), 6669 (x, y), 6670 input_names=["x", "y"], 6671 dynamic_axes={"x": [0, 1, 2], "y": [0, 1, 2]}, 6672 ) 6673 self.run_test(Array(), (x, y), remained_onnx_input_idx=[]) 6674 6675 class List(torch.nn.Module): 6676 def forward(self, x, y): 6677 l1 = list(x.shape) 6678 l2 = list(y.shape) 6679 return x.new_zeros(l1 + l2) 6680 6681 x = torch.randn(2, 3, 4) 6682 y = torch.randn(1, 2, 3) 6683 self.run_test( 6684 List(), 6685 (x, y), 6686 input_names=["x", "y"], 6687 dynamic_axes={"x": [0, 1, 2], "y": [0, 1, 2]}, 6688 ) 6689 self.run_test(List(), (x, y), remained_onnx_input_idx=[]) 6690 6691 @skipIfUnsupportedMinOpsetVersion(9) 6692 def test_new_empty(self): 6693 class Emtpy(torch.nn.Module): 6694 def forward(self, x): 6695 return ( 6696 x.new_empty(x.shape[0]).fill_(0), 6697 x.new_empty(x.shape[0], dtype=torch.long) * 0, 6698 ) 6699 6700 x = torch.randn(2, 3, 4) 6701 self.run_test(Emtpy(), x, input_names=["x"], dynamic_axes={"x": [0, 1, 2]}) 6702 self.run_test(Emtpy(), x, remained_onnx_input_idx=[]) 6703 6704 @skipIfUnsupportedMinOpsetVersion(9) 6705 def test_new_full(self): 6706 class Full(torch.nn.Module): 6707 def forward(self, x): 6708 return x.new_full(x.shape[1:2], 5), x.new_full( 6709 x.shape[0:1], 1.3, dtype=torch.long 6710 ) 6711 6712 x = torch.randn(2, 3, 4) 6713 self.run_test(Full(), x, input_names=["x"], dynamic_axes={"x": [0, 1, 2]}) 6714 self.run_test(Full(), x, remained_onnx_input_idx=[]) 6715 6716 @skipIfUnsupportedMinOpsetVersion(9) 6717 def test_inplace_list(self): 6718 class Arithmetic(torch.jit.ScriptModule): 6719 @torch.jit.script_method 6720 def forward(self, x, y): 6721 return torch.cat([x.add_(3), y.fill_(0)]) 6722 6723 x = torch.randn(2, 3) 6724 y = torch.randn(2, 3) 6725 self.run_test( 6726 Arithmetic(), 6727 (x, y), 6728 input_names=["x", "y"], 6729 dynamic_axes={"x": [0, 1], "y": [0, 1]}, 6730 ) 6731 self.run_test(Arithmetic(), (x, y), remained_onnx_input_idx=[0]) 6732 6733 @skipIfUnsupportedMinOpsetVersion(9) 6734 def test_inplace_fill(self): 6735 class Fill_(torch.nn.Module): 6736 def forward(self, x): 6737 return x.fill_(3), x 6738 6739 x = torch.randn(2, 3, 4) 6740 self.run_test(Fill_(), x, input_names=["x"], dynamic_axes={"x": [0, 1, 2]}) 6741 self.run_test(Fill_(), x, remained_onnx_input_idx=[]) 6742 6743 def test_inplace_arithmetic(self): 6744 class Arithmetic(torch.jit.ScriptModule): 6745 @torch.jit.script_method 6746 def forward(self, x, y): 6747 x.add_(3) 6748 y.mul_(x) 6749 return x, y 6750 6751 x = torch.randn(2, 3, 4) 6752 y = torch.randn(2, 3, 4) 6753 self.run_test(Arithmetic(), (x, y)) 6754 6755 def test_inplace_arithmetic_half(self): 6756 class InplaceAddModel(torch.nn.Module): 6757 def forward(self, x, y): 6758 return x.add_(y) 6759 6760 class InplaceMulModel(torch.nn.Module): 6761 def forward(self, x, y): 6762 return x.mul_(y) 6763 6764 x = torch.randn(2, 2, dtype=torch.half) 6765 y = torch.randn(2, 2, dtype=torch.float) 6766 self.run_test(InplaceAddModel(), (x, y), rtol=1e-2, atol=1e-2) 6767 self.run_test(InplaceMulModel(), (x, y), rtol=1e-2, atol=1e-2) 6768 6769 @skipIfUnsupportedMinOpsetVersion(9) 6770 def test_inplace_with_loop(self): 6771 class M(torch.nn.Module): 6772 def forward(self, x): 6773 a = torch.ones( 6774 12, 6775 ) 6776 for i in range(10): 6777 a.add_( 6778 torch.ones( 6779 12, 6780 ) 6781 ) 6782 return a + x 6783 6784 m = M() 6785 x = torch.randn( 6786 12, 6787 ) 6788 self.run_test(torch.jit.script(M()), (x)) 6789 6790 @skipIfUnsupportedMinOpsetVersion(9) 6791 def test_inplace_with_loop_2(self): 6792 class M(torch.nn.Module): 6793 def forward(self, x): 6794 _bias = torch.ones( 6795 12, 6796 ) 6797 a = torch.ones( 6798 12, 6799 ) # used in loop, altered. 6800 a_ref = a # not used in loop, should be altered. 6801 b = x.clone() # used in loop, not be altered. 6802 b_ref = b # not used in loop, should not be altered. 6803 for i in range(10): 6804 if i == 3: 6805 for j in range(5): 6806 a += _bias 6807 _bias.add_( 6808 torch.ones( 6809 12, 6810 ) 6811 ) 6812 b = b + torch.ones( 6813 12, 6814 ) 6815 6816 _bias.add_( 6817 torch.ones( 6818 12, 6819 ) 6820 ) 6821 a += _bias 6822 # TODO: value for a_ref is incorrect. 6823 # a_ref += torch.ones(12,) 6824 b_ref += torch.ones( 6825 12, 6826 ) 6827 return _bias + x, a, b, b_ref 6828 6829 m = M() 6830 x = torch.zeros( 6831 12, 6832 ) 6833 self.run_test(torch.jit.script(M()), (x)) 6834 6835 @skipIfUnsupportedMinOpsetVersion(11) 6836 def test_inplace_attr_with_loop(self): 6837 class M(torch.nn.Module): 6838 def __init__(self) -> None: 6839 super().__init__() 6840 self._bias = torch.arange( 6841 12, 6842 ) 6843 6844 def forward(self, x): 6845 self._bias = torch.arange( 6846 12, 6847 ) 6848 for i in range(10): 6849 if i == 3: 6850 for j in range(5): 6851 self._bias += torch.arange( 6852 12, 6853 ) 6854 return self._bias + x 6855 6856 m = M() 6857 x = torch.zeros( 6858 12, 6859 ) 6860 self.run_test(torch.jit.script(M()), (x)) 6861 6862 @skipIfUnsupportedMinOpsetVersion(11) 6863 def test_inplace_attr_copy_with_loop(self): 6864 class M(torch.nn.Module): 6865 def __init__(self) -> None: 6866 super().__init__() 6867 self._bias = torch.arange( 6868 12, 6869 ) 6870 6871 def forward(self, x): 6872 self._bias = torch.arange( 6873 12, 6874 ) 6875 for i in range(10): 6876 if i == 3: 6877 for j in range(5): 6878 self._bias.copy_( 6879 torch.arange( 6880 12, 6881 ) 6882 ) 6883 self._bias.copy_( 6884 self._bias 6885 + torch.arange( 6886 12, 6887 ) 6888 ) 6889 6890 self._bias.copy_( 6891 self._bias 6892 + torch.arange( 6893 12, 6894 ) 6895 ) 6896 return self._bias + x 6897 6898 m = M() 6899 x = torch.zeros( 6900 12, 6901 ) 6902 self.run_test(torch.jit.script(M()), (x)) 6903 6904 @skipIfUnsupportedMinOpsetVersion(14) # Need onnx::Identity of sequence in opset 14 6905 def test_inplace_sequence_with_loop(self): 6906 class M(torch.nn.Module): 6907 def process(self, beam_hyps: List[Tensor], done: Tensor, x): 6908 batch_size = x.shape[0] 6909 for i in range(batch_size): 6910 if done[i]: 6911 continue 6912 6913 beam_idx = 0 6914 for _, token in enumerate(x[i]): 6915 beam_hyps.append(token) 6916 beam_idx += 1 6917 6918 if beam_idx == 6: 6919 break 6920 6921 done[i] = len(beam_hyps) > 4 6922 6923 return beam_hyps, done 6924 6925 def forward(self, x): 6926 beam_hyps: List[Tensor] = [] 6927 batch_size = x.shape[0] 6928 cur_len = 0 6929 max_len = x.shape[1] 6930 done = torch.zeros(batch_size, dtype=torch.bool) 6931 while cur_len < max_len: 6932 beam_hyps, done = self.process(beam_hyps, done, x[:, 0, :]) 6933 cur_len = cur_len + 1 6934 6935 return beam_hyps 6936 6937 m = torch.jit.script(M()) 6938 x = torch.randn(8, 4, 3) 6939 self.run_test(torch.jit.script(M()), (x)) 6940 6941 @skipScriptTest() # Sort with dynamic dim not supported in ONNX 6942 def test_sort(self): 6943 class SortModel(torch.nn.Module): 6944 def forward(self, x): 6945 out = [] 6946 for i in range(-2, 2): 6947 out.append(torch.sort(x, dim=i, descending=True)) 6948 return out 6949 6950 x = torch.randn(3, 4) 6951 self.run_test(SortModel(), x) 6952 6953 @skipIfUnsupportedMinOpsetVersion(11) 6954 @skipScriptTest() # Sort with dynamic dim not supported in ONNX 6955 def test_sort_ascending(self): 6956 class SortModel(torch.nn.Module): 6957 def forward(self, x): 6958 out = [] 6959 for i in range(-2, 2): 6960 out.append(torch.sort(x, dim=i, descending=False)) 6961 return out 6962 6963 x = torch.randn(3, 4) 6964 self.run_test(SortModel(), x) 6965 6966 @skipIfUnsupportedMinOpsetVersion(11) 6967 def test_argsort(self): 6968 class ArgSortModel(torch.nn.Module): 6969 def forward(self, x): 6970 return torch.argsort(x, dim=1, descending=False) 6971 6972 x = torch.randn(3, 4) 6973 self.run_test(ArgSortModel(), x) 6974 6975 @skipIfUnsupportedMinOpsetVersion(9) 6976 def test_masked_fill(self): 6977 class MaskedFillModel(torch.nn.Module): 6978 def forward(self, x): 6979 mask = torch.tensor([[0, 0, 1], [1, 1, 0]], dtype=torch.bool) 6980 return x.masked_fill(mask, 2) 6981 6982 x = torch.zeros(4, 2, 3, requires_grad=True) 6983 self.run_test(MaskedFillModel(), x) 6984 6985 class MaskedFillModel2(torch.nn.Module): 6986 def forward(self, x): 6987 return x.masked_fill(x > 3, -1) 6988 6989 x = torch.arange(16).view(2, 2, 4).to(torch.float32) 6990 self.run_test(MaskedFillModel2(), x) 6991 6992 @skipIfUnsupportedMinOpsetVersion(9) 6993 def test_masked_fill_inplace(self): 6994 class MaskedFillModel(torch.jit.ScriptModule): 6995 @torch.jit.script_method 6996 def forward(self, x): 6997 mask = torch.tensor([[0, 0, 1], [1, 1, 0]], dtype=torch.bool) 6998 x.masked_fill_(mask, 2) 6999 return x 7000 7001 x = torch.zeros(4, 2, 3, requires_grad=True) 7002 self.run_test(MaskedFillModel(), x) 7003 7004 class MaskedFillModel2(torch.jit.ScriptModule): 7005 @torch.jit.script_method 7006 def forward(self, x): 7007 x.masked_fill_(x > 3, -1) 7008 return x 7009 7010 x = torch.arange(16).view(2, 2, 4).to(torch.float32) 7011 self.run_test(MaskedFillModel2(), x) 7012 7013 @skipIfUnsupportedMinOpsetVersion(11) 7014 def test_masked_scatter(self): 7015 class MaskedScatterModel(torch.nn.Module): 7016 def forward(self, x): 7017 return torch.masked_scatter(x, x.ge(0.5), torch.ones(100, 100) * 5) 7018 7019 x = torch.randn(3, 4, 5, requires_grad=True) 7020 self.run_test(MaskedScatterModel(), x) 7021 7022 @skipIfUnsupportedMinOpsetVersion(11) 7023 def test_masked_select(self): 7024 class MaskedSelectModel(torch.nn.Module): 7025 def forward(self, x): 7026 return torch.masked_select(x, x.ge(0.5)) 7027 7028 x = torch.randn(3, 4, 5, requires_grad=True) 7029 self.run_test(MaskedSelectModel(), x) 7030 7031 @skipIfUnsupportedMinOpsetVersion(11) 7032 def test_index_put_to_masked_fill(self): 7033 class MaskedFillModel(torch.nn.Module): 7034 def forward(self, input_mask, some_const): 7035 mask = input_mask.clone() 7036 mask[mask != some_const] = 1 7037 mask[mask == some_const] = 0 7038 return mask 7039 7040 mask = torch.randn(2, 2, 2, requires_grad=True) 7041 constant = torch.tensor(5, dtype=torch.float) 7042 self.run_test(MaskedFillModel(), (mask, constant)) 7043 7044 @skipIfUnsupportedMinOpsetVersion(11) 7045 def test_index_put_to_masked_scatter(self): 7046 class MaskedScatterModel(torch.nn.Module): 7047 def forward(self, input_mask, some_const): 7048 mask = input_mask.clone() 7049 mask[mask != some_const] = torch.ones(8) 7050 return mask 7051 7052 mask = torch.randn(2, 2, 2, requires_grad=True) 7053 constant = torch.tensor(5, dtype=torch.float) 7054 self.run_test(MaskedScatterModel(), (mask, constant)) 7055 7056 @skipIfUnsupportedMinOpsetVersion(11) 7057 def test_index_put_with_1d_mask_to_masked_scatter(self): 7058 class MaskedScatterModel(torch.nn.Module): 7059 def forward(self, tensor, mask, some_const): 7060 tensor[mask] = some_const 7061 return tensor 7062 7063 mask = torch.tensor([0, 1, 0, 1, 0, 1, 0, 1], dtype=torch.bool) 7064 tensor = torch.randn(8, 4, 5, requires_grad=True) 7065 some_const = torch.randn(4, 4, 5, dtype=torch.float) 7066 self.run_test(MaskedScatterModel(), (tensor, mask, some_const)) 7067 7068 @skipIfUnsupportedMinOpsetVersion(9) 7069 def test_pixel_shuffle(self): 7070 class PixelShuffle(torch.nn.Module): 7071 def forward(self, x): 7072 return torch.pixel_shuffle(x, upscale_factor=2) 7073 7074 x = torch.randn(2, 16, 4, 3, requires_grad=True) 7075 y = torch.randn(4, 32, 8, 4, requires_grad=True) 7076 self.run_test(PixelShuffle(), x) 7077 self.run_test( 7078 PixelShuffle(), 7079 x, 7080 input_names=["x"], 7081 dynamic_axes={"x": [0, 1, 2, 3]}, 7082 additional_test_inputs=[y], 7083 ) 7084 7085 @skipIfUnsupportedMinOpsetVersion(9) 7086 def test_pixel_unshuffle(self): 7087 class PixelUnshuffle(torch.nn.Module): 7088 def forward(self, x): 7089 return torch.pixel_unshuffle(x, downscale_factor=2) 7090 7091 x = torch.randn(2, 16, 4, 6, requires_grad=True) 7092 y = torch.randn(4, 32, 8, 4, requires_grad=True) 7093 self.run_test(PixelUnshuffle(), x) 7094 self.run_test( 7095 PixelUnshuffle(), 7096 x, 7097 input_names=["x"], 7098 dynamic_axes={"x": [0, 1, 2, 3]}, 7099 additional_test_inputs=[y], 7100 ) 7101 7102 @skipIfUnsupportedMinOpsetVersion(9) 7103 def test_reciprocal(self): 7104 class ReciprocalModel(torch.nn.Module): 7105 def forward(self, x): 7106 return torch.reciprocal(x) 7107 7108 model = ReciprocalModel() 7109 x = torch.tensor([2, 4]) 7110 self.run_test(model, x.to(torch.long)) 7111 self.run_test(model, x.to(torch.float)) 7112 self.run_test(model, x.to(torch.double)) 7113 7114 @skipIfUnsupportedMinOpsetVersion(9) 7115 def test_scalar_type(self): 7116 class ArithmeticModel(torch.nn.Module): 7117 def forward(self, x): 7118 return x.size(0) * 2 * x, 2 - x 7119 7120 x = torch.ones(2, 3, dtype=torch.float32) 7121 self.run_test(ArithmeticModel(), x) 7122 7123 class ComparisonModel(torch.nn.Module): 7124 def forward(self, x, y): 7125 a = torch.tensor([12.0]) 7126 return x.lt(1.5) & y.le(2) & x.le(1), x.gt(y), x.lt(y), a.ge(x.size(0)) 7127 7128 x = torch.ones(2, 3, dtype=torch.int32) 7129 y = torch.ones(2, 3, dtype=torch.float32) 7130 self.run_test(ComparisonModel(), (x, y)) 7131 7132 class MatMulModel(torch.nn.Module): 7133 def forward(self, x): 7134 return torch.mm(x, x) + x + torch.mm(x, x) + x 7135 7136 x = torch.ones(3, 3) 7137 self.run_test(MatMulModel(), x) 7138 7139 class AddMMModel(torch.nn.Module): 7140 def forward(self, x): 7141 return torch.mm(x, x) + x 7142 7143 x = torch.ones(3, 3) 7144 self.run_test(AddMMModel(), x) 7145 7146 class FullModel(torch.nn.Module): 7147 # add is used for exporting full 7148 def forward(self, x): 7149 return torch.full((3, 4), x) 7150 7151 x = torch.tensor(12.0) 7152 self.run_test(FullModel(), x) 7153 7154 class CatModel(torch.nn.Module): 7155 def forward(self, fp16, fp32): 7156 return torch.cat([fp16, fp32]) 7157 7158 fp16 = Tensor([0.5]) 7159 fp16 = fp16.half() 7160 fp32 = Tensor([1.5]) 7161 self.run_test(CatModel(), (fp16, fp32)) 7162 7163 @skipIfUnsupportedMinOpsetVersion(9) 7164 def test_scalar_type_does_not_trigger_upcast_type_promotion(self): 7165 class DoNotUpcastModel(torch.nn.Module): 7166 def forward(self, x): 7167 scale = x.size()[-1] ** -0.5 7168 # 'scale' is exported as onnx float32 rank 0 tensor. 7169 # The following 'Mul' should NOT be promoted to float32. 7170 return x * scale 7171 7172 x = torch.ones(2, 3, dtype=torch.float16) 7173 self.run_test(DoNotUpcastModel(), x) 7174 7175 @skipIfUnsupportedMinOpsetVersion(9) 7176 def test_scalar_type_promotion_onnx_where_two_prim_const(self): 7177 class TwoPrimConstCastWhereModel(torch.nn.Module): 7178 def forward(self, c): 7179 return torch.where(c, 0, 1.0) 7180 7181 c = torch.ones(8, dtype=torch.bool) 7182 self.run_test(TwoPrimConstCastWhereModel(), (c)) 7183 7184 @skipIfUnsupportedMinOpsetVersion(9) 7185 def test_scalar_type_promotion_onnx_where_one_prim_const(self): 7186 class OnePrimConstCastWhereModel(torch.nn.Module): 7187 def forward(self, c, x): 7188 return torch.where(c, x, 1.0) 7189 7190 c = torch.ones(8, dtype=torch.bool) 7191 x = torch.ones(8, dtype=torch.float16) 7192 self.run_test(OnePrimConstCastWhereModel(), (c, x)) 7193 7194 @skipIfUnsupportedMinOpsetVersion(9) 7195 def test_scalar_type_promotion_onnx_where_one_tensor_const(self): 7196 class OneTensorConstCastWhereModel(torch.nn.Module): 7197 def forward(self, c, x): 7198 return torch.where(c, x, torch.ones(size=(), dtype=torch.float64)) 7199 7200 c = torch.ones(8, dtype=torch.bool) 7201 x = torch.ones(8, dtype=torch.float16) 7202 self.run_test(OneTensorConstCastWhereModel(), (c, x)) 7203 7204 @skipIfUnsupportedMinOpsetVersion(9) 7205 def test_scalar_type_upcast_type_promotion_onnx_where_no_const(self): 7206 class OnnxWhereUpcastModel(torch.nn.Module): 7207 def forward(self, c, x, y): 7208 return torch.where(c, x, y) 7209 7210 c = torch.ones(8, dtype=torch.bool) 7211 x = torch.ones(8, dtype=torch.float16) 7212 y = torch.ones(8, dtype=torch.float32) 7213 7214 self.run_test(OnnxWhereUpcastModel(), (c, x, y)) 7215 7216 @skipIfUnsupportedMinOpsetVersion(9) 7217 def test_full_like(self): 7218 class FullLikeModel(torch.nn.Module): 7219 def forward(self, x): 7220 return torch.full_like(x, 1.3, dtype=torch.int) 7221 7222 x = torch.tensor(12) 7223 self.run_test(FullLikeModel(), x) 7224 7225 @skipIfUnsupportedMinOpsetVersion(9) 7226 @skipDtypeChecking 7227 def test_full_like_value(self): 7228 class FullLikeModel(torch.nn.Module): 7229 def forward(self, x, y): 7230 out = y + 2 7231 return torch.full_like(x, out) 7232 7233 x = torch.tensor(12) 7234 y = torch.tensor(2) 7235 self.run_test(FullLikeModel(), (x, y)) 7236 7237 def test_l1_norm(self): 7238 class NormModel(torch.nn.Module): 7239 def forward(self, x): 7240 return torch.norm(x, p=1, dim=-1, keepdim=False) 7241 7242 x = torch.randn(4, 2, 3, requires_grad=True) 7243 self.run_test(NormModel(), x) 7244 7245 def test_l2_norm(self): 7246 class NormModel(torch.nn.Module): 7247 def forward(self, x): 7248 return torch.norm(x, p=2, dim=-2, keepdim=False) 7249 7250 x = torch.randn(4, 2, 3, requires_grad=True) 7251 self.run_test(NormModel(), x) 7252 7253 def test_frobenius_norm(self): 7254 class NormModel(torch.nn.Module): 7255 def forward(self, x): 7256 return torch.norm(x, p="fro", dim=0, keepdim=False) 7257 7258 x = torch.randn(4, 2, 3, requires_grad=True) 7259 self.run_test(NormModel(), x) 7260 7261 def test_frobenius_norm_keepdim(self): 7262 class NormModel(torch.nn.Module): 7263 def forward(self, x): 7264 return torch.norm(x, p="fro", dim=(0, 1), keepdim=True) 7265 7266 x = torch.randn(4, 2, 3, requires_grad=True) 7267 self.run_test(NormModel(), x) 7268 7269 def test_unfold(self): 7270 class UnfoldModel(torch.nn.Module): 7271 def forward(self, x): 7272 return x.unfold(dimension=2, size=2, step=2) 7273 7274 x = torch.randn(4, 2, 3, requires_grad=True) 7275 y = torch.randn(2, 1, 3, requires_grad=True) 7276 self.run_test( 7277 UnfoldModel(), 7278 x, 7279 dynamic_axes={"x": [0, 1]}, 7280 input_names=["x"], 7281 additional_test_inputs=[y], 7282 ) 7283 7284 def test_unfold_infer_shape(self): 7285 class UnfoldModule(torch.jit.ScriptModule): 7286 def __init__(self) -> None: 7287 super().__init__() 7288 self.conv = torch.nn.Conv1d(3, 1, 3, stride=2) 7289 7290 @torch.jit.script_method 7291 def forward(self, x): 7292 x = self.conv(x) 7293 return x.unfold(dimension=2, size=2, step=2) 7294 7295 x = torch.randn(32, 3, 64) 7296 self.run_test(UnfoldModule(), x) 7297 7298 @skipIfUnsupportedMinOpsetVersion(12) 7299 def test_unfold_dynamic_inputs(self): 7300 class UnfoldModel(torch.nn.Module): 7301 def forward(self, x): 7302 return x.unfold(dimension=2, size=x.shape[1], step=x.shape[1] - 1) 7303 7304 x = torch.randn(4, 2, 4, requires_grad=True) 7305 self.run_test(UnfoldModel(), x) 7306 7307 class UnfoldModel(torch.nn.Module): 7308 def forward(self, x): 7309 return x.unfold(dimension=2, size=x.shape[1], step=1) 7310 7311 x = torch.randn(4, 2, 4, requires_grad=True) 7312 self.run_test(UnfoldModel(), x) 7313 7314 @skipIfUnsupportedMinOpsetVersion(9) # MatMul long inputs is added in ONNX opset 9. 7315 def test_mv(self): 7316 class MatmulModel(torch.nn.Module): 7317 def forward(self, input, other): 7318 return torch.mv(input, other) 7319 7320 x = torch.randn(4, 5, requires_grad=True) 7321 y = torch.randn(5, requires_grad=True) 7322 self.run_test(MatmulModel(), (x, y)) 7323 7324 x = torch.randint(10, (4, 5)) 7325 y = torch.randint(10, (5,)) 7326 self.run_test(MatmulModel(), (x, y)) 7327 7328 @skipIfUnsupportedMinOpsetVersion(9) # MatMul long inputs is added in ONNX opset 9. 7329 def test_dot(self): 7330 class MatmulModel(torch.nn.Module): 7331 def forward(self, input, other): 7332 return torch.dot(input, other) 7333 7334 x = torch.randn(5, requires_grad=True) 7335 y = torch.randn(5, requires_grad=True) 7336 self.run_test(MatmulModel(), (x, y)) 7337 7338 x = torch.randint(10, (5,)) 7339 y = torch.randint(10, (5,)) 7340 self.run_test(MatmulModel(), (x, y)) 7341 7342 @skipScriptTest() # SpectralNorm not TorchScript compatible. 7343 def test_spectral_norm(self): 7344 m = torch.nn.utils.spectral_norm(torch.nn.Linear(2, 4)) 7345 7346 x = torch.randn(6, 2) 7347 self.run_test(m, (x,)) 7348 7349 def test_prelu(self): 7350 class PReluModel(torch.nn.Module): 7351 def __init__(self) -> None: 7352 super().__init__() 7353 self.prelu = torch.nn.PReLU() 7354 7355 def forward(self, x): 7356 return self.prelu(x) 7357 7358 x = torch.randn(2, 3, 4) 7359 y = torch.randn(2, 4, 5) 7360 self.run_test( 7361 PReluModel(), 7362 x, 7363 input_names=["x"], 7364 dynamic_axes={"x": [1, 2]}, 7365 additional_test_inputs=[y], 7366 ) 7367 7368 def test_prelu_scalar(self): 7369 x = torch.scalar_tensor(1.0) 7370 self.run_test(torch.nn.PReLU(), x, input_names=["x"]) 7371 7372 def test_relu6(self): 7373 class Relu6Model(torch.nn.Module): 7374 def __init__(self) -> None: 7375 super().__init__() 7376 self.relu6 = torch.nn.ReLU6() 7377 7378 def forward(self, x): 7379 return self.relu6(x) 7380 7381 x = torch.randn(2, 3, 4) * 100.0 7382 y = torch.randn(2, 4, 5) * 100.0 7383 self.run_test( 7384 Relu6Model(), 7385 x, 7386 input_names=["x"], 7387 dynamic_axes={"x": [1, 2]}, 7388 additional_test_inputs=[y], 7389 ) 7390 7391 def test_silu(self): 7392 class SiLUModel(torch.nn.Module): 7393 def __init__(self) -> None: 7394 super().__init__() 7395 self.silu = torch.nn.SiLU() 7396 7397 def forward(self, x): 7398 return self.silu(x) 7399 7400 x = torch.randn(2, 3, 4) 7401 self.run_test(SiLUModel(), (x)) 7402 7403 @skipIfUnsupportedMinOpsetVersion(14) 7404 def test_tril(self): 7405 class trilModel(torch.nn.Module): 7406 def forward(self, x): 7407 return torch.tril(x) 7408 7409 x = torch.randn(2, 3, 4) 7410 self.run_test(trilModel(), (x)) 7411 7412 class trilModelwithDiagonal(torch.nn.Module): 7413 def forward(self, x): 7414 return torch.tril(x, diagonal=1) 7415 7416 x = torch.randn(2, 3, 4) 7417 self.run_test(trilModelwithDiagonal(), (x)) 7418 7419 class trilModelwithNegDiagonal(torch.nn.Module): 7420 def forward(self, x): 7421 return torch.tril(x, diagonal=-1) 7422 7423 x = torch.randn(2, 3, 4) 7424 self.run_test(trilModelwithNegDiagonal(), (x)) 7425 7426 class trilModelWithDiagonalInput(torch.nn.Module): 7427 def forward(self, x, diagnonal: int): 7428 return torch.tril(x, diagonal=diagnonal) 7429 7430 x = torch.randn(2, 3, 4) 7431 self.run_test(trilModelWithDiagonalInput(), (x, 5)) 7432 7433 @skipIfUnsupportedMinOpsetVersion(14) 7434 def test_triu(self): 7435 class triuModel(torch.nn.Module): 7436 def forward(self, x): 7437 return torch.triu(x) 7438 7439 x = torch.randn(2, 3, 4) 7440 self.run_test(triuModel(), (x)) 7441 7442 class triuModelwithDiagonal(torch.nn.Module): 7443 def forward(self, x): 7444 return torch.triu(x, diagonal=1) 7445 7446 x = torch.randn(2, 3, 4) 7447 self.run_test(triuModelwithDiagonal(), (x)) 7448 7449 class triuModelwithNegDiagonal(torch.nn.Module): 7450 def forward(self, x): 7451 return torch.triu(x, diagonal=-1) 7452 7453 x = torch.randn(2, 3, 4) 7454 self.run_test(triuModelwithNegDiagonal(), (x)) 7455 7456 class triuModelWithDiagonalInput(torch.nn.Module): 7457 def forward(self, x, diagnonal: int): 7458 return torch.triu(x, diagonal=diagnonal) 7459 7460 x = torch.randn(2, 3, 4) 7461 self.run_test(triuModelWithDiagonalInput(), (x, 5)) 7462 7463 def test_mish(self): 7464 class MishModel(torch.nn.Module): 7465 def __init__(self) -> None: 7466 super().__init__() 7467 self.mish = torch.nn.Mish() 7468 7469 def forward(self, x): 7470 return self.mish(x) 7471 7472 x = torch.randn(2, 3, 4) 7473 self.run_test(MishModel(), (x)) 7474 7475 def test_remainder(self): 7476 class RemainderModel(torch.nn.Module): 7477 def forward(self, input, other): 7478 return torch.remainder(input, other) 7479 7480 x = torch.randn(4, 2, 3) 7481 y = torch.randn(1, 2, 1) 7482 self.run_test(RemainderModel(), (x, y)) 7483 7484 x = torch.tensor([7, 6, -7, -6], dtype=torch.long) 7485 y = torch.tensor([2], dtype=torch.long) 7486 self.run_test(RemainderModel(), (x, y)) 7487 7488 x = x.to(torch.float) 7489 self.run_test(RemainderModel(), (x, y)) 7490 7491 y = y.to(torch.float) 7492 self.run_test(RemainderModel(), (x, y)) 7493 7494 x = x.to(torch.int32) 7495 self.run_test(RemainderModel(), (x, y)) 7496 7497 def test_remainder_scalar(self): 7498 class RemainderModel(torch.nn.Module): 7499 def __init__(self, scalar=2.55): 7500 super().__init__() 7501 self.scalar = scalar 7502 7503 def forward(self, input): 7504 return torch.remainder(input, self.scalar) 7505 7506 x = torch.randint(10, (2, 3)) 7507 self.run_test(RemainderModel(), x) 7508 7509 x = torch.tensor([7, 6, -7, -6], dtype=torch.long) 7510 self.run_test(RemainderModel(2), x) 7511 7512 @skipIfUnsupportedMinOpsetVersion(10) 7513 def test_fmod(self): 7514 class FModModel(torch.nn.Module): 7515 def forward(self, input, other): 7516 return torch.fmod(input, other) 7517 7518 x = torch.randn(4, 2, 3) 7519 y = torch.randn(1, 2, 1) 7520 self.run_test(FModModel(), (x, y)) 7521 7522 @skipIfUnsupportedMinOpsetVersion(10) 7523 def test_fmod_scalar(self): 7524 class FModModel(torch.nn.Module): 7525 def forward(self, input): 7526 return torch.fmod(input, 2.55) 7527 7528 x = torch.randint(10, (2, 3)) 7529 self.run_test(FModModel(), x) 7530 7531 @skipIfUnsupportedMinOpsetVersion(9) 7532 def test_glu(self): 7533 class GluModel(torch.nn.Module): 7534 def forward(self, x): 7535 return torch.nn.functional.glu(x) 7536 7537 x = torch.randn(2, 4, 5, 6, requires_grad=True) 7538 self.run_test(GluModel(), x) 7539 7540 @skipIfUnsupportedMinOpsetVersion(9) 7541 def test_gelu(self): 7542 class GeluModel(torch.nn.Module): 7543 def forward(self, x): 7544 return torch.nn.functional.gelu(x, approximate="none") 7545 7546 x = torch.randn(2, 4, 5, 6, requires_grad=True) 7547 self.run_test(GeluModel(), x) 7548 7549 @skipIfUnsupportedMinOpsetVersion(9) 7550 def test_tanh_gelu(self): 7551 class GeluModel(torch.nn.Module): 7552 def forward(self, x): 7553 return torch.nn.functional.gelu(x, approximate="tanh") 7554 7555 x = torch.randn(2, 4, 5, 6, requires_grad=True) 7556 self.run_test(GeluModel(), x) 7557 7558 def test_add_inplace(self): 7559 class InplaceAddModel(torch.nn.Module): 7560 def forward(self, x): 7561 x += 12 7562 return x 7563 7564 x = torch.randn(4, 2, 3, requires_grad=True) 7565 self.run_test(InplaceAddModel(), x) 7566 7567 def test_addcmul(self): 7568 class AddcmulModel(torch.nn.Module): 7569 def forward(self, x, t1, t2): 7570 return torch.addcmul(x, t1, t2), torch.addcmul(x, t1, t2, value=2.2) 7571 7572 x = torch.randn(1, 3) 7573 t1 = torch.randn(3, 1) 7574 t2 = torch.randn(1, 3) 7575 self.run_test(AddcmulModel(), (x, t1, t2)) 7576 7577 def test_rsqrt(self): 7578 class RsqrtModel(torch.nn.Module): 7579 def forward(self, x): 7580 return x.rsqrt() 7581 7582 x = torch.randn(4, 2, 3, requires_grad=True, dtype=torch.float64) 7583 self.run_test(RsqrtModel(), x) 7584 7585 def test_rsqrt_zeros(self): 7586 class RsqrtModel(torch.nn.Module): 7587 def forward(self, x): 7588 return x.rsqrt() 7589 7590 x = torch.zeros(4, 2, 3, requires_grad=True, dtype=torch.float64) 7591 self.run_test(RsqrtModel(), x) 7592 7593 @skipIfUnsupportedMinOpsetVersion(11) 7594 def test_unique(self): 7595 class UniqueModel(torch.nn.Module): 7596 def forward(self, x): 7597 return torch.unique( 7598 x, sorted=True, return_inverse=False, return_counts=True 7599 ) 7600 7601 x = torch.tensor([1, 3, 2, 3], dtype=torch.long) 7602 self.run_test(UniqueModel(), x) 7603 7604 @skipIfUnsupportedMinOpsetVersion(11) 7605 def test_unique_along_dim(self): 7606 class UniqueModel(torch.nn.Module): 7607 def forward(self, x): 7608 return torch.unique( 7609 x, dim=0, sorted=True, return_inverse=True, return_counts=False 7610 ) 7611 7612 x = torch.tensor([1, 3, 2, 3], dtype=torch.long) 7613 self.run_test(UniqueModel(), x) 7614 7615 @skipIfUnsupportedMinOpsetVersion(11) 7616 def test_cumsum(self): 7617 class CumSum(torch.nn.Module): 7618 def forward(self, input): 7619 return torch.cumsum(input, dim=0) 7620 7621 x = torch.randn(2, 3, 4) 7622 model = CumSum() 7623 self.run_test(model, x) 7624 7625 @skipIfUnsupportedMinOpsetVersion(11) 7626 def test_cumsum_with_cast(self): 7627 class CumSum(torch.nn.Module): 7628 def forward(self, input): 7629 return torch.cumsum(input, dim=0, dtype=torch.float32) 7630 7631 model = CumSum() 7632 x = torch.tensor([2, 3, 4], dtype=torch.int32) 7633 self.run_test(model, x) 7634 x = torch.tensor([False, True, True]) 7635 self.run_test(model, x) 7636 7637 @skipScriptTest() # error in propagate as assign input shape 7638 @skipIfUnsupportedMinOpsetVersion(10) 7639 def test_embedding_bag(self): 7640 model = torch.nn.EmbeddingBag(10, 5, mode="sum", scale_grad_by_freq=True) 7641 input = torch.randint(10, (7,)) 7642 offset = torch.tensor([0, 2, 5, 6]) 7643 self.run_test(model, (input, offset)) 7644 7645 model = torch.nn.EmbeddingBag(10, 5, mode="sum", include_last_offset=True) 7646 input = torch.randint(10, (7,)) 7647 offset = torch.tensor([0, 2, 5, 6]) 7648 self.run_test(model, (input, offset)) 7649 7650 model = torch.nn.EmbeddingBag(10, 5, mode="max") 7651 input = torch.randint(10, (7, 5)) 7652 self.run_test(model, (input)) 7653 7654 @skipIfUnsupportedMinOpsetVersion(11) 7655 def test_embedding_bag_1d_per_sample_weights(self): 7656 class EmbeddingModel(torch.nn.Module): 7657 def forward(self, embedding_matrix, input, offset, weights): 7658 return torch.nn.functional.embedding_bag( 7659 input, 7660 embedding_matrix, 7661 offsets=offset, 7662 mode="sum", 7663 per_sample_weights=weights, 7664 ) 7665 7666 model = EmbeddingModel() 7667 x = torch.randint(7, (6,)) 7668 w = torch.randn( 7669 6, 7670 ) 7671 offset = torch.tensor([0, 2, 5]) 7672 embedding_matrix = torch.rand(10, 15) 7673 self.run_test(model, (embedding_matrix, x, offset, w)) 7674 7675 @skipIfUnsupportedMinOpsetVersion(11) 7676 @unittest.skip( 7677 "This test is broken with ONNXRuntime(17): " 7678 "when running with onnxruntime 1.17.0 this test fails with the following error:" 7679 "FAIL : Non-zero status code returned while running If node. " 7680 "Name:'/If' Status Message: if.cc:253 Compute " 7681 "If nodes condition input must have exactly one element" 7682 "https://github.com/pytorch/pytorch/issues/119442" 7683 ) 7684 def test_embedding_bag_2d_per_sample_weights(self): 7685 class EmbeddingModel(torch.nn.Module): 7686 def forward(self, embedding_matrix, input, weights): 7687 return torch.nn.functional.embedding_bag( 7688 input, embedding_matrix, mode="sum", per_sample_weights=weights 7689 ) 7690 7691 embedding_matrix = torch.rand(10, 15) 7692 model = EmbeddingModel() 7693 x = torch.randint(7, (2, 3)) 7694 w = torch.randn(2, 3) 7695 7696 x2 = torch.randint(7, (4, 3)) 7697 w2 = torch.randn(4, 3) 7698 self.run_test( 7699 model, 7700 (embedding_matrix, x, w), 7701 input_names=["embed", "x", "w"], 7702 dynamic_axes={"x": [0], "w": [0]}, 7703 additional_test_inputs=[(embedding_matrix, x2, w2)], 7704 ) 7705 7706 @skipScriptTest() # scripting prim::Uninitialized, prim::dtype, prim::unchecked_cast 7707 @skipIfUnsupportedMinOpsetVersion(11) 7708 @unittest.skip( 7709 "Due to ONNX Loop shape inference issue. " 7710 "https://msdata.visualstudio.com/Vienna/_workitems/edit/1352001" 7711 ) 7712 def test_embedding_bag_dynamic_input(self): 7713 class EmbeddingModel1D(torch.nn.Module): 7714 def forward(self, embedding_matrix, input, weights, offsets): 7715 return torch.nn.functional.embedding_bag( 7716 input, 7717 embedding_matrix, 7718 offsets=offsets, 7719 mode="sum", 7720 per_sample_weights=weights, 7721 ) 7722 7723 model = EmbeddingModel1D() 7724 x = torch.randint(7, (6,)) 7725 w = torch.randn( 7726 6, 7727 ) 7728 offsets = torch.tensor([0, 2, 5], dtype=torch.long) 7729 embedding_matrix = torch.rand(10, 15) 7730 x2 = torch.randint(7, (2,)) 7731 w2 = torch.randn( 7732 2, 7733 ) 7734 embedding_matrix2 = torch.rand(12, 25) 7735 offsets2 = torch.tensor( 7736 [ 7737 0, 7738 ], 7739 dtype=torch.long, 7740 ) 7741 self.run_test( 7742 model, 7743 (embedding_matrix, x, w, offsets), 7744 additional_test_inputs=[(embedding_matrix2, x2, w2, offsets2)], 7745 input_names=["embedding_matrix", "x", "offsets", "w"], 7746 dynamic_axes={ 7747 "embedding_matrix": [0, 1], 7748 "x": [0], 7749 "offsets": [0], 7750 "w": [0], 7751 }, 7752 ) 7753 7754 class EmbeddingModel2D(torch.nn.Module): 7755 def forward(self, embedding_matrix, input, weights): 7756 return torch.nn.functional.embedding_bag( 7757 input, embedding_matrix, mode="sum", per_sample_weights=weights 7758 ) 7759 7760 model = EmbeddingModel2D() 7761 x = torch.randint(7, (2, 3)) 7762 w = torch.randn(2, 3) 7763 embedding_matrix = torch.rand(10, 15) 7764 x2 = torch.randint(7, (3, 5)) 7765 w2 = torch.randn(3, 5) 7766 embedding_matrix2 = torch.rand(12, 25) 7767 self.run_test( 7768 model, 7769 (embedding_matrix, x, w), 7770 additional_test_inputs=[(embedding_matrix2, x2, w2)], 7771 input_names=["embedding_matrix", "x", "w"], 7772 dynamic_axes={"embedding_matrix": [0, 1], "x": [0, 1], "w": [0, 1]}, 7773 ) 7774 7775 @skipIfUnsupportedMinOpsetVersion(8) 7776 def test_meshgrid(self): 7777 class Meshgrid(torch.nn.Module): 7778 def forward(self, x, y, z): 7779 output1, output2, output3 = torch.meshgrid(x, y, z) 7780 return output1, output2, output3 7781 7782 x = torch.randn(3, requires_grad=True) 7783 y = torch.zeros(4, requires_grad=True) 7784 z = torch.randn(5, requires_grad=True) 7785 self.run_test(Meshgrid(), (x, y, z)) 7786 7787 @skipIfUnsupportedMinOpsetVersion(8) 7788 def test_meshgrid_indexing(self): 7789 class Meshgrid(torch.nn.Module): 7790 def __init__(self, indexing): 7791 super().__init__() 7792 self.indexing = indexing 7793 7794 def forward(self, x, y, z): 7795 output1, output2, output3 = torch.meshgrid( 7796 x, y, z, indexing=self.indexing 7797 ) 7798 return output1, output2, output3 7799 7800 x = torch.randn(5, requires_grad=True) 7801 y = torch.zeros(6, requires_grad=True) 7802 z = torch.randn(7, requires_grad=True) 7803 for indexing in ("xy", "ij"): 7804 self.run_test(Meshgrid(indexing), (x, y, z)) 7805 7806 @skipIfUnsupportedMinOpsetVersion(8) 7807 def test_meshgrid_scalar(self): 7808 class Meshgrid(torch.nn.Module): 7809 def forward(self, x, y, z): 7810 output1, output2, output3 = torch.meshgrid(x, y, z) 7811 return output1, output2, output3 7812 7813 x = torch.ones(3, requires_grad=True) 7814 y = torch.zeros(4, requires_grad=True) 7815 z = torch.tensor(2.0) 7816 self.run_test(Meshgrid(), (x, y, z)) 7817 7818 def test_baddbmm(self): 7819 class MyModule(torch.nn.Module): 7820 def forward(self, input, batch1, batch2): 7821 return torch.baddbmm( 7822 input, batch1, batch2, alpha=torch.tensor(5), beta=3.5 7823 ) 7824 7825 x = torch.randn(10, 3, 5) 7826 batch1 = torch.randn(10, 3, 4) 7827 batch2 = torch.randn(10, 4, 5) 7828 model = MyModule() 7829 self.run_test(model, (x, batch1, batch2)) 7830 7831 def test_baddbmm_dynamic(self): 7832 class MyModule(torch.nn.Module): 7833 def forward(self, input, batch1, batch2, alpha, beta): 7834 return torch.baddbmm(input, batch1, batch2, alpha=alpha, beta=beta) 7835 7836 x = torch.randn(10, 3, 5) 7837 batch1 = torch.randn(10, 3, 4) 7838 batch2 = torch.randn(10, 4, 5) 7839 alpha = torch.tensor(5) 7840 beta = torch.tensor(3.5) 7841 model = MyModule() 7842 self.run_test(model, (x, batch1, batch2, alpha, beta)) 7843 7844 def test_numel(self): 7845 class MyModule(torch.nn.Module): 7846 def forward(self, input): 7847 return input.numel() * input 7848 7849 x = torch.randn(2, 3, 5) 7850 x2 = torch.randn(4, 5, 6) 7851 model = MyModule() 7852 self.run_test( 7853 model, 7854 (x,), 7855 input_names=["x"], 7856 dynamic_axes={"x": [0, 1, 2]}, 7857 additional_test_inputs=[(x2,)], 7858 ) 7859 7860 def test_numel_empty(self): 7861 class MyModule(torch.nn.Module): 7862 def forward(self, input): 7863 return input.numel() * input 7864 7865 x = torch.randn(0) 7866 x2 = torch.randn(4) 7867 model = MyModule() 7868 self.run_test( 7869 model, 7870 (x,), 7871 input_names=["x"], 7872 dynamic_axes={"x": [0]}, 7873 additional_test_inputs=[(x2,)], 7874 ) 7875 7876 def test_dtype(self): 7877 class MyModel(torch.jit.ScriptModule): 7878 @torch.jit.script_method 7879 def forward(self, input, other): 7880 return input.to(dtype=other.dtype) + other 7881 7882 x = torch.randn(2, 3) 7883 y = torch.randn(2, 3) 7884 self.run_test(MyModel(), (x, y)) 7885 7886 def test_dtype_eq(self): 7887 class MyModel(torch.jit.ScriptModule): 7888 @torch.jit.script_method 7889 def forward(self, input, other): 7890 if input.dtype == other.dtype: 7891 return input + other 7892 return input 7893 7894 x = torch.randn(2, 3) 7895 y = torch.randn(2, 3) 7896 self.run_test(MyModel(), (x, y)) 7897 7898 def test_cast_to(self): 7899 class MyModule(torch.jit.ScriptModule): 7900 @torch.jit.script_method 7901 def forward(self, input, other): 7902 return input.to(other) + other 7903 7904 x = torch.randn(2, 3, 4) 7905 y = torch.tensor([1], dtype=torch.int64) 7906 model = MyModule() 7907 self.run_test(model, (x, y)) 7908 7909 def test_cast_to_bool(self): 7910 class MyModule(torch.nn.Module): 7911 def forward(self, input, other): 7912 return torch.cat((input.to(other), other), 0) 7913 7914 x = torch.randn(2, 3, 4) 7915 y = torch.zeros([2, 3, 4], dtype=torch.bool) 7916 model = MyModule() 7917 self.run_test(model, (x, y)) 7918 7919 # ONNX supports bfloat16 for opsets >= 13 7920 @skipIfUnsupportedMinOpsetVersion(13) 7921 def test_cast_type_as_with_bfloat16(self): 7922 class MyModule(torch.nn.Module): 7923 def forward(self, x): 7924 y = torch.ones((3, 4), dtype=torch.bfloat16) 7925 x = x.type_as(y) 7926 return x.to(dtype=torch.float16) 7927 7928 x = torch.ones(3, 4, dtype=torch.float16) 7929 model = MyModule() 7930 self.run_test(model, x) 7931 7932 @skipIfUnsupportedMinOpsetVersion(9) 7933 def test_type_as(self): 7934 class MyModule(torch.nn.Module): 7935 def forward(self, x): 7936 y = torch.tensor([1.0]) 7937 return x.type_as(y) 7938 7939 a = torch.tensor([True, False], dtype=torch.bool) 7940 b = torch.randn(3, 4, dtype=torch.double) 7941 c = torch.ones((2, 2), dtype=torch.int64) 7942 model = MyModule() 7943 self.run_test(model, a) 7944 self.run_test(model, b) 7945 self.run_test(model, c) 7946 7947 @skipIfUnsupportedMinOpsetVersion(9) 7948 def test_ones_bool(self): 7949 class MyModule(torch.nn.Module): 7950 def forward(self, input): 7951 true = torch.ones(input.shape, dtype=torch.bool) 7952 return input.to(true) & true 7953 7954 x = torch.randn(2, 3, 4) 7955 model = MyModule() 7956 self.run_test(model, x) 7957 7958 def test_log(self): 7959 class Log(torch.nn.Module): 7960 def forward(self, input): 7961 return torch.log(input) 7962 7963 x = torch.rand(2, 3, 4) 7964 model = Log() 7965 self.run_test(model, x) 7966 7967 def test_log1p(self): 7968 class Log1p(torch.nn.Module): 7969 def forward(self, input): 7970 return torch.log1p(input) 7971 7972 x = torch.rand(2, 3, 4) 7973 model = Log1p() 7974 self.run_test(model, x) 7975 7976 def test_log10(self): 7977 class Log10(torch.nn.Module): 7978 def forward(self, input): 7979 return torch.log10(input) 7980 7981 x = torch.rand(2, 3, 4) 7982 model = Log10() 7983 self.run_test(model, x) 7984 7985 def test_log2(self): 7986 class Log2(torch.nn.Module): 7987 def forward(self, input): 7988 return torch.log2(input) 7989 7990 x = torch.tensor(1.0) 7991 model = Log2() 7992 self.run_test(model, x) 7993 7994 @skipIfUnsupportedMinOpsetVersion(11) 7995 def test_round(self): 7996 class Round(torch.nn.Module): 7997 def forward(self, x): 7998 return torch.round(x) 7999 8000 x = torch.tensor([0.9920, -1.0362, -1.5000, 3.5000], requires_grad=True) 8001 self.run_test(Round(), x) 8002 8003 int_x = torch.tensor([9920, 1036, -1500, 35], dtype=torch.int32) 8004 self.run_test(Round(), int_x) 8005 8006 @skipIfUnsupportedMinOpsetVersion(11) 8007 def test_round_with_decimals(self): 8008 class Round(torch.nn.Module): 8009 def __init__(self, decimals): 8010 super().__init__() 8011 self.decimals = decimals 8012 8013 def forward(self, x): 8014 return torch.round(x, decimals=self.decimals) 8015 8016 x = torch.tensor([0.9920, -1234.0362, -1.58960, 3.5000]) 8017 for decimals in (0, -2, 3): 8018 self.run_test(Round(decimals), x) 8019 8020 @skipIfUnsupportedMinOpsetVersion(17) 8021 def test_stft_default(self): 8022 class STFT(torch.nn.Module): 8023 def forward(self, x): 8024 n_fft = 16 8025 return torch.stft(x, n_fft=n_fft, center=False, return_complex=False) 8026 8027 x = torch.randn((1, 32), requires_grad=True) 8028 self.run_test(STFT(), x, atol=1e-6) 8029 8030 @skipIfUnsupportedMinOpsetVersion(17) 8031 def test_stft_hop_length(self): 8032 class STFT(torch.nn.Module): 8033 def forward(self, x): 8034 n_fft = 16 8035 hop_length = 4 8036 return torch.stft( 8037 x, 8038 n_fft=n_fft, 8039 center=False, 8040 hop_length=hop_length, 8041 return_complex=False, 8042 ) 8043 8044 x = torch.randn((1, 32), requires_grad=True) 8045 self.run_test(STFT(), x, atol=1e-6) 8046 8047 @skipIfUnsupportedMinOpsetVersion(17) 8048 def test_stft_non_divisible_hop_length(self): 8049 class STFT(torch.nn.Module): 8050 def forward(self, x): 8051 n_fft = 16 8052 hop_length = 5 8053 return torch.stft( 8054 x, 8055 n_fft=n_fft, 8056 center=False, 8057 hop_length=hop_length, 8058 return_complex=False, 8059 ) 8060 8061 x = torch.randn((1, 32), requires_grad=True) 8062 self.run_test(STFT(), x, atol=1e-6) 8063 8064 @skipIfUnsupportedMinOpsetVersion(17) 8065 def test_stft_window_int_same_size(self): 8066 class STFT(torch.nn.Module): 8067 def forward(self, x): 8068 n_fft = 16 8069 win_length = 16 8070 return torch.stft( 8071 x, 8072 n_fft=n_fft, 8073 center=False, 8074 win_length=win_length, 8075 return_complex=False, 8076 ) 8077 8078 x = torch.randn((1, 32), requires_grad=True) 8079 self.run_test(STFT(), x, atol=1e-6) 8080 8081 @skipIfUnsupportedMinOpsetVersion(17) 8082 def test_stft_window_int_different_size(self): 8083 class STFT(torch.nn.Module): 8084 def forward(self, x): 8085 n_fft = 16 8086 win_length = 9 8087 return torch.stft( 8088 x, 8089 n_fft=n_fft, 8090 center=False, 8091 win_length=win_length, 8092 return_complex=False, 8093 ) 8094 8095 x = torch.randn((1, 32), requires_grad=True) 8096 self.run_test(STFT(), x, atol=1e-6) 8097 8098 @skipIfUnsupportedMinOpsetVersion(17) 8099 def test_stft_window_custom(self): 8100 class STFT(torch.nn.Module): 8101 def forward(self, x): 8102 n_fft = 16 8103 window = torch.hann_window(16) 8104 return torch.stft( 8105 x, 8106 n_fft=n_fft, 8107 center=False, 8108 window=window, 8109 return_complex=False, 8110 ) 8111 8112 x = torch.randn((1, 32), requires_grad=True) 8113 self.run_test(STFT(), x, atol=1e-6) 8114 8115 @skipIfUnsupportedMinOpsetVersion(17) 8116 def test_stft_wrong_custom_window_size(self): 8117 class STFT(torch.nn.Module): 8118 def forward(self, x): 8119 n_fft = 16 8120 window = torch.hann_window(10) 8121 return torch.stft( 8122 x, n_fft=n_fft, window=window, center=False, return_complex=False 8123 ) 8124 8125 x = torch.randn((1, 32), requires_grad=True) 8126 with self.assertRaises((AssertionError, RuntimeError)): 8127 self.run_test(STFT(), x) 8128 8129 @skipIfUnsupportedMinOpsetVersion(17) 8130 def test_stft_wrong_window_length(self): 8131 class STFT(torch.nn.Module): 8132 def forward(self, x): 8133 n_fft = 16 8134 win_len = 17 8135 return torch.stft( 8136 x, 8137 n_fft=n_fft, 8138 win_length=win_len, 8139 center=False, 8140 return_complex=False, 8141 ) 8142 8143 x = torch.randn((1, 32), requires_grad=True) 8144 with self.assertRaises(RuntimeError): 8145 self.run_test(STFT(), x) 8146 8147 @skipIfUnsupportedMinOpsetVersion(17) 8148 def test_stft_window_size_with_win_len(self): 8149 class STFT(torch.nn.Module): 8150 def forward(self, x): 8151 n_fft = 16 8152 window = torch.hann_window(10) 8153 win_len = 10 8154 return torch.stft( 8155 x, 8156 n_fft=n_fft, 8157 window=window, 8158 win_length=win_len, 8159 center=False, 8160 return_complex=False, 8161 ) 8162 8163 x = torch.randn((1, 32), requires_grad=True) 8164 self.run_test(STFT(), x, atol=1e-6) 8165 8166 @skipIfUnsupportedMinOpsetVersion(17) 8167 def test_stft_one_dimension(self): 8168 class STFT(torch.nn.Module): 8169 def forward(self, x): 8170 n_fft = 16 8171 return torch.stft( 8172 x, 8173 n_fft=n_fft, 8174 center=False, 8175 return_complex=False, 8176 ) 8177 8178 x = torch.randn((32), requires_grad=True) 8179 self.run_test(STFT(), x, atol=1e-6) 8180 8181 @skipIfUnsupportedMinOpsetVersion(17) 8182 def test_stft_wrong_input_size(self): 8183 class STFT(torch.nn.Module): 8184 def forward(self, x): 8185 n_fft = 16 8186 return torch.stft(x, n_fft=n_fft, center=False, return_complex=False) 8187 8188 x = torch.randn((1, 1, 32), requires_grad=True) 8189 with self.assertRaises(RuntimeError): 8190 self.run_test(STFT(), x) 8191 8192 @skipIfUnsupportedMinOpsetVersion(17) 8193 def test_stft_wrong_return_complex(self): 8194 class STFT(torch.nn.Module): 8195 def forward(self, x): 8196 n_fft = 16 8197 return torch.stft(x, n_fft=n_fft, center=False, return_complex=True) 8198 8199 x = torch.randn((1, 32), requires_grad=True) 8200 with self.assertRaises(errors.SymbolicValueError): 8201 self.run_test(STFT(), x) 8202 8203 @skipIfUnsupportedMinOpsetVersion(17) 8204 def test_stft_normalize(self): 8205 class STFT(torch.nn.Module): 8206 def forward(self, x): 8207 n_fft = 16 8208 return torch.stft( 8209 x, 8210 n_fft=n_fft, 8211 center=False, 8212 normalized=True, 8213 return_complex=False, 8214 ) 8215 8216 x = torch.randn((32), requires_grad=True) 8217 self.run_test(STFT(), x, atol=1e-6) 8218 8219 @skipIfUnsupportedMinOpsetVersion(17) 8220 def test_stft_not_onesided(self): 8221 class STFT(torch.nn.Module): 8222 def forward(self, x): 8223 n_fft = 16 8224 return torch.stft( 8225 x, 8226 n_fft=n_fft, 8227 center=False, 8228 onesided=False, 8229 return_complex=False, 8230 ) 8231 8232 x = torch.randn((32), requires_grad=True) 8233 self.run_test(STFT(), x, atol=1e-6) 8234 8235 def test_constant_pad(self): 8236 model = torch.nn.ConstantPad1d(2, 3.5) 8237 x = torch.randn(2, 4, 4) 8238 self.run_test(model, x) 8239 8240 model = torch.nn.ConstantPad2d((3, 0, 2, 1), 3.5) 8241 x = torch.randn(2, 2, 4, 4) 8242 self.run_test(model, x) 8243 8244 @common_utils.parametrize( 8245 "pad", 8246 [ 8247 common_utils.subtest([2, 4], name="scalar_list"), 8248 common_utils.subtest( 8249 [ 8250 torch.tensor(2, dtype=torch.int64), 8251 torch.tensor(4, dtype=torch.int64), 8252 ], 8253 name="scalar_tensor_list", 8254 ), 8255 ], 8256 ) 8257 @skipIfUnsupportedMinOpsetVersion(11) # Dynamic padding is added in opset 11 8258 def test_pad_types(self, pad): 8259 # Test for different pad integer types 8260 class Pad(torch.nn.Module): 8261 def forward(self, x, pad: List[int]): 8262 return torch.nn.functional.pad(x, pad) 8263 8264 x = torch.randn(2, 2, 4, 4) 8265 self.run_test(Pad(), (x, pad)) 8266 8267 @skipIfUnsupportedMinOpsetVersion(11) 8268 def test_pad_circular(self): 8269 class PadModel(torch.nn.Module): 8270 def forward(self, x): 8271 out = torch.nn.functional.pad(x, (1, 2, 1, 2), mode="circular") 8272 return out 8273 8274 x = torch.randn(2, 3, 3, 4) 8275 self.run_test(PadModel(), (x)) 8276 8277 @skipIfUnsupportedMinOpsetVersion(11) 8278 def test_pad_circular_negative(self): 8279 # Test for different pad integer types 8280 class PadModel(torch.nn.Module): 8281 def forward(self, x): 8282 out = torch.nn.functional.pad(x, (-1, -2), mode="circular") 8283 return out 8284 8285 x = torch.randn(2, 3, 6) 8286 self.run_test(PadModel(), (x)) 8287 8288 @skipIfUnsupportedMinOpsetVersion(11) 8289 def test_pad_circular_dynamic_axes(self): 8290 class PadModel(torch.nn.Module): 8291 def forward(self, x): 8292 out = torch.nn.functional.pad(x, (2, 1, 2, 1), mode="circular") 8293 return out 8294 8295 x = torch.randn(4, 3, 5, 6) 8296 self.run_test( 8297 PadModel(), 8298 x, 8299 input_names=["input_1"], 8300 dynamic_axes={"input_1": [0, 1, 2, 3]}, 8301 ) 8302 8303 @skipIfUnsupportedMaxOpsetVersion(10) 8304 @skipScriptTest() # TODO: the logic in symbolic_opset9 doesn't handle script 8305 def test_unsupported_pad(self): 8306 class Pad(torch.nn.Module): 8307 def forward(self, x, pad: List[int]): 8308 return torch.nn.functional.pad(x, pad) 8309 8310 x = torch.randn(2, 2, 4, 4) 8311 y = [2, 4] 8312 8313 with self.assertRaisesRegex( 8314 RuntimeError, 8315 ( 8316 "Unsupported: ONNX export of Pad.*" 8317 + "The sizes of the padding must be constant" 8318 ), 8319 ): 8320 self.run_test(Pad(), (x, y)) 8321 8322 @skipIfUnsupportedMinOpsetVersion(9) 8323 def test_if_fold(self): 8324 class IfFoldModel(torch.nn.Module): 8325 def forward(self, y): 8326 if y.dim() == 2: 8327 y = y + 4 8328 y = y + 2 8329 else: 8330 y = y - 1 8331 return y 8332 8333 x = torch.ones((3, 4), dtype=torch.int) 8334 self.run_test(IfFoldModel(), x) 8335 8336 class IfFoldModel(torch.nn.Module): 8337 def forward(self, y): 8338 if y.numel() > 1: 8339 y = y + 4 8340 else: 8341 y = y + 2 8342 return y 8343 8344 x = torch.ones((3, 4), dtype=torch.int) 8345 self.run_test(IfFoldModel(), x) 8346 8347 class IfFoldModel(torch.nn.Module): 8348 def forward(self, y): 8349 if y.dim() != 3: 8350 y = y + 4 8351 y = y + 2 8352 else: 8353 return y 8354 return y 8355 8356 x = torch.ones((3, 4), dtype=torch.int) 8357 self.run_test(IfFoldModel(), x) 8358 8359 class IfFoldModel(torch.nn.Module): 8360 def forward(self, y): 8361 if y.dim() >= 1: 8362 y = y + 4 8363 else: 8364 y = y - 1 8365 return y 8366 8367 x = torch.ones((3, 4), dtype=torch.int) 8368 self.run_test(IfFoldModel(), x) 8369 8370 class IfFoldModel(torch.nn.Module): 8371 def forward(self, y): 8372 if y.dim() <= 1: 8373 y = y + 4 8374 else: 8375 y = y + 2 8376 return y 8377 8378 x = torch.ones((3, 4), dtype=torch.int) 8379 self.run_test(IfFoldModel(), x) 8380 8381 class IfFoldModel(torch.nn.Module): 8382 def forward(self, y): 8383 if y.dim() < 3 and y.dtype == torch.int: 8384 y = y + 4 8385 y = y + 2 8386 else: 8387 return y 8388 return y 8389 8390 x = torch.ones((3, 4), dtype=torch.int) 8391 self.run_test(IfFoldModel(), x) 8392 8393 class IfFoldModel(torch.nn.Module): 8394 def forward(self, y): 8395 if y.dim() == 3 and y.dtype == torch.int: 8396 y = y + 4 8397 y = y + 2 8398 else: 8399 y = y + 1 8400 return y 8401 8402 x = torch.ones((3, 4), dtype=torch.int) 8403 self.run_test(IfFoldModel(), x) 8404 8405 class IfFoldModel(torch.nn.Module): 8406 def forward(self, y): 8407 if y.numel() != 0 and y.dim() == 2: 8408 y = y + 4 8409 y = y + 2 8410 else: 8411 return y 8412 return y 8413 8414 x = torch.ones((3, 4), dtype=torch.int) 8415 self.run_test(IfFoldModel(), x) 8416 8417 class IfFoldModel(torch.nn.Module): 8418 def forward(self, x, y): 8419 if x.numel() == y.numel(): 8420 y = x + y 8421 else: 8422 y = y - x 8423 return y 8424 8425 x = torch.ones((3, 4), dtype=torch.int) 8426 y = torch.ones((3, 4), dtype=torch.int) 8427 self.run_test(IfFoldModel(), (x, y)) 8428 8429 class IfFoldModel(torch.nn.Module): 8430 def forward(self, x, y): 8431 if x.numel() != y.numel(): 8432 y = x + y 8433 else: 8434 y = y - x 8435 return y 8436 8437 x = torch.ones((3, 4), dtype=torch.int) 8438 y = torch.ones((3, 4), dtype=torch.int) 8439 self.run_test(IfFoldModel(), (x, y)) 8440 8441 @skipIfUnsupportedMinOpsetVersion(11) 8442 def test_uninitialized(self): 8443 class UninitializedModel(torch.nn.Module): 8444 def forward(self, y): 8445 if y.shape[1] < 5: 8446 if y.size(0) == 1: 8447 y = y + 4 8448 else: 8449 return y 8450 return y 8451 8452 x = torch.ones((3, 4), dtype=torch.int) 8453 self.run_test(UninitializedModel(), x) 8454 8455 @skipIfUnsupportedMinOpsetVersion(11) 8456 def test_uninitialized_dynamic(self): 8457 class UninitializedModel(torch.nn.Module): 8458 def forward(self, y): 8459 if y.shape[1] < 5: 8460 if y.size(0) == 1: 8461 y = y + 4 8462 else: 8463 return y 8464 return y 8465 8466 x = torch.ones((3, 4), dtype=torch.int) 8467 y = torch.ones((6, 7), dtype=torch.int) 8468 self.run_test( 8469 UninitializedModel(), 8470 x, 8471 additional_test_inputs=[y], 8472 input_names=["input_1"], 8473 dynamic_axes={"input_1": [0, 1]}, 8474 ) 8475 8476 # onnx::Identity of sequence supported for ONNX opset >= 14 8477 @skipIfUnsupportedMinOpsetVersion(14) 8478 def test_uninitialized_tensorList(self): 8479 class UninitializedTensorListModel(torch.nn.Module): 8480 def forward(self, x): 8481 if x[0].shape[0] < 5: 8482 if x.size(0) == 1: 8483 x = x + 4 8484 else: 8485 return [x] 8486 return [x] 8487 8488 x = torch.ones((3, 4), dtype=torch.int) 8489 self.run_test(torch.jit.script(UninitializedTensorListModel()), x) 8490 8491 # onnx::Identity of sequence supported for ONNX opset >= 14 8492 @skipIfUnsupportedMinOpsetVersion(14) 8493 def test_uninitialized_tensorList_dynamic(self): 8494 class UninitializedTensorListModel(torch.nn.Module): 8495 def forward(self, x): 8496 if x[0].shape[0] < 5: 8497 if x.size(0) == 1: 8498 x += x 8499 else: 8500 return list(x) 8501 return list(x) 8502 8503 x = torch.ones((3, 4), dtype=torch.double) 8504 self.run_test( 8505 torch.jit.script(UninitializedTensorListModel()), 8506 x, 8507 input_names=["input_1"], 8508 dynamic_axes={"input_1": [0, 1]}, 8509 ) 8510 8511 # onnx::Identity of sequence supported for ONNX opset >= 14 8512 @skipIfUnsupportedMinOpsetVersion(14) 8513 def test_uninitialized_intList(self): 8514 class UninitializedListModel(torch.nn.Module): 8515 def forward(self, x): 8516 y = list(range(x.size(0))) 8517 if y[0] < 5: 8518 # if x.size(0) != 3, ORT will throw type error. 8519 if x.size(0) == 3: 8520 y.append(10) 8521 else: 8522 return y 8523 return y 8524 8525 x = torch.ones((3, 4), dtype=torch.int) 8526 self.run_test( 8527 torch.jit.script(UninitializedListModel()), 8528 x, 8529 input_names=["input_1"], 8530 dynamic_axes={"input_1": [0, 1]}, 8531 ) 8532 8533 # onnx::Identity of sequence supported for ONNX opset >= 14 8534 @skipIfUnsupportedMinOpsetVersion(14) 8535 def test_uninitialized_tensorList_shape(self): 8536 class UninitializedModel(torch.nn.Module): 8537 def forward(self, x): 8538 if x.shape[1] < 5: 8539 if x.size(0) == 1: 8540 x = x + 4 8541 else: 8542 x_list = list(x) 8543 x_list.append(x) 8544 return x_list 8545 return [x, x] 8546 8547 x = torch.ones((3, 4), dtype=torch.int) 8548 y = torch.ones((4, 6), dtype=torch.int) 8549 self.run_test( 8550 torch.jit.script(UninitializedModel()), 8551 x, 8552 additional_test_inputs=[y], 8553 input_names=["input_1"], 8554 dynamic_axes={"input_1": [0, 1]}, 8555 ) 8556 8557 # Sequence type as loop-carried dependencies only supported for ONNX opset >= 13 8558 @skipIfUnsupportedMinOpsetVersion(13) 8559 def test_sequance_loopcarried(self): 8560 class SequanceLoopModel(torch.nn.Module): 8561 def forward(self, x): 8562 outputs = [] 8563 for i in range(3): 8564 outputs += [x] 8565 return torch.stack(outputs).transpose(0, 1) 8566 8567 x = torch.ones((3, 4), dtype=torch.int) 8568 self.run_test(torch.jit.script(SequanceLoopModel()), x) 8569 8570 def test_reflection_pad(self): 8571 model = torch.nn.ReflectionPad1d(2) 8572 x = torch.randn(2, 4, 4) 8573 self.run_test(model, x) 8574 8575 model = torch.nn.ReflectionPad2d((3, 0, 2, 1)) 8576 x = torch.randn(2, 2, 4, 4) 8577 self.run_test(model, x) 8578 8579 def test_replication_pad(self): 8580 model = torch.nn.ReplicationPad1d(2) 8581 x = torch.randn(2, 4, 4) 8582 self.run_test(model, x) 8583 8584 model = torch.nn.ReplicationPad2d((3, 0, 2, 1)) 8585 x = torch.randn(2, 2, 4, 4) 8586 self.run_test(model, x) 8587 8588 @skipIfUnsupportedMinOpsetVersion(11) 8589 def test_im2col(self): 8590 class Unfold(torch.nn.Module): 8591 def forward(self, input): 8592 return ( 8593 torch.nn.functional.unfold( 8594 input, kernel_size=(10, 15), dilation=2, padding=5, stride=3 8595 ), 8596 torch.nn.functional.unfold( 8597 input, kernel_size=(2, 2), dilation=1, padding=0, stride=3 8598 ), 8599 torch.nn.functional.unfold( 8600 input, kernel_size=(1, 1), dilation=5, padding=2, stride=3 8601 ), 8602 ) 8603 8604 x = torch.rand(1, 1, 200, 100) 8605 self.run_test(Unfold(), x) 8606 8607 @skipIfNoLapack 8608 @skipIfUnsupportedMinOpsetVersion(11) 8609 def test_det(self): 8610 class Det(torch.nn.Module): 8611 def forward(self, x): 8612 return torch.linalg.det(x) 8613 8614 x = torch.randn(2, 3, 5, 5) 8615 self.run_test(Det(), x) 8616 8617 def test_linalg_norm(self): 8618 class LinalgSingleDimModel(torch.nn.Module): 8619 def __init__(self, ord_val): 8620 super().__init__() 8621 self.ord = ord_val 8622 8623 def forward(self, x): 8624 return torch.linalg.norm(x, ord=self.ord, dim=1) 8625 8626 x = torch.randn(2, 3, 5, 5) 8627 self.run_test(LinalgSingleDimModel(None), x) 8628 self.run_test(LinalgSingleDimModel(2), x) 8629 self.run_test(LinalgSingleDimModel(float("inf")), x) 8630 self.run_test(LinalgSingleDimModel(-float("inf")), x) 8631 self.run_test(LinalgSingleDimModel(-4), x) 8632 self.run_test(LinalgSingleDimModel(1.5), x) 8633 8634 class LinalgMultiDimModel(torch.nn.Module): 8635 def __init__(self, ord_val): 8636 super().__init__() 8637 self.ord = ord_val 8638 8639 def forward(self, x): 8640 return torch.linalg.norm(x, ord=self.ord, dim=(0, 2)) 8641 8642 x = torch.randn(2, 3, 5, 5) 8643 self.run_test(LinalgMultiDimModel("fro"), x) 8644 self.run_test(LinalgMultiDimModel(float("inf")), x) 8645 self.run_test(LinalgMultiDimModel(-float("inf")), x) 8646 self.run_test(LinalgMultiDimModel(1), x) 8647 self.run_test(LinalgMultiDimModel(-1), x) 8648 8649 class LinalgNoDimNoOrdModel(torch.nn.Module): 8650 def forward(self, x): 8651 return torch.linalg.norm(x) 8652 8653 x = torch.randn(2, 3, 5, 5) 8654 self.run_test(LinalgNoDimNoOrdModel(), x) 8655 y = torch.randn(2, 3) 8656 self.run_test(LinalgNoDimNoOrdModel(), y) 8657 z = torch.randn(2) 8658 self.run_test(LinalgNoDimNoOrdModel(), z) 8659 8660 class LinalgNoDim1DModel(torch.nn.Module): 8661 def __init__(self, ord_val): 8662 super().__init__() 8663 self.ord = ord_val 8664 8665 def forward(self, x): 8666 return torch.linalg.norm(x, ord=self.ord) 8667 8668 x = torch.randn(2) 8669 self.run_test(LinalgNoDim1DModel(None), x) 8670 self.run_test(LinalgNoDim1DModel(2), x) 8671 self.run_test(LinalgNoDim1DModel(float("inf")), x) 8672 self.run_test(LinalgNoDim1DModel(-float("inf")), x) 8673 self.run_test(LinalgNoDim1DModel(-4), x) 8674 self.run_test(LinalgNoDim1DModel(1.5), x) 8675 8676 class LinalgNoDim2DModel(torch.nn.Module): 8677 def __init__(self, ord_val): 8678 super().__init__() 8679 self.ord = ord_val 8680 8681 def forward(self, x): 8682 return torch.linalg.norm(x, ord=self.ord) 8683 8684 x = torch.randn(2, 3) 8685 self.run_test(LinalgNoDim2DModel("fro"), x) 8686 self.run_test(LinalgNoDim2DModel(float("inf")), x) 8687 self.run_test(LinalgNoDim2DModel(-float("inf")), x) 8688 self.run_test(LinalgNoDim2DModel(1), x) 8689 self.run_test(LinalgNoDim2DModel(-1), x) 8690 8691 @skipIfUnsupportedMinOpsetVersion(11) 8692 def test_linalg_vector_norm_zero(self): 8693 class LinalgVectorNormModel(torch.nn.Module): 8694 def __init__(self, ord_val): 8695 super().__init__() 8696 self.ord = ord_val 8697 8698 def forward(self, x): 8699 return torch.linalg.vector_norm(x, ord=self.ord) 8700 8701 x = torch.randn(2, 3, 5, 5) 8702 self.run_test(LinalgVectorNormModel(0), x) 8703 8704 def test_linalg_vector_norm(self): 8705 class LinalgVectorNormModel(torch.nn.Module): 8706 def __init__(self, ord_val, dim_info): 8707 super().__init__() 8708 self.ord = ord_val 8709 self.dim, self.keepdim = dim_info 8710 8711 def forward(self, x): 8712 return torch.linalg.vector_norm( 8713 x, ord=self.ord, dim=self.dim, keepdim=self.keepdim 8714 ) 8715 8716 x = torch.randn(2, 3, 5, 5) 8717 ord_options = [2, float("inf"), -float("inf"), -4, 1.5] 8718 dim_options = [(None, False), (1, False), ((1, 2), False), ((1, 2), True)] 8719 for ord_val in ord_options: 8720 for dim_info in dim_options: 8721 self.run_test(LinalgVectorNormModel(ord_val, dim_info), x) 8722 8723 def test_linalg_matrix_norm(self): 8724 class LinalgMatrixNormModel(torch.nn.Module): 8725 def __init__(self, ord_val, dim_val=(-2, -1), keepdim_val=False): 8726 super().__init__() 8727 self.ord = ord_val 8728 self.dim = dim_val 8729 self.keepdim = keepdim_val 8730 8731 def forward(self, x): 8732 return torch.linalg.matrix_norm( 8733 x, ord=self.ord, dim=self.dim, keepdim=self.keepdim 8734 ) 8735 8736 x = torch.randn(2, 3, 5, 5) 8737 ord_options = ["fro", float("inf"), -float("inf"), 1, -1] 8738 for ord_val in ord_options: 8739 self.run_test(LinalgMatrixNormModel(ord_val), x) 8740 self.run_test(LinalgMatrixNormModel(ord_val, (0, 2)), x) 8741 self.run_test(LinalgMatrixNormModel(ord_val, (0, 2), True), x) 8742 8743 @skipIfUnsupportedMinOpsetVersion(9) 8744 def test_linalg_cross(self): 8745 class Cross(torch.nn.Module): 8746 def forward(self, x, y): 8747 return torch.linalg.cross(x, y, dim=1), torch.linalg.cross(x, y) 8748 8749 x = torch.randn(5, 3, 2, 3) 8750 y = torch.randn(1, 3, 1, 3) 8751 self.run_test(Cross(), input_args=(x, y)) 8752 8753 # This test checks output scalar type in the ONNX graph should not be null 8754 # https://github.com/pytorch/pytorch/issues/28607 8755 @skipIfUnsupportedMinOpsetVersion(10) 8756 def test_trace_script(self): 8757 @torch.jit.script 8758 def center_slice_helper(input, h_offset): 8759 return input[:, h_offset:] 8760 8761 class CenterCrop(torch.nn.Module): 8762 def forward(self, input): 8763 return center_slice_helper(input, torch.tensor(input.shape[1] - 1)) 8764 8765 x = torch.randn(3, 4) 8766 self.run_test(CenterCrop(), x) 8767 8768 @skipIfNoLapack 8769 @skipIfUnsupportedMinOpsetVersion(11) 8770 def test_logdet(self): 8771 class LogDet(torch.nn.Module): 8772 def forward(self, x): 8773 return torch.logdet(x) 8774 8775 x = torch.randn(2, 3, 5, 5) 8776 self.run_test(LogDet(), x) 8777 8778 def test_dim(self): 8779 class DimModel(torch.jit.ScriptModule): 8780 @torch.jit.script_method 8781 def forward(self, input): 8782 out = input * 2 8783 out *= out.dim() 8784 return out 8785 8786 empty_input = torch.randn(0, requires_grad=True) 8787 multi_dim_input = torch.randn(1, 2, 3, requires_grad=True) 8788 self.run_test(DimModel(), empty_input) 8789 self.run_test(DimModel(), multi_dim_input) 8790 8791 @skipIfUnsupportedMinOpsetVersion(11) 8792 def test_dim_1(self): 8793 class M(torch.jit.ScriptModule): 8794 @torch.jit.script_method 8795 def forward(self, poses): 8796 boxes = torch.zeros([poses.shape[0], 2, 4]) 8797 batch_boxes = [] 8798 for kp_boxes in boxes: 8799 kp_boxes = torchvision.ops.clip_boxes_to_image(kp_boxes, (2, 3)) 8800 batch_boxes.append(kp_boxes) 8801 return batch_boxes 8802 8803 dummy_inputs = torch.rand(2, 2, 3) 8804 self.run_test(M(), (dummy_inputs,), input_names=["x"], dynamic_axes={"x": [0]}) 8805 8806 @skipIfUnsupportedMinOpsetVersion(12) 8807 @skipDtypeChecking 8808 def test_outer(self): 8809 class Outer(torch.nn.Module): 8810 def forward(self, x, y): 8811 return torch.outer(x, y) 8812 8813 x = torch.arange(1, 5) 8814 y = torch.arange(1, 4) 8815 self.run_test(Outer(), input_args=(x, y)) 8816 8817 x = torch.arange(1, 6).to(dtype=torch.float32) 8818 y = torch.arange(1, 4).to(dtype=torch.long) 8819 self.run_test(Outer(), input_args=(x, y)) 8820 8821 x = torch.arange(2, 5).to(dtype=torch.float32) 8822 y = torch.arange(2, 4).to(dtype=torch.float64) 8823 self.run_test(Outer(), input_args=(x, y)) 8824 8825 x = torch.arange(3, 6).to(dtype=torch.int32) 8826 y = torch.arange(4, 7).to(dtype=torch.long) 8827 self.run_test(Outer(), input_args=(x, y)) 8828 8829 @skipIfUnsupportedMinOpsetVersion(9) 8830 def test_movedim(self): 8831 class MovedimModel(torch.nn.Module): 8832 def forward(self, x): 8833 return ( 8834 x.movedim(1, 3), 8835 x.movedim(2, 0), 8836 x.movedim(1, 1), 8837 x.movedim((1, 2, 3), (3, 0, 1)), 8838 x.movedim((0, 1, 2), (1, 2, 3)), 8839 x.movedim((1, 3, 2), (1, 3, 2)), 8840 ) 8841 8842 x = torch.randn(5, 3, 4, 2) 8843 8844 self.run_test(MovedimModel(), x) 8845 8846 @skipIfUnsupportedMinOpsetVersion(9) 8847 def test_moveaxis(self): 8848 # moveaxis is an alias of movedim; thus, mostly copied from `test_movedim`. 8849 class MoveaxisModel(torch.nn.Module): 8850 def forward(self, x): 8851 return ( 8852 x.moveaxis(1, 3), 8853 x.moveaxis(2, 0), 8854 x.moveaxis(1, 1), 8855 x.moveaxis((1, 2, 3), (3, 0, 1)), 8856 x.moveaxis((0, 1, 2), (1, 2, 3)), 8857 x.moveaxis((1, 3, 2), (1, 3, 2)), 8858 ) 8859 8860 x = torch.randn(5, 3, 4, 2) 8861 8862 self.run_test(MoveaxisModel(), x) 8863 8864 @skipIfUnsupportedMinOpsetVersion(12) 8865 def test_einsum(self): 8866 class EinsumModelBatchDiagonal(torch.nn.Module): 8867 def forward(self, x): 8868 eqn = "...ii ->...i" 8869 return torch.einsum(eqn, x) 8870 8871 for x in [torch.randn(3, 5, 5), torch.randn(3, 5, 5).to(dtype=torch.bool)]: 8872 self.run_test(EinsumModelBatchDiagonal(), input_args=(x,)) 8873 8874 class EinsumModelBatchMatmul(torch.nn.Module): 8875 def forward(self, x, y): 8876 eqn = "bij, bjk -> bik" 8877 return torch.einsum(eqn, x, y) 8878 8879 x = torch.randn(5, 2, 3) 8880 y = torch.randn(5, 3, 4) 8881 self.run_test(EinsumModelBatchMatmul(), input_args=(x, y)) 8882 8883 class EinsumModelInnerProd(torch.nn.Module): 8884 def forward(self, x, y): 8885 eqn = "i,i" 8886 return torch.einsum(eqn, x, y) 8887 8888 x = torch.randn(5) 8889 y = torch.randn(5) 8890 self.run_test(EinsumModelInnerProd(), input_args=(x, y)) 8891 8892 class EinsumModelTranspose(torch.nn.Module): 8893 def forward(self, x): 8894 eqn = "ij->ji" 8895 return torch.einsum(eqn, x) 8896 8897 for x in [torch.randn(3, 4), torch.randn(3, 4).to(dtype=torch.bool)]: 8898 self.run_test(EinsumModelTranspose(), input_args=(x,)) 8899 8900 @skipIfUnsupportedMinOpsetVersion(9) 8901 def test_cosine_similarity(self): 8902 x = torch.randn(5, 3, 2) 8903 y = torch.randn(5, 3, 2) 8904 self.run_test(torch.nn.CosineSimilarity(dim=2), input_args=(x, y)) 8905 8906 @skipIfUnsupportedMinOpsetVersion(9) 8907 def test_pairwise_distance(self): 8908 x = torch.randn(5, 3, 2) 8909 y = torch.randn(5, 3, 2) 8910 self.run_test(torch.nn.PairwiseDistance(p=2.0), input_args=(x, y)) 8911 8912 @skipIfUnsupportedMinOpsetVersion(9) 8913 def test_cross(self): 8914 class Cross(torch.nn.Module): 8915 def forward(self, x, y): 8916 return torch.cross(x, y, dim=3), torch.cross(x, y) 8917 8918 x = torch.randn(5, 3, 2, 3) 8919 y = torch.randn(5, 3, 2, 3) 8920 self.run_test(Cross(), input_args=(x, y)) 8921 8922 @skipIfUnsupportedMinOpsetVersion(9) 8923 def test_cdist(self): 8924 class Cdist(torch.nn.Module): 8925 def forward(self, x, y): 8926 return torch.cdist(x, y) 8927 8928 x = torch.randn(5, 3, 3) 8929 y = torch.randn(5, 2, 3) 8930 self.run_test(Cdist(), input_args=(x, y)) 8931 8932 @skipIfUnsupportedMinOpsetVersion(12) 8933 def test_crossentropyloss(self): 8934 for ignore_index in [-100, 1]: 8935 x = torch.randn(3, 5) 8936 y = torch.empty(3, dtype=torch.long).random_(5) 8937 y[y == 1] = ignore_index 8938 8939 self._crossentropyloss(x, y, ignore_index) 8940 8941 x = torch.randn(3, 5, 2) 8942 y = torch.empty(3, 2, dtype=torch.long).random_(5) 8943 y[y == 1] = ignore_index 8944 self._crossentropyloss(x, y, ignore_index) 8945 8946 x = torch.randn(3, 5, 2, 7) 8947 y = torch.empty(3, 2, 7, dtype=torch.long).random_(5) 8948 y[y == 1] = ignore_index 8949 self._crossentropyloss(x, y, ignore_index) 8950 8951 def _crossentropyloss(self, x, y, ignore_index): 8952 class CrossEntropyLossNone(torch.nn.Module): 8953 def __init__(self, ignore_index): 8954 super().__init__() 8955 if ignore_index == -100: 8956 self.loss = torch.nn.CrossEntropyLoss(reduction="none") 8957 else: 8958 self.loss = torch.nn.CrossEntropyLoss( 8959 reduction="none", ignore_index=ignore_index 8960 ) 8961 8962 def forward(self, input, target): 8963 return self.loss(input, target) 8964 8965 self.run_test(CrossEntropyLossNone(ignore_index), input_args=(x, y)) 8966 8967 class CrossEntropyLossNoneWeight(torch.nn.Module): 8968 def __init__(self, ignore_index): 8969 super().__init__() 8970 if ignore_index == -100: 8971 self.loss = torch.nn.CrossEntropyLoss( 8972 reduction="none", weight=torch.randn(5) 8973 ) 8974 else: 8975 self.loss = torch.nn.CrossEntropyLoss( 8976 reduction="none", 8977 weight=torch.randn(5), 8978 ignore_index=ignore_index, 8979 ) 8980 8981 def forward(self, input, target): 8982 return self.loss(input, target) 8983 8984 self.run_test(CrossEntropyLossNoneWeight(ignore_index), input_args=(x, y)) 8985 8986 class CrossEntropyLossSum(torch.nn.Module): 8987 def __init__(self, ignore_index): 8988 super().__init__() 8989 if ignore_index == -100: 8990 self.loss = torch.nn.CrossEntropyLoss(reduction="sum") 8991 else: 8992 self.loss = torch.nn.CrossEntropyLoss( 8993 reduction="sum", ignore_index=ignore_index 8994 ) 8995 8996 def forward(self, input, target): 8997 return self.loss(input, target) 8998 8999 self.run_test(CrossEntropyLossSum(ignore_index), input_args=(x, y)) 9000 9001 class CrossEntropyLossSumWeight(torch.nn.Module): 9002 def __init__(self, ignore_index): 9003 super().__init__() 9004 if ignore_index == -100: 9005 self.loss = torch.nn.CrossEntropyLoss( 9006 reduction="sum", weight=torch.randn(5) 9007 ) 9008 else: 9009 self.loss = torch.nn.CrossEntropyLoss( 9010 reduction="sum", 9011 weight=torch.randn(5), 9012 ignore_index=ignore_index, 9013 ) 9014 9015 def forward(self, input, target): 9016 return self.loss(input, target) 9017 9018 self.run_test(CrossEntropyLossSumWeight(ignore_index), input_args=(x, y)) 9019 9020 class CrossEntropyLossMean(torch.nn.Module): 9021 def __init__(self, ignore_index): 9022 super().__init__() 9023 if ignore_index == -100: 9024 self.loss = torch.nn.CrossEntropyLoss() 9025 else: 9026 self.loss = torch.nn.CrossEntropyLoss(ignore_index=ignore_index) 9027 9028 def forward(self, input, target): 9029 return self.loss(input, target) 9030 9031 self.run_test(CrossEntropyLossMean(ignore_index), input_args=(x, y)) 9032 9033 class CrossEntropyLossMeanWeight(torch.nn.Module): 9034 def __init__(self, ignore_index): 9035 super().__init__() 9036 if ignore_index == -100: 9037 self.loss = torch.nn.CrossEntropyLoss(weight=torch.randn(5)) 9038 else: 9039 self.loss = torch.nn.CrossEntropyLoss( 9040 weight=torch.randn(5), ignore_index=ignore_index 9041 ) 9042 9043 def forward(self, input, target): 9044 return self.loss(input, target) 9045 9046 self.run_test(CrossEntropyLossMeanWeight(ignore_index), input_args=(x, y)) 9047 9048 @skipIfUnsupportedMinOpsetVersion(9) 9049 def test_MSELoss(self): 9050 class MSELoss(torch.nn.Module): 9051 def __init__(self) -> None: 9052 super().__init__() 9053 self.loss1 = torch.nn.MSELoss(reduction="none") 9054 self.loss2 = torch.nn.MSELoss(reduction="sum") 9055 self.loss3 = torch.nn.MSELoss(reduction="mean") 9056 9057 def forward(self, input, target): 9058 return ( 9059 self.loss1(input, target), 9060 self.loss2(input, target), 9061 self.loss3(input, target), 9062 ) 9063 9064 x = torch.randn(2, 3, 5) 9065 y = torch.randn(2, 3, 5) 9066 self.run_test(MSELoss(), input_args=(x, y)) 9067 9068 @skipIfUnsupportedMinOpsetVersion(9) 9069 def test_kldiv_loss(self): 9070 x = torch.rand(5).log() 9071 y = torch.rand(5) 9072 self._kldiv_loss(x, y) 9073 9074 x = torch.rand(2, 3, 5).log() 9075 y = torch.rand(2, 3, 5) 9076 self._kldiv_loss(x, y) 9077 9078 x = torch.rand(2, 3, 5, 7).log() 9079 y = torch.rand(2, 3, 5, 7) 9080 self._kldiv_loss(x, y) 9081 9082 def _kldiv_loss(self, x, y): 9083 class KLDivLossNone(torch.nn.Module): 9084 def __init__(self) -> None: 9085 super().__init__() 9086 self.loss = torch.nn.KLDivLoss(reduction="none", log_target=True) 9087 9088 def forward(self, input, target): 9089 return self.loss(input, target.log()) 9090 9091 self.run_test(KLDivLossNone(), input_args=(x, y)) 9092 9093 class KLDivLossMean(torch.nn.Module): 9094 def __init__(self) -> None: 9095 super().__init__() 9096 self.loss = torch.nn.KLDivLoss(reduction="mean", log_target=False) 9097 9098 def forward(self, input, target): 9099 return self.loss(input, target) 9100 9101 self.run_test(KLDivLossMean(), input_args=(x, y)) 9102 9103 class KLDivLossSum(torch.nn.Module): 9104 def __init__(self) -> None: 9105 super().__init__() 9106 self.loss = torch.nn.KLDivLoss(reduction="sum", log_target=True) 9107 9108 def forward(self, input, target): 9109 return self.loss(input, target.log()) 9110 9111 self.run_test(KLDivLossSum(), input_args=(x, y)) 9112 9113 class KLDivLossBatchMean(torch.nn.Module): 9114 def __init__(self) -> None: 9115 super().__init__() 9116 self.loss = torch.nn.KLDivLoss(reduction="batchmean", log_target=False) 9117 9118 def forward(self, input, target): 9119 return self.loss(input, target) 9120 9121 self.run_test(KLDivLossBatchMean(), input_args=(x, y)) 9122 9123 class KLDivLossMiniBatchMean(torch.nn.Module): 9124 def __init__(self) -> None: 9125 super().__init__() 9126 self.loss = torch.nn.KLDivLoss( 9127 reduction="batchmean", size_average=False, log_target=True 9128 ) 9129 9130 def forward(self, input, target): 9131 return self.loss(input, target.log()) 9132 9133 self.run_test(KLDivLossMiniBatchMean(), input_args=(x, y)) 9134 9135 @skipIfUnsupportedMinOpsetVersion(12) 9136 def test_nllloss(self): 9137 class NLLModel(torch.nn.Module): 9138 def __init__(self) -> None: 9139 super().__init__() 9140 self.loss = torch.nn.NLLLoss(reduction="none") 9141 self.m = torch.nn.LogSoftmax(dim=1) 9142 9143 def forward(self, input, target): 9144 output = self.loss(self.m(2 * input), target) 9145 return output 9146 9147 N, C = 5, 4 9148 input = torch.randn(N, 16) 9149 target = torch.empty(N, dtype=torch.long).random_(0, C) 9150 9151 # using test data containing default ignore_index=-100 9152 target[target == 1] = -100 9153 self.run_test(NLLModel(), (input, target)) 9154 9155 @skipIfUnsupportedMinOpsetVersion(12) 9156 def test_nllloss_2d_none(self): 9157 class NLLModel(torch.nn.Module): 9158 def __init__(self) -> None: 9159 super().__init__() 9160 self.loss = torch.nn.NLLLoss(reduction="none") 9161 self.conv = torch.nn.Conv2d(16, C, (3, 3)) 9162 self.m = torch.nn.LogSoftmax(dim=1) 9163 9164 def forward(self, input, target): 9165 output = self.loss(self.m(self.conv(input)), target) 9166 return output 9167 9168 N, C = 5, 4 9169 input = torch.randn(N, 16, 10, 10) 9170 target = torch.empty(N, 8, 8, dtype=torch.long).random_(0, C) 9171 9172 # using test data containing default ignore_index=-100 9173 target[target == 1] = -100 9174 self.run_test(NLLModel(), (input, target)) 9175 9176 @skipIfUnsupportedMinOpsetVersion(12) 9177 def test_nllloss_2d_mean(self): 9178 class NLLModel(torch.nn.Module): 9179 def __init__(self) -> None: 9180 super().__init__() 9181 self.loss = torch.nn.NLLLoss(reduction="mean") 9182 self.conv = torch.nn.Conv2d(16, C, (3, 3)) 9183 self.m = torch.nn.LogSoftmax(dim=1) 9184 9185 def forward(self, input, target): 9186 output = self.loss(self.m(self.conv(input)), target) 9187 return output 9188 9189 N, C = 5, 4 9190 input = torch.randn(N, 16, 10, 10) 9191 target = torch.empty(N, 8, 8, dtype=torch.long).random_(0, C) 9192 9193 # using test data containing default ignore_index=-100 9194 target[target == 1] = -100 9195 self.run_test(NLLModel(), (input, target)) 9196 9197 @skipIfUnsupportedMinOpsetVersion(12) 9198 def test_nllloss_2d_sum(self): 9199 class NLLModel(torch.nn.Module): 9200 def __init__(self) -> None: 9201 super().__init__() 9202 self.loss = torch.nn.NLLLoss(reduction="sum") 9203 self.conv = torch.nn.Conv2d(16, C, (3, 3)) 9204 self.m = torch.nn.LogSoftmax(dim=1) 9205 9206 def forward(self, input, target): 9207 output = self.loss(self.m(self.conv(input)), target) 9208 return output 9209 9210 N, C = 5, 4 9211 input = torch.randn(N, 16, 10, 10) 9212 target = torch.empty(N, 8, 8, dtype=torch.long).random_(0, C) 9213 9214 # using test data containing default ignore_index=-100 9215 target[target == 1] = -100 9216 self.run_test(NLLModel(), (input, target)) 9217 9218 @skipIfUnsupportedMinOpsetVersion(12) 9219 def test_nllloss_2d_mean_weights(self): 9220 class NLLModel(torch.nn.Module): 9221 def __init__(self) -> None: 9222 super().__init__() 9223 self.loss = torch.nn.NLLLoss(reduction="mean", weight=torch.randn(C)) 9224 self.conv = torch.nn.Conv2d(16, C, (3, 3)) 9225 self.m = torch.nn.LogSoftmax(dim=1) 9226 9227 def forward(self, input, target): 9228 output = self.loss(self.m(self.conv(input)), target) 9229 return output 9230 9231 N, C = 5, 4 9232 input = torch.randn(N, 16, 10, 10) 9233 target = torch.empty(N, 8, 8, dtype=torch.long).random_(0, C) 9234 9235 # using test data containing default ignore_index=-100 9236 target[target == 1] = -100 9237 self.run_test(NLLModel(), (input, target)) 9238 9239 @skipIfUnsupportedMinOpsetVersion(12) 9240 def test_nllloss_2d_mean_ignore_index(self): 9241 class NLLModel(torch.nn.Module): 9242 def __init__(self) -> None: 9243 super().__init__() 9244 self.loss = torch.nn.NLLLoss(reduction="mean", ignore_index=1) 9245 self.conv = torch.nn.Conv2d(16, C, (3, 3)) 9246 self.m = torch.nn.LogSoftmax(dim=1) 9247 9248 def forward(self, input, target): 9249 output = self.loss(self.m(self.conv(input)), target) 9250 return output 9251 9252 N, C = 5, 4 9253 input = torch.randn(N, 16, 10, 10) 9254 target = torch.empty(N, 8, 8, dtype=torch.long).random_(0, C) 9255 self.run_test(NLLModel(), (input, target)) 9256 9257 @skipIfUnsupportedMinOpsetVersion(12) 9258 def test_nllloss_dynamic_ignore_index(self): 9259 import torch.nn.functional as F 9260 9261 def linear_combination(x, y, epsilon): 9262 return epsilon * x + (1 - epsilon) * y 9263 9264 def reduce_loss(loss, reduction="mean"): 9265 return ( 9266 loss.mean() 9267 if reduction == "mean" 9268 else loss.sum() 9269 if reduction == "sum" 9270 else loss 9271 ) 9272 9273 class LabelSmoothingCrossEntropy(torch.nn.Module): 9274 def __init__(self, epsilon: float = 0.1, reduction="mean"): 9275 super().__init__() 9276 self.epsilon = epsilon 9277 self.reduction = reduction 9278 9279 def forward(self, preds, target, start_position): 9280 n = preds.size()[-1] 9281 log_preds = F.log_softmax(preds, dim=-1) 9282 ignore_index = start_position.size(1) 9283 nll = F.nll_loss( 9284 log_preds, 9285 target, 9286 reduction=self.reduction, 9287 ignore_index=ignore_index, 9288 ) 9289 return nll + start_position.float() 9290 9291 N = 5 9292 preds = torch.randn(N, 16) 9293 target = torch.randint(5, (N,)) 9294 start_position = torch.randint(10, (N, N)) 9295 self.run_test(LabelSmoothingCrossEntropy(), (preds, target, start_position)) 9296 9297 @skipIfUnsupportedMinOpsetVersion(12) 9298 def test_nllloss_2d_mean_ignore_index_weights(self): 9299 class NLLModel(torch.nn.Module): 9300 def __init__(self) -> None: 9301 super().__init__() 9302 self.loss = torch.nn.NLLLoss( 9303 reduction="mean", weight=torch.randn(C), ignore_index=1 9304 ) 9305 self.conv = torch.nn.Conv2d(16, C, (3, 3)) 9306 self.m = torch.nn.LogSoftmax(dim=1) 9307 9308 def forward(self, input, target): 9309 output = self.loss(self.m(self.conv(input)), target) 9310 return output 9311 9312 N, C = 5, 4 9313 input = torch.randn(N, 16, 10, 10) 9314 target = torch.empty(N, 8, 8, dtype=torch.long).random_(0, C) 9315 self.run_test(NLLModel(), (input, target)) 9316 9317 @skipIfUnsupportedMinOpsetVersion(12) 9318 def test_binary_cross_entropy_with_logits(self): 9319 x = torch.randn(5) 9320 y = torch.empty(5).random_(2) 9321 self._bce_logits(x, y) 9322 9323 x = torch.randn(3, 4) 9324 y = torch.empty(3, 4).random_(2) 9325 weight = torch.tensor([3]) 9326 self._bce_logits_wegiht(x, y, weight) 9327 9328 x = torch.randn(3, 2, 4) 9329 y = torch.empty(3, 2, 4).random_(2) 9330 pos_weight = torch.empty([2, 4]).random_(2) 9331 self._bce_logits_posweight(x, y, pos_weight) 9332 9333 x = torch.randn(3, 3, 4) 9334 y = torch.empty(3, 3, 4).random_(2) 9335 weight = torch.tensor([3]) 9336 pos_weight = torch.empty([3, 4]).random_(2) 9337 self._bce_logits_loss_weight_posweight(x, y, weight, pos_weight) 9338 9339 def _bce_logits(self, x, y): 9340 class BCEWithLogitsLossNone(torch.nn.Module): 9341 def forward(self, input, target): 9342 return torch.nn.functional.binary_cross_entropy_with_logits( 9343 input, target, reduction="none" 9344 ) 9345 9346 self.run_test(BCEWithLogitsLossNone(), input_args=(x, y)) 9347 9348 class BCEWithLogitsLossMean(torch.nn.Module): 9349 def forward(self, input, target): 9350 return torch.nn.functional.binary_cross_entropy_with_logits( 9351 input, target, reduction="mean" 9352 ) 9353 9354 self.run_test(BCEWithLogitsLossMean(), input_args=(x, y)) 9355 9356 class BCEWithLogitsLossSum(torch.nn.Module): 9357 def forward(self, input, target): 9358 return torch.nn.functional.binary_cross_entropy_with_logits( 9359 input, target, reduction="sum" 9360 ) 9361 9362 self.run_test(BCEWithLogitsLossSum(), input_args=(x, y)) 9363 9364 def _bce_logits_wegiht(self, x, y, weight): 9365 class BCEWithLogitsLossWegihtNone(torch.nn.Module): 9366 def forward(self, input, target, weight): 9367 return torch.nn.functional.binary_cross_entropy_with_logits( 9368 input, target, weight=weight, reduction="none" 9369 ) 9370 9371 self.run_test(BCEWithLogitsLossWegihtNone(), input_args=(x, y, weight)) 9372 9373 class BCEWithLogitsLossWegihtMean(torch.nn.Module): 9374 def forward(self, input, target, weight): 9375 return torch.nn.functional.binary_cross_entropy_with_logits( 9376 input, target, weight=weight, reduction="mean" 9377 ) 9378 9379 self.run_test(BCEWithLogitsLossWegihtMean(), input_args=(x, y, weight)) 9380 9381 class BCEWithLogitsLossWegihtSum(torch.nn.Module): 9382 def forward(self, input, target, weight): 9383 return torch.nn.functional.binary_cross_entropy_with_logits( 9384 input, target, weight=weight, reduction="sum" 9385 ) 9386 9387 self.run_test(BCEWithLogitsLossWegihtSum(), input_args=(x, y, weight)) 9388 9389 def _bce_logits_posweight(self, x, y, pos_weight): 9390 class BCEWithLogitsLossPosWegihtNone(torch.nn.Module): 9391 def forward(self, input, target, pos_weight): 9392 return torch.nn.functional.binary_cross_entropy_with_logits( 9393 input, target, pos_weight=pos_weight, reduction="none" 9394 ) 9395 9396 self.run_test(BCEWithLogitsLossPosWegihtNone(), input_args=(x, y, pos_weight)) 9397 9398 class BCEWithLogitsLossPosWegihtMean(torch.nn.Module): 9399 def forward(self, input, target, pos_weight): 9400 return torch.nn.functional.binary_cross_entropy_with_logits( 9401 input, target, pos_weight=pos_weight, reduction="mean" 9402 ) 9403 9404 self.run_test(BCEWithLogitsLossPosWegihtMean(), input_args=(x, y, pos_weight)) 9405 9406 class BCEWithLogitsLossPosWegihtSum(torch.nn.Module): 9407 def forward(self, input, target, pos_weight): 9408 return torch.nn.functional.binary_cross_entropy_with_logits( 9409 input, target, pos_weight=pos_weight, reduction="sum" 9410 ) 9411 9412 self.run_test(BCEWithLogitsLossPosWegihtSum(), input_args=(x, y, pos_weight)) 9413 9414 def _bce_logits_loss_weight_posweight(self, x, y, weight, pos_weight): 9415 class BCEWithLogitsLossWeightPosweightNone(torch.nn.Module): 9416 def forward(self, input, target, weight, pos_weight): 9417 return torch.nn.functional.binary_cross_entropy_with_logits( 9418 input, 9419 target, 9420 weight=weight, 9421 pos_weight=pos_weight, 9422 reduction="none", 9423 ) 9424 9425 self.run_test( 9426 BCEWithLogitsLossWeightPosweightNone(), 9427 input_args=(x, y, weight, pos_weight), 9428 ) 9429 9430 class BCEWithLogitsLossWeightPosweightMean(torch.nn.Module): 9431 def forward(self, input, target, weight, pos_weight): 9432 return torch.nn.functional.binary_cross_entropy_with_logits( 9433 input, 9434 target, 9435 weight=weight, 9436 pos_weight=pos_weight, 9437 reduction="mean", 9438 ) 9439 9440 self.run_test( 9441 BCEWithLogitsLossWeightPosweightMean(), 9442 input_args=(x, y, weight, pos_weight), 9443 ) 9444 9445 class BCEWithLogitsLossWeightPosweightSum(torch.nn.Module): 9446 def forward(self, input, target, weight, pos_weight): 9447 return torch.nn.functional.binary_cross_entropy_with_logits( 9448 input, target, weight=weight, pos_weight=pos_weight, reduction="sum" 9449 ) 9450 9451 self.run_test( 9452 BCEWithLogitsLossWeightPosweightSum(), input_args=(x, y, weight, pos_weight) 9453 ) 9454 9455 def test_torch_mm(self): 9456 class M(torch.nn.Module): 9457 def forward(self, mat1, mat2): 9458 mm = torch.mm(mat1, mat2) 9459 return mm 9460 9461 mat1 = torch.randn(2, 3) 9462 mat2 = torch.randn(3, 3) 9463 self.run_test(M(), input_args=(mat1, mat2)) 9464 9465 @skipIfUnsupportedMinOpsetVersion( 9466 9 9467 ) # Because where op is not supported for opset < 9. 9468 def test_where_with_bool_tensor(self): 9469 class M(torch.nn.Module): 9470 def forward(self, mat1, mat2): 9471 out = torch.where(mat1 > 0, mat1, mat2) 9472 return out 9473 9474 mat1 = torch.randn(2, 3) 9475 mat2 = torch.ones(2, 3) 9476 self.run_test(M(), input_args=(mat1, mat2)) 9477 9478 @skipIfUnsupportedMinOpsetVersion( 9479 9 9480 ) # Because where op is not supported for opset < 9. 9481 def test_where_with_byte_tensor(self): 9482 class M(torch.nn.Module): 9483 def forward(self, cond, mat1, mat2): 9484 out = torch.where(cond, mat1, mat2) 9485 return out 9486 9487 cond = torch.ones(2, 3, dtype=torch.uint8) 9488 cond[1, 2] = 0 9489 mat1 = torch.randn(2, 3) 9490 mat2 = torch.ones(2, 3) 9491 self.run_test(M(), input_args=(cond, mat1, mat2)) 9492 9493 @skipIfUnsupportedMinOpsetVersion(10) # ONNX IsInf op is added in opset 10. 9494 def test_isinf(self): 9495 class M(torch.nn.Module): 9496 def forward(self, x): 9497 return x.isinf() 9498 9499 x = torch.tensor([[1, 2, float("inf")], [2, float("nan"), float("inf")]]) 9500 self.run_test(M(), (x,)) 9501 9502 @skipIfUnsupportedMinOpsetVersion(10) 9503 def test_isfinite(self): 9504 class M(torch.nn.Module): 9505 def forward(self, x): 9506 return x.isfinite() 9507 9508 x = torch.tensor([[1, 2, float("inf")], [2, float("nan"), -float("inf")]]) 9509 self.run_test(M(), (x,)) 9510 9511 @skipIfUnsupportedMinOpsetVersion(9) # ONNX IsNaN op is added in opset 9. 9512 def test_isnan(self): 9513 class M(torch.nn.Module): 9514 def forward(self, x): 9515 return x.isnan() 9516 9517 x = torch.tensor([[1, 2, float("inf")], [2, float("nan"), float("inf")]]) 9518 self.run_test(M(), (x,)) 9519 9520 @skipIfUnsupportedMinOpsetVersion( 9521 10 9522 ) # ONNX IsNaN, IsInf op is added in opset 9, 10 respectively. 9523 def test_nan_to_num(self): 9524 class NoParams(torch.nn.Module): 9525 def forward(self, x): 9526 return x.nan_to_num() 9527 9528 x = torch.tensor([[1, 2, float("inf")], [2, float("nan"), -float("inf")]]) 9529 xint = torch.ones((2, 4), dtype=torch.int) 9530 xhalf = torch.ones((2, 4), dtype=torch.half) 9531 self.run_test(NoParams(), (x,)) 9532 self.run_test(NoParams(), (xint,)) 9533 self.run_test(NoParams(), (xhalf,)) 9534 9535 class WithParams(torch.nn.Module): 9536 def forward(self, x): 9537 return x.nan_to_num(nan=2.3, posinf=4.5, neginf=6.7) 9538 9539 x = torch.tensor([[1, 2, float("inf")], [2, float("nan"), -float("inf")]]) 9540 self.run_test(WithParams(), (x,)) 9541 9542 @skipIfUnsupportedMinOpsetVersion(9) 9543 def test_maximum_minimum(self): 9544 class ModelWithNan(torch.nn.Module): 9545 def forward(self, x, y): 9546 return torch.maximum(x, y), torch.minimum(x, y) 9547 9548 x = torch.tensor([-2, -2, float("nan")]) 9549 y = torch.rand(1, 3) 9550 self.run_test(ModelWithNan(), (x, y)) 9551 9552 @skipIfUnsupportedMinOpsetVersion(12) 9553 def test_minimum_dtypes(self): 9554 class MinimumModel(torch.nn.Module): 9555 def forward(self, x, y): 9556 return torch.minimum(x, y) 9557 9558 x = torch.randn((5, 5), dtype=torch.float16) 9559 y = torch.randn((5, 5), dtype=torch.float) 9560 self.run_test(MinimumModel(), (x, y)) 9561 9562 x = torch.randn((5, 5), dtype=torch.float16) 9563 y = torch.randint(10, (5, 5), dtype=torch.int16) 9564 self.run_test(MinimumModel(), (x, y)) 9565 9566 x = torch.randint(10, (5, 5), dtype=torch.int16) 9567 y = torch.randint(10, (5, 5), dtype=torch.int32) 9568 self.run_test(MinimumModel(), (x, y)) 9569 9570 x = torch.randint(10, (5, 5), dtype=torch.int) 9571 y = torch.full_like(x, True) 9572 self.run_test(MinimumModel(), (x, y)) 9573 9574 @skipIfUnsupportedMinOpsetVersion(12) 9575 def test_maximum_dtypes(self): 9576 class MaximumModel(torch.nn.Module): 9577 def forward(self, x, y): 9578 return torch.maximum(x, y) 9579 9580 x = torch.randn((5, 5), dtype=torch.float16) 9581 y = torch.randn((5, 5), dtype=torch.float) 9582 self.run_test(MaximumModel(), (x, y)) 9583 9584 x = torch.randn((5, 5), dtype=torch.float16) 9585 y = torch.randint(10, (5, 5), dtype=torch.int16) 9586 self.run_test(MaximumModel(), (x, y)) 9587 9588 x = torch.randint(10, (5, 5), dtype=torch.int16) 9589 y = torch.randint(10, (5, 5), dtype=torch.int32) 9590 self.run_test(MaximumModel(), (x, y)) 9591 9592 x = torch.randint(10, (5, 5), dtype=torch.int) 9593 y = torch.full_like(x, True) 9594 self.run_test(MaximumModel(), (x, y)) 9595 9596 @skipIfUnsupportedMinOpsetVersion(9) 9597 def test_any(self): 9598 class M(torch.nn.Module): 9599 def forward(self, x): 9600 return x.any() 9601 9602 x = torch.tensor([[True, False], [False, False]]) 9603 self.run_test(M(), (x,)) 9604 9605 class MDim(torch.nn.Module): 9606 def forward(self, x): 9607 return x.any(dim=1) 9608 9609 x = torch.rand(3, 4).bool() 9610 self.run_test(MDim(), (x,)) 9611 9612 class MKeepdim(torch.nn.Module): 9613 def forward(self, x): 9614 return x.any(dim=1, keepdim=True) 9615 9616 x = torch.rand(3, 4).bool() 9617 self.run_test(MKeepdim(), (x,)) 9618 9619 @skipIfUnsupportedMinOpsetVersion(9) 9620 def test_all(self): 9621 class M(torch.nn.Module): 9622 def forward(self, x): 9623 return x.all() 9624 9625 x = torch.tensor([[True, False], [False, False]]) 9626 self.run_test(M(), (x,)) 9627 9628 class MDim(torch.nn.Module): 9629 def forward(self, x): 9630 return x.all(dim=1) 9631 9632 x = torch.rand(3, 4).bool() 9633 self.run_test(MDim(), (x,)) 9634 9635 class MKeepdim(torch.nn.Module): 9636 def forward(self, x): 9637 return x.all(dim=1, keepdim=True) 9638 9639 x = torch.rand(3, 4).bool() 9640 self.run_test(MKeepdim(), (x,)) 9641 9642 def test_dropout(self): 9643 class M(torch.nn.Module): 9644 def __init__(self) -> None: 9645 super().__init__() 9646 self.dropout = torch.nn.Dropout(0.3) 9647 9648 def forward(self, x): 9649 dropout = self.dropout(x) 9650 return dropout 9651 9652 x = torch.randn(10, 3, 53) 9653 self.run_test(M(), (x)) 9654 9655 def test_rrelu_eval(self): 9656 x = torch.tensor([0.5, -0.5]) 9657 self.run_test(torch.nn.RReLU(0.1, 0.3).eval(), x) 9658 9659 def test_shape_constant_fold(self): 9660 class ShapeModule(torch.nn.Module): 9661 def __init__(self) -> None: 9662 super().__init__() 9663 self.weight = torch.nn.Buffer(torch.ones(5)) 9664 9665 def forward(self, x): 9666 shape = self.weight.shape[0] 9667 return x + shape 9668 9669 x = torch.randn(2, 5) 9670 self.run_test(ShapeModule(), (x,), rtol=1e-3, atol=1e-5) 9671 9672 @skipIfUnsupportedMinOpsetVersion(12) 9673 def test_celu(self): 9674 class Celu(torch.nn.Module): 9675 def __init__(self) -> None: 9676 super().__init__() 9677 self.celu = torch.nn.CELU(alpha=1.0) 9678 9679 def forward(self, input): 9680 return self.celu(input) 9681 9682 input = torch.randn(2) 9683 self.run_test(Celu(), (input,)) 9684 9685 @skipIfUnsupportedMinOpsetVersion(12) 9686 def test_celu_default(self): 9687 class Celu(torch.nn.Module): 9688 def __init__(self) -> None: 9689 super().__init__() 9690 self.celu = torch.nn.CELU() 9691 9692 def forward(self, input): 9693 return self.celu(input) 9694 9695 input = torch.randn(2) 9696 self.run_test(Celu(), (input,)) 9697 9698 @skipIfUnsupportedMinOpsetVersion(12) 9699 def test_celu_alpha(self): 9700 class Celu(torch.nn.Module): 9701 def __init__(self) -> None: 9702 super().__init__() 9703 self.celu = torch.nn.CELU(alpha=2.0) 9704 9705 def forward(self, input): 9706 return self.celu(input) 9707 9708 input = torch.randn(2) 9709 self.run_test(Celu(), (input,)) 9710 9711 @skipIfUnsupportedMinOpsetVersion(12) 9712 def test_celu_cast(self): 9713 class Celu(torch.nn.Module): 9714 def __init__(self) -> None: 9715 super().__init__() 9716 self.celu = torch.nn.CELU() 9717 9718 def forward(self, input): 9719 return self.celu(input) 9720 9721 input = torch.randn(2, 5, 7, dtype=torch.float64) 9722 self.run_test(Celu(), (input,)) 9723 9724 def test_lower_tuple(self): 9725 class TupleModule(torch.nn.Module): 9726 def forward(self, input1: Tensor, input2: Tensor, input3: Tensor) -> Tensor: 9727 a = (input1, input2) 9728 b = a 9729 c = (input1, input2, input3) 9730 for i in range(5): 9731 d = a[0] 9732 for j in range(2): 9733 e, f = a 9734 a = (d, f) 9735 f = c[2] 9736 if f.size(0) != input1.size(-1): 9737 g = b[1] 9738 b = (g, f) 9739 else: 9740 k = c[1:] 9741 b = (f, k[0]) 9742 m, n = b 9743 c = (input1, n, m) 9744 p, q, r = c 9745 return p + q + r 9746 9747 input1 = torch.randn(2) 9748 input2 = torch.randn(2) 9749 input3 = torch.randn(2) 9750 self.run_test(TupleModule(), (input1, input2, input3)) 9751 9752 def test_lower_tuple_2(self): 9753 class TupleModule(torch.nn.Module): 9754 def forward(self, input1: Tensor, input2: Tensor) -> Tuple[Tensor, Tensor]: 9755 a = (input1, input2) 9756 for x in range(5): 9757 c, d = a 9758 a = (c, d) 9759 return a 9760 9761 input1 = torch.randn(2) 9762 input2 = torch.randn(2) 9763 self.run_test(TupleModule(), (input1, input2)) 9764 9765 def test_lower_tuple_3(self): 9766 class TupleModule(torch.nn.Module): 9767 def forward( 9768 self, 9769 input1: Tuple[Tensor, Tensor], 9770 input2: Tuple[Tensor, Tensor], 9771 ) -> Tuple[Tuple[Tensor, Tensor], Tuple[Tensor, Tensor]]: 9772 a = input1 9773 b = input2 9774 for x in range(5): 9775 c, d = a 9776 e, f = b 9777 if c.shape[0] == e.shape[0]: 9778 e = e + c 9779 else: 9780 f = f + d 9781 a = (e, f) 9782 b = (c, d) 9783 return a, b 9784 9785 input1 = (torch.randn(2), torch.randn(2)) 9786 input2 = (torch.randn(2), torch.randn(2)) 9787 self.run_test(TupleModule(), (input1, input2)) 9788 9789 @skipIfUnsupportedMinOpsetVersion(9) 9790 def test_where(self): 9791 class Model(torch.nn.Module): 9792 def forward(self, cond, input, other): 9793 return torch.where(cond, input, other) 9794 9795 x = torch.randint(0, 1, (2, 3, 4), dtype=torch.bool) 9796 y = torch.randn(2, 1, 4) 9797 z = torch.ones(2, 3, 1) 9798 self.run_test(Model(), (x, y, z)) 9799 9800 @skipIfUnsupportedMinOpsetVersion(9) 9801 @skipScriptTest() # scripting tests run for opsets > 11. See: test_where_condition_script 9802 def test_where_condition(self): 9803 class Model1(torch.nn.Module): 9804 def forward(self, input): 9805 return torch.stack(torch.where(input > 0.5), dim=1) 9806 9807 x = torch.randint(0, 2, (2, 3, 4), dtype=bool) 9808 self.run_test(Model1(), (x)) 9809 9810 class Model2(torch.nn.Module): 9811 def forward(self, input, other): 9812 return torch.stack(torch.where(input > other), dim=1) 9813 9814 x = torch.randint(0, 1, (2, 3, 4), dtype=bool) 9815 y = torch.randint(1, 2, (2, 3, 4), dtype=bool) 9816 self.run_test(Model2(), (x, y)) 9817 9818 @skipIfUnsupportedOpsetVersion([13]) 9819 @skipIfUnsupportedMinOpsetVersion(11) 9820 def test_where_condition_script(self): 9821 class Model1(torch.nn.Module): 9822 def forward(self, input): 9823 return torch.stack(torch.where(input > 0.5), dim=1) 9824 9825 x = torch.randint(0, 2, (2, 3, 4), dtype=bool) 9826 self.run_test(Model1(), (x)) 9827 9828 class Model2(torch.nn.Module): 9829 def forward(self, input, other): 9830 return torch.stack(torch.where(input > other), dim=1) 9831 9832 x = torch.randint(0, 1, (2, 3, 4), dtype=bool) 9833 y = torch.randint(1, 2, (2, 3, 4), dtype=bool) 9834 self.run_test(Model2(), (x, y)) 9835 9836 def test_empty_branch(self): 9837 class EmptyBranchModel(torch.jit.ScriptModule): 9838 @torch.jit.script_method 9839 def forward(self, input): 9840 out = input + 1 9841 if out.dim() > 2: 9842 if out.dim() > 3: 9843 out += 3 9844 else: 9845 pass 9846 else: 9847 pass 9848 return out 9849 9850 x = torch.randn(1, 2, 3, requires_grad=True) 9851 self.run_test(EmptyBranchModel(), x) 9852 9853 @skipIfUnsupportedMinOpsetVersion(11) 9854 def test_derive_index_scripting(self): 9855 class MyModule(torch.nn.Module): 9856 def forward(self, x: Tensor): 9857 j = [] 9858 for idx in range(len(x) - 1, -len(x), -2): 9859 y = x[idx] 9860 j += [x * y] 9861 return j 9862 9863 x = torch.randn(5, 13) 9864 self.run_test(MyModule(), x) 9865 9866 class MyModule(torch.nn.Module): 9867 def forward(self, x: Tensor): 9868 j = [] 9869 for idx in range(-len(x), len(x) - 1, 2): 9870 y = x[idx] 9871 j += [x * y] 9872 return j 9873 9874 x = torch.randn(5, 13) 9875 self.run_test(MyModule(), x) 9876 9877 class MyModule(torch.nn.Module): 9878 def forward(self, x: Tensor): 9879 j = [] 9880 for idx in range(len(x) - 1, -len(x), -3): 9881 y = x[idx] 9882 j += [x * y] 9883 return j 9884 9885 self.run_test(MyModule(), x) 9886 9887 class MyModule(torch.nn.Module): 9888 def forward(self, x: Tensor): 9889 j = [] 9890 for idx in range(-len(x), len(x) - 1, 3): 9891 y = x[idx] 9892 j += [x * y] 9893 return j 9894 9895 self.run_test(MyModule(), x) 9896 9897 @skipScriptTest() # Scripting fails for add lists for opsets < 11. Chek test_derive_index_scripting 9898 def test_derive_index(self): 9899 class MyModule(torch.nn.Module): 9900 def forward(self, x: Tensor): 9901 j = [] 9902 for idx in range(len(x) - 1, -len(x), -2): 9903 y = x[idx] 9904 j += [x * y] 9905 return j 9906 9907 x = torch.randn(5, 13) 9908 self.run_test(MyModule(), x) 9909 9910 class MyModule(torch.nn.Module): 9911 def forward(self, x: Tensor): 9912 j = [] 9913 for idx in range(-len(x), len(x) - 1, 2): 9914 y = x[idx] 9915 j += [x * y] 9916 return j 9917 9918 x = torch.randn(5, 13) 9919 self.run_test(MyModule(), x) 9920 9921 class MyModule(torch.nn.Module): 9922 def forward(self, x: Tensor): 9923 j = [] 9924 for idx in range(len(x) - 1, -len(x), -3): 9925 y = x[idx] 9926 j += [x * y] 9927 return j 9928 9929 self.run_test(MyModule(), x) 9930 9931 class MyModule(torch.nn.Module): 9932 def forward(self, x: Tensor): 9933 j = [] 9934 for idx in range(-len(x), len(x) - 1, 3): 9935 y = x[idx] 9936 j += [x * y] 9937 return j 9938 9939 self.run_test(MyModule(), x) 9940 9941 @skipIfUnsupportedMinOpsetVersion(11) 9942 def test_if_transpose(self): 9943 class IfModel(torch.nn.Module): 9944 def forward(self, x): 9945 x = x.transpose(0, 1) 9946 if x.size(0) == 2: 9947 return x.transpose(0, 1) 9948 else: 9949 return x 9950 9951 x = torch.randn(2, 3) 9952 self.run_test( 9953 torch.jit.script(IfModel()), 9954 x, 9955 output_names=["output_1"], 9956 dynamic_axes={"output_1": [0, 1]}, 9957 ) 9958 9959 @skipIfUnsupportedMinOpsetVersion(13) 9960 def test_if_list(self): 9961 class IfModel(torch.nn.Module): 9962 def forward(self, x, y, cond): 9963 res = [] 9964 if cond: 9965 res = res + [x] 9966 else: 9967 res = res + [y] 9968 return res 9969 9970 x = torch.randn(2, 3) 9971 y = torch.randn(3, 3) 9972 cond = torch.tensor(1, dtype=torch.bool) 9973 self.run_test(torch.jit.script(IfModel()), (x, y, cond)) 9974 9975 @skipIfUnsupportedMinOpsetVersion(13) 9976 def test_if_view(self): 9977 class IfModel(torch.nn.Module): 9978 def forward(self, x, y, cond): 9979 bs, seq = y.shape[:2] 9980 if cond: 9981 res = x.view(bs, seq, -1) 9982 else: 9983 res = y 9984 return res.transpose(1, 2) 9985 9986 x = torch.randn(2, 16, 2, 2) 9987 y = torch.randn(2, 16, 8) 9988 cond = torch.tensor(1, dtype=torch.bool) 9989 self.run_test( 9990 torch.jit.script(IfModel()), 9991 (x, y, cond), 9992 output_names=["output_1"], 9993 dynamic_axes={"output_1": [1]}, 9994 ) 9995 9996 @skipScriptTest( 9997 skip_before_opset_version=11, reason="dynamic split support added in 11" 9998 ) 9999 def test_split_tensor_scalar(self): 10000 class SplitModel(torch.nn.Module): 10001 def forward(self, x): 10002 return torch.split(x, x.size(1)) 10003 10004 x = torch.randn(1, 2, 3, requires_grad=True) 10005 self.run_test(SplitModel(), x) 10006 10007 def test_split_tensor_multi(self): 10008 class SplitModel(torch.nn.Module): 10009 def forward(self, x): 10010 return torch.split(x, torch.ones(3)) 10011 10012 x = torch.randn(1, 2, 3, requires_grad=True) 10013 10014 def run_model(): 10015 SplitModel(x) 10016 10017 self.assertRaises(TypeError, run_model) 10018 10019 @skipIfUnsupportedMinOpsetVersion(9) 10020 def test_embedding(self): 10021 class EmbedModel(torch.nn.Module): 10022 def forward(self, input, emb): 10023 return torch.nn.functional.embedding(input, emb, padding_idx=1) 10024 10025 model = EmbedModel() 10026 x = torch.randint(4, (4,)) 10027 x[2] = x[0] = 1 10028 embedding_matrix = torch.rand(10, 3) 10029 self.run_test(model, (x, embedding_matrix)) 10030 10031 x = torch.randint(4, (4, 3, 2)) 10032 x[2] = 1 10033 x[0][1] = 1 10034 self.run_test(model, (x, embedding_matrix)) 10035 self.run_test( 10036 model, (x, embedding_matrix), training=torch.onnx.TrainingMode.TRAINING 10037 ) 10038 10039 class EmbedModelWithoutPaddingIdx(torch.nn.Module): 10040 def forward(self, input, emb): 10041 return torch.nn.functional.embedding(input, emb) 10042 10043 model = EmbedModelWithoutPaddingIdx() 10044 x = torch.randint(4, (4, 3, 2)) 10045 self.run_test(model, (x, embedding_matrix)) 10046 10047 @skipIfUnsupportedMinOpsetVersion(9) 10048 def test_embedding_module(self): 10049 class EmbedModel(torch.nn.Module): 10050 def __init__(self) -> None: 10051 super().__init__() 10052 self.emb = torch.nn.Embedding(4, 3, padding_idx=1) 10053 self.emb2 = torch.nn.Embedding(4, 3, padding_idx=1) 10054 with torch.no_grad(): 10055 self.emb2.weight[1] = torch.ones(3) 10056 10057 def forward(self, input): 10058 return self.emb(input), self.emb2(input) 10059 10060 model = EmbedModel() 10061 x = torch.randint(4, (4,)) 10062 x[2] = x[0] = 1 10063 self.run_test(model, (x,)) 10064 10065 x = torch.randint(4, (4, 3, 2)) 10066 x[2] = 1 10067 x[0][1] = 1 10068 self.run_test(model, (x,)) 10069 10070 class EmbedModelWithoutPaddingIdx(torch.nn.Module): 10071 def __init__(self) -> None: 10072 super().__init__() 10073 self.emb = torch.nn.Embedding(4, 3) 10074 10075 def forward(self, input): 10076 return self.emb(input) 10077 10078 model = EmbedModelWithoutPaddingIdx() 10079 x = torch.randint(4, (4, 3, 2)) 10080 self.run_test(model, (x,)) 10081 10082 @skipIfUnsupportedMinOpsetVersion(11) 10083 def test_embedding_renorm(self): 10084 n, d = 7, 5 10085 embedding = torch.nn.Embedding(n, d, max_norm=0.2) 10086 idx = torch.tensor([2, 1]) 10087 self.run_test(embedding, idx) 10088 10089 embedding = torch.nn.Embedding(n, d, max_norm=0.5, norm_type=1.0) 10090 idx = torch.tensor([4, 3, 4, 2]) 10091 self.run_test(embedding, idx) 10092 10093 def _dispatch_rnn_test(self, name, *args, **kwargs): 10094 if name == "elman": 10095 self._elman_rnn_test(*args, **kwargs) 10096 if name == "lstm": 10097 self._lstm_test(*args, **kwargs) 10098 if name == "gru": 10099 self._gru_test(*args, **kwargs) 10100 10101 def _elman_rnn_test( 10102 self, 10103 layers, 10104 nonlinearity, 10105 bidirectional, 10106 initial_state, 10107 packed_sequence, 10108 dropout, 10109 **extra_kwargs, 10110 ): 10111 class ElmanWithStateModel(torch.nn.Module): 10112 def __init__(self, layers, nonlinearity, bidirect, dropout, batch_first): 10113 super().__init__() 10114 10115 self.batch_first = batch_first 10116 self.inner_model = torch.nn.RNN( 10117 RNN_INPUT_SIZE, 10118 RNN_HIDDEN_SIZE, 10119 layers, 10120 nonlinearity=nonlinearity, 10121 bidirectional=bidirectional, 10122 dropout=dropout, 10123 batch_first=batch_first, 10124 ) 10125 10126 def forward(self, input: rnn_utils.PackedSequence, hx=None): 10127 return self.inner_model(input, hx) 10128 10129 class ElmanWithoutStateModel(torch.nn.Module): 10130 def __init__(self, layers, nonlinearity, bidirect, dropout, batch_first): 10131 super().__init__() 10132 self.batch_first = batch_first 10133 self.inner_model = torch.nn.RNN( 10134 RNN_INPUT_SIZE, 10135 RNN_HIDDEN_SIZE, 10136 layers, 10137 nonlinearity=nonlinearity, 10138 bidirectional=bidirectional, 10139 dropout=dropout, 10140 batch_first=batch_first, 10141 ) 10142 10143 def forward(self, input: rnn_utils.PackedSequence): 10144 return self.inner_model(input) 10145 10146 batch_first = packed_sequence == 2 10147 10148 if initial_state: 10149 model = ElmanWithStateModel( 10150 layers=layers, 10151 bidirect=bidirectional, 10152 nonlinearity=nonlinearity, 10153 dropout=dropout, 10154 batch_first=batch_first, 10155 ) 10156 if packed_sequence: 10157 model = ( 10158 rnn_model_with_packed_sequence.RnnModelWithPackedSequenceWithState( 10159 model, batch_first 10160 ) 10161 ) 10162 else: 10163 model = ElmanWithoutStateModel( 10164 layers=layers, 10165 bidirect=bidirectional, 10166 nonlinearity=nonlinearity, 10167 dropout=dropout, 10168 batch_first=batch_first, 10169 ) 10170 if packed_sequence: 10171 model = rnn_model_with_packed_sequence.RnnModelWithPackedSequenceWithoutState( 10172 model, batch_first 10173 ) 10174 10175 def make_input(batch_size): 10176 seq_lengths = np.random.randint(1, RNN_SEQUENCE_LENGTH + 1, size=batch_size) 10177 seq_lengths = sorted(map(int, seq_lengths), reverse=True) 10178 inputs = [torch.randn(l, RNN_INPUT_SIZE) for l in seq_lengths] 10179 inputs = rnn_utils.pad_sequence(inputs, batch_first=batch_first) 10180 inputs = [inputs] 10181 input_names = ["input"] 10182 10183 directions = 2 if bidirectional else 1 10184 10185 if initial_state: 10186 h0 = torch.randn(directions * layers, batch_size, RNN_HIDDEN_SIZE) 10187 inputs.append(h0) 10188 input_names.append("h0") 10189 if packed_sequence != 0: 10190 inputs.append(torch.IntTensor(seq_lengths)) 10191 input_names.append("seq_lengths") 10192 if len(inputs) == 1: 10193 input = inputs[0] 10194 else: 10195 input = tuple(inputs) 10196 return input, input_names 10197 10198 input, input_names = make_input(RNN_BATCH_SIZE) 10199 dynamic_axes = {"input": [0, 1], "seq_lengths": [0]} 10200 if initial_state: 10201 dynamic_axes.update({"h0": [1]}) 10202 export_options = {"input_names": input_names, "dynamic_axes": dynamic_axes} 10203 10204 # test that the model still runs with a different batch size 10205 other_input, _ = make_input(RNN_BATCH_SIZE + 1) 10206 self.run_test( 10207 model, input, additional_test_inputs=[other_input], **export_options 10208 ) 10209 10210 def _lstm_test( 10211 self, 10212 layers, 10213 bidirectional, 10214 initial_state, 10215 packed_sequence, 10216 dropout, 10217 **extra_kwargs, 10218 ): 10219 batch_first = packed_sequence == 2 10220 10221 if packed_sequence: 10222 model = lstm_flattening_result.LstmFlatteningResultWithSeqLength( 10223 RNN_INPUT_SIZE, 10224 RNN_HIDDEN_SIZE, 10225 layers, 10226 bidirectional, 10227 dropout, 10228 batch_first, 10229 ) 10230 if initial_state: 10231 model = ( 10232 rnn_model_with_packed_sequence.RnnModelWithPackedSequenceWithState( 10233 model, batch_first 10234 ) 10235 ) 10236 else: 10237 model = rnn_model_with_packed_sequence.RnnModelWithPackedSequenceWithoutState( 10238 model, batch_first 10239 ) 10240 else: 10241 model = lstm_flattening_result.LstmFlatteningResultWithoutSeqLength( 10242 RNN_INPUT_SIZE, 10243 RNN_HIDDEN_SIZE, 10244 layers, 10245 bidirectional, 10246 dropout, 10247 batch_first, 10248 ) 10249 10250 def make_input(batch_size): 10251 seq_lengths = np.random.randint(1, RNN_SEQUENCE_LENGTH + 1, size=batch_size) 10252 seq_lengths = sorted(map(int, seq_lengths), reverse=True) 10253 inputs = [torch.randn(l, RNN_INPUT_SIZE) for l in seq_lengths] 10254 inputs = rnn_utils.pad_sequence(inputs, batch_first=batch_first) 10255 inputs = [inputs] 10256 input_names = ["input"] 10257 directions = 2 if bidirectional else 1 10258 10259 if initial_state: 10260 h0 = torch.randn(directions * layers, batch_size, RNN_HIDDEN_SIZE) 10261 c0 = torch.randn(directions * layers, batch_size, RNN_HIDDEN_SIZE) 10262 inputs.append((h0, c0)) 10263 input_names.append("h0") 10264 input_names.append("c0") 10265 if packed_sequence != 0: 10266 inputs.append(torch.IntTensor(seq_lengths)) 10267 input_names.append("seq_lengths") 10268 if len(inputs) == 1: 10269 input = inputs[0] 10270 else: 10271 input = tuple(inputs) 10272 return input, input_names 10273 10274 input, input_names = make_input(RNN_BATCH_SIZE) 10275 dynamic_axes = {"input": [0, 1], "seq_lengths": [0]} 10276 if initial_state: 10277 dynamic_axes.update({"h0": [1], "c0": [1]}) 10278 export_options = {"input_names": input_names, "dynamic_axes": dynamic_axes} 10279 10280 # test that the model still runs with a different batch size 10281 other_input, _ = make_input(RNN_BATCH_SIZE + 1) 10282 self.run_test( 10283 model, input, additional_test_inputs=[other_input], **export_options 10284 ) 10285 10286 def _gru_test( 10287 self, 10288 layers, 10289 bidirectional, 10290 initial_state, 10291 packed_sequence, 10292 dropout, 10293 **extra_kwargs, 10294 ): 10295 class GRUWithStateModel(torch.nn.Module): 10296 def __init__(self, layers, bidirect, dropout, batch_first): 10297 super().__init__() 10298 10299 self.batch_first = batch_first 10300 self.inner_model = torch.nn.GRU( 10301 RNN_INPUT_SIZE, 10302 RNN_HIDDEN_SIZE, 10303 num_layers=layers, 10304 bidirectional=bidirectional, 10305 dropout=dropout, 10306 batch_first=batch_first, 10307 ) 10308 10309 def forward(self, input: rnn_utils.PackedSequence, hx): 10310 return self.inner_model(input, hx) 10311 10312 class GRUWithoutStateModel(torch.nn.Module): 10313 def __init__(self, layers, bidirect, dropout, batch_first): 10314 super().__init__() 10315 self.batch_first = batch_first 10316 self.inner_model = torch.nn.GRU( 10317 RNN_INPUT_SIZE, 10318 RNN_HIDDEN_SIZE, 10319 num_layers=layers, 10320 bidirectional=bidirectional, 10321 dropout=dropout, 10322 batch_first=batch_first, 10323 ) 10324 10325 def forward(self, input: rnn_utils.PackedSequence): 10326 return self.inner_model(input) 10327 10328 class GRUNoSeqLengthWithoutStateModel(torch.nn.Module): 10329 def __init__(self, layers, bidirect, dropout, batch_first): 10330 super().__init__() 10331 self.batch_first = batch_first 10332 self.inner_model = torch.nn.GRU( 10333 RNN_INPUT_SIZE, 10334 RNN_HIDDEN_SIZE, 10335 num_layers=layers, 10336 bidirectional=bidirectional, 10337 dropout=dropout, 10338 batch_first=batch_first, 10339 ) 10340 10341 def forward(self, input): 10342 return self.inner_model(input) 10343 10344 class GRUNoSeqLengthWithStateModel(torch.nn.Module): 10345 def __init__(self, layers, bidirect, dropout, batch_first): 10346 super().__init__() 10347 self.batch_first = batch_first 10348 self.inner_model = torch.nn.GRU( 10349 RNN_INPUT_SIZE, 10350 RNN_HIDDEN_SIZE, 10351 num_layers=layers, 10352 bidirectional=bidirectional, 10353 dropout=dropout, 10354 batch_first=batch_first, 10355 ) 10356 10357 def forward(self, input, hx): 10358 return self.inner_model(input, hx) 10359 10360 batch_first = packed_sequence == 2 10361 10362 if packed_sequence: 10363 if initial_state: 10364 model = GRUWithStateModel( 10365 layers=layers, 10366 bidirect=bidirectional, 10367 dropout=dropout, 10368 batch_first=batch_first, 10369 ) 10370 model = ( 10371 rnn_model_with_packed_sequence.RnnModelWithPackedSequenceWithState( 10372 model, batch_first 10373 ) 10374 ) 10375 else: 10376 model = GRUWithoutStateModel( 10377 layers=layers, 10378 bidirect=bidirectional, 10379 dropout=dropout, 10380 batch_first=batch_first, 10381 ) 10382 model = rnn_model_with_packed_sequence.RnnModelWithPackedSequenceWithoutState( 10383 model, batch_first 10384 ) 10385 else: 10386 if initial_state: 10387 model = GRUNoSeqLengthWithStateModel( 10388 layers=layers, 10389 bidirect=bidirectional, 10390 dropout=dropout, 10391 batch_first=batch_first, 10392 ) 10393 else: 10394 model = GRUNoSeqLengthWithoutStateModel( 10395 layers=layers, 10396 bidirect=bidirectional, 10397 dropout=dropout, 10398 batch_first=batch_first, 10399 ) 10400 10401 def make_input(batch_size): 10402 seq_lengths = np.random.randint(1, RNN_SEQUENCE_LENGTH + 1, size=batch_size) 10403 seq_lengths = sorted(map(int, seq_lengths), reverse=True) 10404 inputs = [torch.randn(l, RNN_INPUT_SIZE) for l in seq_lengths] 10405 inputs = rnn_utils.pad_sequence(inputs, batch_first=batch_first) 10406 inputs = [inputs] 10407 input_names = ["input"] 10408 10409 directions = 2 if bidirectional else 1 10410 10411 if initial_state: 10412 h0 = torch.randn(directions * layers, batch_size, RNN_HIDDEN_SIZE) 10413 inputs.append(h0) 10414 input_names.append("h0") 10415 if packed_sequence != 0: 10416 inputs.append(torch.IntTensor(seq_lengths)) 10417 input_names.append("seq_lengths") 10418 if len(inputs) == 1: 10419 input = inputs[0] 10420 else: 10421 input = tuple(inputs) 10422 return input, input_names 10423 10424 input, input_names = make_input(RNN_BATCH_SIZE) 10425 dynamic_axes = {"input": [0, 1], "seq_lengths": [0]} 10426 if initial_state: 10427 dynamic_axes.update({"h0": [1]}) 10428 export_options = {"input_names": input_names, "dynamic_axes": dynamic_axes} 10429 10430 # test that the model still runs with a different batch size 10431 other_input, _ = make_input(RNN_BATCH_SIZE + 1) 10432 self.run_test( 10433 model, input, additional_test_inputs=[other_input], **export_options 10434 ) 10435 10436 @skipIfUnsupportedMinOpsetVersion(10) 10437 def test_fake_quantize_per_tensor(self): 10438 class FakeQuantizePerTensorModel(torch.nn.Module): 10439 def forward(self, input): 10440 scale = 1.0 / 127 10441 zero_point = 0 10442 quant_min = -128 10443 quant_max = 127 10444 return torch.fake_quantize_per_tensor_affine( 10445 input, scale, zero_point, quant_min, quant_max 10446 ) 10447 10448 x = torch.randn(6, 4, 3, 3) 10449 self.run_test(FakeQuantizePerTensorModel(), (x)) 10450 10451 @skipIfUnsupportedMinOpsetVersion(13) 10452 def test_fake_quantize_per_tensor_dynamic_scale_zeropoint(self): 10453 class FakeQuantizePerTensorModel(torch.nn.Module): 10454 def forward(self, input, scale, zero_point): 10455 quant_min = -128 10456 quant_max = 127 10457 return torch.fake_quantize_per_tensor_affine( 10458 input, scale, zero_point, quant_min, quant_max 10459 ) 10460 10461 x = torch.randn(6, 4, 3, 3) 10462 scale = torch.tensor(1.0 / 127) 10463 zero_point = torch.tensor(0) 10464 self.run_test(FakeQuantizePerTensorModel(), (x, scale, zero_point)) 10465 10466 @skipIfUnsupportedMinOpsetVersion(13) 10467 def test_fake_quantize_per_channel(self): 10468 class FakeQuantizePerChannelModel(torch.nn.Module): 10469 def forward(self, input): 10470 amax = torch.ones(4) 10471 scale = amax / 127.0 10472 zero_point = torch.zeros_like(amax, dtype=torch.int) 10473 # Quantize twice to test differnet branches 10474 y = torch.fake_quantize_per_channel_affine( 10475 input, scale, zero_point, 1, 0, 255 10476 ) 10477 return torch.fake_quantize_per_channel_affine( 10478 y, scale, zero_point, 1, -128, 127 10479 ) 10480 10481 x = torch.randn(6, 4, 3, 3) 10482 self.run_test(FakeQuantizePerChannelModel(), (x)) 10483 10484 @skipIfUnsupportedMinOpsetVersion(13) 10485 # RuntimeError: Can't redefine method: 10486 # forward on class: __torch__.torch.nn.modules.linear.Linear 10487 @skipScriptTest() 10488 def test_fake_quantize_activation(self): 10489 from torch.ao import quantization 10490 10491 m = torch.nn.Linear(1, 1) 10492 m.qconfig = quantization.QConfig( 10493 activation=quantization.default_fake_quant, 10494 weight=quantization.default_per_channel_weight_fake_quant, 10495 ) 10496 quantization.prepare_qat(m.train(), inplace=True) 10497 m.apply(quantization.enable_observer) 10498 m.apply(quantization.enable_fake_quant) 10499 for module in m.modules(): 10500 if isinstance(module, quantization.FakeQuantize): 10501 module.calculate_qparams() 10502 10503 m.apply(quantization.disable_observer) 10504 m.eval() 10505 10506 # Fake quantize activation is a special case, as it restricts quantized range to be (0, 127), 10507 # while standard 8bit quantization range is (-128, 127) or (0, 255). 10508 # Set fixed weight, bias and inputs to test if ONNX handles the overflow correctly. 10509 m.weight = torch.nn.Parameter(torch.tensor([[1.0], [1.0], [1.0]])) 10510 m.bias = torch.nn.Parameter(torch.tensor([0.0])) 10511 x = torch.tensor([[150.0], [127.0], [-5.0]]) 10512 self.run_test(m, x) 10513 10514 def test_batchnorm_training(self): 10515 class MyModule(torch.nn.Module): 10516 def __init__(self) -> None: 10517 super().__init__() 10518 self.bn1 = torch.nn.BatchNorm2d(3, affine=False) 10519 self.cv1 = torch.nn.Conv2d(3, 3, 10) 10520 self.bn2 = torch.nn.BatchNorm2d(3, affine=True) 10521 self.cv2 = torch.nn.Conv2d(3, 3, 10) 10522 self.bn3 = torch.nn.BatchNorm2d(3, affine=False) 10523 10524 def forward(self, x): 10525 x = self.bn1(x) 10526 x = self.cv1(x) 10527 x = self.bn2(x) 10528 x = self.cv2(x) 10529 x = self.bn3(x) 10530 return x 10531 10532 x = torch.randn(10, 3, 20, 20) * 2 10533 model_export = MyModule() 10534 self.run_test( 10535 model_export, 10536 (x,), 10537 training=torch.onnx.TrainingMode.TRAINING, 10538 rtol=1e-3, 10539 atol=1e-5, 10540 ) 10541 model_export.train() 10542 self.run_test( 10543 model_export, 10544 (x,), 10545 training=torch.onnx.TrainingMode.PRESERVE, 10546 rtol=1e-3, 10547 atol=1e-5, 10548 ) 10549 10550 def test_batchnorm_training_mode_fix_layer(self): 10551 class MyModule(torch.nn.Module): 10552 def __init__(self) -> None: 10553 super().__init__() 10554 self.bn1 = torch.nn.BatchNorm2d(3, affine=True) 10555 self.cv1 = torch.nn.Conv2d(3, 3, 10) 10556 self.bn2 = torch.nn.BatchNorm2d(3, affine=False) 10557 self.cv2 = torch.nn.Conv2d(3, 3, 10) 10558 self.bn3 = torch.nn.BatchNorm2d(3, affine=True) 10559 self.bn3.eval() 10560 10561 def forward(self, x): 10562 x = self.bn1(x) 10563 x = self.cv1(x) 10564 x = self.bn2(x) 10565 x = self.cv2(x) 10566 x = self.bn3(x) 10567 return x 10568 10569 x = torch.randn(10, 3, 128, 128) 10570 model_export = MyModule() 10571 self.run_test( 10572 model_export, 10573 (x,), 10574 training=torch.onnx.TrainingMode.TRAINING, 10575 rtol=1e-3, 10576 atol=1e-5, 10577 ) 10578 model_export.train() 10579 self.run_test( 10580 model_export, 10581 (x,), 10582 training=torch.onnx.TrainingMode.PRESERVE, 10583 rtol=1e-3, 10584 atol=1e-5, 10585 ) 10586 10587 def test_batchnorm_eval_mode_train_layer(self): 10588 class MyModule(torch.nn.Module): 10589 def __init__(self) -> None: 10590 super().__init__() 10591 self.bn1 = torch.nn.BatchNorm2d(3, affine=True) 10592 self.cv1 = torch.nn.Conv2d(3, 3, 10) 10593 self.bn2 = torch.nn.BatchNorm2d(3, affine=False) 10594 self.cv2 = torch.nn.Conv2d(3, 3, 10) 10595 self.bn3 = torch.nn.BatchNorm2d(3, affine=True) 10596 self.bn3.train() 10597 10598 def forward(self, x): 10599 x = self.bn1(x) 10600 x = self.cv1(x) 10601 x = self.bn2(x) 10602 x = self.cv2(x) 10603 x = self.bn3(x) 10604 return x 10605 10606 x = torch.randn(10, 3, 128, 128) 10607 model_export = MyModule() 10608 self.run_test( 10609 model_export, 10610 (x,), 10611 training=torch.onnx.TrainingMode.EVAL, 10612 rtol=1e-3, 10613 atol=1e-5, 10614 ) 10615 model_export.eval() 10616 self.run_test( 10617 model_export, 10618 (x,), 10619 training=torch.onnx.TrainingMode.PRESERVE, 10620 rtol=1e-3, 10621 atol=1e-5, 10622 ) 10623 10624 def test_instancenorm_training(self): 10625 class MyModule(torch.nn.Module): 10626 def __init__(self) -> None: 10627 super().__init__() 10628 self.in1 = torch.nn.InstanceNorm2d(3, affine=True) 10629 self.cv1 = torch.nn.Conv2d(3, 3, 10) 10630 self.in2 = torch.nn.InstanceNorm2d(3, affine=False) 10631 self.cv2 = torch.nn.Conv2d(3, 3, 10) 10632 self.in3 = torch.nn.InstanceNorm2d(3, affine=True) 10633 10634 def forward(self, x): 10635 x = self.in1(x) 10636 x = self.cv1(x) 10637 x = self.in2(x) 10638 x = self.cv2(x) 10639 x = self.in3(x) 10640 return x 10641 10642 x = torch.randn(10, 3, 128, 128) 10643 model_export = MyModule() 10644 self.run_test( 10645 model_export, 10646 (x,), 10647 training=torch.onnx.TrainingMode.TRAINING, 10648 rtol=1e-3, 10649 atol=1e-5, 10650 ) 10651 model_export.train() 10652 self.run_test( 10653 model_export, 10654 (x,), 10655 training=torch.onnx.TrainingMode.PRESERVE, 10656 rtol=1e-3, 10657 atol=1e-5, 10658 ) 10659 10660 def test_instancenorm_training_mode_fix_layer(self): 10661 class MyModule(torch.nn.Module): 10662 def __init__(self) -> None: 10663 super().__init__() 10664 self.in1 = torch.nn.InstanceNorm2d(3, affine=True) 10665 self.cv1 = torch.nn.Conv2d(3, 3, 10) 10666 self.in2 = torch.nn.InstanceNorm2d(3, affine=False) 10667 self.cv2 = torch.nn.Conv2d(3, 3, 10) 10668 self.in3 = torch.nn.InstanceNorm2d(3, affine=True) 10669 self.in3.eval() 10670 10671 def forward(self, x): 10672 x = self.in1(x) 10673 x = self.cv1(x) 10674 x = self.in2(x) 10675 x = self.cv2(x) 10676 x = self.in3(x) 10677 return x 10678 10679 x = torch.randn(10, 3, 128, 128) 10680 model_export = MyModule() 10681 self.run_test( 10682 model_export, 10683 (x,), 10684 training=torch.onnx.TrainingMode.TRAINING, 10685 rtol=1e-3, 10686 atol=1e-5, 10687 ) 10688 model_export.train() 10689 self.run_test( 10690 model_export, 10691 (x,), 10692 training=torch.onnx.TrainingMode.PRESERVE, 10693 rtol=1e-3, 10694 atol=1e-5, 10695 ) 10696 10697 def test_instancenorm_eval_mode_train_layer(self): 10698 class MyModule(torch.nn.Module): 10699 def __init__(self) -> None: 10700 super().__init__() 10701 self.in1 = torch.nn.InstanceNorm2d(8, affine=True) 10702 self.cv1 = torch.nn.Conv2d(8, 8, 10) 10703 self.in2 = torch.nn.InstanceNorm2d(8, affine=False) 10704 self.cv2 = torch.nn.Conv2d(8, 8, 10) 10705 self.in3 = torch.nn.InstanceNorm2d(8, affine=True) 10706 self.in3.train() 10707 10708 def forward(self, x): 10709 x = self.in1(x) 10710 x = self.cv1(x) 10711 x = self.in2(x) 10712 x = self.cv2(x) 10713 x = self.in3(x) 10714 return x 10715 10716 x = torch.randn(10, 8, 128, 128) 10717 model_export = MyModule() 10718 self.run_test( 10719 model_export, 10720 (x,), 10721 training=torch.onnx.TrainingMode.EVAL, 10722 rtol=1e-3, 10723 atol=1e-5, 10724 ) 10725 model_export.eval() 10726 self.run_test( 10727 model_export, 10728 (x,), 10729 training=torch.onnx.TrainingMode.PRESERVE, 10730 rtol=1e-3, 10731 atol=1e-5, 10732 ) 10733 10734 @skipIfUnsupportedMinOpsetVersion(12) 10735 def test_dropout_training(self): 10736 class MyModule(torch.nn.Module): 10737 def __init__(self) -> None: 10738 super().__init__() 10739 self.dropout = torch.nn.Dropout(0.4) 10740 10741 def forward(self, x): 10742 dropout = self.dropout(x) 10743 return dropout 10744 10745 model = MyModule() 10746 x = torch.randn(10) 10747 model.train() 10748 10749 model_onnx = io.BytesIO() 10750 torch.onnx.export( 10751 model, 10752 x, 10753 model_onnx, 10754 opset_version=self.opset_version, 10755 do_constant_folding=False, 10756 training=torch.onnx.TrainingMode.TRAINING, 10757 ) 10758 ort_sess = verification._ort_session(model_onnx) 10759 ort_outs = verification._run_onnx(ort_sess, (x,)) 10760 assert not torch.all(torch.eq(x, torch.from_numpy(ort_outs[0]))) 10761 10762 script_model = torch.jit.script(model) 10763 output = model(x) 10764 model_onnx = io.BytesIO() 10765 torch.onnx.export( 10766 model, 10767 x, 10768 model_onnx, 10769 opset_version=self.opset_version, 10770 do_constant_folding=False, 10771 training=torch.onnx.TrainingMode.TRAINING, 10772 ) 10773 ort_outs = verification._run_onnx(ort_sess, (x,)) 10774 assert not torch.all(torch.eq(x, torch.from_numpy(ort_outs[0]))) 10775 10776 @skipIfUnsupportedMinOpsetVersion(12) 10777 def test_dropout_training_zero(self): 10778 class MyModule(torch.nn.Module): 10779 def __init__(self) -> None: 10780 super().__init__() 10781 self.dropout = torch.nn.Dropout(0.5) 10782 10783 def forward(self, x): 10784 dropout = self.dropout(x) 10785 return dropout 10786 10787 model = MyModule() 10788 10789 # ensure there are no zeros in the input 10790 x = torch.randn(10, 3, 128, 128) 10791 y = x.numpy() 10792 y_mask = np.where(y == 0, 1, y) 10793 input = torch.from_numpy(y_mask) 10794 nb_elements = torch.numel(input) 10795 10796 model.train() 10797 model_onnx = io.BytesIO() 10798 torch.onnx.export( 10799 model, 10800 x, 10801 model_onnx, 10802 opset_version=self.opset_version, 10803 do_constant_folding=False, 10804 training=torch.onnx.TrainingMode.TRAINING, 10805 ) 10806 ort_sess = verification._ort_session(model_onnx) 10807 ort_outs = verification._run_onnx(ort_sess, (x,)) 10808 10809 y = model(input) 10810 output = y.cpu().numpy() 10811 ort_mask = np.where(ort_outs[0] != 0, 1, 0) 10812 pyt_mask = np.where(output != 0, 1, 0) 10813 10814 ratio_pytorch = np.sum(pyt_mask) / nb_elements 10815 ratio_ort = np.sum(ort_mask) / nb_elements 10816 10817 np.testing.assert_allclose(ratio_pytorch, ratio_ort, rtol=0.01, atol=0.01) 10818 10819 script_model = torch.jit.script(model) 10820 y = model(input) 10821 output = y.cpu().numpy() 10822 model_onnx = io.BytesIO() 10823 torch.onnx.export( 10824 model, 10825 x, 10826 model_onnx, 10827 opset_version=self.opset_version, 10828 do_constant_folding=False, 10829 training=torch.onnx.TrainingMode.TRAINING, 10830 ) 10831 ort_sess = verification._ort_session(model_onnx) 10832 ort_outs = verification._run_onnx(ort_sess, (x,)) 10833 ort_mask = np.where(ort_outs[0] != 0, 1, 0) 10834 pyt_mask = np.where(output != 0, 1, 0) 10835 10836 ratio_pytorch = np.sum(pyt_mask) / nb_elements 10837 ratio_ort = np.sum(ort_mask) / nb_elements 10838 10839 np.testing.assert_allclose(ratio_pytorch, ratio_ort, rtol=0.01, atol=0.01) 10840 10841 def test_conv_bn(self): 10842 class MyModule(torch.nn.Module): 10843 def __init__(self) -> None: 10844 super().__init__() 10845 self.conv = torch.nn.Conv2d( 10846 3, 16, kernel_size=1, stride=2, padding=3, bias=True 10847 ) 10848 self.bn = torch.nn.BatchNorm2d(16, affine=True) 10849 10850 def forward(self, x): 10851 x = self.conv(x) 10852 bn = self.bn(x) 10853 return bn 10854 10855 model_export = MyModule() 10856 x = torch.randn(10, 3, 128, 128) 10857 self.run_test(model_export, (x,), training=torch.onnx.TrainingMode.EVAL) 10858 self.run_test( 10859 model_export, 10860 (x,), 10861 training=torch.onnx.TrainingMode.TRAINING, 10862 rtol=1e-3, 10863 atol=1e-5, 10864 ) 10865 10866 def test_multiple_conv_bn(self): 10867 class MyModule(torch.nn.Module): 10868 def __init__(self) -> None: 10869 super().__init__() 10870 self.conv1 = torch.nn.Conv2d( 10871 3, 64, kernel_size=7, stride=2, padding=3, bias=False 10872 ) 10873 self.conv2 = torch.nn.Conv2d( 10874 64, 2, kernel_size=1, stride=1, padding=0, bias=False 10875 ) 10876 self.conv3 = torch.nn.Conv2d( 10877 2, 2, kernel_size=3, stride=1, padding=1, bias=False 10878 ) 10879 self.bn = torch.nn.BatchNorm2d(64) 10880 self.bn2 = torch.nn.BatchNorm2d(2) 10881 self.relu = torch.nn.ReLU(inplace=True) 10882 self.maxpool = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 10883 10884 def forward(self, x): 10885 x = self.conv1(x) 10886 x = self.bn(x) 10887 x = self.relu(x) 10888 x = self.maxpool(x) 10889 x = self.conv2(x) 10890 x = self.bn2(x) 10891 x = self.relu(x) 10892 x = self.conv3(x) 10893 x = self.bn2(x) 10894 x = self.relu(x) 10895 return x 10896 10897 model_export = MyModule() 10898 x = torch.randn(2, 3, 224, 224) 10899 self.run_test( 10900 model_export, 10901 (x,), 10902 training=torch.onnx.TrainingMode.TRAINING, 10903 rtol=1e-3, 10904 atol=1e-5, 10905 ) 10906 self.run_test(model_export, (x,), training=torch.onnx.TrainingMode.EVAL) 10907 10908 @skipIfUnsupportedMinOpsetVersion(11) 10909 def test_nms(self): 10910 num_boxes = 100 10911 boxes = torch.rand(num_boxes, 4) 10912 boxes[:, 2:] += boxes[:, :2] 10913 scores = torch.randn(num_boxes) 10914 10915 class Module(torch.nn.Module): 10916 def forward(self, boxes, scores): 10917 return torchvision.ops.nms(boxes, scores, 0.5) 10918 10919 self.run_test(Module(), (boxes, scores)) 10920 10921 @skipIfUnsupportedMinOpsetVersion(11) 10922 def test_batched_nms(self): 10923 num_boxes = 100 10924 boxes = torch.rand(num_boxes, 4) 10925 boxes[:, 2:] += boxes[:, :2] 10926 scores = torch.randn(num_boxes) 10927 idxs = torch.randint(0, 5, size=(num_boxes,)) 10928 10929 class Module(torch.nn.Module): 10930 def forward(self, boxes, scores, idxs): 10931 return torchvision.ops.batched_nms(boxes, scores, idxs, 0.5) 10932 10933 self.run_test(Module(), (boxes, scores, idxs)) 10934 10935 @skipIfUnsupportedMinOpsetVersion(11) 10936 @skipScriptTest() 10937 def test_clip_boxes_to_image(self): 10938 boxes = torch.randn(5, 4) * 500 10939 boxes[:, 2:] += boxes[:, :2] 10940 size = torch.randn(200, 300) 10941 10942 size_2 = torch.randn(300, 400) 10943 10944 class Module(torch.nn.Module): 10945 def forward(self, boxes, size): 10946 shape = (size.shape[0], size.shape[1]) 10947 return torchvision.ops.boxes.clip_boxes_to_image(boxes, shape) 10948 10949 self.run_test( 10950 Module(), 10951 (boxes, size), 10952 input_names=["boxes", "size"], 10953 dynamic_axes={"size": [0, 1]}, 10954 additional_test_inputs=[(boxes, size), (boxes, size_2)], 10955 ) 10956 10957 @skipScriptTest( 10958 reason="Conditioning on input type via prim::isinstance unsupported in ONNX" 10959 ) 10960 @skipIfUnsupportedMinOpsetVersion(11) 10961 def test_roi_align(self): 10962 x = torch.rand(1, 1, 10, 10, dtype=torch.float32) 10963 single_roi = torch.tensor([[0, 0, 0, 4, 4]], dtype=torch.float32) 10964 model = torchvision.ops.RoIAlign((5, 5), 1.0, 2) 10965 self.run_test(model, (x, single_roi)) 10966 10967 @skipScriptTest( 10968 reason="Conditioning on input type via prim::isinstance unsupported in ONNX" 10969 ) 10970 @skipIfUnsupportedMinOpsetVersion(16) 10971 def test_roi_align_aligned(self): 10972 x = torch.rand(1, 1, 10, 10, dtype=torch.float32) 10973 single_roi = torch.tensor([[0, 1.5, 1.5, 3, 3]], dtype=torch.float32) 10974 model1 = torchvision.ops.RoIAlign((5, 5), 1.0, 2, aligned=True) 10975 self.run_test(model1, (x, single_roi)) 10976 10977 x = torch.rand(1, 1, 10, 10, dtype=torch.float32) 10978 single_roi = torch.tensor([[0, 0.2, 0.3, 4.5, 3.5]], dtype=torch.float32) 10979 model2 = torchvision.ops.RoIAlign((5, 5), 0.5, 3, aligned=True) 10980 self.run_test(model2, (x, single_roi)) 10981 10982 x = torch.rand(1, 1, 10, 10, dtype=torch.float32) 10983 single_roi = torch.tensor([[0, 0.2, 0.3, 4.5, 3.5]], dtype=torch.float32) 10984 model3 = torchvision.ops.RoIAlign((5, 5), 1.8, 2, aligned=True) 10985 self.run_test(model3, (x, single_roi)) 10986 10987 x = torch.rand(1, 1, 10, 10, dtype=torch.float32) 10988 single_roi = torch.tensor([[0, 0.2, 0.3, 4.5, 3.5]], dtype=torch.float32) 10989 model4 = torchvision.ops.RoIAlign((2, 2), 2.5, 0, aligned=True) 10990 self.run_test(model4, (x, single_roi)) 10991 10992 @skipScriptTest( 10993 reason="Conditioning on input type via prim::isinstance unsupported in ONNX" 10994 ) 10995 @skipIfUnsupportedMinOpsetVersion(11) 10996 def test_roi_pool(self): 10997 x = torch.rand(1, 1, 10, 10, dtype=torch.float32) 10998 rois = torch.tensor([[0, 0, 0, 4, 4]], dtype=torch.float32) 10999 pool_h = 5 11000 pool_w = 5 11001 model = torchvision.ops.RoIPool((pool_h, pool_w), 2.0) 11002 self.run_test(model, (x, rois)) 11003 11004 @skipIfUnsupportedMinOpsetVersion(11) 11005 def test_resize_images(self): 11006 class TransformModule(torch.nn.Module): 11007 def __init__(self) -> None: 11008 super().__init__() 11009 self.transform = _init_test_generalized_rcnn_transform() 11010 11011 def forward(self, images): 11012 return self.transform.resize(images, None)[0] 11013 11014 input = torch.rand(3, 10, 20) 11015 input_test = torch.rand(3, 100, 150) 11016 self.run_test( 11017 TransformModule(), 11018 (input,), 11019 input_names=["input1"], 11020 dynamic_axes={"input1": [0, 1, 2]}, 11021 additional_test_inputs=[(input,), (input_test,)], 11022 ) 11023 11024 @skipIfUnsupportedMinOpsetVersion(11) 11025 @skipScriptTest() 11026 def test_transform_images(self): 11027 class TransformModule(torch.nn.Module): 11028 def __init__(self) -> None: 11029 super().__init__() 11030 self.transform = _init_test_generalized_rcnn_transform() 11031 11032 def forward(self, images: List[Tensor]): 11033 return self.transform(images)[0].tensors 11034 11035 input = torch.rand(3, 100, 200), torch.rand(3, 200, 200) 11036 input_test = torch.rand(3, 100, 200), torch.rand(3, 200, 200) 11037 self.run_test( 11038 TransformModule(), 11039 (input,), 11040 additional_test_inputs=[(input,), (input_test,)], 11041 ) 11042 11043 def get_features(self, images): 11044 s0, s1 = images.shape[-2:] 11045 features = [ 11046 ("0", torch.rand(2, 256, s0 // 4, s1 // 4)), 11047 ("1", torch.rand(2, 256, s0 // 8, s1 // 8)), 11048 ("2", torch.rand(2, 256, s0 // 16, s1 // 16)), 11049 ("3", torch.rand(2, 256, s0 // 32, s1 // 32)), 11050 ("4", torch.rand(2, 256, s0 // 64, s1 // 64)), 11051 ] 11052 features = OrderedDict(features) 11053 return features 11054 11055 @skipIfUnsupportedMinOpsetVersion(11) 11056 @skipScriptTest() 11057 def test_rpn(self): 11058 class RPNModule(torch.nn.Module): 11059 def __init__(self) -> None: 11060 super().__init__() 11061 self.rpn = _init_test_rpn() 11062 11063 def forward(self, images, features: Dict[str, Tensor]): 11064 images_m = torchvision.models.detection.image_list.ImageList( 11065 images, [(i.shape[-1], i.shape[-2]) for i in images] 11066 ) 11067 return self.rpn(images_m, features) 11068 11069 images = torch.rand(2, 3, 150, 150) 11070 features = self.get_features(images) 11071 images2 = torch.rand(2, 3, 80, 80) 11072 test_features = self.get_features(images2) 11073 11074 model = RPNModule() 11075 model.eval() 11076 model(images, features) 11077 self.run_test( 11078 model, 11079 (images, features), 11080 input_names=["input1", "input2", "input3", "input4", "input5", "input6"], 11081 dynamic_axes={ 11082 "input1": [0, 1, 2, 3], 11083 "input2": [0, 1, 2, 3], 11084 "input3": [0, 1, 2, 3], 11085 "input4": [0, 1, 2, 3], 11086 "input5": [0, 1, 2, 3], 11087 "input6": [0, 1, 2, 3], 11088 }, 11089 additional_test_inputs=[(images, features), (images2, test_features)], 11090 # dict_check=False, 11091 ) 11092 11093 @skipIfUnsupportedMaxOpsetVersion(15) # TODO: Opset 16 RoiAlign result mismatch 11094 @skipIfUnsupportedMinOpsetVersion(11) 11095 @skipScriptTest() 11096 def test_multi_scale_roi_align(self): 11097 class TransformModule(torch.nn.Module): 11098 def __init__(self) -> None: 11099 super().__init__() 11100 self.model = torchvision.ops.MultiScaleRoIAlign( 11101 ["feat1", "feat2"], 3, 2 11102 ) 11103 self.image_sizes = [(512, 512)] 11104 11105 def forward(self, input: Dict[str, Tensor], boxes: List[Tensor]) -> Tensor: 11106 return self.model(input, boxes, self.image_sizes) 11107 11108 i = OrderedDict() 11109 i["feat1"] = torch.rand(1, 5, 64, 64) 11110 i["feat2"] = torch.rand(1, 5, 16, 16) 11111 boxes = torch.rand(6, 4) * 256 11112 boxes[:, 2:] += boxes[:, :2] 11113 11114 i1 = OrderedDict() 11115 i1["feat1"] = torch.rand(1, 5, 64, 64) 11116 i1["feat2"] = torch.rand(1, 5, 16, 16) 11117 boxes1 = torch.rand(6, 4) * 256 11118 boxes1[:, 2:] += boxes1[:, :2] 11119 11120 self.run_test( 11121 TransformModule(), 11122 ( 11123 i, 11124 [boxes], 11125 ), 11126 additional_test_inputs=[ 11127 ( 11128 i, 11129 [boxes], 11130 ), 11131 ( 11132 i1, 11133 [boxes1], 11134 ), 11135 ], 11136 ) 11137 11138 def test_set_(self): 11139 class M(torch.nn.Module): 11140 def forward(self, x, y): 11141 x.set_(y) 11142 return x 11143 11144 x = torch.ones(2, 3) 11145 y = torch.randn(4, 6) 11146 self.run_test(M(), (x, y), remained_onnx_input_idx=[1]) 11147 11148 y2 = torch.randn(5, 2) 11149 self.run_test( 11150 M(), 11151 (x, y), 11152 remained_onnx_input_idx=[1], 11153 input_names=["x", "y"], 11154 dynamic_axes={"x": [0, 1], "y": [0, 1]}, 11155 additional_test_inputs=[(y, y2)], 11156 ) 11157 11158 @skipIfUnsupportedMinOpsetVersion(9) 11159 def test_set_attr_modules(self): 11160 class InnerModule2(torch.nn.Module): 11161 def __init__(self, embedding_dim): 11162 super().__init__() 11163 self.weights = InnerModule2.get_embedding(embedding_dim) 11164 self._float_tensor = torch.nn.Buffer(torch.FloatTensor(1)) 11165 self.const = 2 11166 11167 @staticmethod 11168 def get_embedding(embedding_dim: int): 11169 emb = 4 / ((embedding_dim // 2) - 1) 11170 emb = torch.exp( 11171 torch.arange((embedding_dim // 2), dtype=torch.float) * -emb 11172 ) 11173 return emb 11174 11175 def forward(self, input, incremental_state: Optional[Tensor] = None): 11176 bsz, seq_len = input.shape[0], input.shape[1] 11177 self.const = 3 11178 if self.weights is None: 11179 self.weights = InnerModule.get_embedding(self.embedding_dim) 11180 self.weights = self.weights.to(self._float_tensor) 11181 self.weights = self.weights * self.const 11182 if incremental_state is not None: 11183 pos = seq_len 11184 return self.weights[1 + pos, :].expand(bsz, 1, -1) 11185 return self.weights.index_select( 11186 0, torch.ones((bsz * seq_len), dtype=torch.int64) 11187 ).view(bsz, seq_len, -1) 11188 11189 class InnerModule(torch.nn.Module): 11190 def __init__(self, embedding_dim): 11191 super().__init__() 11192 self.weights = InnerModule.get_embedding(embedding_dim) 11193 self.module = InnerModule2(embedding_dim=8) 11194 11195 @staticmethod 11196 def get_embedding(embedding_dim: int): 11197 emb = 4 / ((embedding_dim // 2) - 1) 11198 emb = torch.exp( 11199 torch.arange((embedding_dim // 2), dtype=torch.float) * -emb 11200 ) 11201 return emb 11202 11203 def forward(self, x): 11204 return self.module(x) + self.weights 11205 11206 class Module(torch.nn.Module): 11207 def __init__(self) -> None: 11208 super().__init__() 11209 self.module = InnerModule(embedding_dim=8) 11210 11211 def forward(self, x): 11212 return self.module(x) 11213 11214 x = torch.randn(3, 256) 11215 self.run_test(Module(), (x,), input_names=["x"], dynamic_axes={"x": [0, 1]}) 11216 self.run_test(Module(), (x,), remained_onnx_input_idx=[]) 11217 11218 @skipIfUnsupportedMinOpsetVersion(9) 11219 def test_set_attr_modules_2(self): 11220 class InnerModule(torch.nn.Module): 11221 def __init__(self, embedding_dim): 11222 super().__init__() 11223 self.embedding_dim = embedding_dim 11224 self.const = 2.5 11225 self.weights = InnerModule.get_embedding(self.embedding_dim) 11226 self._float_tensor = torch.nn.Buffer(torch.FloatTensor(1)) 11227 11228 @staticmethod 11229 def get_embedding(embedding_dim: int): 11230 emb = 4 / ((embedding_dim // 2) - 1) 11231 emb = torch.exp( 11232 torch.arange((embedding_dim // 2), dtype=torch.float) * -emb 11233 ) 11234 return emb 11235 11236 def forward(self, input, incremental_state: Optional[Tensor] = None): 11237 bsz, seq_len = input.shape[0], input.shape[1] 11238 self.const = 1.5 11239 self.weights = InnerModule.get_embedding(self.embedding_dim) 11240 return ( 11241 self.weights.index_select( 11242 0, torch.ones((bsz * seq_len), dtype=torch.int64) 11243 ).view(bsz, seq_len, -1) 11244 ) * self.const 11245 11246 class Module(torch.nn.Module): 11247 def __init__(self) -> None: 11248 super().__init__() 11249 self.module = InnerModule(embedding_dim=8) 11250 11251 def forward(self, x): 11252 return self.module(x) 11253 11254 x = torch.randn(3, 256) 11255 self.run_test(Module(), (x,), input_names=["x"], dynamic_axes={"x": [0, 1]}) 11256 self.run_test(Module(), (x,), remained_onnx_input_idx=[]) 11257 11258 def test_set_attr(self): 11259 class MyModule(torch.nn.Module): 11260 def __init__(self) -> None: 11261 super().__init__() 11262 self.conv = torch.nn.Conv1d(3, 10, 2) 11263 self.b = False 11264 11265 def forward(self, box_regression, weight): 11266 self.b = True 11267 self.conv.weight = weight 11268 w = torch.softmax(self.conv.weight, dim=0) 11269 self.conv.weight = w + w 11270 if self.b: 11271 return box_regression + self.conv.weight 11272 else: 11273 return box_regression - self.conv.weight 11274 11275 model = torch.jit.script(MyModule()) 11276 weight = torch.ones(3, 2) 11277 box_regression = torch.randn(3, 2) 11278 self.run_test(model, (box_regression, weight)) 11279 11280 @skipIfUnsupportedMinOpsetVersion(11) 11281 def test_set_attr_2(self): 11282 class MyModule(torch.nn.Module): 11283 def __init__(self) -> None: 11284 super().__init__() 11285 self.conv = torch.nn.Conv1d(10, 3, 3) 11286 self.conv.bias = torch.nn.Parameter(torch.zeros(3, 10, 3)) 11287 11288 def set_cell_anchors(self, anchors): 11289 if self.conv.bias is not None: 11290 b = self.conv.bias 11291 assert b is not None 11292 self.conv.bias = anchors + b 11293 elif self.conv.weight is not None: 11294 self.conv.weight = torch.randn(3, 10) 11295 self.conv.bias = self.conv.weight[:] 11296 11297 def forward(self, anchors) -> Optional[Tensor]: 11298 self.set_cell_anchors(anchors) 11299 return self.conv.bias 11300 11301 model = torch.jit.script(MyModule()) 11302 anchors = torch.ones(3, 10, 3) 11303 self.run_test(model, (anchors)) 11304 11305 @skipIfUnsupportedMinOpsetVersion(11) 11306 def test_set_attr_3(self): 11307 class MyModule(torch.nn.Module): 11308 def __init__(self) -> None: 11309 super().__init__() 11310 self.conv = torch.nn.Conv1d(10, 3, 3) 11311 self.conv.weight = torch.nn.Parameter(torch.zeros(3, 10)) 11312 self.conv.bias = torch.nn.Parameter(torch.zeros(3, 10, 3)) 11313 11314 def set_cell_anchors(self, anchors, boxes): 11315 self.conv.weight = torch.ones(3, 10) 11316 if self.conv.bias is not None: 11317 self.conv.bias = torch.randn(3, 10, 3) 11318 self.conv.weight = anchors + self.conv.weight 11319 boxes[:] = torch.zeros(2, 3) 11320 11321 def forward(self, anchors) -> Tuple[Tensor, Tensor]: 11322 boxes = torch.ones(2, 2, 3) 11323 self.set_cell_anchors(anchors, boxes) 11324 if self.conv.bias is not None: 11325 return self.conv.weight, boxes 11326 return anchors, boxes 11327 11328 model = torch.jit.script(MyModule()) 11329 anchors = torch.rand(3, 10) 11330 self.run_test(model, (anchors)) 11331 11332 @skipIfUnsupportedMinOpsetVersion(11) 11333 def test_set_attr_4(self): 11334 class MyModule(torch.nn.Module): 11335 def __init__(self) -> None: 11336 super().__init__() 11337 self.conv = torch.nn.Conv1d(10, 3, 3) 11338 self.conv.bias = torch.nn.Parameter(torch.zeros(3, 10, 3)) 11339 11340 def set_cell_anchors(self, anchors): 11341 self.conv.weight = torch.zeros(10, 3) 11342 if self.conv.bias is not None: 11343 w = self.conv.bias 11344 assert w is not None 11345 self.conv.bias = anchors + w 11346 else: 11347 self.conv.bias = torch.ones(3, 10, 3) 11348 11349 def forward(self, feature_maps, anchors) -> Tuple[Tensor, Tensor]: 11350 self.set_cell_anchors(anchors) 11351 result = [] 11352 if self.conv.bias is not None: 11353 a = self.conv.bias 11354 assert a is not None 11355 result += [a] 11356 result += [feature_maps] 11357 return result[0], result[1] 11358 11359 model = torch.jit.script(MyModule()) 11360 x = torch.rand(5, 11, 30) 11361 anchors = torch.ones(3, 10, 3) 11362 self.run_test(model, (x, anchors)) 11363 11364 @skipIfUnsupportedMinOpsetVersion(11) 11365 def test_set_attr_5(self): 11366 class MyModule(torch.nn.Module): 11367 def __init__(self) -> None: 11368 super().__init__() 11369 self.conv = torch.nn.Conv1d(10, 3, 3) 11370 self.conv.bias = torch.nn.Parameter(torch.zeros(3, 10, 3)) 11371 11372 def set_cell_anchors(self, anchors): 11373 self.conv.weight = torch.arange(10) 11374 for i in range(10): 11375 if i == 3: 11376 for j in range(10): 11377 w = self.conv.weight 11378 self.conv.weight = torch.arange(10) + w 11379 11380 self.conv.weight = self.conv.weight + torch.arange(10) 11381 # NOTE: `is not None` and `assert` is for passing torchscript. 11382 if self.conv.bias is not None: 11383 a = self.conv.bias 11384 assert a is not None 11385 self.conv.bias = anchors + a 11386 11387 def forward(self, anchors): 11388 self.set_cell_anchors(anchors) 11389 return self.conv.weight, self.conv.bias 11390 11391 model = torch.jit.script(MyModule()) 11392 anchors = torch.ones(3, 10, 3) 11393 self.run_test(model, (anchors)) 11394 11395 @skipIfUnsupportedMinOpsetVersion(11) 11396 def test_set_attr_in_loop(self): 11397 class MyModule(torch.nn.Module): 11398 def __init__(self) -> None: 11399 super().__init__() 11400 self.conv = torch.nn.Conv1d(10, 3, 3) 11401 self.conv.weight = torch.nn.Parameter(torch.zeros(3, 10)) 11402 self.conv.bias = torch.nn.Parameter(torch.zeros(3, 10, 3)) 11403 11404 def set_cell_anchors(self, anchors, boxes): 11405 self.conv.weight = torch.randn(3, 10) 11406 for i in range(self.conv.weight.size(0)): 11407 for j in range(10): 11408 self.conv.bias = torch.randn(3, 10, 3) 11409 self.conv.weight = anchors * i 11410 boxes[j] += torch.ones(3, 3) 11411 11412 def forward(self, anchors) -> Tuple[Tensor, Tensor]: 11413 boxes = torch.ones(10, 3, 3) 11414 self.set_cell_anchors(anchors, boxes) 11415 if self.conv.bias is not None: 11416 return self.conv.weight, boxes 11417 return anchors, boxes 11418 11419 model = torch.jit.script(MyModule()) 11420 anchors = torch.rand(10) 11421 self.run_test(model, anchors) 11422 11423 @skipIfUnsupportedMinOpsetVersion(13) 11424 def test_set_attr_in_loop_with_list(self): 11425 class MyModule(torch.nn.Module): 11426 def __init__(self) -> None: 11427 super().__init__() 11428 self.conv = torch.nn.Conv1d(10, 3, 3) 11429 self.conv.weight = torch.nn.Parameter(torch.zeros(3, 10)) 11430 self.conv.bias = torch.nn.Parameter(torch.zeros(3, 10, 3)) 11431 self.boxes: List[Tensor] = [ 11432 torch.ones(1) 11433 ] # Workaround placeholder for TorchScript 11434 11435 def set_cell_anchors(self, anchors): 11436 self.conv.weight = torch.randn(3, 10) 11437 for i in range(self.conv.weight.size(0)): 11438 for j in range(10): 11439 self.conv.bias = torch.randn(3, 10, 3) 11440 self.conv.weight = anchors * i 11441 self.boxes.append(torch.ones(3, 3)) 11442 11443 def forward(self, anchors) -> Tuple[Tensor, List[Tensor]]: 11444 self.boxes = [] 11445 self.set_cell_anchors(anchors) 11446 if self.conv.bias is not None: 11447 return self.conv.weight, self.boxes 11448 return anchors, self.boxes 11449 11450 model = torch.jit.script(MyModule()) 11451 anchors = torch.rand(10) 11452 self.run_test(model, anchors) 11453 11454 @skipIfUnsupportedMinOpsetVersion(11) 11455 def test_index_put_if(self): 11456 @torch.jit.script 11457 def check_init( 11458 input_data: Tensor, hidden_size: int, prev_state: Tensor 11459 ) -> Tuple[Tensor, Tensor]: 11460 batch_size = input_data.size(0) 11461 spatial_size_0 = input_data.size(2) 11462 spatial_size_1 = input_data.size(3) 11463 # generate empty prev_state, if None is provided 11464 state_size = (2, batch_size, hidden_size, spatial_size_0, spatial_size_1) 11465 state = torch.zeros(state_size, device=input_data.device) 11466 state_copy = torch.zeros(state_size, device=input_data.device) 11467 if prev_state.size(0) == 0: 11468 state[:] = ( 11469 torch.zeros(batch_size, hidden_size, spatial_size_0, spatial_size_1) 11470 + state[:] 11471 ) 11472 state_copy[:] = ( 11473 torch.ones(batch_size, hidden_size, spatial_size_0, spatial_size_1) 11474 * 2 11475 ) 11476 state_copy[:] = ( 11477 torch.zeros(batch_size, hidden_size, spatial_size_0, spatial_size_1) 11478 * 2 11479 ) 11480 else: 11481 state[:] = ( 11482 torch.ones(batch_size, hidden_size, spatial_size_0, spatial_size_1) 11483 * 4 11484 ) 11485 return state, state_copy 11486 11487 class Example(torch.nn.Module): 11488 def __init__(self, hidden_size): 11489 super().__init__() 11490 self.hidden_size = hidden_size 11491 11492 def forward(self, input_data, prev_state): 11493 prev_state = check_init(input_data, self.hidden_size, prev_state) 11494 return prev_state[0], prev_state[1] 11495 11496 model = Example(10) 11497 random_data = torch.rand((1, 5, 30, 30)) 11498 empty_tensor = torch.tensor([], dtype=torch.float).view(0, 0, 0, 0, 0) 11499 self.run_test( 11500 model, 11501 (random_data, empty_tensor), 11502 input_names=["random_data", "empty_tensor"], 11503 dynamic_axes={"random_data": [0, 1, 2, 3], "empty_tensor": [0, 1, 2, 3, 4]}, 11504 ) 11505 self.run_test(model, (random_data, empty_tensor), remained_onnx_input_idx=[]) 11506 11507 @skipIfUnsupportedMinOpsetVersion(11) 11508 def test_index_put_if_2(self): 11509 @torch.jit.script 11510 def check_init( 11511 input_data: Tensor, hidden_size: int, prev_state: Tensor 11512 ) -> Tuple[Tensor, Tensor]: 11513 batch_size = input_data.size(0) 11514 spatial_size_0 = input_data.size(2) 11515 spatial_size_1 = input_data.size(3) 11516 # generate empty prev_state, if None is provided 11517 state_size = (2, batch_size, hidden_size, spatial_size_0, spatial_size_1) 11518 state = torch.zeros(state_size, device=input_data.device) 11519 state_copy = torch.zeros(state_size, device=input_data.device) 11520 if prev_state.size(0) == 0: 11521 for i in range(2): 11522 state[:] = ( 11523 torch.ones( 11524 batch_size, hidden_size, spatial_size_0, spatial_size_1 11525 ) 11526 * i 11527 ) 11528 state_copy[:] = ( 11529 torch.ones( 11530 batch_size, hidden_size, spatial_size_0, spatial_size_1 11531 ) 11532 * i 11533 ) 11534 elif prev_state.size(0) == 1: 11535 s = state[:] 11536 state[:] = prev_state + s 11537 elif prev_state.size(0) == 2: 11538 state[:] = ( 11539 torch.ones(batch_size, hidden_size, spatial_size_0, spatial_size_1) 11540 * 4 11541 ) 11542 return state, state_copy 11543 11544 class Example(torch.nn.Module): 11545 def __init__(self, hidden_size): 11546 super().__init__() 11547 self.hidden_size = hidden_size 11548 11549 def forward(self, input_data, prev_state): 11550 prev_state = check_init(input_data, self.hidden_size, prev_state) 11551 return prev_state[0], prev_state[1] 11552 11553 model = Example(10) 11554 random_data = torch.rand((1, 5, 30, 30)) 11555 empty_tensor = torch.tensor([], dtype=torch.float).view(0, 0, 0, 0, 0) 11556 random_state = torch.rand((1, 1, 10, 30, 30)) 11557 self.run_test( 11558 model, 11559 (random_data, empty_tensor), 11560 input_names=["data", "state"], 11561 dynamic_axes={"data": [0, 1, 2], "state": [0, 1, 2, 3, 4]}, 11562 additional_test_inputs=[(random_data, random_state)], 11563 ) 11564 self.run_test( 11565 model, 11566 (random_data, empty_tensor), 11567 input_names=["data", "state"], 11568 dynamic_axes={"state": [0, 1, 2, 3, 4]}, 11569 additional_test_inputs=[(random_data, random_state)], 11570 remained_onnx_input_idx=[1], 11571 ) 11572 self.run_test(model, (random_data, empty_tensor), remained_onnx_input_idx=[]) 11573 11574 @skipIfUnsupportedMinOpsetVersion(11) 11575 def test_index_put_if_3(self): 11576 @torch.jit.script 11577 def check_init( 11578 input_data: Tensor, hidden_size: int, prev_state: Tensor 11579 ) -> Tensor: 11580 batch_size = input_data.size(0) 11581 spatial_size_0 = input_data.size(2) 11582 spatial_size_1 = input_data.size(3) 11583 # generate empty prev_state, if None is provided 11584 state_size = (2, batch_size, hidden_size, spatial_size_0, spatial_size_1) 11585 state = torch.zeros(state_size, device=input_data.device) 11586 if prev_state.size(0) < 2: 11587 state = state * 3 11588 if prev_state.size(0) == 0: 11589 state[:] = ( 11590 torch.ones( 11591 batch_size, hidden_size, spatial_size_0, spatial_size_1 11592 ) 11593 * 3 11594 ) 11595 else: 11596 state = state + 2 11597 11598 return state 11599 11600 class Example(torch.nn.Module): 11601 def __init__(self, hidden_size): 11602 super().__init__() 11603 self.hidden_size = hidden_size 11604 11605 def forward(self, input_data, prev_state): 11606 prev_state = check_init(input_data, self.hidden_size, prev_state) 11607 return prev_state 11608 11609 model = Example(4) 11610 random_data = torch.rand((1, 5, 4, 4)) 11611 empty_tensor = torch.tensor([], dtype=torch.float).view(0, 0, 0, 0, 0) 11612 self.run_test( 11613 model, 11614 (random_data, empty_tensor), 11615 input_names=["random_data", "empty_tensor"], 11616 dynamic_axes={"random_data": [0, 1, 2, 3], "empty_tensor": [0, 1, 2, 3, 4]}, 11617 ) 11618 self.run_test(model, (random_data, empty_tensor), remained_onnx_input_idx=[]) 11619 11620 @skipIfUnsupportedMinOpsetVersion(11) 11621 def test_index_put_if_4(self): 11622 @torch.jit.script 11623 def check_init( 11624 input_data: Tensor, hidden_size: int, prev_state: Tensor 11625 ) -> Tensor: 11626 batch_size = input_data.size(0) 11627 spatial_size_0 = input_data.size(2) 11628 spatial_size_1 = input_data.size(3) 11629 # generate empty prev_state, if None is provided 11630 state_size = (2, batch_size, hidden_size, spatial_size_0, spatial_size_1) 11631 state = torch.zeros(state_size, device=input_data.device) 11632 if prev_state.size(0) == 0: 11633 state = state + 3 11634 state[:] = ( 11635 torch.ones(batch_size, hidden_size, spatial_size_0, spatial_size_1) 11636 * 3 11637 ) 11638 state = state + 3 11639 state[:] = ( 11640 torch.ones(batch_size, hidden_size, spatial_size_0, spatial_size_1) 11641 * 4 11642 ) 11643 else: 11644 state = state + 2 11645 return state 11646 11647 class Example(torch.nn.Module): 11648 def __init__(self, hidden_size): 11649 super().__init__() 11650 self.hidden_size = hidden_size 11651 11652 def forward(self, input_data, prev_state): 11653 prev_state = check_init(input_data, self.hidden_size, prev_state) 11654 return prev_state 11655 11656 model = Example(4) 11657 random_data = torch.rand((1, 5, 4, 4)) 11658 empty_tensor = torch.tensor([], dtype=torch.float).view(0, 0, 0, 0, 0) 11659 self.run_test( 11660 model, 11661 (random_data, empty_tensor), 11662 input_names=["random_data", "empty_tensor"], 11663 dynamic_axes={"random_data": [0, 1, 2, 3], "empty_tensor": [0, 1, 2, 3, 4]}, 11664 ) 11665 self.run_test(model, (random_data, empty_tensor), remained_onnx_input_idx=[]) 11666 11667 @skipIfUnsupportedMinOpsetVersion(11) 11668 def test_index_put_if_5(self): 11669 @torch.jit.script 11670 def check_init( 11671 input_data: Tensor, hidden_size: int, prev_state: Tensor 11672 ) -> Tuple[Tensor, Tensor]: 11673 batch_size = input_data.size(0) 11674 spatial_size_0 = input_data.size(2) 11675 spatial_size_1 = input_data.size(3) 11676 # generate empty prev_state, if None is provided 11677 state_size = (2, batch_size, hidden_size, spatial_size_0, spatial_size_1) 11678 state = torch.zeros(state_size, device=input_data.device) 11679 state_ref = state 11680 if prev_state.size(0) == 0: 11681 state[:] = ( 11682 torch.ones(batch_size, hidden_size, spatial_size_0, spatial_size_1) 11683 * 3 11684 ) 11685 state = state + 3 11686 state[:] = ( 11687 torch.ones(batch_size, hidden_size, spatial_size_0, spatial_size_1) 11688 * 4 11689 ) 11690 else: 11691 state = state + 2 11692 return state, state_ref 11693 11694 class Example(torch.nn.Module): 11695 def __init__(self, hidden_size): 11696 super().__init__() 11697 self.hidden_size = hidden_size 11698 11699 def forward(self, input_data, prev_state): 11700 prev_state, state_ref = check_init( 11701 input_data, self.hidden_size, prev_state 11702 ) 11703 return prev_state, state_ref 11704 11705 model = Example(4) 11706 random_data = torch.rand((1, 5, 4, 4)) 11707 empty_tensor = torch.tensor([], dtype=torch.float).view(0, 0, 0, 0, 0) 11708 self.run_test( 11709 model, 11710 (random_data, empty_tensor), 11711 input_names=["random_data", "empty_tensor"], 11712 dynamic_axes={"random_data": [0, 1, 2, 3], "empty_tensor": [0, 1, 2, 3, 4]}, 11713 ) 11714 self.run_test(model, (random_data, empty_tensor), remained_onnx_input_idx=[]) 11715 11716 @skipIfUnsupportedMinOpsetVersion(11) 11717 def test_list_append_in_block(self): 11718 class ListModel(torch.nn.Module): 11719 def forward(self, x, y): 11720 res = [] 11721 for i in range(x.size(0)): 11722 res.append(torch.matmul(x[i], y)) 11723 return res 11724 11725 model = torch.jit.script(ListModel()) 11726 x = torch.randn(16, 3, 4) 11727 y = torch.randn(4, 5) 11728 self.run_test(model, (x, y)) 11729 11730 @skipIfUnsupportedMinOpsetVersion(13) 11731 def test_list_append_in_nested_block(self): 11732 class ListModel(torch.nn.Module): 11733 def forward(self, x, y): 11734 res = [] 11735 for i in range(x.size(0)): 11736 for j in range(x.size(1)): 11737 res.append(torch.matmul(x[i][j], y)) 11738 return res 11739 11740 model = torch.jit.script(ListModel()) 11741 x = torch.randn(4, 4, 3, 4) 11742 y = torch.randn(4, 5) 11743 self.run_test(model, (x, y)) 11744 11745 @skipIfUnsupportedMinOpsetVersion(13) 11746 def test_list_pop_in_block(self): 11747 class ListModel(torch.nn.Module): 11748 def forward(self, x, y): 11749 res = [] 11750 elem = torch.matmul(x[0], y) 11751 for i in range(x.size(0)): 11752 res.append(torch.matmul(x[i], y)) 11753 for i in range(x.size(0)): 11754 elem = res.pop() 11755 for i in range(x.size(0)): 11756 res.append(torch.matmul(x[i], y)) 11757 elem = res.pop() 11758 return res.append(elem) 11759 11760 model = torch.jit.script(ListModel()) 11761 x = torch.randn(16, 3, 4) 11762 y = torch.randn(4, 5) 11763 self.run_test(model, (x, y)) 11764 11765 @skipIfUnsupportedMinOpsetVersion(13) 11766 def test_list_del_in_block(self): 11767 class ListModel(torch.nn.Module): 11768 def forward(self, x, y): 11769 res = [] 11770 elem = torch.matmul(x[0], y) 11771 for i in range(x.size(0)): 11772 res.append(torch.matmul(x[i], y)) 11773 for i in range(x.size(0)): 11774 del res[0] 11775 for i in range(x.size(0)): 11776 res.append(torch.matmul(x[i], y)) 11777 del res[0] 11778 return res.append(elem) 11779 11780 model = torch.jit.script(ListModel()) 11781 x = torch.randn(16, 3, 4) 11782 y = torch.randn(4, 5) 11783 self.run_test(model, (x, y)) 11784 11785 @skipIfUnsupportedMinOpsetVersion(11) 11786 def test_list_unpack(self): 11787 class ListModel(torch.nn.Module): 11788 def forward(self, x, y): 11789 res = [] 11790 elem = torch.matmul(x[0], y) 11791 for i in range(x.size(0)): 11792 res.append(torch.matmul(x[i], y)) 11793 a, b, c = res 11794 return a, b 11795 11796 model = torch.jit.script(ListModel()) 11797 x = torch.randn(3, 3, 4) 11798 y = torch.randn(4, 5) 11799 self.run_test(model, (x, y)) 11800 11801 @skipIfUnsupportedMinOpsetVersion(11) 11802 def test_index_put_inplace_ops(self): 11803 @torch.jit.script 11804 def check_init(input_data: Tensor, hidden_size: int) -> Tensor: 11805 batch_size = input_data.size(0) 11806 spatial_size_0 = input_data.size(2) 11807 spatial_size_1 = input_data.size(3) 11808 # generate empty prev_state, if None is provided 11809 state_size = (2, batch_size, hidden_size, spatial_size_0, spatial_size_1) 11810 state = torch.zeros(state_size, device=input_data.device) 11811 if input_data.size(0) == 1: 11812 state[1] += ( 11813 torch.ones(batch_size, hidden_size, spatial_size_0, spatial_size_1) 11814 * 2 11815 ) 11816 state[1] /= ( 11817 torch.ones(batch_size, hidden_size, spatial_size_0, spatial_size_1) 11818 * 3 11819 ) 11820 for i in range(input_data.size(0)): 11821 state[1] += torch.ones( 11822 batch_size, hidden_size, spatial_size_0, spatial_size_1 11823 ) 11824 state[1] /= ( 11825 torch.ones(batch_size, hidden_size, spatial_size_0, spatial_size_1) 11826 * i 11827 ) 11828 return state 11829 11830 class Example(torch.nn.Module): 11831 def __init__(self, hidden_size): 11832 super().__init__() 11833 self.hidden_size = hidden_size 11834 11835 def forward(self, input_data): 11836 state = check_init(input_data, self.hidden_size) 11837 return state 11838 11839 model = Example(10) 11840 random_data = torch.rand((1, 5, 30, 30)) 11841 self.run_test( 11842 model, 11843 (random_data), 11844 input_names=["random_data"], 11845 dynamic_axes={"random_data": [0, 1, 2, 3]}, 11846 ) 11847 self.run_test(model, (random_data), remained_onnx_input_idx=[]) 11848 11849 @skipIfUnsupportedMinOpsetVersion(11) 11850 def test_input_mask_model(self): 11851 class InputMaskModel(torch.nn.Module): 11852 def __init__(self, output_size): 11853 super().__init__() 11854 self.bias = torch.nn.Parameter( 11855 torch.empty(output_size, dtype=torch.float) 11856 ) 11857 with torch.no_grad(): 11858 self.bias.zero_() 11859 11860 def forward(self, model_input, y): 11861 input_mask = (model_input <= 0) | (model_input > 25) 11862 y[input_mask, :] = 0.0 11863 output = y + self.bias 11864 return output 11865 11866 output_size = 4 11867 m = InputMaskModel(output_size) 11868 x = torch.tensor([0, 4, 24, 25], dtype=torch.int64) 11869 y = torch.tensor( 11870 [ 11871 [0.1, 0.2, 0.3, 0.4], 11872 [0.1, 0.2, 0.3, 0.4], 11873 [0.1, 0.2, 0.3, 0.4], 11874 [0.1, 0.2, 0.3, 0.4], 11875 ], 11876 dtype=torch.float, 11877 ) 11878 self.run_test(m, (x, y)) 11879 11880 class InputMaskModel(torch.nn.Module): 11881 def __init__(self, output_size): 11882 super().__init__() 11883 11884 def forward(self, model_input_1, model_input_2, y): 11885 input_mask_1 = (model_input_1 <= 0) | (model_input_1 > 25) 11886 input_mask_2 = (model_input_2 < 1) | (model_input_2 >= 12) 11887 y[input_mask_1, input_mask_2] = 0.0 11888 return y 11889 11890 output_size = 4 11891 m = InputMaskModel(output_size) 11892 x1 = torch.tensor([0, 4, 24, 25], dtype=torch.int64) 11893 x2 = torch.tensor([0, 3, 12, 15], dtype=torch.int64) 11894 y = torch.tensor( 11895 [ 11896 [0.1, 0.2, 0.3, 0.4], 11897 [0.1, 0.2, 0.3, 0.4], 11898 [0.1, 0.2, 0.3, 0.4], 11899 [0.1, 0.2, 0.3, 0.4], 11900 ], 11901 dtype=torch.float, 11902 ) 11903 self.run_test(m, (x1, x2, y)) 11904 11905 @skipScriptTest() 11906 def test_unsafe_chunk(self): 11907 class ChunkModel(torch.nn.Module): 11908 def forward(self, x): 11909 return torch.unsafe_chunk(x, 3, dim=1) 11910 11911 model = ChunkModel() 11912 model.eval() 11913 x = torch.randn(1, 18) 11914 self.run_test(model, x, input_names=["x"]) 11915 11916 def test_symbolic_shape_inference(self): 11917 # ConstantOfShape is tested in test_embedding_bag 11918 # Tile is tested in test_repeat 11919 # test Shape, Reshape, Transpose, Gather 11920 class ShapeModel(torch.nn.Module): 11921 def forward(self, x, y): 11922 shape = x.size()[:3] + (-1,) # shape [4], ("batch", 3, 4, -1) 11923 y = y.reshape(shape) # batch, 3, 4, 10/batch 11924 return y.transpose(1, 2) 11925 11926 model = ShapeModel() 11927 model.eval() 11928 x = torch.ones(2, 3, 4, 5) 11929 y = torch.ones(3, 4, 5, 2) 11930 self.run_test( 11931 model, 11932 (x, y), 11933 input_names=["x", "y"], 11934 dynamic_axes={"x": [0, 1, 2, 3], "y": [0, 1, 2, 3]}, 11935 ) 11936 self.run_test(model, (x, y), remained_onnx_input_idx=[1]) 11937 11938 class ViewModel(torch.nn.Module): 11939 def forward(self, x): 11940 return x.view(-1) 11941 11942 model = ViewModel() 11943 model.eval() 11944 x = torch.tensor(2.0) 11945 self.run_test(model, (x,)) 11946 11947 # test prim::ListConstruct for Reshape input 1 11948 class ViewModel_2(torch.nn.Module): 11949 def forward(self, x): 11950 N, C, H, W = x.shape[0], x.shape[2], x.shape[3], x.shape[4] 11951 x1 = x.view(N, -1, C, H, W) 11952 x2 = x1.permute(0, 3, 4, 1, 2) 11953 return x2.reshape(N, -1, C) 11954 11955 model = ViewModel_2() 11956 model.eval() 11957 x = torch.ones(2, 3, 4, 5, 6) 11958 self.run_test(model, x) 11959 11960 @skipIfUnsupportedMinOpsetVersion(9) 11961 def test_symbolic_shape_inference_arange(self): 11962 # test Range 11963 class ArangeModel(torch.nn.Module): 11964 def forward(self, signal): 11965 frame_step = 2 11966 outer_dimensions = signal.size()[:-2] 11967 frames, frame_length = signal.size()[-2:] 11968 11969 subframe_length = signal.size()[0] 11970 subframe_step = frame_step // subframe_length 11971 subframes_per_frame = frame_length // subframe_length 11972 output_size = frame_step * (frames - 1) + frame_length 11973 output_subframes = output_size // subframe_length 11974 11975 frame = torch.arange(0, output_subframes) 11976 return frame 11977 11978 model = ArangeModel() 11979 model.eval() 11980 M, C, K, N = 1, 2, 3, 4 11981 x = torch.randint(5, (M, C, K, N)) 11982 y = torch.randint(5, (M, C + 1, K + 1, N + 1)) 11983 self.run_test(model, x, input_names=["x"], dynamic_axes={"x": [0, 1, 2, 3]}) 11984 self.run_test(model, x, remained_onnx_input_idx=[]) 11985 self.run_test( 11986 model, 11987 x, 11988 input_names=["x"], 11989 dynamic_axes={"x": [0, 1, 2, 3]}, 11990 additional_test_inputs=[(x,), (y,)], 11991 ) 11992 11993 @skipIfUnsupportedMinOpsetVersion(11) 11994 def test_symbolic_shape_inference_box(self): 11995 # test NonZero 11996 class BoxModel(torch.nn.Module): 11997 def forward(self, boxes): 11998 min_size = 1e-2 11999 ws, hs = boxes[:, 2] - boxes[:, 0], boxes[:, 3] - boxes[:, 1] 12000 keep = (ws >= min_size) & (hs >= min_size) 12001 keep = torch.where(keep)[0] 12002 return keep 12003 12004 model = BoxModel() 12005 model.eval() 12006 x = torch.ones(2, 4) 12007 y = torch.ones(3, 5) 12008 self.run_test(model, x) 12009 self.run_test( 12010 model, 12011 x, 12012 input_names=["x"], 12013 dynamic_axes={"x": [0, 1]}, 12014 additional_test_inputs=[(x,), (y,)], 12015 ) 12016 12017 @skipIfUnsupportedMinOpsetVersion(11) 12018 def test_symbolic_shape_inference_box_if(self): 12019 # test If 12020 class BoxIfModel(torch.nn.Module): 12021 def forward(self, boxes, scores): 12022 score_thresh = 0.0 12023 inds = torch.where(scores > score_thresh)[0] 12024 boxes_1 = boxes[inds] 12025 if boxes_1.numel() > 3: 12026 return boxes_1 12027 else: 12028 return boxes_1 * 2 12029 12030 model = BoxIfModel() 12031 model.eval() 12032 boxes = torch.ones(2, 4) 12033 scores = torch.ones(1, 4) 12034 self.run_test(model, (boxes, scores)) 12035 12036 @skipIfUnsupportedMinOpsetVersion(11) 12037 @skipDtypeChecking 12038 def test_symbolic_shape_inference_arange_2(self): 12039 # test Range 12040 class ArangeModel(torch.nn.Module): 12041 def forward(self, start): 12042 return torch.arange(start.size(0), 8.5, 1.5, dtype=torch.int64) 12043 12044 x = torch.randn(2, 3, 4) 12045 self.run_test( 12046 ArangeModel(), (x,), input_names=["x"], dynamic_axes={"x": [0, 1, 2]} 12047 ) 12048 self.run_test(ArangeModel(), (x,), remained_onnx_input_idx=[]) 12049 12050 class ArangeModel2(torch.nn.Module): 12051 def forward(self, start): 12052 return torch.arange(start.size(0), 8.5, 1.5, dtype=torch.double) 12053 12054 x = torch.randn(2, 3, 4) 12055 self.run_test( 12056 ArangeModel2(), (x,), input_names=["x"], dynamic_axes={"x": [0, 1, 2]} 12057 ) 12058 self.run_test(ArangeModel2(), (x,), remained_onnx_input_idx=[]) 12059 12060 @skipIfUnsupportedMinOpsetVersion(9) 12061 def test_symbolic_shape_inference_nonzero(self): 12062 class OneLikeModel(torch.nn.Module): 12063 def forward(self, x): 12064 ones = torch.ones_like( 12065 x, 12066 dtype=torch.float, 12067 layout=torch.strided, 12068 device=torch.device("cpu"), 12069 ) 12070 return torch.nonzero(ones) 12071 12072 x = torch.randn(2) 12073 self.run_test(OneLikeModel(), x, input_names=["x"], dynamic_axes={"x": [0]}) 12074 self.run_test(OneLikeModel(), x, remained_onnx_input_idx=[]) 12075 x = torch.randn(2, 3, 4) 12076 self.run_test( 12077 OneLikeModel(), x, input_names=["x"], dynamic_axes={"x": [0, 1, 2]} 12078 ) 12079 self.run_test(OneLikeModel(), x, remained_onnx_input_idx=[]) 12080 12081 class ZeroLikeModel(torch.nn.Module): 12082 def forward(self, x): 12083 zeros = torch.zeros_like( 12084 x, 12085 dtype=torch.float, 12086 layout=torch.strided, 12087 device=torch.device("cpu"), 12088 ) 12089 return torch.nonzero(zeros) 12090 12091 x = torch.randn(2) 12092 self.run_test(ZeroLikeModel(), x, input_names=["x"], dynamic_axes={"x": [0]}) 12093 self.run_test(ZeroLikeModel(), x, remained_onnx_input_idx=[]) 12094 x = torch.randn(2, 3, 4) 12095 self.run_test( 12096 ZeroLikeModel(), x, input_names=["x"], dynamic_axes={"x": [0, 1, 2]} 12097 ) 12098 self.run_test(ZeroLikeModel(), x, remained_onnx_input_idx=[]) 12099 12100 @skipIfUnsupportedMinOpsetVersion(9) 12101 def test_symbolic_shape_inference_expand_1(self): 12102 class ExpandModel(torch.nn.Module): 12103 def forward(self, x): 12104 return x.expand(4, 6, 2) 12105 12106 x = torch.randn(6, 1, requires_grad=True) 12107 self.run_test(ExpandModel(), (x,)) 12108 12109 @skipIfUnsupportedMinOpsetVersion(9) 12110 def test_symbolic_shape_inference_expand_2(self): 12111 class M(torch.nn.Module): 12112 def forward(self, x): 12113 input_shape = x.size() 12114 batch_size, seq_length = input_shape 12115 seq_ids = torch.arange(seq_length) 12116 causal_mask = ( 12117 seq_ids[None, None, :].repeat(batch_size, seq_length, 1) 12118 <= seq_ids[None, :, None] 12119 ) 12120 return causal_mask.transpose(0, 1) 12121 12122 x = torch.randn(3, 16) 12123 self.run_test(M(), (x,), input_names=["x"], dynamic_axes={"x": [0, 1]}) 12124 self.run_test(M(), (x,), remained_onnx_input_idx=[]) 12125 12126 @skipIfUnsupportedMinOpsetVersion(10) 12127 def test_symbolic_shape_inference_slice(self): 12128 class M(torch.nn.Module): 12129 def forward(self, x, position_bias): 12130 input_shape = x.size() 12131 batch_size, seq_length = input_shape 12132 position_bias = position_bias[:, :, -seq_length:, :] 12133 return position_bias.transpose(0, 1) 12134 12135 x = torch.randn(3, 16) 12136 position_bias = torch.randn(1, 3, 20, 8) 12137 self.run_test( 12138 M(), 12139 (x, position_bias), 12140 input_names=["x", "position_bias"], 12141 dynamic_axes={"x": [0, 1], "position_bias": [0, 1, 2, 3]}, 12142 ) 12143 self.run_test(M(), (x, position_bias), remained_onnx_input_idx=[1]) 12144 12145 def test_symbolic_shape_inference_slice_2(self): 12146 class M(torch.nn.Module): 12147 def forward(self, position_bias): 12148 position_bias = position_bias[:, :, -2:, :] 12149 return position_bias.transpose(0, 1) 12150 12151 position_bias = torch.randn(1, 3, 20, 8) 12152 self.run_test(M(), (position_bias,)) 12153 12154 @skipIfUnsupportedMinOpsetVersion(9) 12155 @skipScriptTest() 12156 def test_symbolic_shape_inference_time(self): 12157 input = torch.randn(RNN_SEQUENCE_LENGTH, BATCH_SIZE, RNN_INPUT_SIZE) 12158 h0 = torch.randn(1, BATCH_SIZE, RNN_HIDDEN_SIZE) 12159 c0 = torch.randn(1, BATCH_SIZE, RNN_HIDDEN_SIZE) 12160 model_lstm = torch.nn.LSTM( 12161 RNN_INPUT_SIZE, RNN_HIDDEN_SIZE, 1, bidirectional=False 12162 ) 12163 self.run_test( 12164 model_lstm, 12165 (input, (h0, c0)), 12166 input_names=["x", "y"], 12167 dynamic_axes={"x": [0, 1]}, 12168 ) 12169 model_gru = torch.nn.GRU( 12170 RNN_INPUT_SIZE, RNN_HIDDEN_SIZE, 1, bidirectional=False, bias=False 12171 ) 12172 self.run_test( 12173 model_gru, (input, h0), input_names=["x", "y"], dynamic_axes={"x": [0, 1]} 12174 ) 12175 model_rnn = torch.nn.RNN( 12176 RNN_INPUT_SIZE, RNN_HIDDEN_SIZE, 1, bidirectional=False, bias=False 12177 ) 12178 self.run_test( 12179 model_rnn, (input, h0), input_names=["x", "y"], dynamic_axes={"x": [0, 1]} 12180 ) 12181 12182 def test_symbolic_shape_inference_dynamic_axes(self): 12183 class M(torch.nn.Module): 12184 def forward(self, input_ids): 12185 input_shape = input_ids.size() 12186 input_ids = input_ids.view(-1, input_shape[-1]) 12187 return input_ids.transpose(0, 1) 12188 12189 x = torch.randn(3, 16) 12190 self.run_test( 12191 M(), 12192 (x,), 12193 input_names=["input_ids"], 12194 dynamic_axes={"input_ids": {0: "batch", 1: "sequence"}}, 12195 ) 12196 12197 @skipIfUnsupportedMinOpsetVersion(9) 12198 def test_hann_window_periodic(self): 12199 class HannWindowModule_Periodic(torch.nn.Module): 12200 def __init__(self) -> None: 12201 super().__init__() 12202 self.window_length = 0 12203 12204 def forward(self, x, window_length: int): 12205 self.window_length = window_length 12206 return torch.add( 12207 x, 12208 torch.hann_window( 12209 self.window_length, periodic=True, dtype=torch.float 12210 ), 12211 ) 12212 12213 win_length = 100 12214 x = torch.randn(win_length) 12215 12216 module = HannWindowModule_Periodic() 12217 self.run_test(module, (x, win_length)) 12218 12219 @skipIfUnsupportedMinOpsetVersion(9) 12220 def test_hann_window_not_periodic(self): 12221 class HannWindowModule_NotPeriodic(torch.nn.Module): 12222 def __init__(self) -> None: 12223 super().__init__() 12224 self.window_length = 0 12225 12226 def forward(self, x, window_length: int): 12227 self.window_length = window_length 12228 return torch.add( 12229 x, 12230 torch.hann_window( 12231 self.window_length, periodic=False, dtype=torch.float 12232 ), 12233 ) 12234 12235 win_length = 100 12236 x = torch.randn(win_length) 12237 12238 module = HannWindowModule_NotPeriodic() 12239 self.run_test(module, (x, win_length)) 12240 12241 @skipIfUnsupportedMinOpsetVersion(9) 12242 @skipScriptTest() 12243 def test_hann_window_default_values(self): 12244 class HannWindowModule(torch.nn.Module): 12245 def __init__(self) -> None: 12246 super().__init__() 12247 self.window_length = 0 12248 12249 def forward(self, x, window_length: int): 12250 import torch.nn.functional as F 12251 12252 self.window_length = window_length 12253 return torch.add(x, F.relu(torch.hann_window(self.window_length))) 12254 12255 win_length = 100 12256 x = torch.randn(win_length, dtype=torch.float) 12257 module = HannWindowModule() 12258 12259 output = module(x, win_length) 12260 self.run_test(module, (x, win_length)) 12261 12262 @skipIfUnsupportedMinOpsetVersion(12) 12263 def test_tensordot_dim_count(self): 12264 class M(torch.nn.Module): 12265 def forward(self, x, y): 12266 output = torch.tensordot(x, y, 2) 12267 return output 12268 12269 x = torch.randint(6, (7, 5, 3, 4)) 12270 y = torch.randint(6, (3, 4, 9, 2)) 12271 12272 self.run_test(M(), (x, y)) 12273 12274 @skipIfUnsupportedMinOpsetVersion(12) 12275 def test_tensordot_dim_list(self): 12276 class M(torch.nn.Module): 12277 def forward(self, x, y): 12278 output = torch.tensordot(x, y, ([1, -2, -1], [1, 0, 3])) 12279 return output 12280 12281 x = torch.randint(6, (7, 4, 3, 5, 2)) 12282 y = torch.randint(6, (5, 4, 4, 2, 6)) 12283 12284 self.run_test(M(), (x, y)) 12285 12286 @skipIfUnsupportedMinOpsetVersion(12) 12287 def test_tensordot_dynamic_dim(self): 12288 class M(torch.nn.Module): 12289 def forward(self, x, y): 12290 output = torch.tensordot(x, y, 2) 12291 return output 12292 12293 x = torch.randint(6, (7, 5, 3, 4)) 12294 y = torch.randint(6, (3, 4, 9, 2)) 12295 12296 new_x = torch.randint(6, (8, 6, 2, 5)) 12297 new_y = torch.randint(6, (2, 5, 3, 4)) 12298 12299 self.run_test( 12300 M(), 12301 (x, y), 12302 additional_test_inputs=[(new_x, new_y)], 12303 input_names=["input_x", "input_y"], 12304 dynamic_axes={"input_x": [0, 1, 2, 3], "input_y": [0, 1, 2, 3]}, 12305 ) 12306 12307 @skipIfUnsupportedMinOpsetVersion(9) 12308 def test_to_device(self): 12309 class M_ToDevice(torch.nn.Module): 12310 def forward(self, x, y): 12311 return x.to(y.device), y 12312 12313 class M_ToDeviceDtype(torch.nn.Module): 12314 def forward(self, x, y): 12315 return x.to(y.device, dtype=torch.long), y 12316 12317 x = torch.randn(6) 12318 y = torch.randn(6) 12319 12320 self.run_test(M_ToDevice(), (x, y)) 12321 self.run_test(M_ToDeviceDtype(), (x, y)) 12322 12323 @skipIfUnsupportedMinOpsetVersion(9) 12324 def test_fill(self): 12325 class FillModule(torch.nn.Module): 12326 def forward(self, x, filled_value: int): 12327 return x.fill_(filled_value) 12328 12329 x = torch.randn((4, 5, 6)) 12330 filled_value = 7 12331 self.run_test(FillModule(), (x, filled_value)) 12332 12333 class FillFloatModule(torch.nn.Module): 12334 def forward(self, x, filled_value: float): 12335 return x.fill_(filled_value) 12336 12337 x = torch.randn((4, 5, 6)) 12338 filled_value = 7.5 12339 self.run_test(FillFloatModule(), (x, filled_value)) 12340 12341 class FillScalarModule(torch.nn.Module): 12342 def forward(self, x): 12343 res = x + 2 12344 res.fill_(2.5) 12345 return res, x 12346 12347 x = torch.ones(2, 3, 4, dtype=torch.long) 12348 self.run_test(FillScalarModule(), x) 12349 12350 @skipIfUnsupportedMinOpsetVersion(9) 12351 def test_index_add_normal(self): 12352 class M(torch.nn.Module): 12353 def __init__(self, dim, index, updates): 12354 super().__init__() 12355 self.dim = dim 12356 self.index = index 12357 self.updates = updates 12358 12359 def forward(self, x): 12360 x.index_add_(self.dim, self.index, self.updates) 12361 return x 12362 12363 x = torch.ones(5, 1) 12364 updates = torch.tensor([[1], [4], [7], [3], [2]], dtype=torch.float) 12365 index = torch.tensor([0, 2, 3, 1, 4]) 12366 self.run_test(M(0, index, updates), (x,)) 12367 12368 x = torch.ones(1, 4, 3) 12369 updates = torch.tensor( 12370 [[[1, 5, 7], [2, 4, 5], [5, 5, 6], [2, 3, 4]]], dtype=torch.float 12371 ) 12372 index = torch.tensor([0, 2, 3, 1]) 12373 self.run_test(M(1, index, updates), (x,)) 12374 12375 updates = torch.tensor( 12376 [[[1, 2, 3], [4, 5, 6], [7, 8, 9], [2, 3, 4]]], dtype=torch.float 12377 ) 12378 index = torch.tensor([0, 2, 1]) 12379 self.run_test(M(2, index, updates), (x,)) 12380 12381 @skipIfUnsupportedMinOpsetVersion(9) 12382 def test_index_add_dim_size_differ(self): 12383 class M(torch.nn.Module): 12384 def __init__(self, dim, index, updates): 12385 super().__init__() 12386 self.dim = dim 12387 self.index = index 12388 self.updates = updates 12389 12390 def forward(self, x): 12391 x.index_add_(self.dim, self.index, self.updates) 12392 return x 12393 12394 x = torch.ones(1, 4, 3) 12395 updates = torch.tensor([[[1, 5, 7], [2, 4, 5], [5, 5, 6]]], dtype=torch.float) 12396 index = torch.tensor([0, 2, 1]) 12397 self.run_test(M(1, index, updates), (x,)) 12398 12399 @skipIfUnsupportedMinOpsetVersion(9) 12400 def test_index_add_in_loop(self): 12401 class M(torch.nn.Module): 12402 def __init__(self, dim, index, updates, loop_count): 12403 super().__init__() 12404 self.dim = dim 12405 self.index = index 12406 self.updates = updates 12407 self.loop_count = loop_count 12408 12409 def forward(self, x): 12410 for i in range(self.loop_count): 12411 x.index_add_(self.dim, self.index, self.updates) 12412 return x 12413 12414 x = torch.ones(1, 4, 3) 12415 updates = torch.tensor( 12416 [[[1, 5, 7], [2, 4, 5], [5, 5, 6], [2, 3, 4]]], dtype=torch.float 12417 ) 12418 index = torch.tensor([0, 2, 3, 1]) 12419 loop_count = torch.randint(20, (1,))[0].item() 12420 self.run_test(M(1, index, updates, loop_count), (x,)) 12421 12422 @skipIfUnsupportedMinOpsetVersion(9) 12423 def test_index_add_if(self): 12424 class M(torch.nn.Module): 12425 def __init__(self, dim, updates, index_true, index_false): 12426 super().__init__() 12427 self.dim = dim 12428 self.updates = updates 12429 self.index_true = index_true 12430 self.index_false = index_false 12431 12432 def forward(self, x, cond): 12433 if cond: 12434 x.index_add_(self.dim, self.index_true, self.updates) 12435 else: 12436 x.index_add_(self.dim, self.index_false, self.updates) 12437 return x 12438 12439 x = torch.ones(1, 4, 3) 12440 updates = torch.tensor( 12441 [[[1, 5, 7], [2, 4, 5], [5, 5, 6], [2, 3, 4]]], dtype=torch.float 12442 ) 12443 index_true = torch.tensor([0, 2, 3, 1]) 12444 index_false = torch.tensor([1, 0, 2, 3]) 12445 cond = torch.tensor(1, dtype=torch.bool) 12446 self.run_test( 12447 torch.jit.script(M(1, updates, index_true, index_false)), (x, cond) 12448 ) 12449 12450 @skipIfUnsupportedMinOpsetVersion(9) 12451 def test_index_add_dynamic_axes(self): 12452 class M(torch.nn.Module): 12453 def __init__(self, dim, index, updates): 12454 super().__init__() 12455 self.dim = dim 12456 self.index = index 12457 self.updates = updates 12458 12459 def forward(self, x): 12460 x.index_add_(self.dim, self.index, self.updates) 12461 return x 12462 12463 x = torch.ones(1, 4, 3) 12464 updates = torch.tensor( 12465 [[[1, 5, 7], [2, 4, 5], [5, 5, 6], [2, 3, 4]]], dtype=torch.float 12466 ) 12467 index = torch.tensor([0, 2, 3, 1]) 12468 12469 self.run_test( 12470 M(1, index, updates), 12471 (x,), 12472 input_names=["input_1"], 12473 dynamic_axes={"input_1": [0, 1]}, 12474 ) 12475 12476 def test_roll(self): 12477 class M(torch.nn.Module): 12478 def __init__(self, shifts, dims): 12479 super().__init__() 12480 self.shifts = shifts 12481 self.dims = dims 12482 12483 def forward(self, x): 12484 return torch.roll(x, self.shifts, self.dims) 12485 12486 x = torch.randn(2, 3, 4) 12487 self.run_test(M([1, 1], [1, 0]), (x,)) 12488 self.run_test(M([0, 1, 2], [1, 0, 2]), (x,)) 12489 self.run_test(M(2, 1), (x,)) 12490 self.run_test(M([-1, 3], [-2, -1]), (x,)) 12491 12492 def test_sum(self): 12493 class M(torch.nn.Module): 12494 def forward(self, x): 12495 return torch.sum(x) 12496 12497 x = torch.ones(12, 3) 12498 self.run_test(M(), (x,), input_names=["x"], dynamic_axes={"x": [0]}) 12499 12500 @skipShapeChecking 12501 def test_sum_empty_tensor(self): 12502 class M(torch.nn.Module): 12503 def forward(self, x): 12504 return x[0:0].sum(), x.sum() 12505 12506 x = torch.ones(12) 12507 self.run_test(M(), (x,)) 12508 12509 x = torch.ones(2, 0, 3) 12510 self.run_test(M(), (x,)) 12511 12512 x = torch.ones(0) 12513 self.run_test(M(), (x,)) 12514 12515 @skipIfUnsupportedMinOpsetVersion(11) 12516 def test_broad_cast_tensors(self): 12517 class M(torch.nn.Module): 12518 def forward(self, x, y): 12519 m = torch.broadcast_tensors(x, y) 12520 return m 12521 12522 x = torch.randint(5, (1,)) 12523 y = torch.randint(5, (5,)) 12524 12525 self.run_test(M(), (x, y)) 12526 12527 x = torch.randint(5, (4, 2, 1, 4)) 12528 y = torch.randint(5, (2, 3, 1)) 12529 12530 self.run_test(M(), (x, y)) 12531 12532 x = torch.randn(2, 1, 4) 12533 y = torch.randn(5, 2, 3, 1) 12534 12535 self.run_test(M(), (x, y)) 12536 12537 @skipIfUnsupportedMinOpsetVersion(14) 12538 def test_scaled_dot_product_attention(self): 12539 class M(torch.nn.Module): 12540 def forward(self, q, k, v): 12541 return torch.nn.functional.scaled_dot_product_attention( 12542 q, k, v, scale=1.0 12543 ) 12544 12545 # Parameters 12546 batch_size = 2 # Number of samples in the batch 12547 num_heads = 4 # Number of attention heads 12548 seq_length = 5 # Sequence length 12549 head_dim = 8 # Dimensionality of each head 12550 12551 # Create random query, key, and value tensors 12552 q = torch.randn(batch_size, num_heads, seq_length, head_dim) 12553 k = torch.randn(batch_size, num_heads, seq_length, head_dim) 12554 v = torch.randn(batch_size, num_heads, seq_length, head_dim) 12555 12556 self.run_test(M(), (q, k, v)) 12557 12558 @skipScriptTest() 12559 @skipIfUnsupportedMinOpsetVersion(11) 12560 def test_dist_normal(self): 12561 class M(torch.nn.Module): 12562 def forward(self, x, y): 12563 return torch.distributions.Normal(x, y).sample().size(0), x, y 12564 12565 self.run_test(M(), (torch.tensor([0.0]), torch.tensor([[1.0], [2.0]]))) 12566 self.run_test(M(), (torch.tensor([0.0]), torch.tensor([1.0]))) 12567 12568 self.run_test( 12569 M(), 12570 ( 12571 torch.tensor([[[0.0], [10.0]], [[2.0], [8.0]], [[2.0], [8.0]]]), 12572 torch.tensor([[1.0], [3.0]]), 12573 ), 12574 ) 12575 12576 @skipScriptTest() 12577 @skipIfUnsupportedMinOpsetVersion(11) 12578 def test_dist_normal_correctness(self): 12579 class M(torch.nn.Module): 12580 def forward(self, x, y): 12581 return torch.distributions.Normal(x, y).sample([20000]) 12582 12583 expected_mean = 5.0 12584 expected_std = 10.0 12585 12586 model_export = M() 12587 dummy_input = (torch.tensor([expected_mean]), torch.tensor([expected_std])) 12588 model_onnx = io.BytesIO() 12589 torch.onnx.export( 12590 model_export, dummy_input, model_onnx, opset_version=self.opset_version 12591 ) 12592 ort_sess = verification._ort_session(model_onnx) 12593 ort_out = verification._run_onnx(ort_sess, inputs=dummy_input) 12594 12595 actual_std = np.std(ort_out) 12596 actual_mean = np.mean(ort_out) 12597 12598 assert ( 12599 abs(abs(actual_mean) - expected_mean) <= expected_mean * 0.1 12600 ), "the gap of mean between ort outputs and expected one is unacceptable." 12601 assert ( 12602 abs(abs(actual_std) - expected_std) <= expected_std * 0.1 12603 ), "the gap of variance between ort outputs and expected one is unacceptable." 12604 12605 @skipScriptTest() 12606 @skipIfUnsupportedMinOpsetVersion(11) 12607 def test_nn_init_normal_correctness(self): 12608 expected_mean = 5.0 12609 expected_std = 10.0 12610 12611 class M(torch.nn.Module): 12612 def forward(self): 12613 x = torch.ones([]).new_empty(1, 400, 50) 12614 torch.nn.init.normal_(x, expected_mean, expected_std) 12615 return x 12616 12617 model_export = M() 12618 model_onnx = io.BytesIO() 12619 test_inputs = () 12620 torch.onnx.export( 12621 model_export, test_inputs, model_onnx, opset_version=self.opset_version 12622 ) 12623 ort_sess = verification._ort_session(model_onnx) 12624 ort_out = verification._run_onnx(ort_sess, inputs=test_inputs) 12625 12626 actual_std = np.std(ort_out) 12627 actual_mean = np.mean(ort_out) 12628 12629 assert ( 12630 abs(abs(actual_mean) - expected_mean) <= expected_mean * 0.1 12631 ), "the gap of mean between ort outputs and expected one is unacceptable." 12632 assert ( 12633 abs(abs(actual_std) - expected_std) <= expected_std * 0.1 12634 ), "the gap of variance between ort outputs and expected one is unacceptable." 12635 12636 @skipScriptTest() 12637 @skipIfUnsupportedMinOpsetVersion(11) 12638 def test_dist_uniform(self): 12639 class M(torch.nn.Module): 12640 def forward(self, x, y): 12641 return torch.distributions.Uniform(x, y).sample().size(0), x, y 12642 12643 self.run_test(M(), (torch.tensor([0.0]), torch.tensor([10.0]))) 12644 self.run_test(M(), (torch.tensor([[0.0], [6.0]]), torch.tensor([[1.0], [7.0]]))) 12645 self.run_test( 12646 M(), (torch.tensor([1.0]), torch.tensor([[10.0], [7.0], [9.0], [20.0]])) 12647 ) 12648 12649 @skipScriptTest() 12650 @skipIfUnsupportedMinOpsetVersion(11) 12651 def test_dist_uniform_correctness(self): 12652 class M(torch.nn.Module): 12653 def forward(self, x, y): 12654 return torch.distributions.Uniform(x, y).sample([10000]) 12655 12656 expected_min = 5.0 12657 expected_max = 10.0 12658 expected_mean = (expected_min + expected_max) / 2 12659 12660 model_export = M() 12661 dummy_input = (torch.tensor([expected_min]), torch.tensor([expected_max])) 12662 model_onnx = io.BytesIO() 12663 torch.onnx.export( 12664 model_export, dummy_input, model_onnx, opset_version=self.opset_version 12665 ) 12666 ort_sess = verification._ort_session(model_onnx) 12667 12668 ort_out = verification._run_onnx(ort_sess, inputs=dummy_input) 12669 actual_min = np.min(ort_out) 12670 actual_max = np.max(ort_out) 12671 actual_mean = np.mean(ort_out) 12672 12673 assert ( 12674 actual_min >= expected_min 12675 ), "the minimum value of ort outputs is out of scope." 12676 assert ( 12677 actual_max <= expected_max 12678 ), "the maximum value of ort outputs is out of scope." 12679 assert ( 12680 abs(actual_mean - expected_mean) <= expected_mean * 0.05 12681 ), "the mean value of ort outputs is out of scope." 12682 12683 @skipIfUnsupportedMinOpsetVersion(13) 12684 def test_sequence_to_int(self): 12685 class M(torch.nn.Module): 12686 def forward(self, x): 12687 result = torch.tensor([2 for i in range(x.size()[0])], dtype=torch.int) 12688 return x, result 12689 12690 x = torch.randn(10, 5) 12691 self.run_test(M(), (x,)) 12692 12693 @skipIfUnsupportedMinOpsetVersion(13) 12694 def test_sequence_to_float(self): 12695 class M(torch.nn.Module): 12696 def forward(self, x): 12697 result = torch.tensor( 12698 [1.1 for i in range(x.size()[0])], dtype=torch.float 12699 ) 12700 return x, result 12701 12702 x = torch.randn(10, 5) 12703 self.run_test(M(), (x,)) 12704 12705 @skipIfUnsupportedMinOpsetVersion(13) 12706 def test_sequence_to_bool(self): 12707 class M(torch.nn.Module): 12708 def forward(self, x): 12709 result = torch.tensor( 12710 [False for i in range(x.size()[0])], dtype=torch.bool 12711 ) 12712 return x, result 12713 12714 x = torch.randn(10, 5) 12715 self.run_test(M(), (x,)) 12716 12717 def test_tuple_output_from_if_with_raised_exception(self): 12718 class M(torch.nn.Module): 12719 def forward(self, t: Tensor) -> Tuple[Tensor, Tensor]: 12720 if float(t) < 0: 12721 raise Exception("Negative input") # noqa: TRY002 12722 else: 12723 return torch.zeros(5), torch.zeros(5) 12724 12725 x = torch.zeros(1) 12726 self.run_test(torch.jit.script(M()), (x,)) 12727 12728 # NOTE: For quantization tests, choose scale and zero point carefully 12729 # such that inputs and outputs do not always overflow/underflow. 12730 # Otherwise test results could be inaccurate. 12731 @skipIfUnsupportedMinOpsetVersion(10) 12732 def test_quantized_linear(self): 12733 model = torch.ao.nn.quantized.Linear(4, 8) 12734 # Set fixed weight to avoid flaky test. 12735 weight = torch.quantize_per_tensor( 12736 torch.arange(32, dtype=torch.float).view(8, 4), 0.5, 0, torch.qint8 12737 ) 12738 # Set non-zero bias. 12739 bias = torch.arange(8, dtype=torch.float) 12740 model.set_weight_bias(weight, bias) 12741 # Set fixed input to avoid flaky test. 12742 input = torch.randn(4, 4) 12743 input = torch.arange(16, dtype=torch.float).view(4, 4) - 8 12744 input_tensor = torch.quantize_per_tensor(input, 0.5, 128, torch.quint8) 12745 self.run_test(model, input_tensor) 12746 12747 @skipIfUnsupportedMinOpsetVersion(10) 12748 def test_quantized_conv1d(self): 12749 model = torch.ao.nn.quantized.Conv1d(16, 33, 3, stride=2) 12750 # Manually initialize model weight and bias to random numbers. 12751 # By default all zeros. 12752 q_weight = torch.quantize_per_tensor( 12753 torch.randn(33, 16, 3), 0.5, 0, torch.qint8 12754 ) 12755 bias = torch.arange(33).to(torch.float) - 16 12756 model.set_weight_bias(q_weight, bias) 12757 input = torch.randn(3, 16, 32) 12758 q_input = torch.quantize_per_tensor(input, 0.5, 128, torch.quint8) 12759 self.run_test(model, q_input) 12760 12761 @skipIfUnsupportedMinOpsetVersion(10) 12762 def test_quantized_conv2d(self): 12763 model = torch.ao.nn.quantized.Conv2d(16, 33, 3, stride=2) 12764 # Manually initialize model weight and bias to random numbers. 12765 # By default all zeros. 12766 q_weight = torch.quantize_per_tensor( 12767 torch.randn(33, 16, 3, 3), 0.5, 0, torch.qint8 12768 ) 12769 bias = torch.arange(33).to(torch.float) - 16 12770 model.set_weight_bias(q_weight, bias) 12771 input = torch.randn(3, 16, 32, 32) 12772 q_input = torch.quantize_per_tensor(input, 0.5, 128, torch.quint8) 12773 self.run_test(model, q_input) 12774 12775 @skipIfUnsupportedMinOpsetVersion(10) 12776 @skipIfQuantizationBackendQNNPack 12777 def test_quantized_conv3d(self): 12778 model = torch.ao.nn.quantized.Conv3d(16, 33, [2, 3, 4], stride=[3, 1, 2]) 12779 # Manually initialize model weight and bias to random numbers. 12780 # By default all zeros. 12781 q_weight = torch.quantize_per_tensor( 12782 torch.randn(33, 16, 2, 3, 4), 0.5, 0, torch.qint8 12783 ) 12784 bias = torch.arange(33).to(torch.float) - 16 12785 model.set_weight_bias(q_weight, bias) 12786 input = torch.randn(3, 16, 8, 8, 8) 12787 q_input = torch.quantize_per_tensor(input, 0.5, 128, torch.quint8) 12788 self.run_test(model, q_input) 12789 12790 @skipIfUnsupportedMinOpsetVersion(10) 12791 def test_quantized_adaptive_avg_pool2d(self): 12792 model = torch.nn.AdaptiveAvgPool2d((5, 7)) 12793 input = torch.randn(4, 3, 10, 14) 12794 q_input = torch.quantize_per_tensor(input, 0.2, 128, torch.quint8) 12795 self.run_test(model, q_input) 12796 12797 @skipIfUnsupportedMinOpsetVersion(10) 12798 def test_quantized_conv1d_relu(self): 12799 model = torch.ao.nn.intrinsic.quantized.ConvReLU1d(16, 33, 3, stride=2) 12800 # Manually initialize model weight and bias to random numbers. 12801 # By default all zeros. 12802 q_weight = torch.quantize_per_tensor( 12803 torch.randn(33, 16, 3), 0.5, 0, torch.qint8 12804 ) 12805 bias = torch.arange(33).to(torch.float) - 16 12806 model.set_weight_bias(q_weight, bias) 12807 input = torch.randn(3, 16, 32) 12808 q_input = torch.quantize_per_tensor(input, 0.5, 128, torch.quint8) 12809 self.run_test(model, q_input) 12810 12811 @skipIfUnsupportedMinOpsetVersion(10) 12812 def test_quantized_conv2d_relu(self): 12813 model = torch.ao.nn.intrinsic.quantized.ConvReLU2d(16, 33, 3, stride=2) 12814 # Manually initialize model weight and bias to random numbers. 12815 # By default all zeros. 12816 q_weight = torch.quantize_per_tensor( 12817 torch.randn(33, 16, 3, 3), 0.5, 0, torch.qint8 12818 ) 12819 bias = torch.arange(33).to(torch.float) - 16 12820 model.set_weight_bias(q_weight, bias) 12821 input = torch.randn(3, 16, 32, 32) 12822 q_input = torch.quantize_per_tensor(input, 0.5, 128, torch.quint8) 12823 self.run_test(model, q_input) 12824 12825 @skipIfUnsupportedMinOpsetVersion(10) 12826 @skipIfQuantizationBackendQNNPack 12827 def test_quantized_conv3d_relu(self): 12828 model = torch.ao.nn.intrinsic.quantized.ConvReLU3d( 12829 16, 33, [2, 3, 4], stride=[3, 1, 2] 12830 ) 12831 # Manually initialize model weight and bias to random numbers. 12832 # By default all zeros. 12833 q_weight = torch.quantize_per_tensor( 12834 torch.randn(33, 16, 2, 3, 4), 0.5, 0, torch.qint8 12835 ) 12836 bias = torch.arange(33).to(torch.float) - 16 12837 model.set_weight_bias(q_weight, bias) 12838 input = torch.randn(3, 16, 8, 8, 8) 12839 q_input = torch.quantize_per_tensor(input, 0.5, 128, torch.quint8) 12840 self.run_test(model, q_input) 12841 12842 @skipIfUnsupportedMinOpsetVersion(10) 12843 def test_quantized_conv_transpose1d(self): 12844 model = torch.ao.nn.quantized.ConvTranspose1d( 12845 16, 33, 3, output_padding=1, stride=2 12846 ) 12847 # Manually initialize model weight and bias to random numbers. 12848 # By default all zeros. 12849 q_weight = torch.quantize_per_tensor( 12850 torch.randn(16, 33, 3), 0.5, 0, torch.qint8 12851 ) 12852 bias = torch.arange(33).to(torch.float) - 16 12853 model.set_weight_bias(q_weight, bias) 12854 input = torch.randn(3, 16, 32) 12855 q_input = torch.quantize_per_tensor(input, 0.5, 128, torch.quint8) 12856 self.run_test(model, q_input) 12857 12858 @skipIfUnsupportedMinOpsetVersion(10) 12859 def test_quantized_conv_transpose2d(self): 12860 model = torch.ao.nn.quantized.ConvTranspose2d( 12861 16, 33, 3, output_padding=(0, 1), stride=2 12862 ) 12863 # Manually initialize model weight and bias to random numbers. 12864 # By default all zeros. 12865 q_weight = torch.quantize_per_tensor( 12866 torch.randn(16, 33, 3, 3), 0.5, 0, torch.qint8 12867 ) 12868 bias = torch.arange(33).to(torch.float) - 16 12869 model.set_weight_bias(q_weight, bias) 12870 input = torch.randn(3, 16, 32, 32) 12871 q_input = torch.quantize_per_tensor(input, 0.5, 128, torch.quint8) 12872 self.run_test(model, q_input) 12873 12874 @skipIfUnsupportedMinOpsetVersion(10) 12875 @skipIfQuantizationBackendQNNPack 12876 def test_quantized_conv_transpose3d(self): 12877 model = torch.ao.nn.quantized.ConvTranspose3d( 12878 16, 33, [2, 3, 4], output_padding=(0, 1, 2), stride=[3, 1, 2] 12879 ) 12880 # Manually initialize model weight and bias to random numbers. 12881 # By default all zeros. 12882 q_weight = torch.quantize_per_tensor( 12883 torch.randn(16, 33, 2, 3, 4), 0.5, 0, torch.qint8 12884 ) 12885 bias = torch.arange(33).to(torch.float) - 16 12886 model.set_weight_bias(q_weight, bias) 12887 input = torch.randn(3, 16, 8, 8, 8) 12888 q_input = torch.quantize_per_tensor(input, 0.5, 128, torch.quint8) 12889 self.run_test(model, q_input) 12890 12891 @common_utils.parametrize( 12892 "function_or_module", 12893 [ 12894 common_utils.subtest( 12895 torch.nn.ReLU(), 12896 name="relu", 12897 ), 12898 common_utils.subtest( 12899 torch.nn.LeakyReLU(), 12900 name="leaky_relu", 12901 ), 12902 common_utils.subtest( 12903 torch.ao.nn.quantized.LeakyReLU(2.0, 1), 12904 name="quantized_leaky_relu", 12905 ), 12906 common_utils.subtest( 12907 torch.ao.nn.quantized.Hardswish(2.0, 1), 12908 name="quantized_hardswish", 12909 ), 12910 common_utils.subtest( 12911 torch.nn.Sigmoid(), 12912 name="sigmoid", 12913 ), 12914 common_utils.subtest( 12915 torch.ao.nn.quantized.Sigmoid(2.0, 1), 12916 name="quantized_sigmoid", 12917 ), 12918 common_utils.subtest( 12919 torch.nn.Hardsigmoid(), 12920 name="hardsigmoid", 12921 ), 12922 common_utils.subtest( 12923 torch.nn.Tanh(), 12924 name="tanh", 12925 ), 12926 common_utils.subtest( 12927 torch.nn.Hardtanh(), 12928 name="hardtanh", 12929 ), 12930 common_utils.subtest( 12931 lambda x: torch.transpose(x, 0, 1), 12932 name="transpose", 12933 ), 12934 common_utils.subtest( 12935 lambda x: x.expand(2, 4, 2, 3), 12936 name="expand", 12937 ), 12938 common_utils.subtest( 12939 lambda x: x.view(1, 4, 6), 12940 name="view", 12941 ), 12942 common_utils.subtest( 12943 lambda x: x.select(1, 1), 12944 name="select", 12945 ), 12946 common_utils.subtest( 12947 torch.ao.nn.quantized.LayerNorm( 12948 [4, 2, 3], 12949 torch.nn.Parameter(torch.ones([4, 2, 3])), 12950 torch.nn.Parameter(torch.zeros([4, 2, 3])), 12951 2.0, 12952 1, 12953 ), 12954 name="layer_norm", 12955 ), 12956 common_utils.subtest( 12957 torch.ao.nn.quantized.InstanceNorm1d( 12958 2, 12959 torch.nn.Parameter(torch.ones(4)), 12960 torch.nn.Parameter(torch.zeros(4)), 12961 2.0, 12962 1, 12963 ), 12964 name="instance_norm", 12965 ), 12966 common_utils.subtest( 12967 torch.ao.nn.quantized.GroupNorm( 12968 2, 12969 4, 12970 torch.nn.Parameter(torch.zeros(4)), 12971 torch.nn.Parameter(torch.zeros(4)), 12972 2.0, 12973 1, 12974 ), 12975 name="group_norm", 12976 ), 12977 common_utils.subtest( 12978 lambda x: torch.as_strided(x, (2, 2), (1, 2)), 12979 name="as_strided", 12980 ), 12981 ], 12982 ) 12983 @skipScriptTest() 12984 @skipIfUnsupportedMinOpsetVersion(10) 12985 def test_quantized_unary_ops(self, function_or_module): 12986 input = torch.randn(1, 4, 2, 3) 12987 q_input = torch.quantize_per_tensor(input, 0.26, 128, torch.quint8) 12988 12989 class Model(torch.nn.Module): 12990 def __init__(self, function_or_module): 12991 super().__init__() 12992 self.function_or_module = function_or_module 12993 12994 def forward(self, x): 12995 return self.function_or_module(x) 12996 12997 self.run_test(Model(function_or_module), q_input) 12998 12999 @skipIfUnsupportedMinOpsetVersion(10) 13000 def test_quantized_flatten(self): 13001 class FlattenModel(torch.nn.Module): 13002 def forward(self, input): 13003 return torch.flatten(input) 13004 13005 x = torch.quantize_per_tensor(torch.randn(1, 2, 3, 4), 1, 0, torch.quint8) 13006 self.run_test(FlattenModel(), x) 13007 13008 @skipIfUnsupportedMinOpsetVersion(10) 13009 @skipScriptTest() # torch.jit.frontend.FrontendError: Cannot instantiate class 'QFunctional' in a script function: 13010 def test_quantized_cat_when_concatinating_the_same_tensor(self): 13011 class QuantizedSelfConcatenationModel(torch.nn.Module): 13012 def forward(self, x): 13013 return torch.ao.nn.quantized.QFunctional().cat((x, x), dim=1) 13014 13015 q_input = torch.quantize_per_tensor(torch.ones(2, 3), 0.26, 128, torch.quint8) 13016 self.run_test(QuantizedSelfConcatenationModel(), q_input) 13017 13018 @common_utils.parametrize( 13019 "x, y", 13020 [ 13021 common_utils.subtest( 13022 [ 13023 torch.quantize_per_tensor( 13024 torch.ones(2, 3), 0.26, 128, torch.quint8 13025 ), 13026 torch.quantize_per_tensor( 13027 torch.zeros(1, 3), 0.26, 128, torch.quint8 13028 ), 13029 ], 13030 name="different_shape", 13031 ), 13032 common_utils.subtest( 13033 [ 13034 torch.quantize_per_tensor( 13035 torch.ones(2, 3), 0.26, 128, torch.quint8 13036 ), 13037 torch.quantize_per_tensor(torch.ones(2, 3), 42, 1, torch.quint8), 13038 ], 13039 name="different_scale", 13040 ), 13041 common_utils.subtest( 13042 [ 13043 torch.quantize_per_tensor( 13044 torch.ones(2, 3), 0.26, 128, torch.quint8 13045 ), 13046 torch.quantize_per_tensor(torch.ones(2, 3), 0.26, 63, torch.quint8), 13047 ], 13048 name="different_zero_point", 13049 ), 13050 common_utils.subtest( 13051 [ 13052 torch.quantize_per_tensor( 13053 torch.ones(2, 3), 0.26, 128, torch.quint8 13054 ), 13055 torch.quantize_per_tensor(torch.ones(2, 3), 0.1, 63, torch.quint8), 13056 ], 13057 name="different_zero_point_and_scale", 13058 ), 13059 ], 13060 ) 13061 @skipIfUnsupportedMinOpsetVersion(10) 13062 @skipScriptTest() # torch.jit.frontend.FrontendError: Cannot instantiate class 'QFunctional' in a script function: 13063 def test_quantized_cat(self, x: torch.Tensor, y: torch.Tensor): 13064 class QuantizedConcatenationModel(torch.nn.Module): 13065 def forward(self, x, y): 13066 return torch.ao.nn.quantized.QFunctional().cat((x, y), dim=0) 13067 13068 self.run_test(QuantizedConcatenationModel(), (x, y)) 13069 13070 @skipIfUnsupportedMinOpsetVersion(10) 13071 # torch.jit.frontend.FrontendError: 13072 # Cannot instantiate class 'QFunctional' in a script function 13073 @skipScriptTest() 13074 def test_quantized_arithmetic_qfunctional(self): 13075 x = torch.quantize_per_tensor(torch.randn(3, 4), 0.2, 128, torch.quint8) 13076 y = torch.quantize_per_tensor(torch.randn(3, 4), 0.2, 128, torch.quint8) 13077 13078 class ArithmeticModel(torch.nn.Module): 13079 def forward(self, x, y): 13080 o = torch.ao.nn.quantized.QFunctional().add(x, y) 13081 o = torch.ao.nn.quantized.QFunctional().mul(o, x) 13082 return o 13083 13084 self.run_test(ArithmeticModel(), (x, y)) 13085 13086 @skipIfUnsupportedMinOpsetVersion(10) 13087 def test_quantized_arithmetic(self): 13088 x = torch.quantize_per_tensor(torch.randn(3, 4), 0.2, 128, torch.quint8) 13089 y = torch.quantize_per_tensor(torch.randn(3, 4), 0.2, 128, torch.quint8) 13090 13091 class ArithmeticModel2(torch.nn.Module): 13092 def forward(self, x, y): 13093 o = torch.ops.quantized.add(x, y, 0.4, 100) 13094 o = torch.ops.quantized.mul(o, x, 0.4, 100) 13095 return o 13096 13097 self.run_test(ArithmeticModel2(), (x, y)) 13098 13099 @skipIfUnsupportedMinOpsetVersion(10) 13100 def test_quantize_per_tensor(self): 13101 class Module(torch.nn.Module): 13102 def forward(self, x): 13103 return ( 13104 torch.quantize_per_tensor(x, 0.2, 0, torch.qint8), 13105 torch.quantize_per_tensor(x, 0.2, 128, torch.quint8), 13106 ) 13107 13108 x = torch.randn(4, 6) 13109 self.run_test(Module(), x) 13110 13111 @skipIfUnsupportedMinOpsetVersion(10) 13112 def test_dequantize(self): 13113 class Module(torch.nn.Module): 13114 def forward(self, x): 13115 return torch.dequantize(x) 13116 13117 x = torch.quantize_per_tensor(torch.randn(3, 4), 0.2, 0, torch.qint8) 13118 self.run_test(Module(), x) 13119 13120 @skipIfUnsupportedMinOpsetVersion(13) 13121 def test_qat_linear_per_channel(self): 13122 class M(torch.nn.Module): 13123 def __init__(self) -> None: 13124 super().__init__() 13125 self.quant = torch.ao.quantization.QuantStub() 13126 self.linear = torch.nn.Linear(4, 3) 13127 self.dequant = torch.ao.quantization.DeQuantStub() 13128 13129 def forward(self, x): 13130 x = self.quant(x) 13131 x = self.linear(x) 13132 x = self.dequant(x) 13133 return x 13134 13135 model = M() 13136 model.qconfig = torch.ao.quantization.get_default_qconfig("fbgemm") 13137 model = torch.ao.quantization.prepare_qat(model) 13138 # Set fixed weight and bias to avoid flaky test. 13139 model.linear.weight = torch.nn.Parameter( 13140 _construct_tensor_for_quantization_test((3, 4)) 13141 ) 13142 model.linear.bias = torch.nn.Parameter(torch.arange(3, dtype=torch.float)) 13143 model = torch.ao.quantization.convert(model) 13144 13145 # Set fixed input to avoid flaky test. 13146 input = _construct_tensor_for_quantization_test((4, 4), offset=-8) 13147 self.run_test(model, input) 13148 13149 @unittest.skip( 13150 "ORT fails with Validating no unexpected access using an invalid node_index on torch converted model" 13151 ) 13152 @skipIfUnsupportedMinOpsetVersion(13) 13153 def test_quantized_list_of_inputs_with_cat(self): 13154 class TestModel(torch.nn.Module): 13155 def __init__(self) -> None: 13156 super().__init__() 13157 self.quant = torch.ao.quantization.QuantStub() 13158 self.dequant = torch.ao.quantization.DeQuantStub() 13159 13160 def forward(self, x): 13161 x = self.quant(x) 13162 x = torch.cat([x, x], 1) 13163 x = self.dequant(x) 13164 return x 13165 13166 model = TestModel() 13167 model.qconfig = torch.ao.quantization.get_default_qconfig("fbgemm") 13168 model = torch.ao.quantization.prepare_qat(model) 13169 model = torch.ao.quantization.convert(model) 13170 x = torch.randn(2, 4, 6) 13171 self.run_test(model, x) 13172 13173 @skipIfUnsupportedMinOpsetVersion(13) 13174 def test_qat_relu(self): 13175 class M(torch.nn.Module): 13176 def __init__(self) -> None: 13177 super().__init__() 13178 self.quant = torch.ao.quantization.QuantStub() 13179 self.relu = torch.nn.ReLU() 13180 self.dequant = torch.ao.quantization.DeQuantStub() 13181 13182 def forward(self, x): 13183 x = self.quant(x) 13184 x = self.relu(x) 13185 x = self.dequant(x) 13186 return x 13187 13188 model = M() 13189 model.qconfig = torch.ao.quantization.get_default_qconfig("fbgemm") 13190 model = torch.ao.quantization.prepare_qat(model) 13191 model = torch.ao.quantization.convert(model) 13192 input = torch.randn(8, 4) 13193 self.run_test(model, input) 13194 13195 @skipIfUnsupportedMinOpsetVersion(13) 13196 def test_qat_conv2d(self): 13197 class M(torch.nn.Module): 13198 def __init__(self) -> None: 13199 super().__init__() 13200 self.quant = torch.ao.quantization.QuantStub() 13201 self.conv = torch.nn.Conv2d(4, 2, 3, stride=2) 13202 self.dequant = torch.ao.quantization.DeQuantStub() 13203 13204 def forward(self, x): 13205 x = self.quant(x) 13206 x = self.conv(x) 13207 x = self.dequant(x) 13208 return x 13209 13210 model = M() 13211 model.qconfig = torch.ao.quantization.get_default_qconfig("fbgemm") 13212 model = torch.ao.quantization.prepare_qat(model) 13213 # Set fixed weight and bias to avoid flaky test. 13214 model.conv.weight = torch.nn.Parameter( 13215 _construct_tensor_for_quantization_test((2, 4, 3, 3), max_val=2) 13216 ) 13217 model.conv.bias = torch.nn.Parameter(torch.tensor([0.0, 1.0])) 13218 model = torch.ao.quantization.convert(model) 13219 13220 # Set fixed input to avoid flaky test. 13221 input = _construct_tensor_for_quantization_test( 13222 (3, 4, 8, 8), offset=-384, max_val=12 13223 ) 13224 self.run_test(model, input) 13225 13226 @skipIfUnsupportedMinOpsetVersion(13) 13227 def test_qat_conv2d_relu(self): 13228 class M(torch.nn.Module): 13229 def __init__(self) -> None: 13230 super().__init__() 13231 self.quant = torch.ao.quantization.QuantStub() 13232 self.conv = torch.nn.Conv2d(4, 2, 3, stride=2) 13233 self.relu = torch.nn.ReLU() 13234 self.dequant = torch.ao.quantization.DeQuantStub() 13235 13236 def forward(self, x): 13237 x = self.quant(x) 13238 x = self.conv(x) 13239 x = self.relu(x) 13240 x = self.dequant(x) 13241 return x 13242 13243 model = M() 13244 model.qconfig = torch.ao.quantization.get_default_qconfig("fbgemm") 13245 model = torch.ao.quantization.prepare_qat(model) 13246 # Set fixed weight and bias to avoid flaky test. 13247 model.conv.weight = torch.nn.Parameter( 13248 _construct_tensor_for_quantization_test((2, 4, 3, 3), max_val=2) 13249 ) 13250 model.conv.bias = torch.nn.Parameter(torch.tensor([0.0, 1.0])) 13251 model = torch.ao.quantization.convert(model) 13252 13253 # Set fixed input to avoid flaky test. 13254 input = _construct_tensor_for_quantization_test( 13255 (3, 4, 8, 8), offset=-384, max_val=12 13256 ) 13257 self.run_test(model, input) 13258 13259 @skipIfUnsupportedMinOpsetVersion(13) 13260 def test_qat_conv2d_relu_fused(self): 13261 class M(torch.nn.Module): 13262 def __init__(self) -> None: 13263 super().__init__() 13264 self.quant = torch.ao.quantization.QuantStub() 13265 self.conv = torch.nn.Conv2d(4, 2, 3, stride=2) 13266 self.relu = torch.nn.ReLU() 13267 self.dequant = torch.ao.quantization.DeQuantStub() 13268 13269 def forward(self, x): 13270 x = self.quant(x) 13271 x = self.conv(x) 13272 x = self.relu(x) 13273 x = self.dequant(x) 13274 return x 13275 13276 model = M() 13277 model.qconfig = torch.ao.quantization.get_default_qconfig("fbgemm") 13278 model = torch.ao.quantization.fuse_modules(model.eval(), [["conv", "relu"]]) 13279 model = torch.ao.quantization.prepare_qat(model.train()) 13280 # Set fixed weight and bias to avoid flaky test. 13281 model.conv.weight = torch.nn.Parameter( 13282 _construct_tensor_for_quantization_test((2, 4, 3, 3), max_val=2) 13283 ) 13284 model.conv.bias = torch.nn.Parameter(torch.tensor([0.0, 1.0])) 13285 model = torch.ao.quantization.convert(model) 13286 13287 # Set fixed input to avoid flaky test. 13288 input = _construct_tensor_for_quantization_test( 13289 (3, 4, 8, 8), offset=-384, max_val=12 13290 ) 13291 self.run_test(model, input) 13292 13293 @skipIfUnsupportedMinOpsetVersion(13) 13294 def test_qat_linear_relu_fused(self): 13295 class M(torch.nn.Module): 13296 def __init__(self) -> None: 13297 super().__init__() 13298 self.quant = torch.ao.quantization.QuantStub() 13299 self.linear = torch.nn.Linear(4, 2) 13300 self.relu = torch.nn.ReLU() 13301 self.dequant = torch.ao.quantization.DeQuantStub() 13302 13303 def forward(self, x): 13304 x = self.quant(x) 13305 x = self.linear(x) 13306 x = self.relu(x) 13307 x = self.dequant(x) 13308 return x 13309 13310 model = M() 13311 model.qconfig = torch.ao.quantization.get_default_qconfig("fbgemm") 13312 model = torch.ao.quantization.fuse_modules(model.eval(), [["linear", "relu"]]) 13313 model = torch.ao.quantization.prepare_qat(model.train()) 13314 # Set fixed weight and bias to avoid flaky test. 13315 model.linear.weight = torch.nn.Parameter( 13316 _construct_tensor_for_quantization_test((2, 4), max_val=2) 13317 ) 13318 model.linear.bias = torch.nn.Parameter(torch.tensor([0.0, 1.0])) 13319 model = torch.ao.quantization.convert(model) 13320 13321 # Set fixed input to avoid flaky test. 13322 input = _construct_tensor_for_quantization_test((3, 4), offset=-384, max_val=12) 13323 self.run_test(model, input) 13324 13325 @skipIfUnsupportedMinOpsetVersion(10) 13326 def test_qat_maxpool2d(self): 13327 class M(torch.nn.Module): 13328 def __init__(self) -> None: 13329 super().__init__() 13330 self.quant = torch.ao.quantization.QuantStub() 13331 self.pool = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 13332 self.dequant = torch.ao.quantization.DeQuantStub() 13333 13334 def forward(self, x): 13335 x = self.quant(x) 13336 x = self.pool(x) 13337 x = self.dequant(x) 13338 return x 13339 13340 model = M() 13341 model.qconfig = torch.ao.quantization.get_default_qconfig("fbgemm") 13342 model = torch.ao.quantization.prepare_qat(model.train()) 13343 model = torch.ao.quantization.convert(model) 13344 13345 # Set fixed input to avoid flaky test. 13346 input = _construct_tensor_for_quantization_test((4, 4, 3, 2)) 13347 self.run_test(model, input) 13348 13349 @skipIfUnsupportedMinOpsetVersion(10) 13350 @skipScriptTest() # Scale and Zero-point must be a scalar in ORT:optimization 13351 def test_qat_avg_pool2d(self): 13352 model = torch.nn.Sequential( 13353 torch.ao.quantization.QuantStub(), 13354 torch.nn.AvgPool2d(kernel_size=3, stride=2, padding=1), 13355 torch.ao.quantization.DeQuantStub(), 13356 ) 13357 model.qconfig = torch.ao.quantization.get_default_qconfig("fbgemm") 13358 model = torch.ao.quantization.prepare_qat(model.train()) 13359 model = torch.ao.quantization.convert(model) 13360 input = _construct_tensor_for_quantization_test((4, 4, 3, 2)) 13361 self.run_test(model, input) 13362 13363 @skipIfUnsupportedMinOpsetVersion(11) 13364 def test_qat_upsample_nearest2d(self): 13365 model = torch.nn.Sequential( 13366 torch.ao.quantization.QuantStub(), 13367 torch.nn.UpsamplingNearest2d(scale_factor=1.5), 13368 torch.ao.quantization.DeQuantStub(), 13369 ) 13370 model.qconfig = torch.ao.quantization.get_default_qconfig("fbgemm") 13371 model = torch.ao.quantization.prepare_qat(model.train()) 13372 model = torch.ao.quantization.convert(model) 13373 input = _construct_tensor_for_quantization_test((4, 3, 2, 2)) 13374 self.run_test(model, input) 13375 13376 def test_0d_tensor_broadcast(self): 13377 class fn(torch.nn.Module): 13378 def forward(self, x, y): 13379 a = torch.add(x, y) 13380 b = torch.mul(y, y) 13381 return a + b 13382 13383 x = torch.ones(0) 13384 y = torch.ones(1) 13385 self.run_test(fn(), (x, y), input_names=["x", "y"], output_names=["output"]) 13386 13387 @skipIfUnsupportedMinOpsetVersion(9) 13388 def test_convolution_allow_tf32(self): 13389 class Module(torch.nn.Module): 13390 def __init__(self, allow_tf32): 13391 super().__init__() 13392 13393 self.allow_tf32 = allow_tf32 13394 weight = torch.rand(32, 3, 3, 3) 13395 self.weight = torch.nn.Parameter(weight) 13396 13397 def forward(self, x): 13398 if self.allow_tf32: 13399 return torch._convolution( 13400 x, 13401 self.weight, 13402 None, 13403 [2, 2], 13404 [0, 0], 13405 [1, 1], 13406 False, 13407 [0, 0], 13408 1, 13409 False, 13410 False, 13411 True, 13412 True, 13413 ) 13414 else: 13415 return torch._convolution( 13416 x, 13417 self.weight, 13418 None, 13419 [2, 2], 13420 [0, 0], 13421 [1, 1], 13422 False, 13423 [0, 0], 13424 1, 13425 False, 13426 False, 13427 True, 13428 ) 13429 13430 x = torch.randn(1, 3, 224, 224) 13431 self.run_test(Module(False), x, rtol=1e-3, atol=1e-6) 13432 self.run_test(Module(True), x, rtol=1e-3, atol=1e-6) 13433 13434 class AffineGridModule(torch.nn.Module): 13435 def __init__(self, align_corners) -> None: 13436 super().__init__() 13437 self.align_corners = align_corners 13438 13439 def forward(self, theta, size): 13440 return torch.nn.functional.affine_grid(theta, size, self.align_corners) 13441 13442 @skipIfUnsupportedMinOpsetVersion(20) 13443 @skipScriptTest() 13444 @common_utils.parametrize( 13445 "align_corners", 13446 (True, False), 13447 ) 13448 @common_utils.parametrize( 13449 "theta_params", 13450 ( 13451 ( 13452 10, 13453 np.array([0.3, -0.5]), 13454 np.array([1.5, 0.5]), 13455 ), 13456 ( 13457 60, 13458 np.array([-0.5, -0.5]), 13459 np.array([3.0, 5.5]), 13460 ), 13461 ), 13462 ) 13463 @common_utils.parametrize( 13464 "size", 13465 ([1, 1, 3, 2], [2, 10, 2, 3]), 13466 ) 13467 def test_affine_grid_2d(self, align_corners, theta_params, size): 13468 angle, translation, scale = theta_params 13469 theta = np.array([], dtype=np.float32) 13470 for _ in range(size[0]): 13471 angle_radian = (angle / 180.0) * np.pi 13472 theta = np.append( 13473 theta, 13474 [ 13475 np.cos(angle_radian) * scale[0], 13476 -np.sin(angle_radian), 13477 translation[0], 13478 np.sin(angle_radian), 13479 np.cos(angle_radian) * scale[1], 13480 translation[1], 13481 ], 13482 ) 13483 theta = theta.reshape(size[0], 2, 3) 13484 theta = torch.Tensor(theta) 13485 self.run_test(TestONNXRuntime.AffineGridModule(align_corners), (theta, size)) 13486 13487 @skipIfUnsupportedMinOpsetVersion(20) 13488 @skipScriptTest() 13489 @common_utils.parametrize( 13490 "align_corners", 13491 (True, False), 13492 ) 13493 @common_utils.parametrize( 13494 "theta_params", 13495 ( 13496 ( 13497 [10, 20], 13498 np.array([0.3, -0.5, 1.8]), 13499 np.array([1.5, 2.0, 0.5]), 13500 ), 13501 ( 13502 [60, -30], 13503 np.array([-0.5, -0.5, 0.3]), 13504 np.array([0.3, 3.0, 5.5]), 13505 ), 13506 ), 13507 ) 13508 @common_utils.parametrize( 13509 "size", 13510 ([1, 1, 3, 2, 2], [2, 10, 2, 2, 3]), 13511 ) 13512 def test_affine_grid_3d(self, align_corners, theta_params, size): 13513 angle, translation, scale = theta_params 13514 theta = np.array([], dtype=np.float32) 13515 for _ in range(size[0]): 13516 angle_radian_x = (angle[0] / 180.0) * np.pi 13517 angle_radian_y = (angle[1] / 180.0) * np.pi 13518 rot_matrix_x = np.array( 13519 [ 13520 [1, 0, 0], 13521 [0, np.cos(angle_radian_x), -np.sin(angle_radian_x)], 13522 [0, np.sin(angle_radian_x), np.cos(angle_radian_x)], 13523 ] 13524 ) 13525 rot_matrix_y = np.array( 13526 [ 13527 [np.cos(angle_radian_y), 0, np.sin(angle_radian_y)], 13528 [0, 1, 0], 13529 [-np.sin(angle_radian_y), 0, np.cos(angle_radian_y)], 13530 ] 13531 ) 13532 rot_matrix = np.matmul(rot_matrix_x, rot_matrix_y) 13533 rot_matrix = rot_matrix * scale.reshape(3, 1) 13534 rot_matrix = np.append(rot_matrix, np.reshape(translation, (3, 1)), axis=1) 13535 theta = np.append(theta, rot_matrix.flatten()) 13536 13537 theta = theta.reshape(size[0], 3, 4) 13538 theta = torch.Tensor(theta) 13539 self.run_test(TestONNXRuntime.AffineGridModule(align_corners), (theta, size)) 13540 13541 @skipIfUnsupportedMinOpsetVersion(16) 13542 @common_utils.parametrize( 13543 "mode", 13544 ("bilinear", "nearest", "bicubic"), 13545 ) 13546 @common_utils.parametrize( 13547 "padding_mode", 13548 ("zeros", "border", "reflection"), 13549 ) 13550 @common_utils.parametrize( 13551 "align_corners", 13552 (True, False), 13553 name_fn=lambda align_corners: str(align_corners), 13554 ) 13555 def test_grid_sample(self, mode, padding_mode, align_corners): 13556 n, c, d_in, h_in, w_in, d_out, h_out, w_out = 1, 1, 2, 3, 2, 3, 2, 4 13557 13558 atol_rtol = {} 13559 if (mode, padding_mode) == ("bicubic", "border"): 13560 if align_corners: 13561 atol_rtol.update({"atol": 0.3, "rtol": 0.4}) 13562 else: 13563 atol_rtol.update({"atol": 0.02, "rtol": 0.02}) 13564 input, grid = torch.randn(n, c, h_in, w_in), torch.randn(n, h_out, w_out, 2) 13565 13566 class GridSampleModule(torch.nn.Module): 13567 def __init__(self, mode, padding_mode, align_corners) -> None: 13568 super().__init__() 13569 self.mode, self.padding_mode, self.align_corners = ( 13570 mode, 13571 padding_mode, 13572 align_corners, 13573 ) 13574 13575 def forward(self, input, grid): 13576 return torch.nn.functional.grid_sample( 13577 input, grid, self.mode, self.padding_mode, self.align_corners 13578 ) 13579 13580 self.run_test( 13581 GridSampleModule(mode, padding_mode, align_corners), 13582 (input, grid), 13583 **atol_rtol, 13584 ) 13585 13586 # ONNX Opset 16 GridSample with 5D volumetric input is not supported. 13587 volumetric_input_tensor = torch.randn(n, c, d_in, h_in, w_in) 13588 volumetric_grid_tensor = torch.randn(n, d_out, h_out, w_out, 3) 13589 for mode, padding_mode, align_corners in itertools.product( 13590 ( 13591 "bilinear", 13592 "nearest", 13593 ), # PyTorch grid_sample "bicubic" mode does not support 5D volumetric input. 13594 ( 13595 "zeros", 13596 "border", 13597 "reflection", 13598 ), 13599 ( 13600 True, 13601 False, 13602 ), 13603 ): 13604 if self.opset_version < 20: 13605 with self.assertRaises( 13606 torch.onnx.OnnxExporterError, 13607 ): 13608 self.run_test( 13609 GridSampleModule(mode, padding_mode, align_corners), 13610 (volumetric_input_tensor, volumetric_grid_tensor), 13611 **atol_rtol, 13612 ) 13613 else: 13614 self.run_test( 13615 GridSampleModule(mode, padding_mode, align_corners), 13616 (volumetric_input_tensor, volumetric_grid_tensor), 13617 **atol_rtol, 13618 ) 13619 13620 class IfNoneInput(torch.nn.Module): 13621 def forward(self, x) -> Optional[Tensor]: 13622 y: Optional[Tensor] = None 13623 if x.size(0) > 1: 13624 y = x 13625 return y 13626 13627 class IfNoneOutput(torch.nn.Module): 13628 def forward(self, x) -> Optional[Tensor]: 13629 y: Optional[Tensor] = x 13630 if x.size(0) > 1: 13631 y = None 13632 return y 13633 13634 class LoopNoneInput(torch.nn.Module): 13635 def forward(self, x) -> Optional[Tensor]: 13636 y: Optional[Tensor] = None 13637 for _ in range(x.size(0)): 13638 y = x 13639 return y 13640 13641 class LoopNoneOutput(torch.nn.Module): 13642 def forward(self, x) -> Optional[Tensor]: 13643 y: Optional[Tensor] = x 13644 for _ in range(x.size(0)): 13645 y = None 13646 return y 13647 13648 @common_utils.parametrize( 13649 "module_class", 13650 (IfNoneOutput, IfNoneInput, LoopNoneOutput, LoopNoneInput), 13651 name_fn=lambda module_class: module_class.__name__, 13652 ) 13653 @common_utils.parametrize("x_size", (0, 1), name_fn=lambda x_size: str(x_size)) 13654 @skipTraceTest() 13655 @skipIfUnsupportedMinOpsetVersion(16) 13656 def test_optional_output(self, module_class: Type[torch.nn.Module], x_size: int): 13657 # Need scripting to preserve control flow for this test to be 13658 # meaningful. 13659 model = torch.jit.script(module_class()) 13660 f = io.BytesIO() 13661 x = torch.ones(x_size) 13662 dynamic_axis_name = "condition" 13663 torch.onnx.export( 13664 model, 13665 x, 13666 f, 13667 opset_version=self.opset_version, 13668 # Ensure condition is not constant 13669 dynamic_axes={"x": {0: dynamic_axis_name}}, 13670 input_names=["x"], 13671 ) 13672 exported = onnx.load_from_string(f.getvalue()) 13673 expected_elem_type = torch.onnx.JitScalarType.from_value(x).onnx_type() 13674 expected_output_type = onnx.helper.make_optional_type_proto( 13675 onnx.helper.make_tensor_type_proto(expected_elem_type, (dynamic_axis_name,)) 13676 ) 13677 self.assertEqual(expected_output_type, exported.graph.output[0].type) 13678 for node in exported.graph.node: 13679 # Both branches output types should match. 13680 if node.op_type == "If": 13681 for attr in node.attribute: 13682 if attr.name in ("then_branch", "else_branch"): 13683 self.assertEqual(expected_output_type, attr.g.output[0].type) 13684 13685 self.run_test( 13686 module_class(), 13687 x, 13688 # Ensure condition is not constant 13689 dynamic_axes={"x": {0: dynamic_axis_name}}, 13690 input_names=["x"], 13691 ) 13692 13693 @skipTraceTest() 13694 @skipIfUnsupportedMinOpsetVersion(16) 13695 def test_uninitialized_optional(self): 13696 class Module(torch.nn.Module): 13697 def forward(self, y: Optional[Tensor]) -> Optional[Tensor]: 13698 if y is not None: 13699 if y.shape[1] < 5: 13700 if y.size(0) == 1: 13701 y = y + 4 13702 else: 13703 return y 13704 return y 13705 13706 self.run_test( 13707 Module(), 13708 torch.ones((3, 4), dtype=torch.int), 13709 dynamic_axes={"y": {0: "y0", 1: "y1"}}, 13710 input_names=["y"], 13711 ) 13712 13713 @skipIfUnsupportedMinOpsetVersion(9) 13714 def test_device_eq(self): 13715 class M(torch.nn.Module): 13716 def forward(self, a): 13717 # exercise both Tensor.device (prim::device) 13718 # and torch.device (prim::Constant). 13719 if a.device != torch.device("cpu"): 13720 return a 13721 return torch.zeros_like(a) 13722 13723 mod = torch.jit.script(M()) # preserve control flow 13724 13725 self.run_test( 13726 mod, 13727 # In order for the ONNX model behavior to match the torch model, we 13728 # need to construct input that has the same device that is checked for 13729 # in forward(). In ONNX there is no such thing as a device, so the if 13730 # condition is always false. 13731 torch.randn(3, 3, device="cpu"), 13732 # Force dynamic axes so that the output shape depends on the input. 13733 # Otherwise the entire model will just return a constant and not have 13734 # any inputs. 13735 input_names=["a"], 13736 dynamic_axes={"a": {0: "a0"}}, 13737 ) 13738 13739 @skipIfUnsupportedMinOpsetVersion(9) 13740 def test_lerp(self): 13741 class LerpModel(torch.nn.Module): 13742 def forward(self, x): 13743 return ( 13744 x.lerp(torch.full_like(x, 10), 0.4), 13745 x.lerp(torch.full_like(x, 20), 0.7), 13746 x.lerp(torch.full_like(x, 30), torch.tensor(0.4)), 13747 x.lerp(torch.full_like(x, 40), x / 10.0), 13748 x.lerp(torch.tensor(10.0), x / 10.0), 13749 x.lerp(torch.tensor(10.0), 0.4), 13750 x.lerp(torch.tensor(10.0), torch.tensor(0.4)), 13751 ) 13752 13753 self.run_test(LerpModel(), torch.rand(5, 4, 3)) 13754 13755 @common_utils.parametrize("input_dtype", [torch.cfloat, torch.float]) 13756 @skipIfUnsupportedMinOpsetVersion(9) 13757 def test_print_tensor_within_torch_nn_module(self, input_dtype: torch.dtype): 13758 class PrintTensorOnMyModel(torch.nn.Module): 13759 def forward(self, x): 13760 # 'print' has side effect calling 'resolve_conj' and 'resolve_neg'. 13761 x_firsts = x[:, 0] 13762 print(f"x_firsts: {x_firsts}") 13763 # 'tolist' has side effect calling 'resolve_conj' and 'resolve_neg'. 13764 # Annotation added to pass torch script. 13765 _: List[float] = x.tolist() 13766 return x_firsts 13767 13768 m = PrintTensorOnMyModel() 13769 x = torch.randn(10, 5, dtype=input_dtype) 13770 if input_dtype == torch.cfloat: 13771 with self.assertRaises(RuntimeError): 13772 self.run_test( 13773 m, 13774 x, 13775 ) 13776 else: 13777 self.run_test( 13778 m, 13779 x, 13780 ) 13781 13782 @skipScriptTest() 13783 @skipIfUnsupportedMinOpsetVersion(16) 13784 @unittest.skipIf( 13785 not torch.hub._check_module_exists("torch_geometric"), 13786 "torch_geometric not installed.", 13787 ) 13788 def test_sage_conv(self): 13789 from torch_geometric import nn as torch_geometric_nn 13790 13791 # Input 13792 coords0 = torch.randn(1, 6) 13793 coords1 = torch.randn(1, 6) 13794 coords = torch.transpose(torch.cat((coords0, coords1), dim=0), 0, 1) 13795 adj = torch_geometric_nn.knn_graph(coords, k=2, batch=None, loop=True) 13796 edge_from = adj[0:1, :] 13797 edge_to = adj[1:, :] 13798 inputs = (coords0, coords1, edge_from, edge_to) 13799 13800 class MySAGEConv(torch.nn.Module): 13801 def __init__(self) -> None: 13802 super().__init__() 13803 self.SAGEConvBlock1 = torch_geometric_nn.SAGEConv( 13804 2, 512, normalize=True 13805 ) 13806 self.bano1 = torch_geometric_nn.BatchNorm(512) 13807 self.relu = torch.nn.ReLU() 13808 self.dense1 = torch.nn.Seq(Lin(512, 1)) # noqa: F821 13809 self.sigmoid = torch.nn.Sigmoid() 13810 13811 def forward(self, coords0, coords1, edge_from, edge_to): 13812 adj = torch.cat((edge_from, edge_to), dim=0) 13813 gra = torch.transpose(torch.cat((coords0, coords1), dim=0), 0, 1) 13814 x1 = self.SAGEConvBlock1(gra, edge_index=adj) 13815 x = torch.unsqueeze(torch.sum(x1), dim=0) 13816 return x 13817 13818 input_names = ["coords0", "coords1", "edge_from", "edge_to"] 13819 output_names = ["outputs"] 13820 dynamic_axes = { 13821 "coords0": {0: "batch_size", 1: "features"}, 13822 "coords1": {0: "batch_size", 1: "features"}, 13823 "edge_from": {0: "batch_size", 1: "features"}, 13824 "edge_to": {0: "batch_size", 1: "features"}, 13825 "outputs": {0: "batch_size"}, 13826 } 13827 self.run_test( 13828 MySAGEConv(), 13829 inputs, 13830 input_names=input_names, 13831 output_names=output_names, 13832 dynamic_axes=dynamic_axes, 13833 ) 13834 13835 # Cannot export with older opsets because of "ConstantFill" op 13836 # ConstantFill was a temp op removed at opset 8. This is no longer supported by onnxruntime 13837 # There are still some issues prevent us from enabling script test for these scenarios: 13838 # test_gru_*: 13839 # Operator aten::as_tensor is not supported by exporter yet. 13840 # - https://msdata.visualstudio.com/Vienna/_workitems/edit/1055382 13841 # Operator aten::_pack_padded_sequence is not supported by exporter yet. 13842 # - https://msdata.visualstudio.com/Vienna/_workitems/edit/1055384 13843 # test_elman_*: 13844 # Compiling in script mode fails with errors like: 13845 # torch.jit.frontend.UnsupportedNodeError: annotated assignments 13846 # without assigned value aren't supported 13847 # - https://msdata.visualstudio.com/Vienna/_workitems/edit/1160723 13848 # test_lstm_*: 13849 # Compiling in script mode fails with errors like: 13850 # RuntimeError: Arguments for call are not valid. 13851 # - https://msdata.visualstudio.com/Vienna/_workitems/edit/1160723 13852 @skipScriptTest() 13853 @skipIfUnsupportedMinOpsetVersion(9) 13854 @common_utils.parametrize( 13855 "name, nonlinearity", 13856 [ 13857 ("elman", "relu"), 13858 ("elman", "tanh"), 13859 ("lstm", None), 13860 ("gru", None), 13861 ], 13862 ) 13863 @common_utils.parametrize(**_parametrize_rnn_args("layers")) 13864 @common_utils.parametrize(**_parametrize_rnn_args("bidirectional")) 13865 @common_utils.parametrize(**_parametrize_rnn_args("initial_state")) 13866 @common_utils.parametrize(**_parametrize_rnn_args("packed_sequence")) 13867 @common_utils.parametrize(**_parametrize_rnn_args("dropout")) 13868 def test_rnn(self, *args, **kwargs): 13869 self._dispatch_rnn_test(*args, **kwargs) 13870 13871 13872if __name__ == "__main__": 13873 common_utils.TestCase._default_dtype_check_enabled = True 13874 common_utils.run_tests() 13875