xref: /aosp_15_r20/external/pytorch/test/onnx/test_fx_to_onnx.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: onnx"]
2from __future__ import annotations
3
4import logging
5import tempfile
6from typing import Mapping, Tuple, TYPE_CHECKING
7
8import onnx
9import onnx.inliner
10
11import pytorch_test_common
12import transformers  # type: ignore[import]
13
14import torch
15from torch import nn
16from torch._subclasses import fake_tensor
17from torch.nn import functional as F
18from torch.onnx import dynamo_export, ExportOptions
19from torch.onnx._internal.fx import diagnostics, registration
20from torch.testing._internal import common_utils
21
22
23if TYPE_CHECKING:
24    from torch.onnx._internal.diagnostics import infra
25
26
27def assert_has_diagnostics(
28    diagnostic_context: diagnostics.DiagnosticContext,
29    rule: infra.Rule,
30    level: infra.Level,
31    expected_node: str,
32):
33    rule_level_pairs = (rule.id, level.name.lower())
34    sarif_log = diagnostic_context.sarif_log()
35    actual_results = []
36    for run in sarif_log.runs:
37        if run.results is None:
38            continue
39        for result in run.results:
40            id_level_pair = (result.rule_id, result.level)
41            actual_results.append(id_level_pair)
42            if (
43                rule_level_pairs == id_level_pair
44                and result.message.text
45                and result.message.markdown
46                and expected_node in result.message.text
47            ):
48                return
49
50    raise AssertionError(
51        f"Expected diagnostic results of rule id and level pair {rule_level_pairs} "
52        f"not found with expected error node {expected_node} and "
53        f"Actual diagnostic results: {actual_results}"
54    )
55
56
57@common_utils.instantiate_parametrized_tests
58class TestFxToOnnx(pytorch_test_common.ExportTestCase):
59    def setUp(self):
60        super().setUp()
61        self.export_options = ExportOptions()
62
63    def tearDown(self):
64        super().tearDown()
65
66    def test_simple_function(self):
67        def func(x):
68            y = x + 1
69            z = y.relu()
70            return (y, z)
71
72        _ = dynamo_export(
73            func, torch.randn(1, 1, 2), export_options=self.export_options
74        )
75
76    def test_empty(self):
77        # Since `torch.empty` returns tensor with uninitialized data, we cannot
78        # test this under `test_fx_to_onnx_with_onnxruntime.py` with result comparison.
79        def func(x):
80            return torch.empty(x.size(), dtype=torch.int64)
81
82        tensor_x = torch.randn(1, 1, 2)
83        _ = dynamo_export(func, tensor_x, export_options=self.export_options)
84
85    def test_args_used_for_export_is_not_converted_to_fake_tensors(self):
86        def func(x, y):
87            return x + y
88
89        tensor_x = torch.randn(1, 1, 2)
90        tensor_y = torch.randn(1, 1, 2)
91        _ = dynamo_export(func, tensor_x, tensor_y, export_options=self.export_options)
92        self.assertNotIsInstance(tensor_x, fake_tensor.FakeTensor)
93        self.assertNotIsInstance(tensor_y, fake_tensor.FakeTensor)
94
95    @common_utils.parametrize(
96        "diagnostic_rule",
97        [
98            common_utils.subtest(
99                diagnostics.rules.find_opschema_matched_symbolic_function,
100                name="optional_inputs",
101            ),
102        ],
103    )
104    def test_mnist_exported_with_no_warnings(self, diagnostic_rule):
105        class MNISTModel(nn.Module):
106            def __init__(self) -> None:
107                super().__init__()
108                self.conv1 = nn.Conv2d(1, 32, 3, 1, bias=False)
109                self.conv2 = nn.Conv2d(32, 64, 3, 1, bias=False)
110                self.fc1 = nn.Linear(9216, 128, bias=False)
111                self.fc2 = nn.Linear(128, 10, bias=False)
112
113            def forward(self, tensor_x: torch.Tensor):
114                tensor_x = self.conv1(tensor_x)
115                tensor_x = F.sigmoid(tensor_x)
116                tensor_x = self.conv2(tensor_x)
117                tensor_x = F.sigmoid(tensor_x)
118                tensor_x = F.max_pool2d(tensor_x, 2)
119                tensor_x = torch.flatten(tensor_x, 1)
120                tensor_x = self.fc1(tensor_x)
121                tensor_x = F.sigmoid(tensor_x)
122                tensor_x = self.fc2(tensor_x)
123                output = F.log_softmax(tensor_x, dim=1)
124                return output
125
126        tensor_x = torch.rand((64, 1, 28, 28), dtype=torch.float32)
127        onnx_program = dynamo_export(MNISTModel(), tensor_x)
128
129        assert_has_diagnostics(
130            onnx_program.diagnostic_context,
131            diagnostic_rule,
132            diagnostics.levels.NONE,
133            expected_node="aten.convolution.default",
134        )
135
136    def test_trace_only_op_with_evaluator(self):
137        model_input = torch.tensor([[1.0, 2.0, 3.0], [1.0, 1.0, 2.0]])
138
139        class ArgminArgmaxModel(torch.nn.Module):
140            def forward(self, input):
141                return (
142                    torch.argmin(input),
143                    torch.argmax(input),
144                    torch.argmin(input, keepdim=True),
145                    torch.argmax(input, keepdim=True),
146                    torch.argmin(input, dim=0, keepdim=True),
147                    torch.argmax(input, dim=1, keepdim=True),
148                )
149
150        _ = dynamo_export(
151            ArgminArgmaxModel(), model_input, export_options=self.export_options
152        )
153
154    def test_multiple_outputs_op_with_evaluator(self):
155        class TopKModel(torch.nn.Module):
156            def forward(self, x):
157                values, _ = torch.topk(x, 3)
158                return torch.sum(values)
159
160        x = torch.arange(1.0, 6.0, requires_grad=True)
161
162        _ = dynamo_export(TopKModel(), x, export_options=self.export_options)
163
164    def test_unsupported_function_schema_raises_diagnostic_warning_when_found_nearest_match(
165        self,
166    ):
167        class TraceModel(torch.nn.Module):
168            def forward(self, input):
169                return input.new_zeros(())
170
171        x = torch.randn((2, 3), dtype=torch.float32)
172        onnx_program = dynamo_export(TraceModel(), x)
173
174        assert_has_diagnostics(
175            onnx_program.diagnostic_context,
176            diagnostics.rules.find_opschema_matched_symbolic_function,
177            diagnostics.levels.WARNING,
178            expected_node="aten.new_zeros.default",
179        )
180
181    def test_perfect_match_on_sequence_and_bool_attributes(
182        self,
183    ):
184        class TraceModel(torch.nn.Module):
185            def __init__(self) -> None:
186                super().__init__()
187                self.conv2 = torch.nn.Conv2d(
188                    16, 33, (3, 5), stride=(2, 1), padding=(4, 2), dilation=(3, 1)
189                )
190
191            def forward(self, input):
192                return self.conv2(input)
193
194        x = torch.randn(20, 16, 50, 50)
195        onnx_program = dynamo_export(TraceModel(), x)
196        assert_has_diagnostics(
197            onnx_program.diagnostic_context,
198            diagnostics.rules.find_opschema_matched_symbolic_function,
199            diagnostics.levels.NONE,
200            expected_node="aten.convolution.default",
201        )
202
203    def test_aten_clone_does_not_raise_warning_of_lack_of_memory_format(self):
204        class CustomModule(torch.nn.Module):
205            def forward(self, input):
206                return torch.ops.aten.clone(input, memory_format=torch.preserve_format)
207
208        x = torch.tensor(3)
209        onnx_program = dynamo_export(CustomModule(), x)
210        assert_has_diagnostics(
211            onnx_program.diagnostic_context,
212            diagnostics.rules.find_opschema_matched_symbolic_function,
213            diagnostics.levels.NONE,
214            expected_node="aten.clone.default",
215        )
216
217    def test_missing_complex_onnx_variant_raises_errors_in_dispatcher(self):
218        registry = torch.onnx.OnnxRegistry()
219
220        # NOTE: simulate unsupported nodes
221        aten_mul_tensor = registration.OpName.from_name_parts(
222            namespace="aten", op_name="mul", overload="Tensor"
223        )
224
225        # Only keep real aten.mul to test missing complex aten.mul
226        registry._registry[aten_mul_tensor] = [
227            onnx_func
228            for onnx_func in registry._registry[aten_mul_tensor]
229            if not onnx_func.is_complex
230        ]
231
232        class TraceModel(torch.nn.Module):
233            def forward(self, input):
234                return torch.ops.aten.mul.Tensor(input, input)
235
236        x = torch.tensor([1 + 2j, 3 + 4j], dtype=torch.complex64)
237
238        with self.assertRaises(torch.onnx.OnnxExporterError) as e:
239            torch.onnx.dynamo_export(
240                TraceModel(),
241                x,
242                export_options=torch.onnx.ExportOptions(onnx_registry=registry),
243            )
244
245    def test_symbolic_shape_of_values_inside_function_is_exported_as_graph_value_info(
246        self,
247    ):
248        class SubModule(torch.nn.Module):
249            def forward(self, x, y, bias):
250                output = x @ y
251                return output + bias
252
253        class Module(torch.nn.Module):
254            def __init__(self) -> None:
255                super().__init__()
256                self.submodule = SubModule()
257
258            def forward(self, x, y, bias):
259                return self.submodule(x, y, bias)
260
261        x = torch.randn(2, 3)
262        y = torch.randn(3, 4)
263        bias = torch.randn(4)
264        onnx_program = torch.onnx.dynamo_export(
265            Module(),
266            x,
267            y,
268            bias,
269            export_options=torch.onnx.ExportOptions(dynamic_shapes=True),
270        )
271        model_proto = onnx_program.model_proto
272
273        # Assert value_info for values inside local function can be retrieved
274        def _assert_node_outputs_has_value_info(
275            node: onnx.NodeProto,
276            value_infos: Mapping[str, onnx.ValueInfoProto],
277            local_functions: Mapping[Tuple[str, str], onnx.FunctionProto],
278            exclude_names_in_value_info,
279            function_id: str = "",
280        ):
281            for output in node.output:
282                name = f"{function_id}/{output}" if function_id else output
283                if name not in exclude_names_in_value_info:
284                    self.assertIn(name, value_infos)
285            if node.domain.startswith("pkg.onnxscript.torch_lib"):
286                # No shape info available for values inside torchlib functions.
287                return
288            if (
289                function := local_functions.get((node.domain, node.op_type))
290            ) is not None:
291                for node in function.node:
292                    function_id = f"{function.domain}::{function.name}"
293                    _assert_node_outputs_has_value_info(
294                        node,
295                        value_infos,
296                        local_functions,
297                        exclude_names_in_value_info,
298                        function_id,
299                    )
300
301        type_infos = {vi.name: vi for vi in model_proto.graph.value_info}
302        functions = {(f.domain, f.name): f for f in model_proto.functions}
303        # NOTE: inputs, outputs, and initializers are not included in value_info spec
304        exclude_names_in_value_info = (
305            [input.name for input in model_proto.graph.input]
306            + [output.name for output in model_proto.graph.output]
307            + [init.name for init in model_proto.graph.initializer]
308        )
309        for node in model_proto.graph.node:
310            _assert_node_outputs_has_value_info(
311                node, type_infos, functions, exclude_names_in_value_info
312            )
313
314    def test_dynamo_export_retains_readable_parameter_and_buffer_names(self):
315        class SubModule(torch.nn.Module):
316            def __init__(self) -> None:
317                super().__init__()
318                self.conv2 = nn.Conv2d(32, 64, 3, 1, bias=False)
319                self.fc1 = nn.Linear(9216, 128, bias=False)
320                self.buffer = torch.nn.Buffer(torch.randn(1, 128))
321
322            def forward(self, tensor_x: torch.Tensor):
323                tensor_x = self.conv2(tensor_x)
324                tensor_x = F.sigmoid(tensor_x)
325                tensor_x = F.max_pool2d(tensor_x, 2)
326                tensor_x = torch.flatten(tensor_x, 1)
327                tensor_x = self.fc1(tensor_x)
328                tensor_x = tensor_x + self.buffer
329                tensor_x = F.sigmoid(tensor_x)
330                return tensor_x
331
332        class MNISTModel(nn.Module):
333            def __init__(self) -> None:
334                super().__init__()
335                self.conv1 = nn.Conv2d(1, 32, 3, 1, bias=False)
336                self.submodule = SubModule()
337                self.fc2 = nn.Linear(128, 10, bias=False)
338
339            def forward(self, tensor_x: torch.Tensor):
340                tensor_x = self.conv1(tensor_x)
341                tensor_x = F.sigmoid(tensor_x)
342                tensor_x = self.submodule(tensor_x)
343                tensor_x = self.fc2(tensor_x)
344                output = F.log_softmax(tensor_x, dim=1)
345                return output
346
347        tensor_x = torch.rand((64, 1, 28, 28), dtype=torch.float32)
348
349        model = MNISTModel()
350        onnx_program = torch.onnx.dynamo_export(model, tensor_x)
351        model_proto = onnx_program.model_proto
352
353        # NOTE: initializers could be optimized away by onnx optimizer
354        onnx_initilizers = {init.name for init in model_proto.graph.initializer}
355        torch_weights = {*model.state_dict().keys()}
356        self.assertTrue(onnx_initilizers.issubset(torch_weights))
357
358    @common_utils.parametrize(
359        "checkpoint_type",
360        [
361            common_utils.subtest(
362                "state_dict",
363                name="state_dict",
364            ),
365            common_utils.subtest(
366                "state_dict",
367                name="checkpoint_file",
368            ),
369        ],
370    )
371    def test_fake_tensor_mode_simple(self, checkpoint_type):
372        class Model(torch.nn.Module):
373            def __init__(self) -> None:
374                super().__init__()
375                self.linear = torch.nn.Linear(2, 2)
376
377            def forward(self, x):
378                out = self.linear(x)
379                return out
380
381        with torch.onnx.enable_fake_mode() as fake_context:
382            x = torch.rand(5, 2, 2)
383            model = Model()
384            export_options = ExportOptions(fake_context=fake_context)
385            onnx_program = torch.onnx.dynamo_export(
386                model, x, export_options=export_options
387            )
388
389        assert (
390            onnx_program is not None
391        ), "ONNXProgram must be created on successful export"
392        assert (
393            onnx_program.model_proto is not None
394        ), "A model protobuf must be created on a successful export"
395        onnx.checker.check_model(onnx_program.model_proto, full_check=True)
396        assert (
397            len(onnx_program.model_proto.graph.initializer) == 0
398        ), "Initializers cannot exist when fake mode is enabled"
399
400        if checkpoint_type == "state_dict":
401            # Variant 1: Save ONNX proto using Model's state_dict()
402            with tempfile.NamedTemporaryFile(suffix=".onnx") as tmp_onnx_file:
403                model_state_dict = (
404                    Model().state_dict()
405                )  # Create a state_dict for testing
406                onnx_program.save(tmp_onnx_file.name, model_state=model_state_dict)
407                assert (
408                    len(onnx.load(tmp_onnx_file.name).graph.initializer) == 2
409                ), "Initializers must be present after loading it from model_state_dict"
410                # Let's make sure consecutive `save` calls don't create dupes
411                onnx_program.save(tmp_onnx_file.name, model_state=model_state_dict)
412                assert (
413                    len(onnx.load(tmp_onnx_file.name).graph.initializer) == 2
414                ), "Initializers must be present after loading it from model_state_dict"
415        elif checkpoint_type == "checkpoint_file":
416            # Variant 2: Save ONNX proto using Model checkpoint file
417            with tempfile.NamedTemporaryFile(
418                suffix=".onnx"
419            ) as tmp_onnx_file, tempfile.NamedTemporaryFile(
420                suffix=".pt"
421            ) as tmp_checkpoint_file:
422                torch.save(
423                    Model().state_dict(), tmp_checkpoint_file.name
424                )  # Create checkpoint file for testing
425                onnx_program.save(
426                    tmp_onnx_file.name, model_state=tmp_checkpoint_file.name
427                )
428                assert (
429                    len(onnx.load(tmp_onnx_file.name).graph.initializer) == 2
430                ), "Initializers must be present after loading it from model_state_dict"
431                # Let's make sure consecutive `save` calls don't create dupes
432                onnx_program.save(
433                    tmp_onnx_file.name, model_state=tmp_checkpoint_file.name
434                )
435                assert (
436                    len(onnx.load(tmp_onnx_file.name).graph.initializer) == 2
437                ), "Initializers must be present after loading it from model_state_dict"
438
439    def test_fake_tensor_mode_simple_invalid_input(self):
440        class Model(torch.nn.Module):
441            def __init__(self) -> None:
442                super().__init__()
443                self.linear = torch.nn.Linear(2, 2)
444
445            def forward(self, x):
446                out = self.linear(x)
447                return out
448
449        real_model = Model()
450        real_x = torch.rand(5, 2, 2)
451        with torch.onnx.enable_fake_mode() as fake_context:
452            fake_model = Model()
453            fake_x = torch.rand(5, 2, 2)
454
455            # TODO: Split each scenario on its own test case
456            # Scenario 1: Fake model and fake input WITHOUT ExportOptions(fake_context=...)
457            with self.assertRaises(torch.onnx.OnnxExporterError):
458                export_options = ExportOptions(fake_context=None)
459                _ = torch.onnx.dynamo_export(
460                    fake_model, fake_x, export_options=export_options
461                )
462
463            # Scenario 2: Fake model and real input WITHOUT fake_context
464            with self.assertRaises(torch.onnx.OnnxExporterError):
465                export_options = ExportOptions(fake_context=None)
466                _ = torch.onnx.dynamo_export(
467                    fake_model, real_x, export_options=export_options
468                )
469
470            # Scenario 3: Real model and real input WITH fake_context
471            with self.assertRaises(torch.onnx.OnnxExporterError):
472                export_options = ExportOptions(fake_context=fake_context)
473                _ = torch.onnx.dynamo_export(
474                    real_model, real_x, export_options=export_options
475                )
476
477            # Scenario 4: Fake model and real input WITH fake_context
478            with self.assertRaises(torch.onnx.OnnxExporterError):
479                export_options = ExportOptions(fake_context=fake_context)
480                _ = torch.onnx.dynamo_export(
481                    fake_model, real_x, export_options=export_options
482                )
483
484    @pytorch_test_common.xfail(
485        error_message="Dynamic control flow is not supported at the moment."
486    )
487    def test_fake_tensor_mode_huggingface_llama(self):
488        config = transformers.LlamaConfig(
489            vocab_size=8096, hidden_size=256, num_hidden_layers=2, num_attention_heads=2
490        )
491        batch, seq = 4, 256
492
493        with torch.onnx.enable_fake_mode() as fake_context:
494            model = transformers.LlamaModel(config).eval()
495            input_ids = torch.randint(0, config.vocab_size, (batch, seq))
496            attention_mask = torch.ones(batch, seq, dtype=torch.bool)
497            position_ids = torch.arange(0, seq, dtype=torch.long)
498            position_ids = position_ids.unsqueeze(0).view(-1, seq)
499
500            export_options = torch.onnx.ExportOptions(fake_context=fake_context)
501            onnx_program = torch.onnx.dynamo_export(
502                model,
503                input_ids=input_ids,
504                attention_mask=attention_mask,
505                position_ids=position_ids,
506                export_options=export_options,
507            )
508            onnx.checker.check_model(onnx_program.model_proto)
509            onnx.shape_inference.infer_shapes(onnx_program.model_proto)
510
511    @pytorch_test_common.xfail(
512        error_message="Dynamic control flow is not supported at the moment."
513    )
514    def test_fake_tensor_mode_huggingface_tiiuae_falcon(self):
515        config = transformers.FalconConfig()
516        batch, seq = 4, 256
517
518        with torch.onnx.enable_fake_mode() as fake_context:
519            model = transformers.FalconModel(config).eval()
520            input_ids = torch.randint(0, config.vocab_size, (batch, seq))
521            attention_mask = torch.ones(batch, seq, dtype=torch.bool)
522
523            export_options = torch.onnx.ExportOptions(fake_context=fake_context)
524            onnx_program = torch.onnx.dynamo_export(
525                model,
526                input_ids=input_ids,
527                attention_mask=attention_mask,
528                export_options=export_options,
529            )
530            onnx.checker.check_model(onnx_program.model_proto)
531            onnx.shape_inference.infer_shapes(onnx_program.model_proto)
532
533    def test_exported_program_torch_distributions_normal_Normal(self):
534        class Model(torch.nn.Module):
535            def __init__(self) -> None:
536                self.normal = torch.distributions.normal.Normal(0, 1)
537                super().__init__()
538
539            def forward(self, x):
540                return self.normal.sample(x.shape)
541
542        x = torch.randn(2, 3)
543        with torch.no_grad():
544            exported_program = torch.export.export(Model(), args=(x,))
545            _ = torch.onnx.dynamo_export(
546                exported_program,
547                x,
548            )
549
550    def test_aten_div_no_opmath_type_promotion(self):
551        class Model(torch.nn.Module):
552            def forward(self, input):
553                return input / 2
554
555        model = Model()
556        input = torch.randn(3, 5, requires_grad=True, dtype=torch.float16)
557
558        model_proto = torch.onnx.dynamo_export(model, input).model_proto
559        model_proto = onnx.inliner.inline_local_functions(model_proto)
560        div_node = next(
561            node for node in model_proto.graph.node if node.op_type == "Div"
562        )
563        # The input of Div node should be the input of the model,
564        # with no Cast node in between.
565        self.assertEqual(div_node.input[0], model_proto.graph.input[0].name)
566
567    @common_utils.parametrize(
568        "float8_type",
569        [
570            common_utils.subtest(
571                torch.float8_e5m2,
572                name="torch_float8_e5m2",
573            ),
574            common_utils.subtest(
575                torch.float8_e5m2fnuz,
576                name="torch_float8_e5m2fnuz",
577            ),
578            common_utils.subtest(
579                torch.float8_e4m3fn,
580                name="torch_float8_e4m3fn",
581            ),
582            common_utils.subtest(
583                torch.float8_e4m3fnuz,
584                name="torch_float8_e4m3fnuz",
585            ),
586        ],
587    )
588    def test_float8_support(self, float8_type):
589        class Float8Module(torch.nn.Module):
590            def forward(self, input: torch.Tensor):
591                input = input.to(float8_type)
592                return input + torch.tensor(1.0, dtype=float8_type)
593
594        # NOTE: shape inference error raised in optimizer due to unsupported dtype
595        with self.assertWarnsOnceRegex(
596            UserWarning, "ONNXScript optimizer failed. Skipping optimization."
597        ):
598            _ = torch.onnx.dynamo_export(Float8Module(), torch.randn(1, 2, 3, 4))
599
600    def test_export_with_logging_logger(self):
601        logger = logging.getLogger(__name__)
602
603        class LoggingLoggerModule(torch.nn.Module):
604            def forward(self, x):
605                logger.log("abc")
606                return x + 1
607
608        input = torch.randn(2, 3)
609        model = LoggingLoggerModule()
610        _ = torch.onnx.dynamo_export(model, input)
611
612    def test_export_with_hf_logging_logger(self):
613        logger = transformers.utils.logging.get_logger(__name__)
614
615        class HFLoggingLoggerModule(torch.nn.Module):
616            def forward(self, x):
617                logger.warning_once("abc")
618                return x + 1
619
620        input = torch.randn(2, 3)
621        model = HFLoggingLoggerModule()
622        _ = torch.onnx.dynamo_export(model, input)
623
624    def test_checkpoint_cast(self):
625        model_id = "openai/whisper-large-v3"
626        feature_extractor = transformers.WhisperFeatureExtractor(feature_size=128)
627        batch = 4
628
629        with torch.onnx.enable_fake_mode() as ctx:
630            model = transformers.AutoModelForSpeechSeq2Seq.from_pretrained(
631                model_id, low_cpu_mem_usage=False, use_safetensors=False
632            )
633            input = {
634                "input_features": torch.randn(
635                    (
636                        batch,
637                        feature_extractor.feature_size,
638                        feature_extractor.nb_max_frames,
639                    )
640                ),
641                "decoder_input_ids": torch.tensor([[1, 1]]) * 8001,
642                "return_dict": False,
643            }
644
645        export_options = torch.onnx.ExportOptions(fake_context=ctx)
646        onnx_program = torch.onnx.dynamo_export(
647            model, **input, export_options=export_options
648        )
649        with tempfile.NamedTemporaryFile(suffix=".onnx") as tmp_onnx_file:
650            onnx_program.save(tmp_onnx_file.name)
651            onnx.checker.check_model(tmp_onnx_file.name, full_check=True)
652
653    @common_utils.parametrize(
654        "include_initializer",
655        [
656            common_utils.subtest(
657                True,
658                name="include_initializer",
659            ),
660            common_utils.subtest(
661                False,
662                name="dont_include_initializer",
663            ),
664        ],
665    )
666    @common_utils.parametrize(
667        "use_fake_mode",
668        [
669            common_utils.subtest(
670                True,
671                name="use_fake_mode",
672            ),
673            common_utils.subtest(
674                False,
675                name="no_fake_mode",
676            ),
677        ],
678    )
679    @common_utils.parametrize(
680        "use_exported_program",
681        [
682            common_utils.subtest(
683                True,
684                name="use_exported_program",
685            ),
686            common_utils.subtest(
687                False,
688                name="no_exported_program",
689            ),
690        ],
691    )
692    def test_save_with_without_initializer(
693        self, include_initializer, use_fake_mode, use_exported_program
694    ):
695        class MNISTModel(nn.Module):
696            def __init__(self) -> None:
697                super().__init__()
698                self.conv1 = nn.Conv2d(1, 32, 3, 1, bias=False)
699                self.conv2 = nn.Conv2d(32, 64, 3, 1, bias=False)
700                self.fc1 = nn.Linear(9216, 128, bias=False)
701                self.fc2 = nn.Linear(128, 10, bias=False)
702
703            def forward(self, tensor_x: torch.Tensor):
704                tensor_x = self.conv1(tensor_x)
705                tensor_x = F.sigmoid(tensor_x)
706                tensor_x = self.conv2(tensor_x)
707                tensor_x = F.sigmoid(tensor_x)
708                tensor_x = F.max_pool2d(tensor_x, 2)
709                tensor_x = torch.flatten(tensor_x, 1)
710                tensor_x = self.fc1(tensor_x)
711                tensor_x = F.sigmoid(tensor_x)
712                tensor_x = self.fc2(tensor_x)
713                output = F.log_softmax(tensor_x, dim=1)
714                return output
715
716        state_dict = MNISTModel().state_dict()
717        if use_fake_mode:
718            with torch.onnx.enable_fake_mode() as ctx:
719                model = MNISTModel()
720                tensor_x = torch.rand((64, 1, 28, 28), dtype=torch.float32)
721                if use_exported_program:
722                    model = torch.export.export(model, args=(tensor_x,))
723                export_options = torch.onnx.ExportOptions(fake_context=ctx)
724        else:
725            model = MNISTModel()
726            tensor_x = torch.rand((64, 1, 28, 28), dtype=torch.float32)
727            if use_exported_program:
728                model = torch.export.export(model, args=(tensor_x,))
729            export_options = torch.onnx.ExportOptions()
730
731        onnx_program = torch.onnx.dynamo_export(
732            model, tensor_x, export_options=export_options
733        )
734        onnx_program.apply_weights(state_dict)
735        with tempfile.NamedTemporaryFile(suffix=".onnx") as tmp_onnx_file:
736            onnx_program.save(
737                tmp_onnx_file.name,
738                include_initializers=include_initializer,
739            )
740            onnx_model = onnx.load(tmp_onnx_file.name)
741            self.assertEqual(
742                (include_initializer and len(onnx_model.graph.initializer) > 0)
743                or (not include_initializer and len(onnx_model.graph.initializer) == 0),
744                True,
745            )
746
747    def test_export_with_print(self):
748        class PrintModule(torch.nn.Module):
749            def forward(self, x):
750                print("abc")
751                return x + 1
752
753        input = torch.randn(2, 3)
754        model = PrintModule()
755        _ = torch.onnx.dynamo_export(model, input)
756
757
758if __name__ == "__main__":
759    common_utils.run_tests()
760