1# Owner(s): ["module: onnx"] 2 3import copy 4import functools 5import io 6import re 7import warnings 8from typing import Callable 9 10import onnx 11 12import parameterized 13import pytorch_test_common 14import torchvision 15from autograd_helper import CustomFunction as CustomFunction2 16from pytorch_test_common import ( 17 skipIfNoCuda, 18 skipIfUnsupportedMaxOpsetVersion, 19 skipIfUnsupportedMinOpsetVersion, 20) 21 22import torch 23import torch.onnx 24import torch.utils.cpp_extension 25from torch.onnx import _constants, OperatorExportTypes, TrainingMode, utils 26from torch.onnx._globals import GLOBALS 27from torch.onnx.symbolic_helper import _unpack_list, parse_args 28from torch.testing._internal import common_utils 29from torch.testing._internal.common_utils import skipIfNoLapack 30 31 32def _remove_test_environment_prefix_from_scope_name(scope_name: str) -> str: 33 """Remove test environment prefix added to module. 34 35 Remove prefix to normalize scope names, since different test environments add 36 prefixes with slight differences. 37 38 Example: 39 40 >>> _remove_test_environment_prefix_from_scope_name( 41 >>> "test_utility_funs.M" 42 >>> ) 43 "M" 44 >>> _remove_test_environment_prefix_from_scope_name( 45 >>> "test_utility_funs.test_abc.<locals>.M" 46 >>> ) 47 "M" 48 >>> _remove_test_environment_prefix_from_scope_name( 49 >>> "__main__.M" 50 >>> ) 51 "M" 52 """ 53 prefixes_to_remove = ["test_utility_funs", "__main__"] 54 for prefix in prefixes_to_remove: 55 scope_name = re.sub(f"{prefix}\\.(.*?<locals>\\.)?", "", scope_name) 56 return scope_name 57 58 59class _BaseTestCase(pytorch_test_common.ExportTestCase): 60 def _model_to_graph( 61 self, 62 model, 63 input, 64 do_constant_folding=True, 65 training=TrainingMode.EVAL, 66 operator_export_type=OperatorExportTypes.ONNX, 67 input_names=None, 68 dynamic_axes=None, 69 ): 70 torch.onnx.utils._setup_trace_module_map(model, False) 71 if training == torch.onnx.TrainingMode.TRAINING: 72 model.train() 73 elif training == torch.onnx.TrainingMode.EVAL: 74 model.eval() 75 utils._validate_dynamic_axes(dynamic_axes, model, None, None) 76 graph, params_dict, torch_out = utils._model_to_graph( 77 model, 78 input, 79 do_constant_folding=do_constant_folding, 80 _disable_torch_constant_prop=True, 81 operator_export_type=operator_export_type, 82 training=training, 83 input_names=input_names, 84 dynamic_axes=dynamic_axes, 85 ) 86 return graph, params_dict, torch_out 87 88 89@common_utils.instantiate_parametrized_tests 90class TestUnconvertibleOps(pytorch_test_common.ExportTestCase): 91 """Unit tests for the `unconvertible_ops` function.""" 92 93 def setUp(self): 94 class EinsumModule(torch.nn.Module): 95 def forward(self, x): 96 return torch.einsum("ii", x) 97 98 self.einsum_module = EinsumModule() 99 100 def test_it_returns_graph_and_unconvertible_ops_at_lower_opset_version(self): 101 x = torch.randn(4, 4) 102 103 # Einsum is supported since opset 12. It should be unconvertible at opset 9. 104 graph, unconvertible_ops = utils.unconvertible_ops( 105 self.einsum_module, (x,), opset_version=9 106 ) 107 nodes = graph.nodes() 108 self.assertEqual(next(nodes).kind(), "prim::Constant") 109 self.assertEqual(next(nodes).kind(), "prim::ListConstruct") 110 self.assertEqual(next(nodes).kind(), "prim::Constant") 111 self.assertEqual(next(nodes).kind(), "aten::einsum") 112 self.assertEqual(unconvertible_ops, ["aten::einsum"]) 113 114 @common_utils.parametrize( 115 "jit_function", 116 [ 117 common_utils.subtest( 118 functools.partial(torch.jit.trace, example_inputs=torch.randn(4, 4)), 119 name="traced", 120 ), 121 common_utils.subtest(torch.jit.script, name="scripted"), 122 ], 123 ) 124 def test_it_returns_unconvertible_ops_at_lower_opset_version_for_jit_module( 125 self, jit_function: Callable 126 ): 127 module = jit_function(self.einsum_module) 128 x = torch.randn(4, 4) 129 130 # Einsum is supported since opset 12. It should be unconvertible at opset 9. 131 _, unconvertible_ops = utils.unconvertible_ops(module, (x,), opset_version=9) 132 self.assertEqual(unconvertible_ops, ["aten::einsum"]) 133 134 @common_utils.parametrize( 135 "jit_function", 136 [ 137 common_utils.subtest(lambda x: x, name="nn_module"), 138 common_utils.subtest( 139 functools.partial(torch.jit.trace, example_inputs=torch.randn(4, 4)), 140 name="traced", 141 ), 142 common_utils.subtest(torch.jit.script, name="scripted"), 143 ], 144 ) 145 def test_it_returns_empty_list_when_all_ops_convertible( 146 self, jit_function: Callable 147 ): 148 module = jit_function(self.einsum_module) 149 x = torch.randn(4, 4) 150 151 # Einsum is supported since opset 12 152 _, unconvertible_ops = utils.unconvertible_ops(module, (x,), opset_version=12) 153 self.assertEqual(unconvertible_ops, []) 154 155 def test_it_returns_empty_list_when_model_contains_supported_inplace_ops(self): 156 class SkipConnectionModule(torch.nn.Module): 157 def forward(self, x): 158 out = x 159 out += x 160 out = torch.nn.functional.relu(out, inplace=True) 161 return out 162 163 module = SkipConnectionModule() 164 x = torch.randn(4, 4) 165 _, unconvertible_ops = utils.unconvertible_ops(module, (x,), opset_version=13) 166 self.assertEqual(unconvertible_ops, []) 167 168 169@parameterized.parameterized_class( 170 [ 171 {"opset_version": opset} 172 for opset in range( 173 _constants.ONNX_BASE_OPSET, 174 _constants.ONNX_TORCHSCRIPT_EXPORTER_MAX_OPSET + 1, 175 ) 176 ], 177 class_name_func=lambda cls, 178 num, 179 params_dict: f"{cls.__name__}_opset_{params_dict['opset_version']}", 180) 181class TestUtilityFuns(_BaseTestCase): 182 opset_version = None 183 184 def test_is_in_onnx_export(self): 185 test_self = self 186 187 class MyModule(torch.nn.Module): 188 def forward(self, x): 189 test_self.assertTrue(torch.onnx.is_in_onnx_export()) 190 raise ValueError 191 return x + 1 192 193 x = torch.randn(3, 4) 194 f = io.BytesIO() 195 try: 196 torch.onnx.export(MyModule(), x, f, opset_version=self.opset_version) 197 except ValueError: 198 self.assertFalse(torch.onnx.is_in_onnx_export()) 199 200 def test_validate_dynamic_axes_invalid_input_output_name(self): 201 with warnings.catch_warnings(record=True) as w: 202 warnings.simplefilter("always") 203 utils._validate_dynamic_axes( 204 {"input1": {}, "output": {}, "invalid_name1": {}, "invalid_name2": {}}, 205 None, 206 ["input1", "input2"], 207 ["output"], 208 ) 209 messages = [str(warning.message) for warning in w] 210 self.assertIn( 211 "Provided key invalid_name1 for dynamic axes is not a valid input/output name", 212 messages, 213 ) 214 self.assertIn( 215 "Provided key invalid_name2 for dynamic axes is not a valid input/output name", 216 messages, 217 ) 218 self.assertEqual(len(messages), 2) 219 220 @skipIfUnsupportedMinOpsetVersion(11) 221 def test_split_to_slice(self): 222 class SplitModule(torch.nn.Module): 223 def forward(self, x, y, t): 224 splits = (x.size(1), y.size(1)) 225 out, out2 = torch.split(t, splits, dim=1) 226 return out, out2 227 228 GLOBALS.export_onnx_opset_version = self.opset_version 229 GLOBALS.operator_export_type = OperatorExportTypes.ONNX 230 x = torch.randn(2, 3) 231 y = torch.randn(2, 4) 232 t = torch.randn(2, 7) 233 graph, _, _ = self._model_to_graph( 234 SplitModule(), 235 (x, y, t), 236 input_names=["x", "y", "t"], 237 dynamic_axes={"x": [0, 1], "y": [0, 1], "t": [0, 1]}, 238 ) 239 for node in graph.nodes(): 240 self.assertNotEqual(node.kind(), "onnx::SplitToSequence") 241 242 def test_constant_fold_transpose(self): 243 class TransposeModule(torch.nn.Module): 244 def forward(self, x): 245 a = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) 246 b = torch.transpose(a, 1, 0) 247 return b + x 248 249 GLOBALS.export_onnx_opset_version = self.opset_version 250 GLOBALS.operator_export_type = OperatorExportTypes.ONNX 251 x = torch.ones(3, 2) 252 graph, _, __ = self._model_to_graph( 253 TransposeModule(), (x,), input_names=["x"], dynamic_axes={"x": [0, 1]} 254 ) 255 256 for node in graph.nodes(): 257 self.assertNotEqual(node.kind(), "onnx::Transpose") 258 self.assertNotEqual(node.kind(), "onnx::Cast") 259 self.assertEqual(len(list(graph.nodes())), 2) 260 261 @skipIfUnsupportedMaxOpsetVersion(17) 262 def test_constant_fold_reduceL2(self): 263 class ReduceModule(torch.nn.Module): 264 def forward(self, x): 265 a = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) 266 b = torch.norm(a, p=2, dim=-2, keepdim=False) 267 return b + x 268 269 GLOBALS.export_onnx_opset_version = self.opset_version 270 GLOBALS.operator_export_type = OperatorExportTypes.ONNX 271 x = torch.ones(2, 3) 272 graph, _, __ = self._model_to_graph( 273 ReduceModule(), (x,), input_names=["x"], dynamic_axes={"x": [0, 1]} 274 ) 275 276 for node in graph.nodes(): 277 self.assertNotEqual(node.kind(), "onnx::ReduceL2") 278 279 @skipIfUnsupportedMaxOpsetVersion(17) 280 def test_constant_fold_reduceL1(self): 281 class NormModule(torch.nn.Module): 282 def forward(self, x): 283 a = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) 284 b = torch.norm(a, p=1, dim=-2) 285 return b + x 286 287 GLOBALS.export_onnx_opset_version = self.opset_version 288 GLOBALS.operator_export_type = OperatorExportTypes.ONNX 289 x = torch.ones(2, 3) 290 graph, _, __ = self._model_to_graph( 291 NormModule(), (x,), input_names=["x"], dynamic_axes={"x": [0, 1]} 292 ) 293 294 for node in graph.nodes(): 295 self.assertNotEqual(node.kind(), "onnx::ReduceL1") 296 297 def test_constant_fold_slice(self): 298 class NarrowModule(torch.nn.Module): 299 def forward(self, x): 300 a = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) 301 b = torch.narrow(a, 0, 0, 1) 302 return b + x 303 304 GLOBALS.export_onnx_opset_version = self.opset_version 305 GLOBALS.operator_export_type = OperatorExportTypes.ONNX 306 x = torch.ones(1, 3) 307 graph, _, __ = self._model_to_graph( 308 NarrowModule(), (x,), input_names=["x"], dynamic_axes={"x": [0, 1]} 309 ) 310 311 for node in graph.nodes(): 312 self.assertNotEqual(node.kind(), "onnx::Slice") 313 self.assertNotEqual(node.kind(), "onnx::Cast") 314 self.assertEqual(len(list(graph.nodes())), 2) 315 316 def test_constant_fold_slice_index_exceeds_dim(self): 317 class SliceIndexExceedsDimModule(torch.nn.Module): 318 def forward(self, x): 319 a = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) 320 b = a[1:10] # index exceeds dimension 321 return b + x 322 323 GLOBALS.export_onnx_opset_version = self.opset_version 324 GLOBALS.operator_export_type = OperatorExportTypes.ONNX 325 x = torch.ones(1, 3) 326 graph, _, __ = self._model_to_graph( 327 SliceIndexExceedsDimModule(), 328 (x,), 329 input_names=["x"], 330 dynamic_axes={"x": [0, 1]}, 331 ) 332 333 for node in graph.nodes(): 334 self.assertNotEqual(node.kind(), "onnx::Slice") 335 self.assertNotEqual(node.kind(), "onnx::Cast") 336 self.assertEqual(len(list(graph.nodes())), 2) 337 338 def test_constant_fold_slice_negative_index(self): 339 class SliceNegativeIndexModule(torch.nn.Module): 340 def forward(self, x): 341 a = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) 342 b = a[0:-1] # index relative to the end 343 c = torch.select(a, dim=-1, index=-2) 344 d = torch.select(a, dim=1, index=0) 345 return b + x, c + d 346 347 GLOBALS.export_onnx_opset_version = self.opset_version 348 GLOBALS.operator_export_type = OperatorExportTypes.ONNX 349 x = torch.ones(1, 3) 350 graph, _, __ = self._model_to_graph( 351 SliceNegativeIndexModule(), 352 (x,), 353 input_names=["x"], 354 dynamic_axes={"x": [0, 1]}, 355 ) 356 357 for node in graph.nodes(): 358 self.assertNotEqual(node.kind(), "onnx::Slice") 359 self.assertNotEqual(node.kind(), "onnx::Cast") 360 361 def test_constant_fold_gather(self): 362 class GatherModule(torch.nn.Module): 363 def forward(self, x): 364 a = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) 365 b = torch.select(a, dim=1, index=-2) 366 c = torch.index_select(a, dim=-2, index=torch.tensor([0, 1])) 367 return b + 1, c + x 368 369 GLOBALS.export_onnx_opset_version = self.opset_version 370 GLOBALS.operator_export_type = OperatorExportTypes.ONNX 371 x = torch.ones(1, 3) 372 model = GatherModule() 373 model(x) 374 graph, _, __ = self._model_to_graph( 375 GatherModule(), (x,), input_names=["x"], dynamic_axes={"x": [0, 1]} 376 ) 377 378 for node in graph.nodes(): 379 self.assertNotEqual(node.kind(), "onnx::Gather") 380 381 def test_constant_fold_unsqueeze(self): 382 class UnsqueezeModule(torch.nn.Module): 383 def forward(self, x): 384 a = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) 385 b = torch.unsqueeze(a, -2) 386 return b + x 387 388 GLOBALS.export_onnx_opset_version = self.opset_version 389 GLOBALS.operator_export_type = OperatorExportTypes.ONNX 390 x = torch.ones(1, 2, 3) 391 graph, _, __ = self._model_to_graph( 392 UnsqueezeModule(), (x,), input_names=["x"], dynamic_axes={"x": [0, 1, 2]} 393 ) 394 395 for node in graph.nodes(): 396 self.assertNotEqual(node.kind(), "onnx::Unsqueeze") 397 self.assertNotEqual(node.kind(), "onnx::Cast") 398 self.assertEqual(len(list(graph.nodes())), 2) 399 400 def test_constant_fold_unsqueeze_multi_axies(self): 401 class PReluModel(torch.nn.Module): 402 def __init__(self) -> None: 403 super().__init__() 404 self.prelu = torch.nn.PReLU() 405 406 def forward(self, x): 407 a = torch.randn(2, 3, 4, 5, 8, 7) 408 return self.prelu(x) + a 409 410 GLOBALS.export_onnx_opset_version = self.opset_version 411 GLOBALS.operator_export_type = OperatorExportTypes.ONNX 412 x = torch.randn(2, 3, 4, 5, 8, 7) 413 graph, _, __ = self._model_to_graph( 414 PReluModel(), x, input_names=["x"], dynamic_axes={"x": [0, 1, 2, 3, 4, 5]} 415 ) 416 417 for node in graph.nodes(): 418 self.assertNotEqual(node.kind(), "onnx::Unsqueeze") 419 self.assertNotEqual(node.kind(), "onnx::Cast") 420 self.assertEqual(len(list(graph.nodes())), 5) 421 422 def test_constant_fold_squeeze_without_axes(self): 423 class SqueezeModule(torch.nn.Module): 424 def forward(self, x): 425 a = torch.tensor([[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]]) 426 return torch.squeeze(a) + x + torch.squeeze(a) 427 428 GLOBALS.export_onnx_opset_version = self.opset_version 429 GLOBALS.operator_export_type = OperatorExportTypes.ONNX 430 x = torch.ones(2, 3) 431 graph, _, __ = self._model_to_graph( 432 SqueezeModule(), (x,), input_names=["x"], dynamic_axes={"x": [0, 1]} 433 ) 434 for node in graph.nodes(): 435 self.assertNotEqual(node.kind(), "onnx::Squeeze") 436 self.assertNotEqual(node.kind(), "onnx::Cast") 437 self.assertEqual(len(list(graph.nodes())), 4) 438 439 def test_constant_fold_squeeze_with_axes(self): 440 class SqueezeAxesModule(torch.nn.Module): 441 def forward(self, x): 442 a = torch.tensor([[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]]) 443 return torch.squeeze(a, dim=-3) + x 444 445 GLOBALS.export_onnx_opset_version = self.opset_version 446 GLOBALS.operator_export_type = OperatorExportTypes.ONNX 447 x = torch.ones(2, 3) 448 graph, _, __ = self._model_to_graph( 449 SqueezeAxesModule(), (x,), input_names=["x"], dynamic_axes={"x": [0, 1]} 450 ) 451 452 for node in graph.nodes(): 453 self.assertNotEqual(node.kind(), "onnx::Squeeze") 454 self.assertNotEqual(node.kind(), "onnx::Cast") 455 self.assertEqual(len(list(graph.nodes())), 2) 456 457 def test_constant_fold_concat(self): 458 class ConcatModule(torch.nn.Module): 459 def forward(self, x): 460 # Why did I insert a Cast here? There appears to be intentional 461 # behavior in ONNX constant folding where constant tensors which 462 # are not attached to any known to be foldable onnx 463 # operations don't get extracted into the initializer graph. So 464 # without these casts, we will actually fail to pull out one of 465 # the constants, thus failing constant folding. I think the 466 # test is wrong but I don't have time to write a more correct 467 # test (I think the right way to go about the test is to setup 468 # a predicate for what invariant graphs should hold after 469 # constant folding, and then verify this predicate holds. 470 # I think the asserts below are an attempt at this predicate, 471 # but it is not right!) 472 # 473 # More commentary at 474 # https://github.com/pytorch/pytorch/pull/18698/files#r340107552 475 a = torch.tensor([[1.0, 2.0, 3.0]]).to(torch.float) 476 b = torch.tensor([[4.0, 5.0, 6.0]]).to(torch.float) 477 c = torch.cat((a, b), 0) 478 d = b + c 479 return x + d 480 481 GLOBALS.export_onnx_opset_version = self.opset_version 482 GLOBALS.operator_export_type = OperatorExportTypes.ONNX 483 x = torch.ones(2, 3) 484 graph, _, __ = self._model_to_graph( 485 ConcatModule(), (x,), input_names=["x"], dynamic_axes={"x": [0, 1]} 486 ) 487 488 for node in graph.nodes(): 489 self.assertNotEqual(node.kind(), "onnx::Concat") 490 self.assertNotEqual(node.kind(), "onnx::Cast") 491 self.assertEqual(len(list(graph.nodes())), 2) 492 493 def test_constant_fold_lstm(self): 494 class GruNet(torch.nn.Module): 495 def __init__(self) -> None: 496 super().__init__() 497 self.mygru = torch.nn.GRU(7, 3, 1, bidirectional=False) 498 499 def forward(self, input, initial_state): 500 return self.mygru(input, initial_state) 501 502 GLOBALS.export_onnx_opset_version = self.opset_version 503 GLOBALS.operator_export_type = OperatorExportTypes.ONNX 504 input = torch.randn(5, 3, 7) 505 h0 = torch.randn(1, 3, 3) 506 graph, _, __ = self._model_to_graph( 507 GruNet(), 508 (input, h0), 509 input_names=["input", "h0"], 510 dynamic_axes={"input": [0, 1, 2], "h0": [0, 1, 2]}, 511 ) 512 513 for node in graph.nodes(): 514 self.assertNotEqual(node.kind(), "onnx::Slice") 515 self.assertNotEqual(node.kind(), "onnx::Concat") 516 self.assertNotEqual(node.kind(), "onnx::Unsqueeze") 517 518 if self.opset_version <= 12: 519 self.assertEqual(len(list(graph.nodes())), 3) 520 else: 521 # Unsqueeze op parameter "axes" as an input instead of as an attribute when opset version >= 13 522 self.assertEqual(len(list(graph.nodes())), 4) 523 524 def test_constant_fold_transpose_matmul(self): 525 class MatMulNet(torch.nn.Module): 526 def __init__(self) -> None: 527 super().__init__() 528 self.B = torch.nn.Parameter(torch.ones(5, 3)) 529 530 def forward(self, A): 531 return torch.matmul(A, torch.transpose(self.B, -1, -2)) 532 533 GLOBALS.export_onnx_opset_version = self.opset_version 534 GLOBALS.operator_export_type = OperatorExportTypes.ONNX 535 A = torch.randn(2, 3) 536 graph, _, __ = self._model_to_graph( 537 MatMulNet(), (A,), input_names=["A"], dynamic_axes={"A": [0, 1]} 538 ) 539 540 for node in graph.nodes(): 541 self.assertNotEqual(node.kind(), "onnx::Transpose") 542 self.assertEqual(len(list(graph.nodes())), 1) 543 544 def test_constant_fold_reshape(self): 545 class ReshapeModule(torch.nn.Module): 546 def __init__( 547 self, 548 ): 549 super().__init__() 550 self.weight = torch.nn.Buffer(torch.ones(5)) 551 552 def forward(self, x): 553 b = self.weight.reshape(1, -1, 1, 1) 554 return x * b 555 556 GLOBALS.export_onnx_opset_version = self.opset_version 557 GLOBALS.operator_export_type = OperatorExportTypes.ONNX 558 x = torch.randn(4, 5) 559 graph, _, __ = self._model_to_graph( 560 ReshapeModule(), (x,), input_names=["x"], dynamic_axes={"x": [0, 1]} 561 ) 562 563 for node in graph.nodes(): 564 self.assertNotEqual(node.kind(), "onnx::Reshape") 565 self.assertEqual(len(list(graph.nodes())), 1) 566 567 def test_constant_fold_div(self): 568 class Module(torch.nn.Module): 569 def __init__( 570 self, 571 ): 572 super().__init__() 573 self.weight = torch.nn.Buffer(torch.ones(5)) 574 575 def forward(self, x): 576 div = self.weight.div(torch.tensor([1, 2, 3, 4, 5])) 577 return div * x 578 579 x = torch.randn(2, 5) 580 GLOBALS.export_onnx_opset_version = self.opset_version 581 GLOBALS.operator_export_type = OperatorExportTypes.ONNX 582 graph, _, __ = self._model_to_graph( 583 Module(), (x,), input_names=["x"], dynamic_axes={"x": [0, 1]} 584 ) 585 586 for node in graph.nodes(): 587 self.assertNotEqual(node.kind(), "onnx::Div") 588 self.assertEqual(len(list(graph.nodes())), 1) 589 590 def test_constant_fold_mul(self): 591 class Module(torch.nn.Module): 592 def __init__( 593 self, 594 ): 595 super().__init__() 596 self.weight = torch.nn.Buffer(torch.ones(5)) 597 598 def forward(self, x): 599 mul = self.weight.mul(torch.tensor([1, 2, 3, 4, 5])) 600 return mul / x 601 602 x = torch.randn(2, 5) 603 GLOBALS.export_onnx_opset_version = self.opset_version 604 GLOBALS.operator_export_type = OperatorExportTypes.ONNX 605 graph, _, __ = self._model_to_graph( 606 Module(), (x,), input_names=["x"], dynamic_axes={"x": [0, 1]} 607 ) 608 609 for node in graph.nodes(): 610 self.assertNotEqual(node.kind(), "onnx::Mul") 611 self.assertEqual(len(list(graph.nodes())), 1) 612 613 def test_constant_fold_add(self): 614 class Module(torch.nn.Module): 615 def __init__( 616 self, 617 ): 618 super().__init__() 619 self.weight = torch.nn.Buffer(torch.ones(5)) 620 621 def forward(self, x): 622 add = self.weight + torch.tensor([1, 2, 3, 4, 5]) 623 return add - x 624 625 x = torch.randn(2, 5) 626 GLOBALS.export_onnx_opset_version = self.opset_version 627 GLOBALS.operator_export_type = OperatorExportTypes.ONNX 628 graph, params_dict, __ = self._model_to_graph( 629 Module(), 630 (x,), 631 do_constant_folding=True, 632 operator_export_type=OperatorExportTypes.ONNX, 633 input_names=["x"], 634 dynamic_axes={"x": [0, 1]}, 635 ) 636 for node in graph.nodes(): 637 self.assertTrue(node.kind() != "onnx::Add") 638 self.assertEqual(len(list(graph.nodes())), 1) 639 params = list(params_dict.values()) 640 self.assertEqual(len(params), 1) 641 weight = params[0] 642 self.assertEqual(weight, torch.tensor([2.0, 3.0, 4.0, 5.0, 6.0])) 643 644 def test_constant_fold_sub(self): 645 class Module(torch.nn.Module): 646 def __init__( 647 self, 648 ): 649 super().__init__() 650 self.weight = torch.nn.Buffer(torch.ones(5)) 651 652 def forward(self, x): 653 sub = self.weight - torch.tensor([1, 2, 3, 4, 5]) 654 return sub + x 655 656 x = torch.randn(2, 5) 657 GLOBALS.export_onnx_opset_version = self.opset_version 658 GLOBALS.operator_export_type = OperatorExportTypes.ONNX 659 graph, params_dict, __ = self._model_to_graph( 660 Module(), 661 (x,), 662 do_constant_folding=True, 663 operator_export_type=OperatorExportTypes.ONNX, 664 input_names=["x"], 665 dynamic_axes={"x": [0, 1]}, 666 ) 667 for node in graph.nodes(): 668 self.assertNotEqual(node.kind(), "onnx::Sub") 669 self.assertEqual(len(list(graph.nodes())), 1) 670 params = list(params_dict.values()) 671 self.assertEqual(len(params), 1) 672 weight = params[0] 673 self.assertEqual(weight, torch.tensor([0.0, -1.0, -2.0, -3.0, -4.0])) 674 675 def test_constant_fold_sqrt(self): 676 class Module(torch.nn.Module): 677 def __init__( 678 self, 679 ): 680 super().__init__() 681 self.weight = torch.nn.Buffer(torch.ones(5)) 682 683 def forward(self, x): 684 sqrt = torch.sqrt(self.weight) 685 return sqrt / x 686 687 x = torch.randn(2, 5) 688 GLOBALS.export_onnx_opset_version = self.opset_version 689 GLOBALS.operator_export_type = OperatorExportTypes.ONNX 690 graph, _, __ = self._model_to_graph( 691 Module(), (x,), input_names=["x"], dynamic_axes={"x": [0, 1]} 692 ) 693 for node in graph.nodes(): 694 self.assertNotEqual(node.kind(), "onnx::Sqrt") 695 self.assertEqual(len(list(graph.nodes())), 1) 696 697 def test_constant_fold_shape(self): 698 class ShapeModule(torch.nn.Module): 699 def __init__(self) -> None: 700 super().__init__() 701 self.weight = torch.nn.Buffer(torch.ones(5)) 702 703 def forward(self, x): 704 shape = self.weight.shape[0] 705 return x + shape 706 707 x = torch.randn(2, 5) 708 GLOBALS.export_onnx_opset_version = self.opset_version 709 GLOBALS.operator_export_type = OperatorExportTypes.ONNX 710 graph, _, __ = self._model_to_graph( 711 ShapeModule(), (x,), input_names=["x"], dynamic_axes={"x": [0, 1]} 712 ) 713 for node in graph.nodes(): 714 self.assertNotEqual(node.kind(), "onnx::Shape") 715 self.assertEqual(len(list(graph.nodes())), 2) 716 717 def test_constant_fold_upsample_scale_fold_as_constant(self): 718 # upsample scale is a constant, not a model parameter, 719 # therefore should not be added as initializer after constant folding. 720 model = torch.nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True) 721 x = torch.randn(1, 32, 224, 224) 722 f = io.BytesIO() 723 torch.onnx.export(model, x, f) 724 onnx_model = onnx.load(io.BytesIO(f.getvalue())) 725 self.assertEqual(len(onnx_model.graph.initializer), 0) 726 727 def test_verbose(self): 728 class MyModule(torch.nn.Module): 729 def forward(self, input): 730 return torch.exp(input) 731 732 x = torch.randn(3, 4) 733 734 def is_model_stripped(f, verbose=None): 735 if verbose is None: 736 torch.onnx.export(MyModule(), x, f, opset_version=self.opset_version) 737 else: 738 torch.onnx.export( 739 MyModule(), x, f, verbose=verbose, opset_version=self.opset_version 740 ) 741 model = onnx.load(io.BytesIO(f.getvalue())) 742 model_strip = copy.copy(model) 743 onnx.helper.strip_doc_string(model_strip) 744 return model == model_strip 745 746 # test verbose=False (default) 747 self.assertTrue(is_model_stripped(io.BytesIO())) 748 # test verbose=True 749 self.assertFalse(is_model_stripped(io.BytesIO(), True)) 750 751 # NB: remove this test once DataParallel can be correctly handled 752 def test_error_on_data_parallel(self): 753 model = torch.nn.DataParallel(torch.nn.ReflectionPad2d((1, 2, 3, 4))) 754 x = torch.randn(1, 2, 3, 4) 755 f = io.BytesIO() 756 with self.assertRaisesRegex( 757 ValueError, 758 "torch.nn.DataParallel is not supported by ONNX " 759 "exporter, please use 'attribute' module to " 760 "unwrap model from torch.nn.DataParallel. Try ", 761 ): 762 torch.onnx.export(model, x, f, opset_version=self.opset_version) 763 764 @skipIfUnsupportedMinOpsetVersion(11) 765 def test_sequence_dim(self): 766 class Module(torch.nn.Module): 767 def forward(self, x, y): 768 return [x, y] 769 770 model = Module() 771 # Export with scripting to keep output as Sequence type. 772 # Tracing unpacks the list. 773 script_model = torch.jit.script(model) 774 x = torch.randn(2, 3) 775 776 # Case 1: dynamic axis 777 f = io.BytesIO() 778 y = torch.randn(2, 3) 779 torch.onnx.export( 780 script_model, 781 (x, y), 782 f, 783 opset_version=self.opset_version, 784 input_names=["x", "y"], 785 dynamic_axes={"y": [1]}, 786 ) 787 onnx_model = onnx.load(io.BytesIO(f.getvalue())) 788 loop_output_value_info_proto = onnx_model.graph.output[0] 789 ref_value_info_proto = onnx.helper.make_tensor_sequence_value_info( 790 loop_output_value_info_proto.name, 1, [2, None] 791 ) 792 self.assertEqual(loop_output_value_info_proto, ref_value_info_proto) 793 794 # Case 2: no dynamic axes. 795 f = io.BytesIO() 796 y = torch.randn(2, 3) 797 torch.onnx.export(script_model, (x, y), f, opset_version=self.opset_version) 798 onnx_model = onnx.load(io.BytesIO(f.getvalue())) 799 loop_output_value_info_proto = onnx_model.graph.output[0] 800 ref_value_info_proto = onnx.helper.make_tensor_sequence_value_info( 801 loop_output_value_info_proto.name, 1, [2, 3] 802 ) 803 self.assertEqual(loop_output_value_info_proto, ref_value_info_proto) 804 805 def test_export_mode(self): 806 class MyModule(torch.nn.Module): 807 def forward(self, x): 808 y = x + 1 809 return y 810 811 model = MyModule() 812 x = torch.randn(10, 3, 128, 128) 813 f = io.BytesIO() 814 815 # set mode to in inference mode and export in training mode 816 model.eval() 817 old_state = model.training 818 torch.onnx.export( 819 model, 820 (x,), 821 f, 822 opset_version=self.opset_version, 823 training=torch.onnx.TrainingMode.TRAINING, 824 ) 825 # verify that the model state is preserved 826 self.assertEqual(model.training, old_state) 827 828 # set mode to training mode and export in inference mode 829 model.train() 830 old_state = model.training 831 torch.onnx.export( 832 model, 833 (x,), 834 f, 835 opset_version=self.opset_version, 836 training=torch.onnx.TrainingMode.EVAL, 837 ) 838 # verify that the model state is preserved 839 self.assertEqual(model.training, old_state) 840 841 def test_export_does_not_fail_on_frozen_scripted_module(self): 842 class Inner(torch.nn.Module): 843 def forward(self, x): 844 if x > 0: 845 return x 846 else: 847 return x * x 848 849 class Outer(torch.nn.Module): 850 def __init__(self) -> None: 851 super().__init__() 852 self.inner = torch.jit.script(Inner()) 853 854 def forward(self, x): 855 return self.inner(x) 856 857 x = torch.zeros(1) 858 # Freezing is only implemented in eval mode. So we need to call eval() 859 outer_module = Outer().eval() 860 module = torch.jit.trace_module(outer_module, {"forward": (x)}) 861 # jit.freeze removes the training attribute in the module 862 module = torch.jit.freeze(module) 863 864 torch.onnx.export(module, (x,), io.BytesIO(), opset_version=self.opset_version) 865 866 @skipIfUnsupportedMinOpsetVersion(15) 867 def test_local_function(self): 868 class N(torch.nn.Module): 869 def __init__(self, prob): 870 super().__init__() 871 self.dropout = torch.nn.Dropout(prob) 872 873 def forward(self, x): 874 return self.dropout(x) 875 876 class M(torch.nn.Module): 877 def __init__(self, num_layers): 878 super().__init__() 879 self.num_layers = num_layers 880 self.lns = torch.nn.ModuleList( 881 [torch.nn.LayerNorm(3, eps=i) for i in range(num_layers)] 882 ) 883 self.celu1 = torch.nn.CELU(1.0) 884 self.celu2 = torch.nn.CELU(2.0) 885 self.dropout = N(0.5) 886 887 def forward(self, x, y, z): 888 res1 = self.celu1(x) 889 res2 = self.celu2(y) 890 for ln in self.lns: 891 z = ln(z) 892 return res1 + res2, self.dropout(z) 893 894 x = torch.randn(2, 3) 895 y = torch.randn(2, 3) 896 z = torch.randn(2, 3) 897 898 # Export specified modules. Test against specifying modules that won't 899 # exist in the exported model. 900 # Model export in inference mode will remove dropout node, 901 # thus the dropout module no longer exist in graph. 902 f = io.BytesIO() 903 torch.onnx.export( 904 M(3), 905 (x, y, z), 906 f, 907 opset_version=self.opset_version, 908 export_modules_as_functions={ 909 torch.nn.CELU, 910 torch.nn.Dropout, 911 torch.nn.LayerNorm, 912 }, 913 ) 914 915 onnx_model = onnx.load(io.BytesIO(f.getvalue())) 916 917 # Check function definition 918 funcs = onnx_model.functions 919 celu_funcs = [f for f in funcs if f.name == "CELU"] 920 self.assertEqual(len(celu_funcs), 1) 921 self.assertEqual(celu_funcs[0].domain, "torch.nn.modules.activation") 922 self.assertEqual(len(celu_funcs[0].attribute), 3) 923 ln_funcs = [f for f in funcs if f.name == "LayerNorm"] 924 self.assertEqual(len(ln_funcs), 1) 925 self.assertEqual(ln_funcs[0].domain, "torch.nn.modules.normalization") 926 self.assertEqual(len(ln_funcs[0].attribute), 3) 927 928 # Check local function nodes 929 nodes = onnx_model.graph.node 930 celu_ns = [n for n in nodes if n.op_type == "CELU"] 931 ln_ns = [n for n in nodes if n.op_type == "LayerNorm"] 932 self.assertEqual(len(celu_ns), 2) 933 self.assertEqual(celu_ns[0].domain, "torch.nn.modules.activation") 934 self.assertEqual(len(celu_ns[0].attribute), 3) 935 self.assertEqual(len(ln_ns), 3) 936 self.assertEqual(ln_ns[0].domain, "torch.nn.modules.normalization") 937 self.assertEqual(len(ln_ns[0].attribute), 3) 938 939 # Export specified modules. 940 f = io.BytesIO() 941 torch.onnx.export( 942 M(3), 943 (x, y, z), 944 f, 945 opset_version=self.opset_version, 946 export_modules_as_functions={torch.nn.CELU}, 947 ) 948 949 onnx_model = onnx.load(io.BytesIO(f.getvalue())) 950 funcs = onnx_model.functions 951 self.assertEqual(len(funcs), 1) 952 self.assertEqual(funcs[0].name, "CELU") 953 954 # Export with empty specified modules. Normal export. 955 f = io.BytesIO() 956 torch.onnx.export( 957 M(3), 958 (x, y, z), 959 f, 960 opset_version=self.opset_version, 961 export_modules_as_functions=set(), 962 ) 963 964 onnx_model = onnx.load(io.BytesIO(f.getvalue())) 965 funcs = onnx_model.functions 966 self.assertEqual(len(funcs), 0) 967 968 # Export all modules. Should contain {M, CELU, LayerNorm}. 969 f = io.BytesIO() 970 torch.onnx.export( 971 M(3), 972 (x, y, z), 973 f, 974 opset_version=self.opset_version, 975 export_modules_as_functions=True, 976 ) 977 978 onnx_model = onnx.load(io.BytesIO(f.getvalue())) 979 funcs = onnx_model.functions 980 self.assertEqual(len(funcs), 3) 981 982 @skipIfUnsupportedMinOpsetVersion(15) 983 def test_local_function_overloads(self): 984 class NWithOverloads(torch.nn.Module): 985 def forward(self, x, y=None, z=None): 986 if y is None: 987 return x + 1 988 elif z is None: 989 return x + y 990 else: 991 return x + y, x + z 992 993 class M(torch.nn.Module): 994 def __init__(self, num_layers): 995 super().__init__() 996 self.n = NWithOverloads() 997 998 def forward(self, x, y, z): 999 return self.n(x), self.n(x, y), self.n(x, y, z) 1000 1001 x = torch.randn(2, 3) 1002 y = torch.randn(2, 3) 1003 z = torch.randn(2, 3) 1004 1005 f = io.BytesIO() 1006 torch.onnx.export( 1007 M(3), 1008 (x, y, z), 1009 f, 1010 opset_version=self.opset_version, 1011 export_modules_as_functions={NWithOverloads}, 1012 ) 1013 1014 onnx_model = onnx.load(io.BytesIO(f.getvalue())) 1015 funcs = onnx_model.functions 1016 self.assertEqual(len(funcs), 3) 1017 func_names = [f.name for f in funcs] 1018 self.assertIn("NWithOverloads", func_names) 1019 self.assertIn("NWithOverloads.1", func_names) 1020 self.assertIn("NWithOverloads.2", func_names) 1021 1022 # Failing after ONNX 1.13.0 1023 @skipIfUnsupportedMaxOpsetVersion(1) 1024 def test_local_function_infer_scopes(self): 1025 class M(torch.nn.Module): 1026 def forward(self, x): 1027 # Concatenation of scalars inserts unscoped tensors in IR graph. 1028 new_tensor_shape = x.size()[:-1] + (1, 1, -1) 1029 tensor = x.view(*new_tensor_shape) 1030 return tensor 1031 1032 x = torch.randn(4, 5) 1033 f = io.BytesIO() 1034 torch.onnx.export( 1035 M(), 1036 (x,), 1037 f, 1038 export_modules_as_functions=True, 1039 opset_version=self.opset_version, 1040 do_constant_folding=False, 1041 ) 1042 1043 onnx_model = onnx.load(io.BytesIO(f.getvalue())) 1044 funcs = onnx_model.functions 1045 self.assertIn("M", [f.name for f in funcs]) 1046 1047 @skipIfUnsupportedMinOpsetVersion(15) 1048 def test_local_function_predefined_attributes(self): 1049 class M(torch.nn.Module): 1050 num_layers: int 1051 1052 def __init__(self, num_layers): 1053 super().__init__() 1054 self.num_layers = num_layers 1055 self.lns = torch.nn.ModuleList( 1056 [torch.nn.LayerNorm(3, eps=1e-4) for _ in range(num_layers)] 1057 ) 1058 1059 def forward(self, x): 1060 for ln in self.lns: 1061 x = ln(x) 1062 return x 1063 1064 x = torch.randn(2, 3) 1065 f = io.BytesIO() 1066 model = M(3) 1067 torch.onnx.export( 1068 model, 1069 (x,), 1070 f, 1071 export_modules_as_functions=True, 1072 opset_version=self.opset_version, 1073 ) 1074 1075 onnx_model = onnx.load(io.BytesIO(f.getvalue())) 1076 funcs = onnx_model.functions 1077 m_funcs = [fn for fn in funcs if fn.name == "M"] 1078 self.assertEqual(m_funcs[0].attribute, ["num_layers"]) 1079 ln_funcs = [fn for fn in funcs if fn.name == "LayerNorm"] 1080 self.assertEqual(ln_funcs[0].attribute, ["eps", "elementwise_affine"]) 1081 1082 from onnx import helper 1083 1084 m_node = [n for n in onnx_model.graph.node if n.op_type == "M"] 1085 self.assertEqual( 1086 m_node[0].attribute[0], 1087 helper.make_attribute("num_layers", model.num_layers), 1088 ) 1089 1090 ln_nodes = [n for n in m_funcs[0].node if n.op_type == "LayerNorm"] 1091 expected_ln_attrs = [ 1092 helper.make_attribute( 1093 "elementwise_affine", model.lns[0].elementwise_affine 1094 ), 1095 helper.make_attribute("eps", model.lns[0].eps), 1096 ] 1097 for ln_node in ln_nodes: 1098 self.assertIn(ln_node.attribute[0], expected_ln_attrs) 1099 self.assertIn(ln_node.attribute[1], expected_ln_attrs) 1100 1101 # This test cases checks the issue where an object does not have an attribute. 1102 # When enabling `export_modules_as_functions = True`, the exporter could return an 1103 # AttributeError. With this test case, we check that the export passes successfully 1104 # without any AttributeError exceptions. 1105 # See https://github.com/pytorch/pytorch/pull/109759 for an example. The exception that 1106 # this test tries to avoid is `AttributeError: 'Embedding' object has no attribute 'freeze'`. 1107 @skipIfUnsupportedMinOpsetVersion(15) 1108 def test_local_function_subset_of_predefined_attributes(self): 1109 class M(torch.nn.Module): 1110 num_layers: int 1111 1112 def __init__(self, num_layers): 1113 super().__init__() 1114 self.embed_layer = torch.nn.Embedding.from_pretrained( 1115 torch.FloatTensor([[1, 2.3, 3], [4, 5.1, 6.3]]) 1116 ) 1117 self.num_layers = num_layers 1118 self.lns = torch.nn.ModuleList( 1119 [torch.nn.LayerNorm(3, eps=1e-4) for _ in range(num_layers)] 1120 ) 1121 1122 def forward(self, x): 1123 e = self.embed_layer(torch.LongTensor([1])) 1124 for ln in self.lns: 1125 x = ln(x) 1126 return x, e 1127 1128 x = torch.randn(2, 3) 1129 f = io.BytesIO() 1130 model = M(3) 1131 torch.onnx.export( 1132 model, 1133 (x,), 1134 f, 1135 export_modules_as_functions=True, 1136 opset_version=self.opset_version, 1137 verbose=True, # Allows the test case to print `Skipping module attribute 'freeze'` 1138 ) 1139 1140 def test_node_scope(self): 1141 class N(torch.nn.Module): 1142 def __init__(self) -> None: 1143 super().__init__() 1144 self.relu = torch.nn.ReLU() 1145 1146 def forward(self, x): 1147 return self.relu(x) 1148 1149 class M(torch.nn.Module): 1150 def __init__(self, num_layers): 1151 super().__init__() 1152 self.num_layers = num_layers 1153 self.lns = torch.nn.ModuleList( 1154 [torch.nn.LayerNorm(3, eps=float(i)) for i in range(num_layers)] 1155 ) 1156 self.gelu1 = torch.nn.GELU() 1157 self.gelu2 = torch.nn.GELU() 1158 self.relu = N() 1159 1160 def forward(self, x, y, z): 1161 res1 = self.gelu1(x) 1162 res2 = self.gelu2(y) 1163 for ln in self.lns: 1164 z = ln(z) 1165 return res1 + res2, self.relu(z) 1166 1167 x = torch.randn(2, 3) 1168 y = torch.randn(2, 3) 1169 z = torch.randn(2, 3) 1170 1171 model = M(3) 1172 expected_scope_names = { 1173 "M::/torch.nn.modules.activation.GELU::gelu1", 1174 "M::/torch.nn.modules.activation.GELU::gelu2", 1175 "M::/torch.nn.modules.normalization.LayerNorm::lns.0", 1176 "M::/torch.nn.modules.normalization.LayerNorm::lns.1", 1177 "M::/torch.nn.modules.normalization.LayerNorm::lns.2", 1178 "M::/N::relu/torch.nn.modules.activation.ReLU::relu", 1179 "M::", 1180 } 1181 1182 graph, _, _ = self._model_to_graph( 1183 model, (x, y, z), input_names=[], dynamic_axes={} 1184 ) 1185 for node in graph.nodes(): 1186 self.assertIn( 1187 _remove_test_environment_prefix_from_scope_name(node.scopeName()), 1188 expected_scope_names, 1189 ) 1190 1191 graph, _, _ = self._model_to_graph( 1192 torch.jit.script(model), (x, y, z), input_names=[], dynamic_axes={} 1193 ) 1194 for node in graph.nodes(): 1195 self.assertIn( 1196 _remove_test_environment_prefix_from_scope_name(node.scopeName()), 1197 expected_scope_names, 1198 ) 1199 1200 def test_scope_of_constants_when_combined_by_cse_pass(self): 1201 layer_num = 3 1202 1203 class M(torch.nn.Module): 1204 def __init__(self, constant): 1205 super().__init__() 1206 self.constant = constant 1207 1208 def forward(self, x): 1209 # 'self.constant' is designed to be the same for all layers, 1210 # hence it is common sub expression. 1211 return x + self.constant 1212 1213 class N(torch.nn.Module): 1214 def __init__(self, layers: int = layer_num): 1215 super().__init__() 1216 self.layers = torch.nn.ModuleList( 1217 [M(constant=torch.tensor(1.0)) for i in range(layers)] 1218 ) 1219 1220 def forward(self, x): 1221 for layer in self.layers: 1222 x = layer(x) 1223 return x 1224 1225 graph, _, _ = self._model_to_graph( 1226 N(), (torch.randn(2, 3)), input_names=[], dynamic_axes={} 1227 ) 1228 1229 # NOTE: Duplicated constants are populated due to implicit casting in scalar_type_analysis, 1230 # so we expect 3 constants with different scopes. The 3 constants are for the 3 layers. 1231 # If CSE in exporter is improved later, this test needs to be updated. 1232 # It should expect 1 constant, with same scope as root. 1233 expected_root_scope_name = "N::" 1234 expected_layer_scope_name = "M::layers" 1235 expected_constant_scope_name = [ 1236 f"{expected_root_scope_name}/{expected_layer_scope_name}.{i}" 1237 for i in range(layer_num) 1238 ] 1239 1240 constant_scope_names = [] 1241 for node in graph.nodes(): 1242 if node.kind() == "onnx::Constant": 1243 constant_scope_names.append( 1244 _remove_test_environment_prefix_from_scope_name(node.scopeName()) 1245 ) 1246 self.assertEqual(constant_scope_names, expected_constant_scope_name) 1247 1248 def test_scope_of_nodes_when_combined_by_cse_pass(self): 1249 layer_num = 3 1250 1251 class M(torch.nn.Module): 1252 def __init__(self, constant, bias): 1253 super().__init__() 1254 self.constant = constant 1255 self.bias = bias 1256 1257 def forward(self, x): 1258 # 'constant' and 'x' is designed to be the same for all layers, 1259 # hence `x + self.constant` is common sub expression. 1260 # 'bias' is designed to be different for all layers, 1261 # hence `* self.bias` is not common sub expression. 1262 return (x + self.constant) * self.bias 1263 1264 class N(torch.nn.Module): 1265 def __init__(self, layers: int = layer_num): 1266 super().__init__() 1267 1268 self.layers = torch.nn.ModuleList( 1269 [ 1270 M(constant=torch.tensor([1.0]), bias=torch.randn(1)) 1271 for i in range(layers) 1272 ] 1273 ) 1274 1275 def forward(self, x): 1276 y = [] 1277 for layer in self.layers: 1278 y.append(layer(x)) 1279 return y[0], y[1], y[2] 1280 1281 graph, _, _ = self._model_to_graph( 1282 N(), (torch.randn(2, 3)), input_names=[], dynamic_axes={} 1283 ) 1284 expected_root_scope_name = "N::" 1285 expected_layer_scope_name = "M::layers" 1286 expected_add_scope_names = [ 1287 f"{expected_root_scope_name}/{expected_layer_scope_name}.0" 1288 ] 1289 expected_mul_scope_names = [ 1290 f"{expected_root_scope_name}/{expected_layer_scope_name}.{i}" 1291 for i in range(layer_num) 1292 ] 1293 1294 add_scope_names = [] 1295 mul_scope_names = [] 1296 for node in graph.nodes(): 1297 if node.kind() == "onnx::Add": 1298 add_scope_names.append( 1299 _remove_test_environment_prefix_from_scope_name(node.scopeName()) 1300 ) 1301 elif node.kind() == "onnx::Mul": 1302 mul_scope_names.append( 1303 _remove_test_environment_prefix_from_scope_name(node.scopeName()) 1304 ) 1305 self.assertEqual(add_scope_names, expected_add_scope_names) 1306 self.assertEqual(mul_scope_names, expected_mul_scope_names) 1307 1308 def test_aten_fallthrough(self): 1309 # Test aten export of op with no symbolic 1310 class Module(torch.nn.Module): 1311 def forward(self, x): 1312 return torch.erfc(x) 1313 1314 x = torch.randn(2, 3, 4) 1315 GLOBALS.export_onnx_opset_version = self.opset_version 1316 graph, _, __ = self._model_to_graph( 1317 Module(), 1318 (x,), 1319 operator_export_type=OperatorExportTypes.ONNX_FALLTHROUGH, 1320 input_names=["x"], 1321 dynamic_axes={"x": [0, 1, 2]}, 1322 ) 1323 iter = graph.nodes() 1324 self.assertEqual(next(iter).kind(), "aten::erfc") 1325 1326 def test_custom_op_fallthrough(self): 1327 # Test custom op 1328 op_source = """ 1329 #include <torch/script.h> 1330 1331 torch::Tensor custom_add(torch::Tensor self, torch::Tensor other) { 1332 return self + other; 1333 } 1334 1335 static auto registry = 1336 torch::RegisterOperators("custom_namespace::custom_op", &custom_add); 1337 """ 1338 1339 torch.utils.cpp_extension.load_inline( 1340 name="custom_add", 1341 cpp_sources=op_source, 1342 is_python_module=False, 1343 verbose=True, 1344 ) 1345 1346 class FooModel(torch.nn.Module): 1347 def forward(self, input, other): 1348 # Calling custom op 1349 return torch.ops.custom_namespace.custom_op(input, other) 1350 1351 x = torch.randn(2, 3, 4, requires_grad=False) 1352 y = torch.randn(2, 3, 4, requires_grad=False) 1353 model = FooModel() 1354 graph, _, __ = self._model_to_graph( 1355 model, 1356 (x, y), 1357 operator_export_type=torch.onnx.OperatorExportTypes.ONNX_FALLTHROUGH, 1358 input_names=["x", "y"], 1359 dynamic_axes={"x": [0, 1, 2], "y": [0, 1, 2]}, 1360 ) 1361 iter = graph.nodes() 1362 self.assertEqual(next(iter).kind(), "custom_namespace::custom_op") 1363 1364 # gelu is exported as onnx::Gelu for opset >= 20 1365 @skipIfUnsupportedMaxOpsetVersion(19) 1366 def test_custom_opsets_gelu(self): 1367 self.addCleanup(torch.onnx.unregister_custom_op_symbolic, "::gelu", 9) 1368 1369 def gelu(g, self, approximate): 1370 return g.op("com.microsoft::Gelu", self).setType(self.type()) 1371 1372 torch.onnx.register_custom_op_symbolic("::gelu", gelu, 9) 1373 model = torch.nn.GELU(approximate="none") 1374 x = torch.randn(3, 3) 1375 f = io.BytesIO() 1376 torch.onnx.export( 1377 model, 1378 (x,), 1379 f, 1380 opset_version=self.opset_version, 1381 custom_opsets={"com.microsoft": 1}, 1382 ) 1383 1384 graph = onnx.load(io.BytesIO(f.getvalue())) 1385 self.assertEqual(graph.graph.node[0].op_type, "Gelu") 1386 self.assertEqual(graph.opset_import[0].version, self.opset_version) 1387 self.assertEqual(graph.opset_import[1].domain, "com.microsoft") 1388 self.assertEqual(graph.opset_import[1].version, 1) 1389 1390 # gelu is exported as onnx::Gelu for opset >= 20 1391 @skipIfUnsupportedMaxOpsetVersion(19) 1392 def test_register_aten_custom_op_symbolic(self): 1393 self.addCleanup(torch.onnx.unregister_custom_op_symbolic, "aten::gelu", 9) 1394 1395 def gelu(g, self, approximate): 1396 return g.op("com.microsoft::Gelu", self).setType(self.type()) 1397 1398 torch.onnx.register_custom_op_symbolic("aten::gelu", gelu, 9) 1399 model = torch.nn.GELU(approximate="none") 1400 x = torch.randn(3, 3) 1401 f = io.BytesIO() 1402 torch.onnx.export(model, (x,), f, opset_version=self.opset_version) 1403 graph = onnx.load(io.BytesIO(f.getvalue())) 1404 1405 self.assertEqual(graph.graph.node[0].op_type, "Gelu") 1406 self.assertEqual(graph.opset_import[1].domain, "com.microsoft") 1407 1408 @skipIfNoLapack 1409 def test_custom_opsets_inverse(self): 1410 self.addCleanup(torch.onnx.unregister_custom_op_symbolic, "::linalg_inv", 9) 1411 1412 class CustomInverse(torch.nn.Module): 1413 def forward(self, x): 1414 return torch.inverse(x) + x 1415 1416 def linalg_inv(g, self): 1417 return g.op("com.microsoft::Inverse", self).setType(self.type()) 1418 1419 torch.onnx.register_custom_op_symbolic("::linalg_inv", linalg_inv, 9) 1420 model = CustomInverse() 1421 x = torch.randn(2, 3, 3) 1422 f = io.BytesIO() 1423 torch.onnx.export( 1424 model, 1425 (x,), 1426 f, 1427 opset_version=self.opset_version, 1428 custom_opsets={"com.microsoft": 1}, 1429 ) 1430 1431 graph = onnx.load(io.BytesIO(f.getvalue())) 1432 self.assertEqual(graph.graph.node[0].op_type, "Inverse") 1433 self.assertEqual(graph.opset_import[0].version, self.opset_version) 1434 self.assertEqual(graph.opset_import[1].domain, "com.microsoft") 1435 self.assertEqual(graph.opset_import[1].version, 1) 1436 1437 def test_onnx_fallthrough(self): 1438 # Test aten export of op with symbolic for aten 1439 class Module(torch.nn.Module): 1440 def forward(self, x): 1441 return torch.digamma(x) 1442 1443 x = torch.randn(100, 128) 1444 graph, _, __ = self._model_to_graph( 1445 Module(), 1446 (x,), 1447 operator_export_type=OperatorExportTypes.ONNX_FALLTHROUGH, 1448 input_names=["x"], 1449 dynamic_axes={"x": [0, 1]}, 1450 ) 1451 iter = graph.nodes() 1452 self.assertEqual(next(iter).kind(), "aten::digamma") 1453 1454 # prim::ListConstruct is exported as onnx::SequenceConstruct for opset >= 11 1455 @skipIfUnsupportedMaxOpsetVersion(10) 1456 def test_prim_fallthrough(self): 1457 # Test prim op 1458 class PrimModule(torch.jit.ScriptModule): 1459 @torch.jit.script_method 1460 def forward(self, x): 1461 if isinstance(x, list): 1462 y = x 1463 else: 1464 y = [x] 1465 return y 1466 1467 x = torch.tensor([2]) 1468 model = PrimModule() 1469 model.eval() 1470 graph, _, __ = self._model_to_graph( 1471 model, 1472 (x,), 1473 operator_export_type=OperatorExportTypes.ONNX_FALLTHROUGH, 1474 input_names=["x"], 1475 dynamic_axes={"x": [0]}, 1476 ) 1477 iter = graph.nodes() 1478 self.assertEqual(next(iter).kind(), "prim::ListConstruct") 1479 1480 def test_custom_layer_tuple(self): 1481 class CustomFunction(torch.autograd.Function): 1482 @staticmethod 1483 def symbolic(g, input): 1484 return g.op("CustomNamespace::Custom", input, outputs=2) 1485 1486 @staticmethod 1487 def forward(ctx, input): 1488 return input, input 1489 1490 class Custom(torch.nn.Module): 1491 def forward(self, input): 1492 return CustomFunction.apply(input) 1493 1494 model = Custom() 1495 batch = torch.FloatTensor(1, 3) 1496 1497 graph, _, _ = self._model_to_graph( 1498 model, batch, input_names=["batch"], dynamic_axes={"batch": [0, 1]} 1499 ) 1500 iter = graph.nodes() 1501 self.assertEqual(next(iter).kind(), "CustomNamespace::Custom") 1502 1503 def test_autograd_onnx_fallthrough(self): 1504 class CustomFunction(torch.autograd.Function): 1505 @staticmethod 1506 def forward(ctx, input): 1507 ctx.save_for_backward(input) 1508 return input.clamp(min=0) 1509 1510 @staticmethod 1511 def backward(ctx, grad_output): 1512 (input,) = ctx.saved_tensors 1513 grad_input = grad_output.clone() 1514 grad_input[input < 0] = 0 1515 return grad_input 1516 1517 class Custom(torch.nn.Module): 1518 def forward(self, input): 1519 return CustomFunction.apply(input) 1520 1521 model = Custom() 1522 batch = torch.FloatTensor(1, 3) 1523 1524 graph, _, _ = self._model_to_graph( 1525 model, 1526 batch, 1527 operator_export_type=OperatorExportTypes.ONNX_FALLTHROUGH, 1528 input_names=["batch"], 1529 dynamic_axes={"batch": [0, 1]}, 1530 ) 1531 iter = graph.nodes() 1532 self.assertEqual(next(iter).kind(), "prim::PythonOp") 1533 1534 def test_autograd_module_name(self): 1535 class CustomFunction(torch.autograd.Function): 1536 @staticmethod 1537 def forward(ctx, input): 1538 ctx.save_for_backward(input) 1539 return input.clamp(min=0) 1540 1541 @staticmethod 1542 def backward(ctx, grad_output): 1543 (input,) = ctx.saved_tensors 1544 grad_input = grad_output.clone() 1545 grad_input[input < 0] = 0 1546 return grad_input 1547 1548 class Custom(torch.nn.Module): 1549 def forward(self, input): 1550 return CustomFunction.apply(input) + CustomFunction2.apply(input) 1551 1552 model = Custom() 1553 batch = torch.FloatTensor(1, 3) 1554 1555 graph, _, _ = self._model_to_graph( 1556 model, 1557 batch, 1558 operator_export_type=OperatorExportTypes.ONNX_FALLTHROUGH, 1559 input_names=["batch"], 1560 dynamic_axes={"batch": [0, 1]}, 1561 ) 1562 iter = graph.nodes() 1563 autograd1 = next(iter) 1564 autograd2 = next(iter) 1565 self.assertEqual(autograd1.kind(), "prim::PythonOp") 1566 self.assertEqual(autograd2.kind(), "prim::PythonOp") 1567 self.assertNotEqual(autograd1.s("module"), autograd2.s("module")) 1568 1569 def test_unused_initializers(self): 1570 class Model(torch.nn.Module): 1571 def __init__(self) -> None: 1572 super().__init__() 1573 self.conv2 = torch.nn.ConvTranspose2d( 1574 16, 33, (3, 5), stride=(2, 1), padding=(4, 2), dilation=(1, 1) 1575 ) 1576 self.k_proj = torch.nn.Linear(5, 5, bias=True) 1577 1578 def forward(self, x): 1579 x = self.conv2(x) 1580 return x 1581 1582 x = torch.randn(20, 16, 50, 100) 1583 GLOBALS.export_onnx_opset_version = self.opset_version 1584 GLOBALS.operator_export_type = OperatorExportTypes.ONNX 1585 _, params_dict, __ = self._model_to_graph( 1586 Model(), 1587 (x,), 1588 do_constant_folding=False, 1589 operator_export_type=OperatorExportTypes.ONNX, 1590 input_names=["x"], 1591 dynamic_axes={"x": [0, 1, 2, 3]}, 1592 ) 1593 1594 self.assertEqual(len(params_dict), 2) 1595 1596 def test_scripting_param(self): 1597 class MyModule(torch.nn.Module): 1598 def __init__(self) -> None: 1599 super().__init__() 1600 self.conv = torch.nn.Conv2d( 1601 3, 16, kernel_size=1, stride=2, padding=3, bias=True 1602 ) 1603 self.bn = torch.nn.BatchNorm2d(16, affine=True) 1604 1605 def forward(self, x): 1606 x = self.conv(x) 1607 bn = self.bn(x) 1608 return bn 1609 1610 model = torch.jit.script(MyModule()) 1611 x = torch.randn(10, 3, 128, 128) 1612 GLOBALS.export_onnx_opset_version = self.opset_version 1613 GLOBALS.operator_export_type = OperatorExportTypes.ONNX 1614 graph, _, __ = self._model_to_graph( 1615 model, 1616 (x,), 1617 do_constant_folding=True, 1618 operator_export_type=OperatorExportTypes.ONNX, 1619 training=torch.onnx.TrainingMode.TRAINING, 1620 input_names=["x"], 1621 dynamic_axes={"x": [0, 1, 2, 3]}, 1622 ) 1623 1624 graph_input_params = [param.debugName() for param in graph.inputs()] 1625 for item in dict(model.named_parameters()): 1626 self.assertIn( 1627 item, 1628 graph_input_params, 1629 "Graph parameter names does not match model parameters.", 1630 ) 1631 1632 def test_fuse_conv_bn(self): 1633 class Fuse(torch.nn.Module): 1634 def __init__(self) -> None: 1635 super().__init__() 1636 self.conv = torch.nn.Conv2d( 1637 3, 2, kernel_size=1, stride=2, padding=3, bias=True 1638 ) 1639 self.bn = torch.nn.BatchNorm2d(2) 1640 1641 def forward(self, x): 1642 out = self.conv(x) 1643 return self.bn(out) 1644 1645 x = torch.randn(2, 3, 2, 2, requires_grad=True) 1646 graph, _, __ = self._model_to_graph( 1647 Fuse(), 1648 (x,), 1649 training=TrainingMode.EVAL, 1650 input_names=["x"], 1651 dynamic_axes={"x": [0, 1, 2, 3]}, 1652 ) 1653 for node in graph.nodes(): 1654 self.assertNotEqual(node.kind(), "onnx::BatchNormalization") 1655 self.assertEqual(node.kind(), "onnx::Conv") 1656 1657 self.assertEqual(len(list(graph.nodes())), 1) 1658 1659 def test_fuse_resnet18(self): 1660 model = torchvision.models.resnet18(weights=None) 1661 x = torch.randn(2, 3, 224, 224, requires_grad=True) 1662 graph, _, __ = self._model_to_graph( 1663 model, 1664 (x,), 1665 training=TrainingMode.EVAL, 1666 input_names=["x"], 1667 dynamic_axes={"x": [0, 1, 2, 3]}, 1668 ) 1669 1670 for node in graph.nodes(): 1671 self.assertNotEqual(node.kind(), "onnx::BatchNormalization") 1672 1673 def test_onnx_function_substitution_pass(self): 1674 @torch.jit.script 1675 def f(x: torch.Tensor, y: torch.Tensor): 1676 z = x - y 1677 return x + z 1678 1679 class MyModule(torch.nn.Module): 1680 def forward(self, x, y): 1681 return f(x, y) 1682 1683 input_1 = torch.tensor([11]) 1684 input_2 = torch.tensor([12]) 1685 GLOBALS.export_onnx_opset_version = self.opset_version 1686 GLOBALS.operator_export_type = OperatorExportTypes.ONNX 1687 graph, _, __ = self._model_to_graph( 1688 MyModule(), 1689 (input_1, input_2), 1690 do_constant_folding=True, 1691 operator_export_type=OperatorExportTypes.ONNX, 1692 input_names=["input_1", "input_2"], 1693 dynamic_axes={"input_1": [0], "input_2": [0]}, 1694 ) 1695 # Check that the prim::Constant node in the graph for representing the 1696 # scripted function `f` is removed and the following prim::CallFunction 1697 # is replced by inline graph, with onnx::Sub and onnx::Add nodes. 1698 for node in graph.nodes(): 1699 self.assertNotEqual(node.kind(), "prim::Constant") 1700 self.assertEqual( 1701 len(list(graph.nodes())), 2 1702 ) # onnx::Sub and onnx::Add nodes only. 1703 1704 def test_onnx_value_name(self): 1705 class MyModule(torch.nn.Module): 1706 def __init__(self) -> None: 1707 super().__init__() 1708 self.in_weight = torch.nn.Parameter(torch.Tensor(3, 3)) 1709 self.in_bias = torch.nn.Parameter(torch.Tensor(3)) 1710 1711 def forward(self, x): 1712 start = 0 1713 end = None 1714 weight = self.in_weight 1715 bias = self.in_bias 1716 weight = weight[start:end, :] 1717 if bias is not None: 1718 bias = bias[start:end] 1719 return torch.nn.functional.linear(x, weight, bias) 1720 1721 model = MyModule() 1722 x = torch.randn(3, 3) 1723 f = io.BytesIO() 1724 1725 model.eval() 1726 torch.onnx.export( 1727 model, 1728 (x,), 1729 f, 1730 opset_version=self.opset_version, 1731 keep_initializers_as_inputs=True, 1732 ) 1733 graph = onnx.load(io.BytesIO(f.getvalue())) 1734 self.assertEqual(graph.graph.input[1].name, "in_weight") 1735 self.assertEqual(graph.graph.input[2].name, "in_bias") 1736 1737 def test_onnx_node_naming(self): 1738 class MainModule(torch.nn.Module): 1739 def __init__(self) -> None: 1740 super().__init__() 1741 self._module_1 = torch.nn.Linear(10, 10) 1742 self._module_2 = torch.nn.Linear(10, 10) 1743 self._module_3 = torch.nn.Linear(10, 10) 1744 self._module_4 = torch.nn.Linear(10, 10) 1745 1746 def forward(self, x): 1747 y = self._module_1(x) 1748 z = self._module_2(y) 1749 z = self._module_3(y * z) 1750 z = self._module_4(y * z) 1751 return z 1752 1753 module = MainModule() 1754 ref_node_names = [ 1755 "/_module_1/Gemm", 1756 "/_module_2/Gemm", 1757 "/_module_3/Gemm", 1758 "/_module_4/Gemm", 1759 "/Mul", 1760 "/Mul_1", 1761 ] 1762 f = io.BytesIO() 1763 1764 torch.onnx.export(module, torch.ones(1, 10), f, output_names=["y"]) 1765 onnx_model = onnx.load(io.BytesIO(f.getvalue())) 1766 for n in onnx_model.graph.node: 1767 self.assertIn(n.name, ref_node_names) 1768 1769 torch.onnx.export( 1770 torch.jit.script(module), torch.ones(1, 10), f, output_names=["y"] 1771 ) 1772 onnx_model = onnx.load(io.BytesIO(f.getvalue())) 1773 for n in onnx_model.graph.node: 1774 self.assertIn(n.name, ref_node_names) 1775 1776 def _test_deduplicate_initializers(self, torchscript=False): 1777 class MyModule(torch.nn.Module): 1778 def __init__(self) -> None: 1779 super().__init__() 1780 self.layer1 = torch.nn.Linear(3, 3) 1781 self.layer2 = torch.nn.Linear(3, 3) 1782 1783 # Reusing layers. 1784 self.layer3 = self.layer1 1785 1786 # Reusing parameters. 1787 self.layer2.weight = self.layer1.weight 1788 self.layer1.bias = self.layer2.bias 1789 1790 # Parameter with different tensors equal in value. 1791 self.param1 = torch.nn.Parameter(torch.tensor([1.0, 2.0, 3.0])) 1792 self.param2 = torch.nn.Parameter(torch.tensor([1.0, 2.0, 3.0])) 1793 1794 def forward(self, x): 1795 return ( 1796 self.layer3(self.layer2(self.layer1(x))) + self.param1 + self.param2 1797 ) 1798 1799 model = torch.jit.script(MyModule()) if torchscript else MyModule() 1800 1801 x = torch.randn(3, 3) 1802 param_name_set = {k for k, _ in model.named_parameters()} 1803 1804 # Test training mode. 1805 model.train() 1806 f = io.BytesIO() 1807 torch.onnx.export( 1808 model, 1809 (x,), 1810 f, 1811 training=TrainingMode.TRAINING, 1812 opset_version=self.opset_version, 1813 ) 1814 graph = onnx.load(io.BytesIO(f.getvalue())) 1815 self.assertSetEqual({i.name for i in graph.graph.initializer}, param_name_set) 1816 1817 model.train() 1818 f = io.BytesIO() 1819 torch.onnx.export( 1820 model, 1821 (x,), 1822 f, 1823 training=TrainingMode.PRESERVE, 1824 opset_version=self.opset_version, 1825 ) 1826 graph = onnx.load(io.BytesIO(f.getvalue())) 1827 self.assertSetEqual({i.name for i in graph.graph.initializer}, param_name_set) 1828 1829 # Test eval mode. 1830 model.eval() 1831 f = io.BytesIO() 1832 torch.onnx.export(model, (x,), f, opset_version=self.opset_version) 1833 graph = onnx.load(io.BytesIO(f.getvalue())) 1834 param_name_set.remove("param2") 1835 self.assertSetEqual({i.name for i in graph.graph.initializer}, param_name_set) 1836 1837 def test_deduplicate_initializers(self): 1838 self._test_deduplicate_initializers(torchscript=False) 1839 1840 def test_deduplicate_initializers_torchscript(self): 1841 self._test_deduplicate_initializers(torchscript=True) 1842 1843 @skipIfNoCuda 1844 def test_deduplicate_initializers_diff_devices(self): 1845 class Model(torch.nn.Module): 1846 def __init__(self) -> None: 1847 super().__init__() 1848 self.w_cpu = torch.nn.Parameter( 1849 torch.ones(3, device=torch.device("cpu")) 1850 ) 1851 self.w_cuda = torch.nn.Parameter( 1852 torch.ones(3, device=torch.device("cuda")) 1853 ) 1854 1855 def forward(self, x, y): 1856 return x + self.w_cpu, y + self.w_cuda 1857 1858 x = torch.randn(3, 3, device=torch.device("cpu")) 1859 y = torch.randn(3, 3, device=torch.device("cuda")) 1860 f = io.BytesIO() 1861 torch.onnx.export(Model(), (x, y), f, opset_version=self.opset_version) 1862 graph = onnx.load(io.BytesIO(f.getvalue())) 1863 self.assertSetEqual({i.name for i in graph.graph.initializer}, {"w_cpu"}) 1864 1865 def test_duplicated_output_node(self): 1866 class DuplicatedOutputNet(torch.nn.Module): 1867 def __init__(self, input_size, num_classes): 1868 super().__init__() 1869 self.fc1 = torch.nn.Linear(input_size, num_classes) 1870 1871 def forward(self, input0, input1): 1872 out1 = self.fc1(input0) 1873 out2 = self.fc1(input1) 1874 return out1, out1, out2, out1, out2 1875 1876 N, D_in, H, D_out = 64, 784, 500, 10 1877 pt_model = DuplicatedOutputNet(D_in, D_out) 1878 1879 f = io.BytesIO() 1880 x = torch.randn(N, D_in) 1881 dynamic_axes = { 1882 "input0": {0: "input0_dim0", 1: "input0_dim1"}, 1883 "input1": {0: "input1_dim0", 1: "input1_dim1"}, 1884 "output-0": {0: "output-0_dim0", 1: "output-0_dim1"}, 1885 "output-1": {0: "output-1_dim0", 1: "output-1_dim1"}, 1886 "output-2": {0: "output-2_dim0", 1: "output-2_dim1"}, 1887 "output-3": {0: "output-3_dim0", 1: "output-3_dim1"}, 1888 "output-4": {0: "output-4_dim0", 1: "output-4_dim1"}, 1889 } 1890 1891 torch.onnx.export( 1892 pt_model, 1893 (x, x), 1894 f, 1895 input_names=["input0", "input1"], 1896 output_names=["output-0", "output-1", "output-2", "output-3", "output-4"], 1897 do_constant_folding=False, 1898 training=torch.onnx.TrainingMode.TRAINING, 1899 dynamic_axes=dynamic_axes, 1900 verbose=True, 1901 keep_initializers_as_inputs=True, 1902 ) 1903 1904 graph = onnx.load(io.BytesIO(f.getvalue())) 1905 self.assertEqual(graph.graph.input[0].name, "input0") 1906 self.assertEqual(graph.graph.input[1].name, "input1") 1907 for i in range(5): 1908 self.assertEqual(graph.graph.output[i].name, f"output-{i}") 1909 self.assertEqual(graph.graph.node[0].op_type, "Gemm") 1910 self.assertEqual(graph.graph.node[1].op_type, "Identity") 1911 self.assertEqual(graph.graph.node[2].op_type, "Identity") 1912 self.assertEqual(graph.graph.node[3].op_type, "Gemm") 1913 self.assertEqual(graph.graph.node[4].op_type, "Identity") 1914 1915 def test_deduplicate_ignore_upsample_scale(self): 1916 # upsample scale is a constant, not a model parameter, 1917 # therefore should be ignored by shared weight deduplication. 1918 class Model(torch.nn.Module): 1919 def __init__(self) -> None: 1920 super().__init__() 1921 self.upsample_1 = torch.nn.Upsample(scale_factor=2) 1922 self.upsample_2 = torch.nn.Upsample(scale_factor=2) 1923 1924 def forward(self, x): 1925 return self.upsample_1(x), self.upsample_2(x) 1926 1927 f = io.BytesIO() 1928 x = torch.randn(1, 32, 224, 224) 1929 torch.onnx.export(Model(), x, f) 1930 onnx_model = onnx.load(io.BytesIO(f.getvalue())) 1931 # aten::upsample converts to onnx::resize 1932 resize_nodes = [n for n in onnx_model.graph.node if n.op_type == "Resize"] 1933 self.assertEqual(len(resize_nodes), 2) 1934 for resize_node in resize_nodes: 1935 scale_node = [ 1936 n for n in onnx_model.graph.node if n.output[0] == resize_node.input[2] 1937 ] 1938 self.assertEqual(len(scale_node), 1) 1939 self.assertEqual(scale_node[0].op_type, "Constant") 1940 1941 def test_bad_symbolic_registration(self): 1942 _onnx_opset_version = 9 1943 1944 @parse_args("v") 1945 def cat(g, tensor_list, dim): 1946 tensors = _unpack_list(tensor_list) 1947 return g.op("Concat", *tensors, axis_i=dim) 1948 1949 torch.onnx.register_custom_op_symbolic("::cat", cat, _onnx_opset_version) 1950 1951 class CatModel(torch.nn.Module): 1952 def forward(self, x): 1953 return torch.cat((x, x, x), 0) 1954 1955 model = CatModel() 1956 x = torch.randn(2, 3) 1957 f = io.BytesIO() 1958 self.assertExpectedRaisesInline( 1959 AssertionError, 1960 lambda: torch.onnx.export( 1961 model, (x,), f, opset_version=_onnx_opset_version 1962 ), 1963 ( 1964 "A mismatch between the number of arguments (2) and their descriptors (1) was found at symbolic function " 1965 "'cat'. If you believe this is not due to custom symbolic implementation within your code or an external " 1966 "library, please file an issue at https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml to " 1967 "report this bug." 1968 ), 1969 ) 1970 torch.onnx.unregister_custom_op_symbolic("::cat", _onnx_opset_version) 1971 1972 1973if __name__ == "__main__": 1974 common_utils.run_tests() 1975