xref: /aosp_15_r20/external/pytorch/test/onnx/test_fx_to_onnx_with_onnxruntime.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: onnx"]
2from __future__ import annotations
3
4import itertools
5import math
6import operator
7import os
8import tempfile
9import unittest
10from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Type
11
12import onnx_test_common
13import onnxruntime  # type: ignore[import]
14import parameterized  # type: ignore[import]
15import pytorch_test_common
16import transformers  # type: ignore[import]
17
18import torch
19import torch.onnx
20from torch import nn
21from torch._subclasses import fake_tensor
22from torch.onnx._internal import _exporter_legacy
23from torch.onnx._internal.fx import (
24    diagnostics,
25    fx_symbolic_graph_extractor,
26    patcher,
27    serialization as fx_serialization,
28)
29from torch.testing._internal import common_utils
30
31
32try:
33    import torchvision  # type: ignore[import]
34
35    HAS_TORCHVISION = True
36except ImportError:
37    HAS_TORCHVISION = False
38except RuntimeError:
39    HAS_TORCHVISION = False
40skip_if_no_torchvision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision")
41
42
43def _parameterized_class_attrs_and_values():
44    input_values = []
45    input_values.extend(
46        itertools.product(
47            (True, False),
48            (pytorch_test_common.TorchModelType.TORCH_NN_MODULE,),
49        )
50    )
51    return {
52        "attrs": ["dynamic_shapes", "model_type"],
53        "input_values": input_values,
54    }
55
56
57def _parameterize_class_name(cls: Type, idx: int, input_dicts: Mapping[Any, Any]):
58    """Combine class name with the parameterized arguments.
59
60    This function is passed to `parameterized.parameterized_class` as the
61    `class_name_func` argument.
62    """
63    suffixes = []
64    for k, v in input_dicts.items():
65        suffixes.append(f"{k}_{v}")
66    return f"{cls.__name__}_{'_'.join(suffixes)}"
67
68
69@parameterized.parameterized_class(
70    **_parameterized_class_attrs_and_values(),
71    class_name_func=_parameterize_class_name,
72)
73class TestFxToOnnxWithOnnxRuntime(onnx_test_common._TestONNXRuntime):
74    dynamic_shapes: bool
75    model_type: pytorch_test_common.TorchModelType
76
77    def setUp(self):
78        super().setUp()
79        self.ort_version = onnxruntime.__version__
80
81    def test_simple_function(self):
82        class Foo(torch.nn.Module):
83            def forward(self, x):
84                # TODO(justinchuby): Replicate torch's type casting policy
85                # in the exporter for type promotion support
86                y = x + 1.0
87                z = y.relu()
88                return (y, z)
89
90        func = Foo()
91
92        tensor_x = torch.randn(1, 1, 2, dtype=torch.float32)
93
94        self.run_test_with_fx_to_onnx_exporter_and_onnx_runtime(func, (tensor_x,))
95
96    @pytorch_test_common.xfail(
97        error_message="Tracing through optional input is not supported yet",
98        reason="https://github.com/pytorch/pytorch/issues/96379",
99    )
100    def test_func_with_args_and_tensor_kwargs(self):
101        # Non-tensor optional kwargs are always folded into constant and
102        # removed from input list in Dynamo-traced graph, if its value is not provided
103        # to tracer. So for a function like
104        #   def func(x, b=1.0)
105        # here. E.g., if you first Dynamo-trace the model with arguments (x,),
106        # and then call the traced graph with arguments (x, b=2.0), it will complain
107        # somewhere that model is called with extra args because the modified
108        # function is traced into
109        #   def forward(self, x : torch.Tensor):
110        #     add = x + 1.0;  x = None
111        #     relu = add.relu()
112        #     return (add, relu)
113        # To summarize, in order to be traced as graph input, the value of optional kwarg
114        # must be provided. Otherwise, they are treated as in-graph constants in Dynamo.
115        # Tensor optional kwargs are an exception. It is always traced as input.
116        # It is unclear if this behavior is intended or not. But in general it is bad
117        # practice to set mutable default values.
118        # `DynamoOptimizeExporter` applies a workaround by binding args and kwargs to
119        # model signature and fill in the default values of unprovided optional arguments.
120        class Foo(torch.nn.Module):
121            def forward(self, x, b=torch.tensor(1.0)):
122                y = x + b
123                z = y.relu()
124                return (y, z)
125
126        func = Foo()
127
128        tensor_x = torch.randn(1, 2, 3, dtype=torch.float32)
129
130        # Test without providing optional kwarg.
131        self.run_test_with_fx_to_onnx_exporter_and_onnx_runtime(func, (tensor_x,))
132        # Test with only positional args.
133        self.run_test_with_fx_to_onnx_exporter_and_onnx_runtime(
134            func, (tensor_x, torch.tensor(8.0))
135        )
136        # Test while specifying optional kwarg.
137        self.run_test_with_fx_to_onnx_exporter_and_onnx_runtime(
138            func, (tensor_x,), input_kwargs={"b": torch.tensor(5.0)}
139        )
140
141    @pytorch_test_common.skip_dynamic_fx_test(
142        "sympy operation tests don't need dynamic shape"
143    )
144    def test_sympy_operatons_return_numeric(self):
145        class Foo(torch.nn.Module):
146            def forward(self, x, y):
147                # TODO: add boolean tests when SymBool is supported
148                # to infer types
149                return (
150                    torch.tensor([operator.add(x.item(), y.item())]),
151                    torch.tensor([operator.sub(x.item(), y.item())]),
152                    torch.tensor([operator.mul(x.item(), y.item())]),
153                    torch.tensor([operator.truediv(x.item(), y.item())]),
154                    # This requires torch.sym_float, probably easy to lower to
155                    # ONNX but I don't know where to put it
156                    # torch.tensor([operator.floordiv(x.item(), y.item())]),
157                    # NB: abs so that the base and exponent are provably
158                    # non-negative, so we don't generate runtime asserts
159                    torch.tensor([operator.pow(abs(x.item()), abs(y.item()))]),
160                    torch.tensor([operator.abs(x.item())]),
161                    torch.tensor([operator.neg(x.item())]),
162                    torch.tensor([math.ceil(x.item())]),
163                    torch.tensor([math.floor(x.item())]),
164                )
165
166        func = Foo()
167
168        x = torch.randn(1, dtype=torch.float32)
169        y = torch.randn(1, dtype=torch.float32)
170        self.run_test_with_fx_to_onnx_exporter_and_onnx_runtime(
171            func,
172            (
173                x,
174                y,
175            ),
176        )
177
178    @pytorch_test_common.xfail(
179        error_message="Model inputs incompatible with the format that was exported",
180        reason="https://github.com/pytorch/pytorch/issues/99534",
181    )
182    def test_xfail_func_with_non_tensor_args(self):
183        class Foo(torch.nn.Module):
184            def forward(self, x, b=1.0):
185                y = x + b
186                z = y.relu()
187                return (y, z)
188
189        func = Foo()
190
191        tensor_x = torch.randn(1, 1, 2, dtype=torch.float32)
192
193        onnx_program = torch.onnx.dynamo_export(
194            func,
195            tensor_x,
196            8.0,
197            export_options=torch.onnx.ExportOptions(
198                dynamic_shapes=self.dynamic_shapes,
199            ),
200        )
201        onnx_test_common.assert_dynamic_shapes(onnx_program, self.dynamic_shapes)
202        onnx_format_args = onnx_program.adapt_torch_inputs_to_onnx(tensor_x, b=8.0)
203        ref_outputs = onnx_program.adapt_torch_outputs_to_onnx(func(tensor_x, 8.0))
204        ort_outputs = onnx_test_common.run_ort(onnx_program, onnx_format_args)
205        for ref_output, ort_output in zip(ref_outputs, ort_outputs):
206            torch.testing.assert_close(ref_output, torch.tensor(ort_output))
207
208        # test on different non-tensor input - xfail
209        onnx_format_args = onnx_program.adapt_torch_inputs_to_onnx(tensor_x, b=9.0)
210        ref_outputs = onnx_program.adapt_torch_outputs_to_onnx(func(tensor_x, 9.0))
211        _ = onnx_test_common.run_ort(onnx_program, onnx_format_args)
212        for ref_output, ort_output in zip(ref_outputs, ort_outputs):
213            torch.testing.assert_close(ref_output, torch.tensor(ort_output))
214
215    def test_func_with_nested_input_structure(self):
216        class Foo(torch.nn.Module):
217            def forward(
218                self,
219                x_dict: Dict[str, torch.Tensor],
220                y_tuple: Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
221                z_list: List[List[torch.Tensor]],
222            ):
223                if "a" in x_dict:
224                    x = x_dict["a"]
225                elif "b" in x_dict:
226                    x = x_dict["b"]
227                else:
228                    x = torch.randn(3)
229
230                y1, (y2, y3) = y_tuple
231
232                z = x + y1 + y2 + y3
233                for z_sub_list in z_list:
234                    z = z + torch.stack(z_sub_list).sum()
235
236                return z
237
238        func = Foo()
239
240        x_dict = {"a": torch.randn(3), "c": torch.randn(3)}
241        y_tuple = (torch.randn(3), (torch.randn(3), torch.randn(3)))
242        z_list = [
243            [torch.randn(3), torch.randn(3)],
244            [torch.randn(3), torch.randn(3), torch.randn(3)],
245        ]
246        self.run_test_with_fx_to_onnx_exporter_and_onnx_runtime(
247            func, (x_dict, y_tuple, z_list)
248        )
249
250    def test_func_with_nested_output_structure(self):
251        class Foo(torch.nn.Module):
252            def forward(self, x, y, z):
253                x = x + y
254                y = y + z
255                z = x + y
256                out1 = (x, (y, z))
257                out2 = [[x, y], [y, z]]
258                out3 = {"z": z, "x": x}
259                return out1, out2, out3
260
261        func = Foo()
262
263        x = torch.randn(3)
264        y = torch.randn(3)
265        z = torch.randn(3)
266        self.run_test_with_fx_to_onnx_exporter_and_onnx_runtime(func, (x, y, z))
267
268    def test_mnist(self):
269        class MNISTModel(nn.Module):
270            def __init__(self) -> None:
271                super().__init__()
272                self.conv1 = nn.Conv2d(1, 32, 3, 1, bias=True)
273                self.conv2 = nn.Conv2d(32, 64, 3, 1, bias=True)
274                self.fc1 = nn.Linear(9216, 128, bias=True)
275                self.fc2 = nn.Linear(128, 10, bias=True)
276
277            def forward(self, tensor_x: torch.Tensor):
278                tensor_x = self.conv1(tensor_x)
279                tensor_x = torch.sigmoid(tensor_x)
280                tensor_x = self.conv2(tensor_x)
281                tensor_x = torch.sigmoid(tensor_x)
282                tensor_x = torch.max_pool2d(tensor_x, 2)
283                tensor_x = torch.flatten(tensor_x, 1)
284                tensor_x = self.fc1(tensor_x)
285                tensor_x = torch.sigmoid(tensor_x)
286                tensor_x = self.fc2(tensor_x)
287                output = torch.log_softmax(tensor_x, dim=1)
288                return output
289
290        tensor_x = torch.rand((64, 1, 28, 28), dtype=torch.float32)
291        self.run_test_with_fx_to_onnx_exporter_and_onnx_runtime(
292            MNISTModel(), (tensor_x,)
293        )
294
295    def test_log_sigmoid(self):
296        # This produces op as `torch.ops.aten.log_sigmoid_forward`, instead of the more
297        # conventional `torch.ops.aten.log_sigmoid`.
298        class Model(torch.nn.Module):
299            def __init__(self) -> None:
300                super().__init__()
301                self.m = torch.nn.LogSigmoid()
302
303            def forward(self, x):
304                return self.m(x)
305
306        input = torch.randn(2)
307        self.run_test_with_fx_to_onnx_exporter_and_onnx_runtime(Model(), (input,))
308
309    @skip_if_no_torchvision
310    def test_resnet18(self):
311        # TODO(bowbao): Note [training vs eval in dynamo_export]
312        # So we are effectively exporting all models in traning mode by
313        # default. But for the sake of this export we are only interested in eval mode.
314        # The question is, should we call `model.eval()` in `dynamo_export`?
315        # This particular test fails 'functionalization' in training mode.
316        # So we are explicitly calling `model.eval()` for any model that contains
317        # batch norm.
318        # Ref: https://github.com/pytorch/pytorch/issues/99662#issuecomment-1528178221
319        model = torchvision.models.resnet18(weights=None).eval()
320        dummy_input = torch.randn(1, 3, 224, 224)
321
322        self.run_test_with_fx_to_onnx_exporter_and_onnx_runtime(
323            model,
324            (dummy_input,),
325        )
326
327    @pytorch_test_common.xfail_dynamic_fx_test(
328        error_message="[ONNXRuntimeError] : 2 : INVALID_ARGUMENT : Got invalid dimensions for input"
329    )
330    @skip_if_no_torchvision
331    def test_shufflenet_v2(self):
332        # TODO(bowbao): see Note [training vs eval in dynamo_export]
333        model = torchvision.models.shufflenet_v2_x0_5(weights=None).eval()
334        dummy_input = torch.randn(1, 3, 224, 224, requires_grad=False)
335        test_inputs = torch.randn(3, 3, 224, 224, requires_grad=False)
336
337        self.run_test_with_fx_to_onnx_exporter_and_onnx_runtime(
338            model,
339            (dummy_input,),
340            additional_test_inputs=[((test_inputs,),)],
341            rtol=1e-3,
342            atol=1e-5,
343        )
344
345    def test_add(self):
346        class DynamicAdd(torch.nn.Module):
347            def forward(self, x, y):
348                return torch.ops.aten.add(x, y)
349
350        x = torch.randn(2, 3)
351        y = torch.randn(2, 3)
352        another_x = torch.randn(3, 4)
353        another_y = torch.randn(3, 4)
354
355        self.run_test_with_fx_to_onnx_exporter_and_onnx_runtime(
356            DynamicAdd(),
357            (x, y),
358            additional_test_inputs=[((another_x, another_y),)],
359        )
360
361    def test_sigmoid_add(self):
362        class DynamicAdd(torch.nn.Module):
363            def __init__(self, *args, **kwargs) -> None:
364                super().__init__(*args, **kwargs)
365                self.sigmoid = torch.nn.Sigmoid()
366
367            def forward(self, x, y):
368                z = torch.ops.aten.add(x, y)
369                return self.sigmoid(z)
370
371        x = torch.randn(2, 3)
372        y = torch.randn(2, 3)
373        x = x[1:, :]
374        y = y[1:, :]
375        input_x = torch.randn(1, 4)
376        input_y = torch.randn(1, 4)
377
378        self.run_test_with_fx_to_onnx_exporter_and_onnx_runtime(
379            DynamicAdd(), (x, y), additional_test_inputs=[((input_x, input_y),)]
380        )
381
382    def test_matmul(self):
383        class DynamicMatMul(torch.nn.Module):
384            def forward(self, x, y):
385                return torch.ops.aten.matmul(x, y)
386
387        x = torch.randn(2, 3, 6)
388        y = torch.randn(2, 6, 4)
389        input_x = torch.randn(2, 3, 4)
390        input_y = torch.randn(2, 4, 4)
391
392        self.run_test_with_fx_to_onnx_exporter_and_onnx_runtime(
393            DynamicMatMul(), (x, y), additional_test_inputs=[((input_x, input_y),)]
394        )
395
396    @pytorch_test_common.xfail_dynamic_fx_test(
397        error_message="The values for attribute 'shape' do not match: torch.Size([]) != torch.Size([1])"
398    )
399    def test_scalar_tensor(self):
400        class test(torch.nn.Module):
401            def forward(self, x):
402                return torch.scalar_tensor(x.size(0)), torch.scalar_tensor(
403                    x.size(1), dtype=torch.int64
404                )
405
406        x = torch.randn(2, 3, 4)
407        y = torch.randn(7, 8, 9)
408        self.run_test_with_fx_to_onnx_exporter_and_onnx_runtime(
409            test(),
410            (x,),
411            additional_test_inputs=[((y,),)],
412        )
413
414    def test_transpose_infer_shape(self):
415        class TransposeModule(torch.nn.Module):
416            def __init__(self) -> None:
417                super().__init__()
418                self.conv = torch.nn.Conv2d(3, 1, 3, stride=2)
419
420            def forward(self, x):
421                x = self.conv(x)
422                return x.transpose(0, 1)
423
424        x = torch.randn(32, 3, 64, 64)
425        y = torch.randn(16, 3, 8, 64)
426        self.run_test_with_fx_to_onnx_exporter_and_onnx_runtime(
427            TransposeModule(),
428            (x,),
429            additional_test_inputs=[((y,),)],
430        )
431
432    @pytorch_test_common.xfail_dynamic_fx_test  # no dynamic shapes present
433    def test_squeeze_runtime_dim(self):
434        class Squeeze(torch.nn.Module):
435            def forward(self, d1, d2):
436                t = torch.zeros(d1[0], d2[0])  # problematic user code for dynamo
437                return t.squeeze(0)
438
439        d1 = torch.tensor([1])
440        d3 = torch.tensor([3])
441        d4 = torch.tensor([4])
442        self.run_test_with_fx_to_onnx_exporter_and_onnx_runtime(
443            Squeeze(), (d1, d4), additional_test_inputs=[((d3, d4),)]
444        )
445        self.run_test_with_fx_to_onnx_exporter_and_onnx_runtime(
446            Squeeze(), (d3, d4), additional_test_inputs=[((d1, d3),)]
447        )
448
449    def test_slice(self):
450        class DynamicSliceExportMod(torch.nn.Module):
451            def forward(self, x):
452                results = []
453                for i in range(4):
454                    results.append(x[: x.size(0) - i, i : x.size(2), i:3])
455                return tuple(results)
456
457        x = torch.rand(5, 5, 5)
458        y = torch.randn(6, 7, 8)
459        self.run_test_with_fx_to_onnx_exporter_and_onnx_runtime(
460            DynamicSliceExportMod(),
461            (x,),
462            additional_test_inputs=[((y,),)],
463        )
464
465    @pytorch_test_common.xfail_if_model_type_is_exportedprogram(
466        error_message="Expected 1 outputs, got 2",
467    )
468    def test_mutation(self):
469        class MutationModel(torch.nn.Module):
470            def forward(self, x):
471                x.view(3, 2, -1).add_(2.0)
472                return x
473
474        self.run_test_with_fx_to_onnx_exporter_and_onnx_runtime(
475            MutationModel(), (torch.randn(12),), has_mutation=True
476        )
477
478    @unittest.skip(
479        "Fixme: arange in torchlib does not support dynamic start and end yet."
480    )
481    def test_arange(self):
482        class ArangeModel(torch.nn.Module):
483            def forward(self, input):
484                return (
485                    torch.arange(input.shape[0]),
486                    torch.arange(12),
487                    torch.arange(start=input.shape[0], end=input.shape[0] + 5),
488                )
489
490        x = torch.randn(5, 3, 2)
491        y = torch.randn(8, 3, 2)
492        self.run_test_with_fx_to_onnx_exporter_and_onnx_runtime(
493            ArangeModel(),
494            (x,),
495            additional_test_inputs=[((y,),)],
496        )
497
498    @pytorch_test_common.xfail_dynamic_fx_test(
499        error_message="[ONNXRuntimeError] : 1 : FAIL : Non-zero status code returned while running Slice node. "
500    )
501    @pytorch_test_common.xfail_if_model_type_is_exportedprogram(
502        error_message="Expected 1 outputs, got 2"
503    )
504    def test_expand_as_fill_zero(self):
505        class Model(torch.nn.Module):
506            def forward(self, x):
507                x[:, x.size(0) :] = 0
508                return x
509
510        x = torch.ones(2, 5)
511        x2 = torch.randn(3, 4)
512        self.run_test_with_fx_to_onnx_exporter_and_onnx_runtime(
513            Model(),
514            (x,),
515            additional_test_inputs=[((x2,),)],
516        )
517
518    @pytorch_test_common.xfail_dynamic_fx_test(
519        error_message="[ONNXRuntimeError] : 1 : FAIL : Non-zero status code returned while running Slice node. "
520    )
521    @pytorch_test_common.xfail_if_model_type_is_exportedprogram(
522        error_message="Expected 1 outputs, got 2"
523    )
524    def test_expand_as_fill_tensor(self):
525        class Model(torch.nn.Module):
526            def forward(self, x):
527                x[:, x.size(0) :] = torch.tensor([1, 2, 3])
528                return x
529
530        x = torch.ones(2, 5, 3)
531        x2 = torch.randn(3, 4, 3)
532        self.run_test_with_fx_to_onnx_exporter_and_onnx_runtime(
533            Model(),
534            (x,),
535            additional_test_inputs=[((x2,),)],
536        )
537
538    @pytorch_test_common.xfail_if_model_type_is_not_exportedprogram(
539        error_message="at::functionalization::impl::isFunctionalTensor(t) INTERNAL ASSERT FAILED"
540    )
541    def test_expand_as_fill_separate_tensor(self):
542        class Model(torch.nn.Module):
543            def forward(self, x):
544                aa = torch.tensor([[0], [1], [2]])
545                return aa.expand_as(x)
546
547        x = torch.ones(3, 2)
548        x2 = torch.randn(3, 5)
549        self.run_test_with_fx_to_onnx_exporter_and_onnx_runtime(
550            Model(),
551            (x,),
552            additional_test_inputs=[((x2,),)],
553        )
554
555    @pytorch_test_common.skipIfNoCuda
556    def test__scaled_dot_product_flash_attention(self):
557        class Foo(torch.nn.Module):
558            def forward(self, x):
559                (
560                    output,
561                    _,
562                    _,
563                    _,
564                    _,
565                    _,
566                    _,
567                    _,
568                    _,
569                ) = torch.ops.aten._scaled_dot_product_flash_attention(x, x, x)
570                return output
571
572        func = Foo()
573
574        x = torch.randn(1, 1, 1, 32, device=torch.device("cuda"))
575        self.run_test_with_fx_to_onnx_exporter_and_onnx_runtime(func, (x,))
576
577    def test_view_dynamic_zero_dim(self):
578        class ViewModel(torch.nn.Module):
579            def forward(self, input):
580                input = input.view(-1, 2)
581                return input.view(1, -1)
582
583        x = torch.ones(2)
584        y = torch.empty(0)
585        self.run_test_with_fx_to_onnx_exporter_and_onnx_runtime(
586            ViewModel(),
587            (x,),
588            additional_test_inputs=[((y,),)],
589        )
590
591    def test_flatten_dynamic_axes(self):
592        class MyModule(torch.nn.Module):
593            def forward(self, x):
594                return torch.flatten(x, start_dim=2, end_dim=3)
595
596        batch_size = 3
597        x = torch.randn(batch_size, 5, 4, 5)
598        y = torch.randn(5, 5, 4, 5)
599        model = MyModule()
600        self.run_test_with_fx_to_onnx_exporter_and_onnx_runtime(
601            model, (x,), additional_test_inputs=[((y,),)]
602        )
603
604    def test_none_input(self):
605        class NoneInputModel(torch.nn.Module):
606            def forward(
607                self, x: torch.Tensor, y: Optional[torch.Tensor], z: torch.Tensor
608            ):
609                if y is None:
610                    return x + z
611                return x + y + z
612
613        self.run_test_with_fx_to_onnx_exporter_and_onnx_runtime(
614            NoneInputModel(), (torch.randn(1, 2), None, torch.randn(1, 2))
615        )
616
617    def test_operator_with_data_dependent_output(self):
618        class Foo(torch.nn.Module):
619            def forward(self, x):
620                # Repro from llama. Emits `torch.ops.aten._local_scalar_dense`.
621                return x + torch.full(x.shape, torch.tensor(torch.finfo(x.dtype).min))
622
623        func = Foo()
624
625        self.run_test_with_fx_to_onnx_exporter_and_onnx_runtime(
626            func, (torch.randn(3, 4),)
627        )
628
629    def test_operator_with_scalar_output(self):
630        class Foo(torch.nn.Module):
631            def forward(self, x, y):
632                return x.item() + y
633
634        func = Foo()
635
636        self.run_test_with_fx_to_onnx_exporter_and_onnx_runtime(
637            func, (torch.tensor([1]), torch.randn(3, 4))
638        )
639
640    def test_operator_with_dynamic_output_shape(self):
641        class Foo(torch.nn.Module):
642            def forward(self, x):
643                return x.nonzero()
644
645        func = Foo()
646
647        self.run_test_with_fx_to_onnx_exporter_and_onnx_runtime(
648            func, (torch.randn(3, 4),)
649        )
650
651    @pytorch_test_common.xfail_if_model_type_is_exportedprogram(
652        error_message="Trying to flatten user inputs with exported input tree spec"
653    )
654    @pytorch_test_common.xfail_dynamic_fx_test(
655        error_message="!(it.GetName().empty())",
656        reason="With after onnx==1.16, constant folding in optimizer causes this error.",
657        model_type=pytorch_test_common.TorchModelType.TORCH_NN_MODULE,
658    )
659    def test_gpt2_tiny_from_config(self):
660        # Model
661        config = transformers.GPT2Config(
662            num_hidden_layers=4,
663            vocab_size=8096,
664            hidden_size=16,
665            intermediate_size=16,
666            max_position_embeddings=512,
667            num_attention_heads=2,
668            hidden_dropout_prob=0.0,
669            attention_dropout_prob=0.0,
670        )
671        model = transformers.GPT2Model(config).eval()
672
673        def input_generator(batch: int, seq: int):
674            input_ids = torch.randint(0, 8096, (batch, seq))
675            attention_mask = torch.ones(batch, seq, dtype=torch.bool)
676            position_ids = torch.arange(0, seq, dtype=torch.long)
677            position_ids = position_ids.unsqueeze(0).view(-1, seq)
678            return input_ids, attention_mask, position_ids
679
680        # Encoded inputs
681        input_ids, attention_mask, position_ids = input_generator(2, 128)
682
683        # Another encoded inputs to test dynamic shapes
684        (
685            another_input_ids,
686            another_attention_mask,
687            another_position_ids,
688        ) = input_generator(3, 256)
689
690        self.run_test_with_fx_to_onnx_exporter_and_onnx_runtime(
691            model,
692            (input_ids,),
693            input_kwargs={
694                "attention_mask": attention_mask,
695                "position_ids": position_ids,
696            },
697            additional_test_inputs=[
698                (
699                    (another_input_ids,),
700                    {
701                        "attention_mask": another_attention_mask,
702                        "position_ids": another_position_ids,
703                    },
704                )
705            ],
706        )
707
708    def test_prims_device_put(self):
709        class CustomModule(nn.Module):
710            def forward(self, x):
711                # Assuming x is a tensor on the CPU, move it to the desired device using device_put()
712                x = torch.ops.prims.device_put(x, "cpu")
713                return x
714
715        self.run_test_with_fx_to_onnx_exporter_and_onnx_runtime(
716            CustomModule(), (torch.randn(1, 2, 3),)
717        )
718
719    def _test_fx_symbolic_tracer_large_scale_exporter(
720        self,
721        model_name: str,
722        create_model: Callable,
723        create_args: Callable,
724        create_pytorch_only_kwargs: Callable,
725    ):
726        """Test helper for large-scale exporter.
727
728        Arguments:
729            model_name: Name of the model. It used to name temporary files.
730            create_model: A function that creates a model. It should always create the same model.
731            create_args: A function that creates random input arguments for the model.
732            create_pytorch_only_kwargs: A function that creates kwargs for calling PyTorch model with real tensors.
733
734        This test contains several steps.
735
736        1. Create a toy model.
737        2. Save the toy's state (parameters) to a file. This is for simulating a checkpoint file.
738        3. Load it back and export it to ONNX with large-scale exporter.
739            All operations (including model loading) are done under
740            FakeTensorMode so no real tensor is created and no real
741            computation happens.
742        4. The ONNX model generated in step 3 doesn't contain parameters,
743            and this step adds them as external data and save a new ONNX model.
744        5. Run PyTorch and ONNX models and compare their results.
745        """
746
747        # Create the toy model.
748        model = create_model()
749
750        with tempfile.NamedTemporaryFile(
751            prefix=model_name, suffix=".pt"
752        ) as tmp_file, tempfile.TemporaryDirectory(
753            suffix="large_scale_export"
754        ) as tmp_folder:
755            # Dump state_dict to a file to simulate how HuggingFace model is initialized.
756            # The file will be loaded via .load_state_dict(...)
757            torch.save(model.state_dict(), tmp_file.name)
758
759            ftm = fake_tensor.FakeTensorMode(
760                allow_non_fake_inputs=True, allow_fallback_kernels=False
761            )
762            ctx = patcher.ONNXTorchPatcher()
763            # NOTE: FakeTensorMode disallows symbolic shape of fx graph
764            # The following coed block does several things.
765            #  1. Create a model whose parameters and buffers are all FakeTensor's.
766            #  2. Convert nn.Module into ONNX model without initializers.
767            #  3. Record the file paths to find real initializers.
768            with ctx, ftm:
769                # Toy model with parameters and buffers as FakeTensor's.
770                fake_model = create_model()
771                fake_model.load_state_dict(torch.load(tmp_file.name))
772                # Toy inputs as FakeTensor's.
773                fake_args = create_args()
774                # Export ONNX model without initializers while ctx.paths records
775                # all files that contains real initializers.
776
777                options = torch.onnx.ExportOptions(
778                    dynamic_shapes=self.dynamic_shapes,
779                )
780                export_options = _exporter_legacy.ResolvedExportOptions(options)
781                export_options.fx_tracer = (
782                    fx_symbolic_graph_extractor.FXSymbolicTracer()
783                )
784                onnx_program = torch.onnx.dynamo_export(
785                    fake_model,
786                    *fake_args,
787                    export_options=export_options,
788                )
789                onnx_model = onnx_program.model_proto
790
791            onnx_test_common.assert_dynamic_shapes(onnx_program, self.dynamic_shapes)
792
793            # Tasks done by the following block.
794            #  1. Iterate through all tensors stored in ctx.paths (the file content is loaded torch.load)
795            #  2. If a tensor's name matches a "onnx_model"'s input name, an initializer is created and saved to
796            #     a seperated folder.
797            #  3. A new ONNX model is saved into file with the initializers saved in the previous step.
798            #  4. ORT executes the new ONNX model and compares the results with the original GPT model.
799
800            # Model saved to tmp_folder/onnx_model_location
801            # Initializers are saved to tmp_folder/onnx_initializer_location/*.onnx
802            onnx_model_location = model_name + "_external_data.onnx"
803            onnx_initializer_location = model_name + "_initializers"
804            # TODO: We are using the internal `save_model_with_external_data` instead of public
805            # `ONNXProgram.save` because we need to rename ONNX initializers before saving.
806            # This is only needed/allowed because we are using `fx_tracer=FXSymbolicTracer`,
807            # which is not an official FX tracer.
808            fx_serialization.save_model_with_external_data(
809                tmp_folder,
810                onnx_model_location,
811                onnx_initializer_location,
812                tuple(ctx.paths),
813                onnx_model,
814                rename_initializer=True,
815            )
816            # Generate random inputs.
817            args = create_args()
818            kwargs = create_pytorch_only_kwargs()
819            # Original outputs.
820            ref_outputs = onnx_program.adapt_torch_outputs_to_onnx(
821                model(*args, **kwargs)
822            )
823            # ORT outputs.
824            args_not_none = onnx_program.adapt_torch_inputs_to_onnx(*args)
825
826            # Drop Parameters and buffers added by fx_serialization.save_model_with_external_data
827            args_not_none = args_not_none[: len(args) - len(kwargs)]
828
829            ort_outputs = onnx_test_common.run_ort(
830                os.path.join(tmp_folder, onnx_model_location),
831                args_not_none,
832            )
833
834            assert len(ref_outputs) == len(ort_outputs)
835
836            for ref_output, ort_output in zip(ref_outputs, ort_outputs):
837                torch.testing.assert_close(ref_output, torch.tensor(ort_output))
838
839    @pytorch_test_common.xfail_dynamic_fx_test(
840        error_message="shape_env should be set if tracing with 'symbolic'"
841    )
842    def test_fx_symbolic_tracer_large_scale_exporter_with_toy_mlp(self):
843        class MLPModel(nn.Module):
844            def __init__(self) -> None:
845                super().__init__()
846                self.fc0 = nn.Linear(8, 8, bias=True)
847                self.fc1 = nn.Linear(8, 4, bias=True)
848                self.fc2 = nn.Linear(4, 2, bias=True)
849                self.fc3 = nn.Linear(2, 2, bias=True)
850
851            def forward(self, tensor_x: torch.Tensor):
852                tensor_x = self.fc0(tensor_x)
853                tensor_x = torch.sigmoid(tensor_x)
854                tensor_x = self.fc1(tensor_x)
855                tensor_x = torch.sigmoid(tensor_x)
856                tensor_x = self.fc2(tensor_x)
857                tensor_x = torch.sigmoid(tensor_x)
858                output = self.fc3(tensor_x)
859                return output
860
861        def create_model() -> nn.Module:
862            return MLPModel()
863
864        def create_args():
865            return (torch.rand((97, 8), dtype=torch.float32),)
866
867        def create_pytorch_only_extra_kwargs():
868            return {}
869
870        self._test_fx_symbolic_tracer_large_scale_exporter(
871            "toy_mlp1",
872            create_model,
873            create_args,
874            create_pytorch_only_extra_kwargs,
875        )
876
877    @pytorch_test_common.xfail_dynamic_fx_test(
878        error_message="shape_env should be set if tracing with 'symbolic'"
879    )
880    def test_fx_symbolic_tracer_large_scale_exporter_with_tiny_gpt2(self):
881        model_name = "sshleifer/tiny-gpt2"
882        device = "cpu"
883
884        def create_model() -> nn.Module:
885            return transformers.AutoModel.from_pretrained(model_name).to(device).eval()
886
887        def create_args():
888            tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
889            kwargs = tokenizer("Hello world!", return_tensors="pt")
890            input_ids = kwargs["input_ids"]
891            attention_mask = kwargs["attention_mask"]
892            return input_ids, None, attention_mask
893
894        def create_pytorch_only_extra_kwargs():
895            return {"return_dict": False}
896
897        self._test_fx_symbolic_tracer_large_scale_exporter(
898            "tiny_gpt2",
899            create_model,
900            create_args,
901            create_pytorch_only_extra_kwargs,
902        )
903
904
905def _parameterized_class_attrs_and_values_with_fake_options():
906    input_values = []
907    input_values.extend(
908        itertools.product(
909            (True, False),
910            (True, False),
911            (True, False),
912            (pytorch_test_common.TorchModelType.TORCH_NN_MODULE,),
913        )
914    )
915    return {
916        "attrs": [
917            "dynamic_shapes",
918            "load_checkpoint_during_init",
919            "export_within_fake_mode",
920            "model_type",
921        ],
922        "input_values": input_values,
923    }
924
925
926@parameterized.parameterized_class(
927    **_parameterized_class_attrs_and_values_with_fake_options(),
928    class_name_func=_parameterize_class_name,
929)
930class TestFxToOnnxFakeTensorWithOnnxRuntime(onnx_test_common._TestONNXRuntime):
931    """ONNX export test for specific Fake Tensor scenarios
932
933    TODO: Should we merge this with  `TestFxToOnnxWithOnnxRuntime`? Considerably increases export time
934    """
935
936    dynamic_shapes: bool
937    load_checkpoint_during_init: bool
938    export_within_fake_mode: bool
939    model_type: pytorch_test_common.TorchModelType
940
941    def setUp(self):
942        super().setUp()
943        self.ort_version = onnxruntime.__version__
944
945    def _test_fake_tensor_mode_exporter(
946        self,
947        model_name: str,
948        create_model: Callable,
949        create_args: Callable,
950        create_kwargs: Callable,
951        load_checkpoint_during_init: bool,
952        export_within_fake_mode: bool,
953        model_type: pytorch_test_common.TorchModelType,
954    ):
955        """Test helper for FakeTensorMode-enabled exporter.
956
957        Arguments:
958            model_name: Name of the model. It used to name temporary files.
959            create_model: A function that creates a model.
960            create_args: A function that creates positional inputs for the model.
961            create_kwargs: A function that creates keyword inputs for ther model.
962            load_checkpoint_during_init: Whether to load a checkpoint during model initialization.
963                (after or during model creation, but before exporting starts)
964            export_within_fake_mode: Whether to call torch.onnx._dynamo_export within torch._subclasses.FakeTensorMode
965            model_type: Type of user model. Used to determine whether the user model must be exported to
966                torch.export.ExportedProgram before passing it to torch.onnx.dynamo_export
967
968        This test contains several steps.
969
970        1. Create a toy model.
971        2. Save the toy's state (parameters) to a file. This is for simulating a checkpoint file.
972        3. Load it back and export it to ONNX with Fake Mode enabled.
973            Because all operations (including model and input loading) are done under
974            FakeTensorMode, no real tensor are created and no real computation happens.
975        4. The ONNX model generated in step 3 doesn't contain parameters,
976            and this step adds them as external data on an ONNX model.
977        5. Run PyTorch and ONNX models and compare their results.
978        """
979
980        # Create the toy model with real weight.
981        real_model = create_model()
982        state_dict = real_model.state_dict()  # concrete (non-fake) state_dict
983
984        with tempfile.NamedTemporaryFile(
985            prefix=model_name, suffix=".pt"
986        ) as tmp_checkpoint_file:
987            # Dump state_dict to a file to simulate how HuggingFace model is initialized.
988            # The file will be loaded via .load_state_dict(...)
989            torch.save(state_dict, tmp_checkpoint_file.name)
990
991            with torch.onnx.enable_fake_mode() as fake_context:
992                fake_args = create_args()
993                fake_kwargs = create_kwargs()
994                fake_model = create_model()
995                if load_checkpoint_during_init:
996                    fake_model.load_state_dict(torch.load(tmp_checkpoint_file.name))
997
998                # Export the model with fake inputs and parameters
999                export_options = torch.onnx.ExportOptions(
1000                    dynamic_shapes=self.dynamic_shapes,
1001                    fake_context=fake_context,
1002                )
1003
1004                if export_within_fake_mode:
1005                    onnx_program = torch.onnx.dynamo_export(
1006                        fake_model,
1007                        *fake_args,
1008                        **fake_kwargs,
1009                        export_options=export_options,
1010                    )
1011
1012            if not export_within_fake_mode:
1013                onnx_program = torch.onnx.dynamo_export(
1014                    fake_model, *fake_args, **fake_kwargs, export_options=export_options
1015                )
1016
1017            onnx_test_common.assert_dynamic_shapes(onnx_program, self.dynamic_shapes)
1018
1019            if diagnostics.is_onnx_diagnostics_log_artifact_enabled():
1020                onnx_program.save_diagnostics(
1021                    f"test_report_{self._testMethodName}"
1022                    f"_dynamic_axes_{self.dynamic_shapes}"
1023                    f"_load_checkpoint_{self.load_checkpoint_during_init}"
1024                    f"_export_within_fake_mode_{self.export_within_fake_mode}"
1025                    f"model_type_{self.model_type}"
1026                    ".sarif"
1027                )
1028
1029            with tempfile.NamedTemporaryFile(suffix=".onnx") as tmp_onnx_file:
1030                onnx_program.save(
1031                    tmp_onnx_file.name, model_state=tmp_checkpoint_file.name
1032                )
1033
1034                # Generate random inputs.
1035                args = create_args()
1036                kwargs = create_kwargs()
1037                # Original outputs.
1038                # model_with_state_dict=real_model is used to create non-fake weights
1039                if isinstance(real_model, torch.export.ExportedProgram):
1040                    outputs = real_model.module()(*args, **kwargs)
1041                else:
1042                    outputs = real_model(*args, **kwargs)
1043                ref_outputs = onnx_program.adapt_torch_outputs_to_onnx(
1044                    outputs, model_with_state_dict=real_model
1045                )
1046                # ORT outputs.
1047                # model_with_state_dict=real_model is used to create non-fake weights
1048                args_not_none = onnx_program.adapt_torch_inputs_to_onnx(
1049                    *args, model_with_state_dict=real_model, **kwargs
1050                )
1051
1052                ort_outputs = onnx_test_common.run_ort(
1053                    tmp_onnx_file.name,
1054                    args_not_none,
1055                )
1056
1057                assert len(ref_outputs) == len(ort_outputs)
1058                for ref_output, ort_output in zip(ref_outputs, ort_outputs):
1059                    torch.testing.assert_close(ref_output, torch.tensor(ort_output))
1060
1061                # Test ONNXProgram.__call__ interface
1062                ort_outputs = onnx_program(
1063                    *args, model_with_state_dict=real_model, **kwargs
1064                )
1065                assert len(ref_outputs) == len(ort_outputs)
1066                for ref_output, ort_output in zip(ref_outputs, ort_outputs):
1067                    torch.testing.assert_close(ref_output, torch.tensor(ort_output))
1068
1069    def test_fake_tensor_mode_simple(self):
1070        def create_model() -> nn.Module:
1071            class Model(torch.nn.Module):
1072                def __init__(self) -> None:
1073                    super().__init__()
1074                    self.linear = torch.nn.Linear(2, 2)
1075
1076                def forward(self, x):
1077                    out = self.linear(x)
1078                    return out
1079
1080            return Model()
1081
1082        def create_args():
1083            return (torch.rand(5, 2, 2),)
1084
1085        def create_kwargs():
1086            return {}
1087
1088        self._test_fake_tensor_mode_exporter(
1089            "simple",
1090            create_model,
1091            create_args,
1092            create_kwargs,
1093            load_checkpoint_during_init=self.load_checkpoint_during_init,
1094            export_within_fake_mode=self.export_within_fake_mode,
1095            model_type=self.model_type,
1096        )
1097
1098    @pytorch_test_common.xfail_dynamic_fx_test(
1099        error_message="!(it.GetName().empty())",
1100        reason="With after onnx==1.16, constant folding in optimizer causes this error.",
1101        model_type=pytorch_test_common.TorchModelType.TORCH_NN_MODULE,
1102    )
1103    @pytorch_test_common.xfail_if_model_type_is_not_exportedprogram(
1104        error_message="Expected 4 inputs, got 2",
1105        reason="https://github.com/pytorch/pytorch/issues/115745",
1106    )
1107    def test_fake_tensor_mode_huggingface_tiny_gpt2(self):
1108        model_name = "sshleifer/tiny-gpt2"
1109        device = "cpu"
1110
1111        def create_model() -> nn.Module:
1112            return transformers.AutoModel.from_pretrained(model_name).to(device).eval()
1113
1114        def create_args():
1115            tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
1116            kwargs = tokenizer("Hello world!", return_tensors="pt")
1117            input_ids = kwargs["input_ids"]
1118            attention_mask = kwargs["attention_mask"]
1119            return input_ids, None, attention_mask
1120
1121        def create_kwargs():
1122            return {"return_dict": False}
1123
1124        self._test_fake_tensor_mode_exporter(
1125            "tiny_gpt2",
1126            create_model,
1127            create_args,
1128            create_kwargs,
1129            load_checkpoint_during_init=self.load_checkpoint_during_init,
1130            export_within_fake_mode=self.export_within_fake_mode,
1131            model_type=self.model_type,
1132        )
1133
1134    def test_large_scale_exporter_with_toy_mlp(self):
1135        class MLPModel(nn.Module):
1136            def __init__(self) -> None:
1137                super().__init__()
1138                self.fc0 = nn.Linear(8, 8, bias=True)
1139                self.fc1 = nn.Linear(8, 4, bias=True)
1140                self.fc2 = nn.Linear(4, 2, bias=True)
1141                self.fc3 = nn.Linear(2, 2, bias=True)
1142
1143            def forward(self, tensor_x: torch.Tensor):
1144                tensor_x = self.fc0(tensor_x)
1145                tensor_x = torch.sigmoid(tensor_x)
1146                tensor_x = self.fc1(tensor_x)
1147                tensor_x = torch.sigmoid(tensor_x)
1148                tensor_x = self.fc2(tensor_x)
1149                tensor_x = torch.sigmoid(tensor_x)
1150                output = self.fc3(tensor_x)
1151                return output
1152
1153        def create_model() -> nn.Module:
1154            return MLPModel()
1155
1156        def create_args():
1157            return (torch.rand((97, 8), dtype=torch.float32),)
1158
1159        def create_kwargs():
1160            return {}
1161
1162        self._test_fake_tensor_mode_exporter(
1163            "toy_mlp1",
1164            create_model,
1165            create_args,
1166            create_kwargs,
1167            load_checkpoint_during_init=self.load_checkpoint_during_init,
1168            export_within_fake_mode=self.export_within_fake_mode,
1169            model_type=self.model_type,
1170        )
1171
1172    def test_fake_tensor_mode_huggingface_google_t5(self):
1173        config = transformers.T5Config(
1174            vocab_size=8096, d_model=64, num_layers=2, num_heads=2
1175        )
1176        batch, seq = 4, 256
1177
1178        def create_args():
1179            return ()
1180
1181        def create_kwargs():
1182            input_ids = torch.randint(0, config.vocab_size, (batch, seq))
1183            attention_mask = torch.ones((batch, seq), dtype=torch.bool)
1184            decoder_input_ids = torch.randint(0, config.vocab_size, (batch, seq))
1185            return {
1186                "input_ids": input_ids,
1187                "attention_mask": attention_mask,
1188                "decoder_input_ids": decoder_input_ids,
1189            }
1190
1191        def create_model():
1192            return transformers.T5Model(config).eval()
1193
1194        self._test_fake_tensor_mode_exporter(
1195            "huggingface_google_t5",
1196            create_model,
1197            create_args,
1198            create_kwargs,
1199            load_checkpoint_during_init=self.load_checkpoint_during_init,
1200            export_within_fake_mode=self.export_within_fake_mode,
1201            model_type=self.model_type,
1202        )
1203
1204    @pytorch_test_common.xfail_dynamic_fx_test(
1205        error_message="scaled_dot_product_attention(): argument 'is_causal' must be bool, not SymBool",
1206        reason="Dynamo error: scaled_dot_product_attention(): argument 'is_causal' must be bool, not SymBool",
1207        model_type=pytorch_test_common.TorchModelType.TORCH_NN_MODULE,
1208    )
1209    @pytorch_test_common.xfail(
1210        error_message="Could not find an implementation for Trilu(14) node",
1211        reason="ORT error during op level dubug",
1212    )
1213    def test_fake_tensor_mode_huggingface_openai_whisper(self):
1214        config = transformers.WhisperConfig(
1215            vocab_size=8096,
1216            num_mel_bins=40,
1217            encoder_layers=2,
1218            encoder_attention_heads=2,
1219            decoder_layers=2,
1220            decoder_attention_heads=2,
1221            decoder_ffn_dim=384,
1222            encoder_ffn_dim=384,
1223            d_model=64,
1224            decoder_start_token_id=8001,
1225            pad_token_id=8000,
1226            bos_token_id=8000,
1227            eos_token_id=8000,
1228            begin_suppress_tokens=[220, 8000],
1229        )
1230        feature_extractor = transformers.WhisperFeatureExtractor(feature_size=40)
1231        device = "cpu"
1232        batch = 4
1233
1234        def create_model() -> nn.Module:
1235            return transformers.AutoModel.from_config(config).to(device).eval()
1236
1237        def create_args():
1238            return ()
1239
1240        def create_kwargs():
1241            input_features = torch.randn(
1242                (
1243                    batch,
1244                    feature_extractor.feature_size,
1245                    feature_extractor.nb_max_frames,
1246                ),
1247                dtype=torch.float32,
1248            )
1249            decoder_input_ids = torch.tensor([[1, 1]]) * config.decoder_start_token_id
1250            return {
1251                "input_features": input_features,
1252                "decoder_input_ids": decoder_input_ids,
1253                "return_dict": False,
1254            }
1255
1256        self._test_fake_tensor_mode_exporter(
1257            "openai_whisper",
1258            create_model,
1259            create_args,
1260            create_kwargs,
1261            load_checkpoint_during_init=self.load_checkpoint_during_init,
1262            export_within_fake_mode=self.export_within_fake_mode,
1263            model_type=self.model_type,
1264        )
1265
1266    def test_fake_tensor_mode_huggingface_mosaicml_mpt(self):
1267        config = transformers.MptConfig(
1268            vocab_size=8096, d_model=64, n_heads=2, n_layers=3
1269        )
1270        batch, seq = 4, 256
1271
1272        def create_args():
1273            return ()
1274
1275        def create_kwargs():
1276            input_ids = torch.randint(0, config.vocab_size, (batch, seq))
1277            attention_mask = torch.ones(batch, seq, dtype=torch.bool)
1278            return {"input_ids": input_ids, "attention_mask": attention_mask}
1279
1280        def create_model():
1281            return transformers.MptModel(config).eval()
1282
1283        self._test_fake_tensor_mode_exporter(
1284            "huggingface_mosaicml_mpt",
1285            create_model,
1286            create_args,
1287            create_kwargs,
1288            load_checkpoint_during_init=self.load_checkpoint_during_init,
1289            export_within_fake_mode=self.export_within_fake_mode,
1290            model_type=self.model_type,
1291        )
1292
1293    @pytorch_test_common.xfail_dynamic_fx_test(
1294        error_message="SymIntArrayRef expected to contain only concrete integers",
1295        model_type=pytorch_test_common.TorchModelType.TORCH_NN_MODULE,
1296    )
1297    def test_fake_tensor_mode_huggingface_bigscience_bloom_560m(self):
1298        config = transformers.BloomConfig()
1299        batch, seq = 4, 256
1300
1301        def create_args():
1302            return ()
1303
1304        def create_kwargs():
1305            input_ids = torch.randint(0, config.vocab_size, (batch, seq))
1306            attention_mask = torch.ones(batch, seq, dtype=torch.bool)
1307            return {"input_ids": input_ids, "attention_mask": attention_mask}
1308
1309        def create_model():
1310            return transformers.BloomModel(config).eval()
1311
1312        self._test_fake_tensor_mode_exporter(
1313            "huggingface_bigscience_bloom_560m",
1314            create_model,
1315            create_args,
1316            create_kwargs,
1317            load_checkpoint_during_init=self.load_checkpoint_during_init,
1318            export_within_fake_mode=self.export_within_fake_mode,
1319            model_type=self.model_type,
1320        )
1321
1322    @pytorch_test_common.xfail_if_model_type_is_not_exportedprogram(
1323        error_message="Expected 5 inputs, got 3",
1324        reason="https://github.com/pytorch/pytorch/issues/115745",
1325    )
1326    def test_fake_tensor_mode_huggingface_gpt2(self):
1327        config = transformers.GPT2Config(
1328            vocab_size=8096, n_positions=256, n_embd=256, n_layer=2, n_head=2
1329        )
1330
1331        def create_model():
1332            return transformers.GPT2Model(config).eval()
1333
1334        def create_args():
1335            return ()
1336
1337        def create_kwargs():
1338            batch, seq = 4, 256
1339
1340            input_ids = torch.randint(0, config.vocab_size, (batch, seq))
1341            attention_mask = torch.ones(batch, seq, dtype=torch.bool)
1342            position_ids = torch.arange(0, seq, dtype=torch.long)
1343            position_ids = position_ids.unsqueeze(0).view(-1, seq)
1344
1345            return {
1346                "input_ids": input_ids,
1347                "attention_mask": attention_mask,
1348                "position_ids": position_ids,
1349            }
1350
1351        self._test_fake_tensor_mode_exporter(
1352            "huggingface_gpt2",
1353            create_model,
1354            create_args,
1355            create_kwargs,
1356            load_checkpoint_during_init=self.load_checkpoint_during_init,
1357            export_within_fake_mode=self.export_within_fake_mode,
1358            model_type=self.model_type,
1359        )
1360
1361    @pytorch_test_common.xfail_dynamic_fx_test(
1362        error_message="SymIntArrayRef expected to contain only concrete integers",
1363        model_type=pytorch_test_common.TorchModelType.TORCH_NN_MODULE,
1364    )
1365    @pytorch_test_common.xfail_if_model_type_is_not_exportedprogram(
1366        error_message="Expected 9 inputs, got 3",
1367        reason="https://github.com/pytorch/pytorch/issues/115745",
1368    )
1369    def test_fake_tensor_mode_huggingface_databricks_dolly_v2_3b(self):
1370        config = transformers.GPTNeoXConfig(
1371            vocab_size=8096, hidden_size=256, num_hidden_layers=2, num_attention_heads=2
1372        )
1373        batch, seq = 4, 256
1374
1375        def create_model():
1376            return transformers.GPTNeoXModel(config).eval()
1377
1378        def create_args():
1379            return ()
1380
1381        def create_kwargs():
1382            input_ids = torch.randint(0, config.vocab_size, (batch, seq))
1383            attention_mask = torch.ones(batch, seq, dtype=torch.bool)
1384            position_ids = torch.arange(0, seq, dtype=torch.long)
1385            position_ids = position_ids.unsqueeze(0).view(-1, seq)
1386
1387            return {
1388                "input_ids": input_ids,
1389                "attention_mask": attention_mask,
1390                "position_ids": position_ids,
1391            }
1392
1393        self._test_fake_tensor_mode_exporter(
1394            "huggingface_databricks_dolly_v2_3b",
1395            create_model,
1396            create_args,
1397            create_kwargs,
1398            load_checkpoint_during_init=self.load_checkpoint_during_init,
1399            export_within_fake_mode=self.export_within_fake_mode,
1400            model_type=self.model_type,
1401        )
1402
1403
1404if __name__ == "__main__":
1405    common_utils.run_tests()
1406