1# Owner(s): ["module: onnx"] 2from __future__ import annotations 3 4import logging 5import tempfile 6from typing import Mapping, Tuple, TYPE_CHECKING 7 8import onnx 9import onnx.inliner 10 11import pytorch_test_common 12import transformers # type: ignore[import] 13 14import torch 15from torch import nn 16from torch._subclasses import fake_tensor 17from torch.nn import functional as F 18from torch.onnx import dynamo_export, ExportOptions 19from torch.onnx._internal.fx import diagnostics, registration 20from torch.testing._internal import common_utils 21 22 23if TYPE_CHECKING: 24 from torch.onnx._internal.diagnostics import infra 25 26 27def assert_has_diagnostics( 28 diagnostic_context: diagnostics.DiagnosticContext, 29 rule: infra.Rule, 30 level: infra.Level, 31 expected_node: str, 32): 33 rule_level_pairs = (rule.id, level.name.lower()) 34 sarif_log = diagnostic_context.sarif_log() 35 actual_results = [] 36 for run in sarif_log.runs: 37 if run.results is None: 38 continue 39 for result in run.results: 40 id_level_pair = (result.rule_id, result.level) 41 actual_results.append(id_level_pair) 42 if ( 43 rule_level_pairs == id_level_pair 44 and result.message.text 45 and result.message.markdown 46 and expected_node in result.message.text 47 ): 48 return 49 50 raise AssertionError( 51 f"Expected diagnostic results of rule id and level pair {rule_level_pairs} " 52 f"not found with expected error node {expected_node} and " 53 f"Actual diagnostic results: {actual_results}" 54 ) 55 56 57@common_utils.instantiate_parametrized_tests 58class TestFxToOnnx(pytorch_test_common.ExportTestCase): 59 def setUp(self): 60 super().setUp() 61 self.export_options = ExportOptions() 62 63 def tearDown(self): 64 super().tearDown() 65 66 def test_simple_function(self): 67 def func(x): 68 y = x + 1 69 z = y.relu() 70 return (y, z) 71 72 _ = dynamo_export( 73 func, torch.randn(1, 1, 2), export_options=self.export_options 74 ) 75 76 def test_empty(self): 77 # Since `torch.empty` returns tensor with uninitialized data, we cannot 78 # test this under `test_fx_to_onnx_with_onnxruntime.py` with result comparison. 79 def func(x): 80 return torch.empty(x.size(), dtype=torch.int64) 81 82 tensor_x = torch.randn(1, 1, 2) 83 _ = dynamo_export(func, tensor_x, export_options=self.export_options) 84 85 def test_args_used_for_export_is_not_converted_to_fake_tensors(self): 86 def func(x, y): 87 return x + y 88 89 tensor_x = torch.randn(1, 1, 2) 90 tensor_y = torch.randn(1, 1, 2) 91 _ = dynamo_export(func, tensor_x, tensor_y, export_options=self.export_options) 92 self.assertNotIsInstance(tensor_x, fake_tensor.FakeTensor) 93 self.assertNotIsInstance(tensor_y, fake_tensor.FakeTensor) 94 95 @common_utils.parametrize( 96 "diagnostic_rule", 97 [ 98 common_utils.subtest( 99 diagnostics.rules.find_opschema_matched_symbolic_function, 100 name="optional_inputs", 101 ), 102 ], 103 ) 104 def test_mnist_exported_with_no_warnings(self, diagnostic_rule): 105 class MNISTModel(nn.Module): 106 def __init__(self) -> None: 107 super().__init__() 108 self.conv1 = nn.Conv2d(1, 32, 3, 1, bias=False) 109 self.conv2 = nn.Conv2d(32, 64, 3, 1, bias=False) 110 self.fc1 = nn.Linear(9216, 128, bias=False) 111 self.fc2 = nn.Linear(128, 10, bias=False) 112 113 def forward(self, tensor_x: torch.Tensor): 114 tensor_x = self.conv1(tensor_x) 115 tensor_x = F.sigmoid(tensor_x) 116 tensor_x = self.conv2(tensor_x) 117 tensor_x = F.sigmoid(tensor_x) 118 tensor_x = F.max_pool2d(tensor_x, 2) 119 tensor_x = torch.flatten(tensor_x, 1) 120 tensor_x = self.fc1(tensor_x) 121 tensor_x = F.sigmoid(tensor_x) 122 tensor_x = self.fc2(tensor_x) 123 output = F.log_softmax(tensor_x, dim=1) 124 return output 125 126 tensor_x = torch.rand((64, 1, 28, 28), dtype=torch.float32) 127 onnx_program = dynamo_export(MNISTModel(), tensor_x) 128 129 assert_has_diagnostics( 130 onnx_program.diagnostic_context, 131 diagnostic_rule, 132 diagnostics.levels.NONE, 133 expected_node="aten.convolution.default", 134 ) 135 136 def test_trace_only_op_with_evaluator(self): 137 model_input = torch.tensor([[1.0, 2.0, 3.0], [1.0, 1.0, 2.0]]) 138 139 class ArgminArgmaxModel(torch.nn.Module): 140 def forward(self, input): 141 return ( 142 torch.argmin(input), 143 torch.argmax(input), 144 torch.argmin(input, keepdim=True), 145 torch.argmax(input, keepdim=True), 146 torch.argmin(input, dim=0, keepdim=True), 147 torch.argmax(input, dim=1, keepdim=True), 148 ) 149 150 _ = dynamo_export( 151 ArgminArgmaxModel(), model_input, export_options=self.export_options 152 ) 153 154 def test_multiple_outputs_op_with_evaluator(self): 155 class TopKModel(torch.nn.Module): 156 def forward(self, x): 157 values, _ = torch.topk(x, 3) 158 return torch.sum(values) 159 160 x = torch.arange(1.0, 6.0, requires_grad=True) 161 162 _ = dynamo_export(TopKModel(), x, export_options=self.export_options) 163 164 def test_unsupported_function_schema_raises_diagnostic_warning_when_found_nearest_match( 165 self, 166 ): 167 class TraceModel(torch.nn.Module): 168 def forward(self, input): 169 return input.new_zeros(()) 170 171 x = torch.randn((2, 3), dtype=torch.float32) 172 onnx_program = dynamo_export(TraceModel(), x) 173 174 assert_has_diagnostics( 175 onnx_program.diagnostic_context, 176 diagnostics.rules.find_opschema_matched_symbolic_function, 177 diagnostics.levels.WARNING, 178 expected_node="aten.new_zeros.default", 179 ) 180 181 def test_perfect_match_on_sequence_and_bool_attributes( 182 self, 183 ): 184 class TraceModel(torch.nn.Module): 185 def __init__(self) -> None: 186 super().__init__() 187 self.conv2 = torch.nn.Conv2d( 188 16, 33, (3, 5), stride=(2, 1), padding=(4, 2), dilation=(3, 1) 189 ) 190 191 def forward(self, input): 192 return self.conv2(input) 193 194 x = torch.randn(20, 16, 50, 50) 195 onnx_program = dynamo_export(TraceModel(), x) 196 assert_has_diagnostics( 197 onnx_program.diagnostic_context, 198 diagnostics.rules.find_opschema_matched_symbolic_function, 199 diagnostics.levels.NONE, 200 expected_node="aten.convolution.default", 201 ) 202 203 def test_aten_clone_does_not_raise_warning_of_lack_of_memory_format(self): 204 class CustomModule(torch.nn.Module): 205 def forward(self, input): 206 return torch.ops.aten.clone(input, memory_format=torch.preserve_format) 207 208 x = torch.tensor(3) 209 onnx_program = dynamo_export(CustomModule(), x) 210 assert_has_diagnostics( 211 onnx_program.diagnostic_context, 212 diagnostics.rules.find_opschema_matched_symbolic_function, 213 diagnostics.levels.NONE, 214 expected_node="aten.clone.default", 215 ) 216 217 def test_missing_complex_onnx_variant_raises_errors_in_dispatcher(self): 218 registry = torch.onnx.OnnxRegistry() 219 220 # NOTE: simulate unsupported nodes 221 aten_mul_tensor = registration.OpName.from_name_parts( 222 namespace="aten", op_name="mul", overload="Tensor" 223 ) 224 225 # Only keep real aten.mul to test missing complex aten.mul 226 registry._registry[aten_mul_tensor] = [ 227 onnx_func 228 for onnx_func in registry._registry[aten_mul_tensor] 229 if not onnx_func.is_complex 230 ] 231 232 class TraceModel(torch.nn.Module): 233 def forward(self, input): 234 return torch.ops.aten.mul.Tensor(input, input) 235 236 x = torch.tensor([1 + 2j, 3 + 4j], dtype=torch.complex64) 237 238 with self.assertRaises(torch.onnx.OnnxExporterError) as e: 239 torch.onnx.dynamo_export( 240 TraceModel(), 241 x, 242 export_options=torch.onnx.ExportOptions(onnx_registry=registry), 243 ) 244 245 def test_symbolic_shape_of_values_inside_function_is_exported_as_graph_value_info( 246 self, 247 ): 248 class SubModule(torch.nn.Module): 249 def forward(self, x, y, bias): 250 output = x @ y 251 return output + bias 252 253 class Module(torch.nn.Module): 254 def __init__(self) -> None: 255 super().__init__() 256 self.submodule = SubModule() 257 258 def forward(self, x, y, bias): 259 return self.submodule(x, y, bias) 260 261 x = torch.randn(2, 3) 262 y = torch.randn(3, 4) 263 bias = torch.randn(4) 264 onnx_program = torch.onnx.dynamo_export( 265 Module(), 266 x, 267 y, 268 bias, 269 export_options=torch.onnx.ExportOptions(dynamic_shapes=True), 270 ) 271 model_proto = onnx_program.model_proto 272 273 # Assert value_info for values inside local function can be retrieved 274 def _assert_node_outputs_has_value_info( 275 node: onnx.NodeProto, 276 value_infos: Mapping[str, onnx.ValueInfoProto], 277 local_functions: Mapping[Tuple[str, str], onnx.FunctionProto], 278 exclude_names_in_value_info, 279 function_id: str = "", 280 ): 281 for output in node.output: 282 name = f"{function_id}/{output}" if function_id else output 283 if name not in exclude_names_in_value_info: 284 self.assertIn(name, value_infos) 285 if node.domain.startswith("pkg.onnxscript.torch_lib"): 286 # No shape info available for values inside torchlib functions. 287 return 288 if ( 289 function := local_functions.get((node.domain, node.op_type)) 290 ) is not None: 291 for node in function.node: 292 function_id = f"{function.domain}::{function.name}" 293 _assert_node_outputs_has_value_info( 294 node, 295 value_infos, 296 local_functions, 297 exclude_names_in_value_info, 298 function_id, 299 ) 300 301 type_infos = {vi.name: vi for vi in model_proto.graph.value_info} 302 functions = {(f.domain, f.name): f for f in model_proto.functions} 303 # NOTE: inputs, outputs, and initializers are not included in value_info spec 304 exclude_names_in_value_info = ( 305 [input.name for input in model_proto.graph.input] 306 + [output.name for output in model_proto.graph.output] 307 + [init.name for init in model_proto.graph.initializer] 308 ) 309 for node in model_proto.graph.node: 310 _assert_node_outputs_has_value_info( 311 node, type_infos, functions, exclude_names_in_value_info 312 ) 313 314 def test_dynamo_export_retains_readable_parameter_and_buffer_names(self): 315 class SubModule(torch.nn.Module): 316 def __init__(self) -> None: 317 super().__init__() 318 self.conv2 = nn.Conv2d(32, 64, 3, 1, bias=False) 319 self.fc1 = nn.Linear(9216, 128, bias=False) 320 self.buffer = torch.nn.Buffer(torch.randn(1, 128)) 321 322 def forward(self, tensor_x: torch.Tensor): 323 tensor_x = self.conv2(tensor_x) 324 tensor_x = F.sigmoid(tensor_x) 325 tensor_x = F.max_pool2d(tensor_x, 2) 326 tensor_x = torch.flatten(tensor_x, 1) 327 tensor_x = self.fc1(tensor_x) 328 tensor_x = tensor_x + self.buffer 329 tensor_x = F.sigmoid(tensor_x) 330 return tensor_x 331 332 class MNISTModel(nn.Module): 333 def __init__(self) -> None: 334 super().__init__() 335 self.conv1 = nn.Conv2d(1, 32, 3, 1, bias=False) 336 self.submodule = SubModule() 337 self.fc2 = nn.Linear(128, 10, bias=False) 338 339 def forward(self, tensor_x: torch.Tensor): 340 tensor_x = self.conv1(tensor_x) 341 tensor_x = F.sigmoid(tensor_x) 342 tensor_x = self.submodule(tensor_x) 343 tensor_x = self.fc2(tensor_x) 344 output = F.log_softmax(tensor_x, dim=1) 345 return output 346 347 tensor_x = torch.rand((64, 1, 28, 28), dtype=torch.float32) 348 349 model = MNISTModel() 350 onnx_program = torch.onnx.dynamo_export(model, tensor_x) 351 model_proto = onnx_program.model_proto 352 353 # NOTE: initializers could be optimized away by onnx optimizer 354 onnx_initilizers = {init.name for init in model_proto.graph.initializer} 355 torch_weights = {*model.state_dict().keys()} 356 self.assertTrue(onnx_initilizers.issubset(torch_weights)) 357 358 @common_utils.parametrize( 359 "checkpoint_type", 360 [ 361 common_utils.subtest( 362 "state_dict", 363 name="state_dict", 364 ), 365 common_utils.subtest( 366 "state_dict", 367 name="checkpoint_file", 368 ), 369 ], 370 ) 371 def test_fake_tensor_mode_simple(self, checkpoint_type): 372 class Model(torch.nn.Module): 373 def __init__(self) -> None: 374 super().__init__() 375 self.linear = torch.nn.Linear(2, 2) 376 377 def forward(self, x): 378 out = self.linear(x) 379 return out 380 381 with torch.onnx.enable_fake_mode() as fake_context: 382 x = torch.rand(5, 2, 2) 383 model = Model() 384 export_options = ExportOptions(fake_context=fake_context) 385 onnx_program = torch.onnx.dynamo_export( 386 model, x, export_options=export_options 387 ) 388 389 assert ( 390 onnx_program is not None 391 ), "ONNXProgram must be created on successful export" 392 assert ( 393 onnx_program.model_proto is not None 394 ), "A model protobuf must be created on a successful export" 395 onnx.checker.check_model(onnx_program.model_proto, full_check=True) 396 assert ( 397 len(onnx_program.model_proto.graph.initializer) == 0 398 ), "Initializers cannot exist when fake mode is enabled" 399 400 if checkpoint_type == "state_dict": 401 # Variant 1: Save ONNX proto using Model's state_dict() 402 with tempfile.NamedTemporaryFile(suffix=".onnx") as tmp_onnx_file: 403 model_state_dict = ( 404 Model().state_dict() 405 ) # Create a state_dict for testing 406 onnx_program.save(tmp_onnx_file.name, model_state=model_state_dict) 407 assert ( 408 len(onnx.load(tmp_onnx_file.name).graph.initializer) == 2 409 ), "Initializers must be present after loading it from model_state_dict" 410 # Let's make sure consecutive `save` calls don't create dupes 411 onnx_program.save(tmp_onnx_file.name, model_state=model_state_dict) 412 assert ( 413 len(onnx.load(tmp_onnx_file.name).graph.initializer) == 2 414 ), "Initializers must be present after loading it from model_state_dict" 415 elif checkpoint_type == "checkpoint_file": 416 # Variant 2: Save ONNX proto using Model checkpoint file 417 with tempfile.NamedTemporaryFile( 418 suffix=".onnx" 419 ) as tmp_onnx_file, tempfile.NamedTemporaryFile( 420 suffix=".pt" 421 ) as tmp_checkpoint_file: 422 torch.save( 423 Model().state_dict(), tmp_checkpoint_file.name 424 ) # Create checkpoint file for testing 425 onnx_program.save( 426 tmp_onnx_file.name, model_state=tmp_checkpoint_file.name 427 ) 428 assert ( 429 len(onnx.load(tmp_onnx_file.name).graph.initializer) == 2 430 ), "Initializers must be present after loading it from model_state_dict" 431 # Let's make sure consecutive `save` calls don't create dupes 432 onnx_program.save( 433 tmp_onnx_file.name, model_state=tmp_checkpoint_file.name 434 ) 435 assert ( 436 len(onnx.load(tmp_onnx_file.name).graph.initializer) == 2 437 ), "Initializers must be present after loading it from model_state_dict" 438 439 def test_fake_tensor_mode_simple_invalid_input(self): 440 class Model(torch.nn.Module): 441 def __init__(self) -> None: 442 super().__init__() 443 self.linear = torch.nn.Linear(2, 2) 444 445 def forward(self, x): 446 out = self.linear(x) 447 return out 448 449 real_model = Model() 450 real_x = torch.rand(5, 2, 2) 451 with torch.onnx.enable_fake_mode() as fake_context: 452 fake_model = Model() 453 fake_x = torch.rand(5, 2, 2) 454 455 # TODO: Split each scenario on its own test case 456 # Scenario 1: Fake model and fake input WITHOUT ExportOptions(fake_context=...) 457 with self.assertRaises(torch.onnx.OnnxExporterError): 458 export_options = ExportOptions(fake_context=None) 459 _ = torch.onnx.dynamo_export( 460 fake_model, fake_x, export_options=export_options 461 ) 462 463 # Scenario 2: Fake model and real input WITHOUT fake_context 464 with self.assertRaises(torch.onnx.OnnxExporterError): 465 export_options = ExportOptions(fake_context=None) 466 _ = torch.onnx.dynamo_export( 467 fake_model, real_x, export_options=export_options 468 ) 469 470 # Scenario 3: Real model and real input WITH fake_context 471 with self.assertRaises(torch.onnx.OnnxExporterError): 472 export_options = ExportOptions(fake_context=fake_context) 473 _ = torch.onnx.dynamo_export( 474 real_model, real_x, export_options=export_options 475 ) 476 477 # Scenario 4: Fake model and real input WITH fake_context 478 with self.assertRaises(torch.onnx.OnnxExporterError): 479 export_options = ExportOptions(fake_context=fake_context) 480 _ = torch.onnx.dynamo_export( 481 fake_model, real_x, export_options=export_options 482 ) 483 484 @pytorch_test_common.xfail( 485 error_message="Dynamic control flow is not supported at the moment." 486 ) 487 def test_fake_tensor_mode_huggingface_llama(self): 488 config = transformers.LlamaConfig( 489 vocab_size=8096, hidden_size=256, num_hidden_layers=2, num_attention_heads=2 490 ) 491 batch, seq = 4, 256 492 493 with torch.onnx.enable_fake_mode() as fake_context: 494 model = transformers.LlamaModel(config).eval() 495 input_ids = torch.randint(0, config.vocab_size, (batch, seq)) 496 attention_mask = torch.ones(batch, seq, dtype=torch.bool) 497 position_ids = torch.arange(0, seq, dtype=torch.long) 498 position_ids = position_ids.unsqueeze(0).view(-1, seq) 499 500 export_options = torch.onnx.ExportOptions(fake_context=fake_context) 501 onnx_program = torch.onnx.dynamo_export( 502 model, 503 input_ids=input_ids, 504 attention_mask=attention_mask, 505 position_ids=position_ids, 506 export_options=export_options, 507 ) 508 onnx.checker.check_model(onnx_program.model_proto) 509 onnx.shape_inference.infer_shapes(onnx_program.model_proto) 510 511 @pytorch_test_common.xfail( 512 error_message="Dynamic control flow is not supported at the moment." 513 ) 514 def test_fake_tensor_mode_huggingface_tiiuae_falcon(self): 515 config = transformers.FalconConfig() 516 batch, seq = 4, 256 517 518 with torch.onnx.enable_fake_mode() as fake_context: 519 model = transformers.FalconModel(config).eval() 520 input_ids = torch.randint(0, config.vocab_size, (batch, seq)) 521 attention_mask = torch.ones(batch, seq, dtype=torch.bool) 522 523 export_options = torch.onnx.ExportOptions(fake_context=fake_context) 524 onnx_program = torch.onnx.dynamo_export( 525 model, 526 input_ids=input_ids, 527 attention_mask=attention_mask, 528 export_options=export_options, 529 ) 530 onnx.checker.check_model(onnx_program.model_proto) 531 onnx.shape_inference.infer_shapes(onnx_program.model_proto) 532 533 def test_exported_program_torch_distributions_normal_Normal(self): 534 class Model(torch.nn.Module): 535 def __init__(self) -> None: 536 self.normal = torch.distributions.normal.Normal(0, 1) 537 super().__init__() 538 539 def forward(self, x): 540 return self.normal.sample(x.shape) 541 542 x = torch.randn(2, 3) 543 with torch.no_grad(): 544 exported_program = torch.export.export(Model(), args=(x,)) 545 _ = torch.onnx.dynamo_export( 546 exported_program, 547 x, 548 ) 549 550 def test_aten_div_no_opmath_type_promotion(self): 551 class Model(torch.nn.Module): 552 def forward(self, input): 553 return input / 2 554 555 model = Model() 556 input = torch.randn(3, 5, requires_grad=True, dtype=torch.float16) 557 558 model_proto = torch.onnx.dynamo_export(model, input).model_proto 559 model_proto = onnx.inliner.inline_local_functions(model_proto) 560 div_node = next( 561 node for node in model_proto.graph.node if node.op_type == "Div" 562 ) 563 # The input of Div node should be the input of the model, 564 # with no Cast node in between. 565 self.assertEqual(div_node.input[0], model_proto.graph.input[0].name) 566 567 @common_utils.parametrize( 568 "float8_type", 569 [ 570 common_utils.subtest( 571 torch.float8_e5m2, 572 name="torch_float8_e5m2", 573 ), 574 common_utils.subtest( 575 torch.float8_e5m2fnuz, 576 name="torch_float8_e5m2fnuz", 577 ), 578 common_utils.subtest( 579 torch.float8_e4m3fn, 580 name="torch_float8_e4m3fn", 581 ), 582 common_utils.subtest( 583 torch.float8_e4m3fnuz, 584 name="torch_float8_e4m3fnuz", 585 ), 586 ], 587 ) 588 def test_float8_support(self, float8_type): 589 class Float8Module(torch.nn.Module): 590 def forward(self, input: torch.Tensor): 591 input = input.to(float8_type) 592 return input + torch.tensor(1.0, dtype=float8_type) 593 594 # NOTE: shape inference error raised in optimizer due to unsupported dtype 595 with self.assertWarnsOnceRegex( 596 UserWarning, "ONNXScript optimizer failed. Skipping optimization." 597 ): 598 _ = torch.onnx.dynamo_export(Float8Module(), torch.randn(1, 2, 3, 4)) 599 600 def test_export_with_logging_logger(self): 601 logger = logging.getLogger(__name__) 602 603 class LoggingLoggerModule(torch.nn.Module): 604 def forward(self, x): 605 logger.log("abc") 606 return x + 1 607 608 input = torch.randn(2, 3) 609 model = LoggingLoggerModule() 610 _ = torch.onnx.dynamo_export(model, input) 611 612 def test_export_with_hf_logging_logger(self): 613 logger = transformers.utils.logging.get_logger(__name__) 614 615 class HFLoggingLoggerModule(torch.nn.Module): 616 def forward(self, x): 617 logger.warning_once("abc") 618 return x + 1 619 620 input = torch.randn(2, 3) 621 model = HFLoggingLoggerModule() 622 _ = torch.onnx.dynamo_export(model, input) 623 624 def test_checkpoint_cast(self): 625 model_id = "openai/whisper-large-v3" 626 feature_extractor = transformers.WhisperFeatureExtractor(feature_size=128) 627 batch = 4 628 629 with torch.onnx.enable_fake_mode() as ctx: 630 model = transformers.AutoModelForSpeechSeq2Seq.from_pretrained( 631 model_id, low_cpu_mem_usage=False, use_safetensors=False 632 ) 633 input = { 634 "input_features": torch.randn( 635 ( 636 batch, 637 feature_extractor.feature_size, 638 feature_extractor.nb_max_frames, 639 ) 640 ), 641 "decoder_input_ids": torch.tensor([[1, 1]]) * 8001, 642 "return_dict": False, 643 } 644 645 export_options = torch.onnx.ExportOptions(fake_context=ctx) 646 onnx_program = torch.onnx.dynamo_export( 647 model, **input, export_options=export_options 648 ) 649 with tempfile.NamedTemporaryFile(suffix=".onnx") as tmp_onnx_file: 650 onnx_program.save(tmp_onnx_file.name) 651 onnx.checker.check_model(tmp_onnx_file.name, full_check=True) 652 653 @common_utils.parametrize( 654 "include_initializer", 655 [ 656 common_utils.subtest( 657 True, 658 name="include_initializer", 659 ), 660 common_utils.subtest( 661 False, 662 name="dont_include_initializer", 663 ), 664 ], 665 ) 666 @common_utils.parametrize( 667 "use_fake_mode", 668 [ 669 common_utils.subtest( 670 True, 671 name="use_fake_mode", 672 ), 673 common_utils.subtest( 674 False, 675 name="no_fake_mode", 676 ), 677 ], 678 ) 679 @common_utils.parametrize( 680 "use_exported_program", 681 [ 682 common_utils.subtest( 683 True, 684 name="use_exported_program", 685 ), 686 common_utils.subtest( 687 False, 688 name="no_exported_program", 689 ), 690 ], 691 ) 692 def test_save_with_without_initializer( 693 self, include_initializer, use_fake_mode, use_exported_program 694 ): 695 class MNISTModel(nn.Module): 696 def __init__(self) -> None: 697 super().__init__() 698 self.conv1 = nn.Conv2d(1, 32, 3, 1, bias=False) 699 self.conv2 = nn.Conv2d(32, 64, 3, 1, bias=False) 700 self.fc1 = nn.Linear(9216, 128, bias=False) 701 self.fc2 = nn.Linear(128, 10, bias=False) 702 703 def forward(self, tensor_x: torch.Tensor): 704 tensor_x = self.conv1(tensor_x) 705 tensor_x = F.sigmoid(tensor_x) 706 tensor_x = self.conv2(tensor_x) 707 tensor_x = F.sigmoid(tensor_x) 708 tensor_x = F.max_pool2d(tensor_x, 2) 709 tensor_x = torch.flatten(tensor_x, 1) 710 tensor_x = self.fc1(tensor_x) 711 tensor_x = F.sigmoid(tensor_x) 712 tensor_x = self.fc2(tensor_x) 713 output = F.log_softmax(tensor_x, dim=1) 714 return output 715 716 state_dict = MNISTModel().state_dict() 717 if use_fake_mode: 718 with torch.onnx.enable_fake_mode() as ctx: 719 model = MNISTModel() 720 tensor_x = torch.rand((64, 1, 28, 28), dtype=torch.float32) 721 if use_exported_program: 722 model = torch.export.export(model, args=(tensor_x,)) 723 export_options = torch.onnx.ExportOptions(fake_context=ctx) 724 else: 725 model = MNISTModel() 726 tensor_x = torch.rand((64, 1, 28, 28), dtype=torch.float32) 727 if use_exported_program: 728 model = torch.export.export(model, args=(tensor_x,)) 729 export_options = torch.onnx.ExportOptions() 730 731 onnx_program = torch.onnx.dynamo_export( 732 model, tensor_x, export_options=export_options 733 ) 734 onnx_program.apply_weights(state_dict) 735 with tempfile.NamedTemporaryFile(suffix=".onnx") as tmp_onnx_file: 736 onnx_program.save( 737 tmp_onnx_file.name, 738 include_initializers=include_initializer, 739 ) 740 onnx_model = onnx.load(tmp_onnx_file.name) 741 self.assertEqual( 742 (include_initializer and len(onnx_model.graph.initializer) > 0) 743 or (not include_initializer and len(onnx_model.graph.initializer) == 0), 744 True, 745 ) 746 747 def test_export_with_print(self): 748 class PrintModule(torch.nn.Module): 749 def forward(self, x): 750 print("abc") 751 return x + 1 752 753 input = torch.randn(2, 3) 754 model = PrintModule() 755 _ = torch.onnx.dynamo_export(model, input) 756 757 758if __name__ == "__main__": 759 common_utils.run_tests() 760