1# Owner(s): ["module: onnx"] 2from __future__ import annotations 3 4import itertools 5import math 6import operator 7import os 8import tempfile 9import unittest 10from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Type 11 12import onnx_test_common 13import onnxruntime # type: ignore[import] 14import parameterized # type: ignore[import] 15import pytorch_test_common 16import transformers # type: ignore[import] 17 18import torch 19import torch.onnx 20from torch import nn 21from torch._subclasses import fake_tensor 22from torch.onnx._internal import _exporter_legacy 23from torch.onnx._internal.fx import ( 24 diagnostics, 25 fx_symbolic_graph_extractor, 26 patcher, 27 serialization as fx_serialization, 28) 29from torch.testing._internal import common_utils 30 31 32try: 33 import torchvision # type: ignore[import] 34 35 HAS_TORCHVISION = True 36except ImportError: 37 HAS_TORCHVISION = False 38except RuntimeError: 39 HAS_TORCHVISION = False 40skip_if_no_torchvision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision") 41 42 43def _parameterized_class_attrs_and_values(): 44 input_values = [] 45 input_values.extend( 46 itertools.product( 47 (True, False), 48 (pytorch_test_common.TorchModelType.TORCH_NN_MODULE,), 49 ) 50 ) 51 return { 52 "attrs": ["dynamic_shapes", "model_type"], 53 "input_values": input_values, 54 } 55 56 57def _parameterize_class_name(cls: Type, idx: int, input_dicts: Mapping[Any, Any]): 58 """Combine class name with the parameterized arguments. 59 60 This function is passed to `parameterized.parameterized_class` as the 61 `class_name_func` argument. 62 """ 63 suffixes = [] 64 for k, v in input_dicts.items(): 65 suffixes.append(f"{k}_{v}") 66 return f"{cls.__name__}_{'_'.join(suffixes)}" 67 68 69@parameterized.parameterized_class( 70 **_parameterized_class_attrs_and_values(), 71 class_name_func=_parameterize_class_name, 72) 73class TestFxToOnnxWithOnnxRuntime(onnx_test_common._TestONNXRuntime): 74 dynamic_shapes: bool 75 model_type: pytorch_test_common.TorchModelType 76 77 def setUp(self): 78 super().setUp() 79 self.ort_version = onnxruntime.__version__ 80 81 def test_simple_function(self): 82 class Foo(torch.nn.Module): 83 def forward(self, x): 84 # TODO(justinchuby): Replicate torch's type casting policy 85 # in the exporter for type promotion support 86 y = x + 1.0 87 z = y.relu() 88 return (y, z) 89 90 func = Foo() 91 92 tensor_x = torch.randn(1, 1, 2, dtype=torch.float32) 93 94 self.run_test_with_fx_to_onnx_exporter_and_onnx_runtime(func, (tensor_x,)) 95 96 @pytorch_test_common.xfail( 97 error_message="Tracing through optional input is not supported yet", 98 reason="https://github.com/pytorch/pytorch/issues/96379", 99 ) 100 def test_func_with_args_and_tensor_kwargs(self): 101 # Non-tensor optional kwargs are always folded into constant and 102 # removed from input list in Dynamo-traced graph, if its value is not provided 103 # to tracer. So for a function like 104 # def func(x, b=1.0) 105 # here. E.g., if you first Dynamo-trace the model with arguments (x,), 106 # and then call the traced graph with arguments (x, b=2.0), it will complain 107 # somewhere that model is called with extra args because the modified 108 # function is traced into 109 # def forward(self, x : torch.Tensor): 110 # add = x + 1.0; x = None 111 # relu = add.relu() 112 # return (add, relu) 113 # To summarize, in order to be traced as graph input, the value of optional kwarg 114 # must be provided. Otherwise, they are treated as in-graph constants in Dynamo. 115 # Tensor optional kwargs are an exception. It is always traced as input. 116 # It is unclear if this behavior is intended or not. But in general it is bad 117 # practice to set mutable default values. 118 # `DynamoOptimizeExporter` applies a workaround by binding args and kwargs to 119 # model signature and fill in the default values of unprovided optional arguments. 120 class Foo(torch.nn.Module): 121 def forward(self, x, b=torch.tensor(1.0)): 122 y = x + b 123 z = y.relu() 124 return (y, z) 125 126 func = Foo() 127 128 tensor_x = torch.randn(1, 2, 3, dtype=torch.float32) 129 130 # Test without providing optional kwarg. 131 self.run_test_with_fx_to_onnx_exporter_and_onnx_runtime(func, (tensor_x,)) 132 # Test with only positional args. 133 self.run_test_with_fx_to_onnx_exporter_and_onnx_runtime( 134 func, (tensor_x, torch.tensor(8.0)) 135 ) 136 # Test while specifying optional kwarg. 137 self.run_test_with_fx_to_onnx_exporter_and_onnx_runtime( 138 func, (tensor_x,), input_kwargs={"b": torch.tensor(5.0)} 139 ) 140 141 @pytorch_test_common.skip_dynamic_fx_test( 142 "sympy operation tests don't need dynamic shape" 143 ) 144 def test_sympy_operatons_return_numeric(self): 145 class Foo(torch.nn.Module): 146 def forward(self, x, y): 147 # TODO: add boolean tests when SymBool is supported 148 # to infer types 149 return ( 150 torch.tensor([operator.add(x.item(), y.item())]), 151 torch.tensor([operator.sub(x.item(), y.item())]), 152 torch.tensor([operator.mul(x.item(), y.item())]), 153 torch.tensor([operator.truediv(x.item(), y.item())]), 154 # This requires torch.sym_float, probably easy to lower to 155 # ONNX but I don't know where to put it 156 # torch.tensor([operator.floordiv(x.item(), y.item())]), 157 # NB: abs so that the base and exponent are provably 158 # non-negative, so we don't generate runtime asserts 159 torch.tensor([operator.pow(abs(x.item()), abs(y.item()))]), 160 torch.tensor([operator.abs(x.item())]), 161 torch.tensor([operator.neg(x.item())]), 162 torch.tensor([math.ceil(x.item())]), 163 torch.tensor([math.floor(x.item())]), 164 ) 165 166 func = Foo() 167 168 x = torch.randn(1, dtype=torch.float32) 169 y = torch.randn(1, dtype=torch.float32) 170 self.run_test_with_fx_to_onnx_exporter_and_onnx_runtime( 171 func, 172 ( 173 x, 174 y, 175 ), 176 ) 177 178 @pytorch_test_common.xfail( 179 error_message="Model inputs incompatible with the format that was exported", 180 reason="https://github.com/pytorch/pytorch/issues/99534", 181 ) 182 def test_xfail_func_with_non_tensor_args(self): 183 class Foo(torch.nn.Module): 184 def forward(self, x, b=1.0): 185 y = x + b 186 z = y.relu() 187 return (y, z) 188 189 func = Foo() 190 191 tensor_x = torch.randn(1, 1, 2, dtype=torch.float32) 192 193 onnx_program = torch.onnx.dynamo_export( 194 func, 195 tensor_x, 196 8.0, 197 export_options=torch.onnx.ExportOptions( 198 dynamic_shapes=self.dynamic_shapes, 199 ), 200 ) 201 onnx_test_common.assert_dynamic_shapes(onnx_program, self.dynamic_shapes) 202 onnx_format_args = onnx_program.adapt_torch_inputs_to_onnx(tensor_x, b=8.0) 203 ref_outputs = onnx_program.adapt_torch_outputs_to_onnx(func(tensor_x, 8.0)) 204 ort_outputs = onnx_test_common.run_ort(onnx_program, onnx_format_args) 205 for ref_output, ort_output in zip(ref_outputs, ort_outputs): 206 torch.testing.assert_close(ref_output, torch.tensor(ort_output)) 207 208 # test on different non-tensor input - xfail 209 onnx_format_args = onnx_program.adapt_torch_inputs_to_onnx(tensor_x, b=9.0) 210 ref_outputs = onnx_program.adapt_torch_outputs_to_onnx(func(tensor_x, 9.0)) 211 _ = onnx_test_common.run_ort(onnx_program, onnx_format_args) 212 for ref_output, ort_output in zip(ref_outputs, ort_outputs): 213 torch.testing.assert_close(ref_output, torch.tensor(ort_output)) 214 215 def test_func_with_nested_input_structure(self): 216 class Foo(torch.nn.Module): 217 def forward( 218 self, 219 x_dict: Dict[str, torch.Tensor], 220 y_tuple: Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], 221 z_list: List[List[torch.Tensor]], 222 ): 223 if "a" in x_dict: 224 x = x_dict["a"] 225 elif "b" in x_dict: 226 x = x_dict["b"] 227 else: 228 x = torch.randn(3) 229 230 y1, (y2, y3) = y_tuple 231 232 z = x + y1 + y2 + y3 233 for z_sub_list in z_list: 234 z = z + torch.stack(z_sub_list).sum() 235 236 return z 237 238 func = Foo() 239 240 x_dict = {"a": torch.randn(3), "c": torch.randn(3)} 241 y_tuple = (torch.randn(3), (torch.randn(3), torch.randn(3))) 242 z_list = [ 243 [torch.randn(3), torch.randn(3)], 244 [torch.randn(3), torch.randn(3), torch.randn(3)], 245 ] 246 self.run_test_with_fx_to_onnx_exporter_and_onnx_runtime( 247 func, (x_dict, y_tuple, z_list) 248 ) 249 250 def test_func_with_nested_output_structure(self): 251 class Foo(torch.nn.Module): 252 def forward(self, x, y, z): 253 x = x + y 254 y = y + z 255 z = x + y 256 out1 = (x, (y, z)) 257 out2 = [[x, y], [y, z]] 258 out3 = {"z": z, "x": x} 259 return out1, out2, out3 260 261 func = Foo() 262 263 x = torch.randn(3) 264 y = torch.randn(3) 265 z = torch.randn(3) 266 self.run_test_with_fx_to_onnx_exporter_and_onnx_runtime(func, (x, y, z)) 267 268 def test_mnist(self): 269 class MNISTModel(nn.Module): 270 def __init__(self) -> None: 271 super().__init__() 272 self.conv1 = nn.Conv2d(1, 32, 3, 1, bias=True) 273 self.conv2 = nn.Conv2d(32, 64, 3, 1, bias=True) 274 self.fc1 = nn.Linear(9216, 128, bias=True) 275 self.fc2 = nn.Linear(128, 10, bias=True) 276 277 def forward(self, tensor_x: torch.Tensor): 278 tensor_x = self.conv1(tensor_x) 279 tensor_x = torch.sigmoid(tensor_x) 280 tensor_x = self.conv2(tensor_x) 281 tensor_x = torch.sigmoid(tensor_x) 282 tensor_x = torch.max_pool2d(tensor_x, 2) 283 tensor_x = torch.flatten(tensor_x, 1) 284 tensor_x = self.fc1(tensor_x) 285 tensor_x = torch.sigmoid(tensor_x) 286 tensor_x = self.fc2(tensor_x) 287 output = torch.log_softmax(tensor_x, dim=1) 288 return output 289 290 tensor_x = torch.rand((64, 1, 28, 28), dtype=torch.float32) 291 self.run_test_with_fx_to_onnx_exporter_and_onnx_runtime( 292 MNISTModel(), (tensor_x,) 293 ) 294 295 def test_log_sigmoid(self): 296 # This produces op as `torch.ops.aten.log_sigmoid_forward`, instead of the more 297 # conventional `torch.ops.aten.log_sigmoid`. 298 class Model(torch.nn.Module): 299 def __init__(self) -> None: 300 super().__init__() 301 self.m = torch.nn.LogSigmoid() 302 303 def forward(self, x): 304 return self.m(x) 305 306 input = torch.randn(2) 307 self.run_test_with_fx_to_onnx_exporter_and_onnx_runtime(Model(), (input,)) 308 309 @skip_if_no_torchvision 310 def test_resnet18(self): 311 # TODO(bowbao): Note [training vs eval in dynamo_export] 312 # So we are effectively exporting all models in traning mode by 313 # default. But for the sake of this export we are only interested in eval mode. 314 # The question is, should we call `model.eval()` in `dynamo_export`? 315 # This particular test fails 'functionalization' in training mode. 316 # So we are explicitly calling `model.eval()` for any model that contains 317 # batch norm. 318 # Ref: https://github.com/pytorch/pytorch/issues/99662#issuecomment-1528178221 319 model = torchvision.models.resnet18(weights=None).eval() 320 dummy_input = torch.randn(1, 3, 224, 224) 321 322 self.run_test_with_fx_to_onnx_exporter_and_onnx_runtime( 323 model, 324 (dummy_input,), 325 ) 326 327 @pytorch_test_common.xfail_dynamic_fx_test( 328 error_message="[ONNXRuntimeError] : 2 : INVALID_ARGUMENT : Got invalid dimensions for input" 329 ) 330 @skip_if_no_torchvision 331 def test_shufflenet_v2(self): 332 # TODO(bowbao): see Note [training vs eval in dynamo_export] 333 model = torchvision.models.shufflenet_v2_x0_5(weights=None).eval() 334 dummy_input = torch.randn(1, 3, 224, 224, requires_grad=False) 335 test_inputs = torch.randn(3, 3, 224, 224, requires_grad=False) 336 337 self.run_test_with_fx_to_onnx_exporter_and_onnx_runtime( 338 model, 339 (dummy_input,), 340 additional_test_inputs=[((test_inputs,),)], 341 rtol=1e-3, 342 atol=1e-5, 343 ) 344 345 def test_add(self): 346 class DynamicAdd(torch.nn.Module): 347 def forward(self, x, y): 348 return torch.ops.aten.add(x, y) 349 350 x = torch.randn(2, 3) 351 y = torch.randn(2, 3) 352 another_x = torch.randn(3, 4) 353 another_y = torch.randn(3, 4) 354 355 self.run_test_with_fx_to_onnx_exporter_and_onnx_runtime( 356 DynamicAdd(), 357 (x, y), 358 additional_test_inputs=[((another_x, another_y),)], 359 ) 360 361 def test_sigmoid_add(self): 362 class DynamicAdd(torch.nn.Module): 363 def __init__(self, *args, **kwargs) -> None: 364 super().__init__(*args, **kwargs) 365 self.sigmoid = torch.nn.Sigmoid() 366 367 def forward(self, x, y): 368 z = torch.ops.aten.add(x, y) 369 return self.sigmoid(z) 370 371 x = torch.randn(2, 3) 372 y = torch.randn(2, 3) 373 x = x[1:, :] 374 y = y[1:, :] 375 input_x = torch.randn(1, 4) 376 input_y = torch.randn(1, 4) 377 378 self.run_test_with_fx_to_onnx_exporter_and_onnx_runtime( 379 DynamicAdd(), (x, y), additional_test_inputs=[((input_x, input_y),)] 380 ) 381 382 def test_matmul(self): 383 class DynamicMatMul(torch.nn.Module): 384 def forward(self, x, y): 385 return torch.ops.aten.matmul(x, y) 386 387 x = torch.randn(2, 3, 6) 388 y = torch.randn(2, 6, 4) 389 input_x = torch.randn(2, 3, 4) 390 input_y = torch.randn(2, 4, 4) 391 392 self.run_test_with_fx_to_onnx_exporter_and_onnx_runtime( 393 DynamicMatMul(), (x, y), additional_test_inputs=[((input_x, input_y),)] 394 ) 395 396 @pytorch_test_common.xfail_dynamic_fx_test( 397 error_message="The values for attribute 'shape' do not match: torch.Size([]) != torch.Size([1])" 398 ) 399 def test_scalar_tensor(self): 400 class test(torch.nn.Module): 401 def forward(self, x): 402 return torch.scalar_tensor(x.size(0)), torch.scalar_tensor( 403 x.size(1), dtype=torch.int64 404 ) 405 406 x = torch.randn(2, 3, 4) 407 y = torch.randn(7, 8, 9) 408 self.run_test_with_fx_to_onnx_exporter_and_onnx_runtime( 409 test(), 410 (x,), 411 additional_test_inputs=[((y,),)], 412 ) 413 414 def test_transpose_infer_shape(self): 415 class TransposeModule(torch.nn.Module): 416 def __init__(self) -> None: 417 super().__init__() 418 self.conv = torch.nn.Conv2d(3, 1, 3, stride=2) 419 420 def forward(self, x): 421 x = self.conv(x) 422 return x.transpose(0, 1) 423 424 x = torch.randn(32, 3, 64, 64) 425 y = torch.randn(16, 3, 8, 64) 426 self.run_test_with_fx_to_onnx_exporter_and_onnx_runtime( 427 TransposeModule(), 428 (x,), 429 additional_test_inputs=[((y,),)], 430 ) 431 432 @pytorch_test_common.xfail_dynamic_fx_test # no dynamic shapes present 433 def test_squeeze_runtime_dim(self): 434 class Squeeze(torch.nn.Module): 435 def forward(self, d1, d2): 436 t = torch.zeros(d1[0], d2[0]) # problematic user code for dynamo 437 return t.squeeze(0) 438 439 d1 = torch.tensor([1]) 440 d3 = torch.tensor([3]) 441 d4 = torch.tensor([4]) 442 self.run_test_with_fx_to_onnx_exporter_and_onnx_runtime( 443 Squeeze(), (d1, d4), additional_test_inputs=[((d3, d4),)] 444 ) 445 self.run_test_with_fx_to_onnx_exporter_and_onnx_runtime( 446 Squeeze(), (d3, d4), additional_test_inputs=[((d1, d3),)] 447 ) 448 449 def test_slice(self): 450 class DynamicSliceExportMod(torch.nn.Module): 451 def forward(self, x): 452 results = [] 453 for i in range(4): 454 results.append(x[: x.size(0) - i, i : x.size(2), i:3]) 455 return tuple(results) 456 457 x = torch.rand(5, 5, 5) 458 y = torch.randn(6, 7, 8) 459 self.run_test_with_fx_to_onnx_exporter_and_onnx_runtime( 460 DynamicSliceExportMod(), 461 (x,), 462 additional_test_inputs=[((y,),)], 463 ) 464 465 @pytorch_test_common.xfail_if_model_type_is_exportedprogram( 466 error_message="Expected 1 outputs, got 2", 467 ) 468 def test_mutation(self): 469 class MutationModel(torch.nn.Module): 470 def forward(self, x): 471 x.view(3, 2, -1).add_(2.0) 472 return x 473 474 self.run_test_with_fx_to_onnx_exporter_and_onnx_runtime( 475 MutationModel(), (torch.randn(12),), has_mutation=True 476 ) 477 478 @unittest.skip( 479 "Fixme: arange in torchlib does not support dynamic start and end yet." 480 ) 481 def test_arange(self): 482 class ArangeModel(torch.nn.Module): 483 def forward(self, input): 484 return ( 485 torch.arange(input.shape[0]), 486 torch.arange(12), 487 torch.arange(start=input.shape[0], end=input.shape[0] + 5), 488 ) 489 490 x = torch.randn(5, 3, 2) 491 y = torch.randn(8, 3, 2) 492 self.run_test_with_fx_to_onnx_exporter_and_onnx_runtime( 493 ArangeModel(), 494 (x,), 495 additional_test_inputs=[((y,),)], 496 ) 497 498 @pytorch_test_common.xfail_dynamic_fx_test( 499 error_message="[ONNXRuntimeError] : 1 : FAIL : Non-zero status code returned while running Slice node. " 500 ) 501 @pytorch_test_common.xfail_if_model_type_is_exportedprogram( 502 error_message="Expected 1 outputs, got 2" 503 ) 504 def test_expand_as_fill_zero(self): 505 class Model(torch.nn.Module): 506 def forward(self, x): 507 x[:, x.size(0) :] = 0 508 return x 509 510 x = torch.ones(2, 5) 511 x2 = torch.randn(3, 4) 512 self.run_test_with_fx_to_onnx_exporter_and_onnx_runtime( 513 Model(), 514 (x,), 515 additional_test_inputs=[((x2,),)], 516 ) 517 518 @pytorch_test_common.xfail_dynamic_fx_test( 519 error_message="[ONNXRuntimeError] : 1 : FAIL : Non-zero status code returned while running Slice node. " 520 ) 521 @pytorch_test_common.xfail_if_model_type_is_exportedprogram( 522 error_message="Expected 1 outputs, got 2" 523 ) 524 def test_expand_as_fill_tensor(self): 525 class Model(torch.nn.Module): 526 def forward(self, x): 527 x[:, x.size(0) :] = torch.tensor([1, 2, 3]) 528 return x 529 530 x = torch.ones(2, 5, 3) 531 x2 = torch.randn(3, 4, 3) 532 self.run_test_with_fx_to_onnx_exporter_and_onnx_runtime( 533 Model(), 534 (x,), 535 additional_test_inputs=[((x2,),)], 536 ) 537 538 @pytorch_test_common.xfail_if_model_type_is_not_exportedprogram( 539 error_message="at::functionalization::impl::isFunctionalTensor(t) INTERNAL ASSERT FAILED" 540 ) 541 def test_expand_as_fill_separate_tensor(self): 542 class Model(torch.nn.Module): 543 def forward(self, x): 544 aa = torch.tensor([[0], [1], [2]]) 545 return aa.expand_as(x) 546 547 x = torch.ones(3, 2) 548 x2 = torch.randn(3, 5) 549 self.run_test_with_fx_to_onnx_exporter_and_onnx_runtime( 550 Model(), 551 (x,), 552 additional_test_inputs=[((x2,),)], 553 ) 554 555 @pytorch_test_common.skipIfNoCuda 556 def test__scaled_dot_product_flash_attention(self): 557 class Foo(torch.nn.Module): 558 def forward(self, x): 559 ( 560 output, 561 _, 562 _, 563 _, 564 _, 565 _, 566 _, 567 _, 568 _, 569 ) = torch.ops.aten._scaled_dot_product_flash_attention(x, x, x) 570 return output 571 572 func = Foo() 573 574 x = torch.randn(1, 1, 1, 32, device=torch.device("cuda")) 575 self.run_test_with_fx_to_onnx_exporter_and_onnx_runtime(func, (x,)) 576 577 def test_view_dynamic_zero_dim(self): 578 class ViewModel(torch.nn.Module): 579 def forward(self, input): 580 input = input.view(-1, 2) 581 return input.view(1, -1) 582 583 x = torch.ones(2) 584 y = torch.empty(0) 585 self.run_test_with_fx_to_onnx_exporter_and_onnx_runtime( 586 ViewModel(), 587 (x,), 588 additional_test_inputs=[((y,),)], 589 ) 590 591 def test_flatten_dynamic_axes(self): 592 class MyModule(torch.nn.Module): 593 def forward(self, x): 594 return torch.flatten(x, start_dim=2, end_dim=3) 595 596 batch_size = 3 597 x = torch.randn(batch_size, 5, 4, 5) 598 y = torch.randn(5, 5, 4, 5) 599 model = MyModule() 600 self.run_test_with_fx_to_onnx_exporter_and_onnx_runtime( 601 model, (x,), additional_test_inputs=[((y,),)] 602 ) 603 604 def test_none_input(self): 605 class NoneInputModel(torch.nn.Module): 606 def forward( 607 self, x: torch.Tensor, y: Optional[torch.Tensor], z: torch.Tensor 608 ): 609 if y is None: 610 return x + z 611 return x + y + z 612 613 self.run_test_with_fx_to_onnx_exporter_and_onnx_runtime( 614 NoneInputModel(), (torch.randn(1, 2), None, torch.randn(1, 2)) 615 ) 616 617 def test_operator_with_data_dependent_output(self): 618 class Foo(torch.nn.Module): 619 def forward(self, x): 620 # Repro from llama. Emits `torch.ops.aten._local_scalar_dense`. 621 return x + torch.full(x.shape, torch.tensor(torch.finfo(x.dtype).min)) 622 623 func = Foo() 624 625 self.run_test_with_fx_to_onnx_exporter_and_onnx_runtime( 626 func, (torch.randn(3, 4),) 627 ) 628 629 def test_operator_with_scalar_output(self): 630 class Foo(torch.nn.Module): 631 def forward(self, x, y): 632 return x.item() + y 633 634 func = Foo() 635 636 self.run_test_with_fx_to_onnx_exporter_and_onnx_runtime( 637 func, (torch.tensor([1]), torch.randn(3, 4)) 638 ) 639 640 def test_operator_with_dynamic_output_shape(self): 641 class Foo(torch.nn.Module): 642 def forward(self, x): 643 return x.nonzero() 644 645 func = Foo() 646 647 self.run_test_with_fx_to_onnx_exporter_and_onnx_runtime( 648 func, (torch.randn(3, 4),) 649 ) 650 651 @pytorch_test_common.xfail_if_model_type_is_exportedprogram( 652 error_message="Trying to flatten user inputs with exported input tree spec" 653 ) 654 @pytorch_test_common.xfail_dynamic_fx_test( 655 error_message="!(it.GetName().empty())", 656 reason="With after onnx==1.16, constant folding in optimizer causes this error.", 657 model_type=pytorch_test_common.TorchModelType.TORCH_NN_MODULE, 658 ) 659 def test_gpt2_tiny_from_config(self): 660 # Model 661 config = transformers.GPT2Config( 662 num_hidden_layers=4, 663 vocab_size=8096, 664 hidden_size=16, 665 intermediate_size=16, 666 max_position_embeddings=512, 667 num_attention_heads=2, 668 hidden_dropout_prob=0.0, 669 attention_dropout_prob=0.0, 670 ) 671 model = transformers.GPT2Model(config).eval() 672 673 def input_generator(batch: int, seq: int): 674 input_ids = torch.randint(0, 8096, (batch, seq)) 675 attention_mask = torch.ones(batch, seq, dtype=torch.bool) 676 position_ids = torch.arange(0, seq, dtype=torch.long) 677 position_ids = position_ids.unsqueeze(0).view(-1, seq) 678 return input_ids, attention_mask, position_ids 679 680 # Encoded inputs 681 input_ids, attention_mask, position_ids = input_generator(2, 128) 682 683 # Another encoded inputs to test dynamic shapes 684 ( 685 another_input_ids, 686 another_attention_mask, 687 another_position_ids, 688 ) = input_generator(3, 256) 689 690 self.run_test_with_fx_to_onnx_exporter_and_onnx_runtime( 691 model, 692 (input_ids,), 693 input_kwargs={ 694 "attention_mask": attention_mask, 695 "position_ids": position_ids, 696 }, 697 additional_test_inputs=[ 698 ( 699 (another_input_ids,), 700 { 701 "attention_mask": another_attention_mask, 702 "position_ids": another_position_ids, 703 }, 704 ) 705 ], 706 ) 707 708 def test_prims_device_put(self): 709 class CustomModule(nn.Module): 710 def forward(self, x): 711 # Assuming x is a tensor on the CPU, move it to the desired device using device_put() 712 x = torch.ops.prims.device_put(x, "cpu") 713 return x 714 715 self.run_test_with_fx_to_onnx_exporter_and_onnx_runtime( 716 CustomModule(), (torch.randn(1, 2, 3),) 717 ) 718 719 def _test_fx_symbolic_tracer_large_scale_exporter( 720 self, 721 model_name: str, 722 create_model: Callable, 723 create_args: Callable, 724 create_pytorch_only_kwargs: Callable, 725 ): 726 """Test helper for large-scale exporter. 727 728 Arguments: 729 model_name: Name of the model. It used to name temporary files. 730 create_model: A function that creates a model. It should always create the same model. 731 create_args: A function that creates random input arguments for the model. 732 create_pytorch_only_kwargs: A function that creates kwargs for calling PyTorch model with real tensors. 733 734 This test contains several steps. 735 736 1. Create a toy model. 737 2. Save the toy's state (parameters) to a file. This is for simulating a checkpoint file. 738 3. Load it back and export it to ONNX with large-scale exporter. 739 All operations (including model loading) are done under 740 FakeTensorMode so no real tensor is created and no real 741 computation happens. 742 4. The ONNX model generated in step 3 doesn't contain parameters, 743 and this step adds them as external data and save a new ONNX model. 744 5. Run PyTorch and ONNX models and compare their results. 745 """ 746 747 # Create the toy model. 748 model = create_model() 749 750 with tempfile.NamedTemporaryFile( 751 prefix=model_name, suffix=".pt" 752 ) as tmp_file, tempfile.TemporaryDirectory( 753 suffix="large_scale_export" 754 ) as tmp_folder: 755 # Dump state_dict to a file to simulate how HuggingFace model is initialized. 756 # The file will be loaded via .load_state_dict(...) 757 torch.save(model.state_dict(), tmp_file.name) 758 759 ftm = fake_tensor.FakeTensorMode( 760 allow_non_fake_inputs=True, allow_fallback_kernels=False 761 ) 762 ctx = patcher.ONNXTorchPatcher() 763 # NOTE: FakeTensorMode disallows symbolic shape of fx graph 764 # The following coed block does several things. 765 # 1. Create a model whose parameters and buffers are all FakeTensor's. 766 # 2. Convert nn.Module into ONNX model without initializers. 767 # 3. Record the file paths to find real initializers. 768 with ctx, ftm: 769 # Toy model with parameters and buffers as FakeTensor's. 770 fake_model = create_model() 771 fake_model.load_state_dict(torch.load(tmp_file.name)) 772 # Toy inputs as FakeTensor's. 773 fake_args = create_args() 774 # Export ONNX model without initializers while ctx.paths records 775 # all files that contains real initializers. 776 777 options = torch.onnx.ExportOptions( 778 dynamic_shapes=self.dynamic_shapes, 779 ) 780 export_options = _exporter_legacy.ResolvedExportOptions(options) 781 export_options.fx_tracer = ( 782 fx_symbolic_graph_extractor.FXSymbolicTracer() 783 ) 784 onnx_program = torch.onnx.dynamo_export( 785 fake_model, 786 *fake_args, 787 export_options=export_options, 788 ) 789 onnx_model = onnx_program.model_proto 790 791 onnx_test_common.assert_dynamic_shapes(onnx_program, self.dynamic_shapes) 792 793 # Tasks done by the following block. 794 # 1. Iterate through all tensors stored in ctx.paths (the file content is loaded torch.load) 795 # 2. If a tensor's name matches a "onnx_model"'s input name, an initializer is created and saved to 796 # a seperated folder. 797 # 3. A new ONNX model is saved into file with the initializers saved in the previous step. 798 # 4. ORT executes the new ONNX model and compares the results with the original GPT model. 799 800 # Model saved to tmp_folder/onnx_model_location 801 # Initializers are saved to tmp_folder/onnx_initializer_location/*.onnx 802 onnx_model_location = model_name + "_external_data.onnx" 803 onnx_initializer_location = model_name + "_initializers" 804 # TODO: We are using the internal `save_model_with_external_data` instead of public 805 # `ONNXProgram.save` because we need to rename ONNX initializers before saving. 806 # This is only needed/allowed because we are using `fx_tracer=FXSymbolicTracer`, 807 # which is not an official FX tracer. 808 fx_serialization.save_model_with_external_data( 809 tmp_folder, 810 onnx_model_location, 811 onnx_initializer_location, 812 tuple(ctx.paths), 813 onnx_model, 814 rename_initializer=True, 815 ) 816 # Generate random inputs. 817 args = create_args() 818 kwargs = create_pytorch_only_kwargs() 819 # Original outputs. 820 ref_outputs = onnx_program.adapt_torch_outputs_to_onnx( 821 model(*args, **kwargs) 822 ) 823 # ORT outputs. 824 args_not_none = onnx_program.adapt_torch_inputs_to_onnx(*args) 825 826 # Drop Parameters and buffers added by fx_serialization.save_model_with_external_data 827 args_not_none = args_not_none[: len(args) - len(kwargs)] 828 829 ort_outputs = onnx_test_common.run_ort( 830 os.path.join(tmp_folder, onnx_model_location), 831 args_not_none, 832 ) 833 834 assert len(ref_outputs) == len(ort_outputs) 835 836 for ref_output, ort_output in zip(ref_outputs, ort_outputs): 837 torch.testing.assert_close(ref_output, torch.tensor(ort_output)) 838 839 @pytorch_test_common.xfail_dynamic_fx_test( 840 error_message="shape_env should be set if tracing with 'symbolic'" 841 ) 842 def test_fx_symbolic_tracer_large_scale_exporter_with_toy_mlp(self): 843 class MLPModel(nn.Module): 844 def __init__(self) -> None: 845 super().__init__() 846 self.fc0 = nn.Linear(8, 8, bias=True) 847 self.fc1 = nn.Linear(8, 4, bias=True) 848 self.fc2 = nn.Linear(4, 2, bias=True) 849 self.fc3 = nn.Linear(2, 2, bias=True) 850 851 def forward(self, tensor_x: torch.Tensor): 852 tensor_x = self.fc0(tensor_x) 853 tensor_x = torch.sigmoid(tensor_x) 854 tensor_x = self.fc1(tensor_x) 855 tensor_x = torch.sigmoid(tensor_x) 856 tensor_x = self.fc2(tensor_x) 857 tensor_x = torch.sigmoid(tensor_x) 858 output = self.fc3(tensor_x) 859 return output 860 861 def create_model() -> nn.Module: 862 return MLPModel() 863 864 def create_args(): 865 return (torch.rand((97, 8), dtype=torch.float32),) 866 867 def create_pytorch_only_extra_kwargs(): 868 return {} 869 870 self._test_fx_symbolic_tracer_large_scale_exporter( 871 "toy_mlp1", 872 create_model, 873 create_args, 874 create_pytorch_only_extra_kwargs, 875 ) 876 877 @pytorch_test_common.xfail_dynamic_fx_test( 878 error_message="shape_env should be set if tracing with 'symbolic'" 879 ) 880 def test_fx_symbolic_tracer_large_scale_exporter_with_tiny_gpt2(self): 881 model_name = "sshleifer/tiny-gpt2" 882 device = "cpu" 883 884 def create_model() -> nn.Module: 885 return transformers.AutoModel.from_pretrained(model_name).to(device).eval() 886 887 def create_args(): 888 tokenizer = transformers.AutoTokenizer.from_pretrained(model_name) 889 kwargs = tokenizer("Hello world!", return_tensors="pt") 890 input_ids = kwargs["input_ids"] 891 attention_mask = kwargs["attention_mask"] 892 return input_ids, None, attention_mask 893 894 def create_pytorch_only_extra_kwargs(): 895 return {"return_dict": False} 896 897 self._test_fx_symbolic_tracer_large_scale_exporter( 898 "tiny_gpt2", 899 create_model, 900 create_args, 901 create_pytorch_only_extra_kwargs, 902 ) 903 904 905def _parameterized_class_attrs_and_values_with_fake_options(): 906 input_values = [] 907 input_values.extend( 908 itertools.product( 909 (True, False), 910 (True, False), 911 (True, False), 912 (pytorch_test_common.TorchModelType.TORCH_NN_MODULE,), 913 ) 914 ) 915 return { 916 "attrs": [ 917 "dynamic_shapes", 918 "load_checkpoint_during_init", 919 "export_within_fake_mode", 920 "model_type", 921 ], 922 "input_values": input_values, 923 } 924 925 926@parameterized.parameterized_class( 927 **_parameterized_class_attrs_and_values_with_fake_options(), 928 class_name_func=_parameterize_class_name, 929) 930class TestFxToOnnxFakeTensorWithOnnxRuntime(onnx_test_common._TestONNXRuntime): 931 """ONNX export test for specific Fake Tensor scenarios 932 933 TODO: Should we merge this with `TestFxToOnnxWithOnnxRuntime`? Considerably increases export time 934 """ 935 936 dynamic_shapes: bool 937 load_checkpoint_during_init: bool 938 export_within_fake_mode: bool 939 model_type: pytorch_test_common.TorchModelType 940 941 def setUp(self): 942 super().setUp() 943 self.ort_version = onnxruntime.__version__ 944 945 def _test_fake_tensor_mode_exporter( 946 self, 947 model_name: str, 948 create_model: Callable, 949 create_args: Callable, 950 create_kwargs: Callable, 951 load_checkpoint_during_init: bool, 952 export_within_fake_mode: bool, 953 model_type: pytorch_test_common.TorchModelType, 954 ): 955 """Test helper for FakeTensorMode-enabled exporter. 956 957 Arguments: 958 model_name: Name of the model. It used to name temporary files. 959 create_model: A function that creates a model. 960 create_args: A function that creates positional inputs for the model. 961 create_kwargs: A function that creates keyword inputs for ther model. 962 load_checkpoint_during_init: Whether to load a checkpoint during model initialization. 963 (after or during model creation, but before exporting starts) 964 export_within_fake_mode: Whether to call torch.onnx._dynamo_export within torch._subclasses.FakeTensorMode 965 model_type: Type of user model. Used to determine whether the user model must be exported to 966 torch.export.ExportedProgram before passing it to torch.onnx.dynamo_export 967 968 This test contains several steps. 969 970 1. Create a toy model. 971 2. Save the toy's state (parameters) to a file. This is for simulating a checkpoint file. 972 3. Load it back and export it to ONNX with Fake Mode enabled. 973 Because all operations (including model and input loading) are done under 974 FakeTensorMode, no real tensor are created and no real computation happens. 975 4. The ONNX model generated in step 3 doesn't contain parameters, 976 and this step adds them as external data on an ONNX model. 977 5. Run PyTorch and ONNX models and compare their results. 978 """ 979 980 # Create the toy model with real weight. 981 real_model = create_model() 982 state_dict = real_model.state_dict() # concrete (non-fake) state_dict 983 984 with tempfile.NamedTemporaryFile( 985 prefix=model_name, suffix=".pt" 986 ) as tmp_checkpoint_file: 987 # Dump state_dict to a file to simulate how HuggingFace model is initialized. 988 # The file will be loaded via .load_state_dict(...) 989 torch.save(state_dict, tmp_checkpoint_file.name) 990 991 with torch.onnx.enable_fake_mode() as fake_context: 992 fake_args = create_args() 993 fake_kwargs = create_kwargs() 994 fake_model = create_model() 995 if load_checkpoint_during_init: 996 fake_model.load_state_dict(torch.load(tmp_checkpoint_file.name)) 997 998 # Export the model with fake inputs and parameters 999 export_options = torch.onnx.ExportOptions( 1000 dynamic_shapes=self.dynamic_shapes, 1001 fake_context=fake_context, 1002 ) 1003 1004 if export_within_fake_mode: 1005 onnx_program = torch.onnx.dynamo_export( 1006 fake_model, 1007 *fake_args, 1008 **fake_kwargs, 1009 export_options=export_options, 1010 ) 1011 1012 if not export_within_fake_mode: 1013 onnx_program = torch.onnx.dynamo_export( 1014 fake_model, *fake_args, **fake_kwargs, export_options=export_options 1015 ) 1016 1017 onnx_test_common.assert_dynamic_shapes(onnx_program, self.dynamic_shapes) 1018 1019 if diagnostics.is_onnx_diagnostics_log_artifact_enabled(): 1020 onnx_program.save_diagnostics( 1021 f"test_report_{self._testMethodName}" 1022 f"_dynamic_axes_{self.dynamic_shapes}" 1023 f"_load_checkpoint_{self.load_checkpoint_during_init}" 1024 f"_export_within_fake_mode_{self.export_within_fake_mode}" 1025 f"model_type_{self.model_type}" 1026 ".sarif" 1027 ) 1028 1029 with tempfile.NamedTemporaryFile(suffix=".onnx") as tmp_onnx_file: 1030 onnx_program.save( 1031 tmp_onnx_file.name, model_state=tmp_checkpoint_file.name 1032 ) 1033 1034 # Generate random inputs. 1035 args = create_args() 1036 kwargs = create_kwargs() 1037 # Original outputs. 1038 # model_with_state_dict=real_model is used to create non-fake weights 1039 if isinstance(real_model, torch.export.ExportedProgram): 1040 outputs = real_model.module()(*args, **kwargs) 1041 else: 1042 outputs = real_model(*args, **kwargs) 1043 ref_outputs = onnx_program.adapt_torch_outputs_to_onnx( 1044 outputs, model_with_state_dict=real_model 1045 ) 1046 # ORT outputs. 1047 # model_with_state_dict=real_model is used to create non-fake weights 1048 args_not_none = onnx_program.adapt_torch_inputs_to_onnx( 1049 *args, model_with_state_dict=real_model, **kwargs 1050 ) 1051 1052 ort_outputs = onnx_test_common.run_ort( 1053 tmp_onnx_file.name, 1054 args_not_none, 1055 ) 1056 1057 assert len(ref_outputs) == len(ort_outputs) 1058 for ref_output, ort_output in zip(ref_outputs, ort_outputs): 1059 torch.testing.assert_close(ref_output, torch.tensor(ort_output)) 1060 1061 # Test ONNXProgram.__call__ interface 1062 ort_outputs = onnx_program( 1063 *args, model_with_state_dict=real_model, **kwargs 1064 ) 1065 assert len(ref_outputs) == len(ort_outputs) 1066 for ref_output, ort_output in zip(ref_outputs, ort_outputs): 1067 torch.testing.assert_close(ref_output, torch.tensor(ort_output)) 1068 1069 def test_fake_tensor_mode_simple(self): 1070 def create_model() -> nn.Module: 1071 class Model(torch.nn.Module): 1072 def __init__(self) -> None: 1073 super().__init__() 1074 self.linear = torch.nn.Linear(2, 2) 1075 1076 def forward(self, x): 1077 out = self.linear(x) 1078 return out 1079 1080 return Model() 1081 1082 def create_args(): 1083 return (torch.rand(5, 2, 2),) 1084 1085 def create_kwargs(): 1086 return {} 1087 1088 self._test_fake_tensor_mode_exporter( 1089 "simple", 1090 create_model, 1091 create_args, 1092 create_kwargs, 1093 load_checkpoint_during_init=self.load_checkpoint_during_init, 1094 export_within_fake_mode=self.export_within_fake_mode, 1095 model_type=self.model_type, 1096 ) 1097 1098 @pytorch_test_common.xfail_dynamic_fx_test( 1099 error_message="!(it.GetName().empty())", 1100 reason="With after onnx==1.16, constant folding in optimizer causes this error.", 1101 model_type=pytorch_test_common.TorchModelType.TORCH_NN_MODULE, 1102 ) 1103 @pytorch_test_common.xfail_if_model_type_is_not_exportedprogram( 1104 error_message="Expected 4 inputs, got 2", 1105 reason="https://github.com/pytorch/pytorch/issues/115745", 1106 ) 1107 def test_fake_tensor_mode_huggingface_tiny_gpt2(self): 1108 model_name = "sshleifer/tiny-gpt2" 1109 device = "cpu" 1110 1111 def create_model() -> nn.Module: 1112 return transformers.AutoModel.from_pretrained(model_name).to(device).eval() 1113 1114 def create_args(): 1115 tokenizer = transformers.AutoTokenizer.from_pretrained(model_name) 1116 kwargs = tokenizer("Hello world!", return_tensors="pt") 1117 input_ids = kwargs["input_ids"] 1118 attention_mask = kwargs["attention_mask"] 1119 return input_ids, None, attention_mask 1120 1121 def create_kwargs(): 1122 return {"return_dict": False} 1123 1124 self._test_fake_tensor_mode_exporter( 1125 "tiny_gpt2", 1126 create_model, 1127 create_args, 1128 create_kwargs, 1129 load_checkpoint_during_init=self.load_checkpoint_during_init, 1130 export_within_fake_mode=self.export_within_fake_mode, 1131 model_type=self.model_type, 1132 ) 1133 1134 def test_large_scale_exporter_with_toy_mlp(self): 1135 class MLPModel(nn.Module): 1136 def __init__(self) -> None: 1137 super().__init__() 1138 self.fc0 = nn.Linear(8, 8, bias=True) 1139 self.fc1 = nn.Linear(8, 4, bias=True) 1140 self.fc2 = nn.Linear(4, 2, bias=True) 1141 self.fc3 = nn.Linear(2, 2, bias=True) 1142 1143 def forward(self, tensor_x: torch.Tensor): 1144 tensor_x = self.fc0(tensor_x) 1145 tensor_x = torch.sigmoid(tensor_x) 1146 tensor_x = self.fc1(tensor_x) 1147 tensor_x = torch.sigmoid(tensor_x) 1148 tensor_x = self.fc2(tensor_x) 1149 tensor_x = torch.sigmoid(tensor_x) 1150 output = self.fc3(tensor_x) 1151 return output 1152 1153 def create_model() -> nn.Module: 1154 return MLPModel() 1155 1156 def create_args(): 1157 return (torch.rand((97, 8), dtype=torch.float32),) 1158 1159 def create_kwargs(): 1160 return {} 1161 1162 self._test_fake_tensor_mode_exporter( 1163 "toy_mlp1", 1164 create_model, 1165 create_args, 1166 create_kwargs, 1167 load_checkpoint_during_init=self.load_checkpoint_during_init, 1168 export_within_fake_mode=self.export_within_fake_mode, 1169 model_type=self.model_type, 1170 ) 1171 1172 def test_fake_tensor_mode_huggingface_google_t5(self): 1173 config = transformers.T5Config( 1174 vocab_size=8096, d_model=64, num_layers=2, num_heads=2 1175 ) 1176 batch, seq = 4, 256 1177 1178 def create_args(): 1179 return () 1180 1181 def create_kwargs(): 1182 input_ids = torch.randint(0, config.vocab_size, (batch, seq)) 1183 attention_mask = torch.ones((batch, seq), dtype=torch.bool) 1184 decoder_input_ids = torch.randint(0, config.vocab_size, (batch, seq)) 1185 return { 1186 "input_ids": input_ids, 1187 "attention_mask": attention_mask, 1188 "decoder_input_ids": decoder_input_ids, 1189 } 1190 1191 def create_model(): 1192 return transformers.T5Model(config).eval() 1193 1194 self._test_fake_tensor_mode_exporter( 1195 "huggingface_google_t5", 1196 create_model, 1197 create_args, 1198 create_kwargs, 1199 load_checkpoint_during_init=self.load_checkpoint_during_init, 1200 export_within_fake_mode=self.export_within_fake_mode, 1201 model_type=self.model_type, 1202 ) 1203 1204 @pytorch_test_common.xfail_dynamic_fx_test( 1205 error_message="scaled_dot_product_attention(): argument 'is_causal' must be bool, not SymBool", 1206 reason="Dynamo error: scaled_dot_product_attention(): argument 'is_causal' must be bool, not SymBool", 1207 model_type=pytorch_test_common.TorchModelType.TORCH_NN_MODULE, 1208 ) 1209 @pytorch_test_common.xfail( 1210 error_message="Could not find an implementation for Trilu(14) node", 1211 reason="ORT error during op level dubug", 1212 ) 1213 def test_fake_tensor_mode_huggingface_openai_whisper(self): 1214 config = transformers.WhisperConfig( 1215 vocab_size=8096, 1216 num_mel_bins=40, 1217 encoder_layers=2, 1218 encoder_attention_heads=2, 1219 decoder_layers=2, 1220 decoder_attention_heads=2, 1221 decoder_ffn_dim=384, 1222 encoder_ffn_dim=384, 1223 d_model=64, 1224 decoder_start_token_id=8001, 1225 pad_token_id=8000, 1226 bos_token_id=8000, 1227 eos_token_id=8000, 1228 begin_suppress_tokens=[220, 8000], 1229 ) 1230 feature_extractor = transformers.WhisperFeatureExtractor(feature_size=40) 1231 device = "cpu" 1232 batch = 4 1233 1234 def create_model() -> nn.Module: 1235 return transformers.AutoModel.from_config(config).to(device).eval() 1236 1237 def create_args(): 1238 return () 1239 1240 def create_kwargs(): 1241 input_features = torch.randn( 1242 ( 1243 batch, 1244 feature_extractor.feature_size, 1245 feature_extractor.nb_max_frames, 1246 ), 1247 dtype=torch.float32, 1248 ) 1249 decoder_input_ids = torch.tensor([[1, 1]]) * config.decoder_start_token_id 1250 return { 1251 "input_features": input_features, 1252 "decoder_input_ids": decoder_input_ids, 1253 "return_dict": False, 1254 } 1255 1256 self._test_fake_tensor_mode_exporter( 1257 "openai_whisper", 1258 create_model, 1259 create_args, 1260 create_kwargs, 1261 load_checkpoint_during_init=self.load_checkpoint_during_init, 1262 export_within_fake_mode=self.export_within_fake_mode, 1263 model_type=self.model_type, 1264 ) 1265 1266 def test_fake_tensor_mode_huggingface_mosaicml_mpt(self): 1267 config = transformers.MptConfig( 1268 vocab_size=8096, d_model=64, n_heads=2, n_layers=3 1269 ) 1270 batch, seq = 4, 256 1271 1272 def create_args(): 1273 return () 1274 1275 def create_kwargs(): 1276 input_ids = torch.randint(0, config.vocab_size, (batch, seq)) 1277 attention_mask = torch.ones(batch, seq, dtype=torch.bool) 1278 return {"input_ids": input_ids, "attention_mask": attention_mask} 1279 1280 def create_model(): 1281 return transformers.MptModel(config).eval() 1282 1283 self._test_fake_tensor_mode_exporter( 1284 "huggingface_mosaicml_mpt", 1285 create_model, 1286 create_args, 1287 create_kwargs, 1288 load_checkpoint_during_init=self.load_checkpoint_during_init, 1289 export_within_fake_mode=self.export_within_fake_mode, 1290 model_type=self.model_type, 1291 ) 1292 1293 @pytorch_test_common.xfail_dynamic_fx_test( 1294 error_message="SymIntArrayRef expected to contain only concrete integers", 1295 model_type=pytorch_test_common.TorchModelType.TORCH_NN_MODULE, 1296 ) 1297 def test_fake_tensor_mode_huggingface_bigscience_bloom_560m(self): 1298 config = transformers.BloomConfig() 1299 batch, seq = 4, 256 1300 1301 def create_args(): 1302 return () 1303 1304 def create_kwargs(): 1305 input_ids = torch.randint(0, config.vocab_size, (batch, seq)) 1306 attention_mask = torch.ones(batch, seq, dtype=torch.bool) 1307 return {"input_ids": input_ids, "attention_mask": attention_mask} 1308 1309 def create_model(): 1310 return transformers.BloomModel(config).eval() 1311 1312 self._test_fake_tensor_mode_exporter( 1313 "huggingface_bigscience_bloom_560m", 1314 create_model, 1315 create_args, 1316 create_kwargs, 1317 load_checkpoint_during_init=self.load_checkpoint_during_init, 1318 export_within_fake_mode=self.export_within_fake_mode, 1319 model_type=self.model_type, 1320 ) 1321 1322 @pytorch_test_common.xfail_if_model_type_is_not_exportedprogram( 1323 error_message="Expected 5 inputs, got 3", 1324 reason="https://github.com/pytorch/pytorch/issues/115745", 1325 ) 1326 def test_fake_tensor_mode_huggingface_gpt2(self): 1327 config = transformers.GPT2Config( 1328 vocab_size=8096, n_positions=256, n_embd=256, n_layer=2, n_head=2 1329 ) 1330 1331 def create_model(): 1332 return transformers.GPT2Model(config).eval() 1333 1334 def create_args(): 1335 return () 1336 1337 def create_kwargs(): 1338 batch, seq = 4, 256 1339 1340 input_ids = torch.randint(0, config.vocab_size, (batch, seq)) 1341 attention_mask = torch.ones(batch, seq, dtype=torch.bool) 1342 position_ids = torch.arange(0, seq, dtype=torch.long) 1343 position_ids = position_ids.unsqueeze(0).view(-1, seq) 1344 1345 return { 1346 "input_ids": input_ids, 1347 "attention_mask": attention_mask, 1348 "position_ids": position_ids, 1349 } 1350 1351 self._test_fake_tensor_mode_exporter( 1352 "huggingface_gpt2", 1353 create_model, 1354 create_args, 1355 create_kwargs, 1356 load_checkpoint_during_init=self.load_checkpoint_during_init, 1357 export_within_fake_mode=self.export_within_fake_mode, 1358 model_type=self.model_type, 1359 ) 1360 1361 @pytorch_test_common.xfail_dynamic_fx_test( 1362 error_message="SymIntArrayRef expected to contain only concrete integers", 1363 model_type=pytorch_test_common.TorchModelType.TORCH_NN_MODULE, 1364 ) 1365 @pytorch_test_common.xfail_if_model_type_is_not_exportedprogram( 1366 error_message="Expected 9 inputs, got 3", 1367 reason="https://github.com/pytorch/pytorch/issues/115745", 1368 ) 1369 def test_fake_tensor_mode_huggingface_databricks_dolly_v2_3b(self): 1370 config = transformers.GPTNeoXConfig( 1371 vocab_size=8096, hidden_size=256, num_hidden_layers=2, num_attention_heads=2 1372 ) 1373 batch, seq = 4, 256 1374 1375 def create_model(): 1376 return transformers.GPTNeoXModel(config).eval() 1377 1378 def create_args(): 1379 return () 1380 1381 def create_kwargs(): 1382 input_ids = torch.randint(0, config.vocab_size, (batch, seq)) 1383 attention_mask = torch.ones(batch, seq, dtype=torch.bool) 1384 position_ids = torch.arange(0, seq, dtype=torch.long) 1385 position_ids = position_ids.unsqueeze(0).view(-1, seq) 1386 1387 return { 1388 "input_ids": input_ids, 1389 "attention_mask": attention_mask, 1390 "position_ids": position_ids, 1391 } 1392 1393 self._test_fake_tensor_mode_exporter( 1394 "huggingface_databricks_dolly_v2_3b", 1395 create_model, 1396 create_args, 1397 create_kwargs, 1398 load_checkpoint_during_init=self.load_checkpoint_during_init, 1399 export_within_fake_mode=self.export_within_fake_mode, 1400 model_type=self.model_type, 1401 ) 1402 1403 1404if __name__ == "__main__": 1405 common_utils.run_tests() 1406