1# Owner(s): ["module: dynamo"] 2 3import collections 4import contextlib 5import copy 6import itertools 7import os 8import tempfile 9import traceback 10import types 11import unittest 12from copy import deepcopy 13from functools import partial 14from typing import Dict, NamedTuple, Tuple 15from unittest.mock import patch 16 17import torch 18import torch._dynamo.test_case 19import torch._dynamo.testing 20import torch.nn.functional as F 21from torch._dynamo.debug_utils import same_two_models 22from torch._dynamo.eval_frame import unsupported 23from torch._dynamo.mutation_guard import GenerationTracker 24from torch._dynamo.testing import expectedFailureDynamic, same 25from torch._dynamo.variables.torch_function import TensorWithTFOverrideVariable 26from torch.nn.modules.lazy import LazyModuleMixin 27from torch.nn.parameter import Parameter, UninitializedParameter 28 29 30try: 31 from . import test_functions 32except ImportError: 33 import test_functions 34 35 36_variable = 0 37_variable1 = 0 38 39 40def update_global(): 41 global _variable, _variable1 42 _variable += 1 43 _variable1 += 1 44 45 46class BasicModule(torch.nn.Module): 47 def __init__(self) -> None: 48 super().__init__() 49 self.linear1 = torch.nn.Linear(10, 10) 50 self.scale = torch.randn(1, 10) 51 52 def forward(self, x): 53 return F.relu(self.linear1(x)) * self.scale 54 55 56class FnMember(torch.nn.Module): 57 def __init__(self) -> None: 58 super().__init__() 59 self.linear1 = torch.nn.Linear(10, 10) 60 self.activation = F.relu 61 62 def forward(self, x): 63 x = self.linear1(x) 64 if self.activation: 65 x = self.activation(x) 66 return x 67 68 69class FnMemberCmp(torch.nn.Module): 70 def __init__(self, activation): 71 super().__init__() 72 self.linear1 = torch.nn.Linear(10, 10) 73 self.activation = activation 74 75 def forward(self, x): 76 x = self.linear1(x) 77 if self.activation is not None: 78 x = self.activation(x) 79 if self.activation is None: 80 x = torch.sigmoid(x) 81 return x 82 83 84class SubmoduleExample(torch.nn.Module): 85 def __init__(self) -> None: 86 super().__init__() 87 self.layer1 = BasicModule() 88 self.layer2 = BasicModule() 89 self.scale = torch.randn(1, 10) 90 91 def forward(self, x): 92 x = self.layer1(x) 93 x = self.layer2(x) 94 return x * self.scale 95 96 97class IsTrainingCheck(torch.nn.Module): 98 def __init__(self) -> None: 99 super().__init__() 100 self.linear1 = torch.nn.Linear(10, 10) 101 self.linear2 = torch.nn.Linear(10, 10) 102 self.train(True) 103 104 def forward(self, x): 105 if self.training: 106 mod = self.linear1 107 else: 108 mod = self.linear2 109 return F.relu(mod(x)) 110 111 112class IsEvalCheck(IsTrainingCheck): 113 def __init__(self) -> None: 114 super().__init__() 115 self.train(False) 116 117 118class ModuleMethodCall(torch.nn.Module): 119 def __init__(self) -> None: 120 super().__init__() 121 self.layer1 = BasicModule() 122 self.layer2 = BasicModule() 123 self.scale = torch.randn(1, 10) 124 125 def call_and_scale(self, mod, x): 126 x = mod(x) 127 return x * self.scale 128 129 def forward(self, x): 130 x1 = self.call_and_scale(self.layer1, x) 131 x2 = self.call_and_scale(self.layer2, x) 132 return x1 + x2 133 134 135class UnsupportedMethodCall(torch.nn.Module): 136 def __init__(self) -> None: 137 super().__init__() 138 self.layer1 = BasicModule() 139 self.scale = torch.randn(1, 10) 140 141 def call_and_scale(self, mod, x): 142 x = mod(x) 143 x = x * self.scale 144 return unsupported(x, x) 145 146 def forward(self, x): 147 x1 = self.call_and_scale(self.layer1, x) 148 return x + x1 149 150 151class UnsupportedModule(torch.nn.Module): 152 def __init__(self) -> None: 153 super().__init__() 154 self.layer1 = BasicModule() 155 self.scale = torch.randn(1, 10) 156 157 def forward(self, x): 158 x = self.layer1(x) * self.scale 159 return unsupported(x, x) 160 161 162class UnsupportedModuleCall(torch.nn.Module): 163 def __init__(self) -> None: 164 super().__init__() 165 self.mod = UnsupportedModule() 166 167 def forward(self, x): 168 return 1 + self.mod(x * 1.5) 169 170 171class ModuleWithStaticForward(torch.nn.Module): 172 @staticmethod 173 def forward(x): 174 return x * torch.sigmoid(x) 175 176 177class ModuleCallModuleWithStaticForward(torch.nn.Module): 178 def __init__(self) -> None: 179 super().__init__() 180 self.mod = ModuleWithStaticForward() 181 182 def forward(self, x): 183 return self.mod(x) 184 185 186class ModuleStaticMethodCall(torch.nn.Module): 187 def __init__(self) -> None: 188 super().__init__() 189 self.layer1 = BasicModule() 190 self.layer2 = BasicModule() 191 self.scale = torch.randn(1, 10) 192 193 @staticmethod 194 def call_and_scale(scale, mod, x): 195 x = mod(x) 196 return x * scale 197 198 def forward(self, x): 199 x1 = self.call_and_scale(self.scale, self.layer1, x) 200 x2 = self.call_and_scale(self.scale, self.layer2, x) 201 return x1 + x2 202 203 204class ModuleClassMethodCall(torch.nn.Module): 205 def __init__(self) -> None: 206 super().__init__() 207 self.layer1 = BasicModule() 208 self.layer2 = BasicModule() 209 self.scale = torch.randn(1, 10) 210 211 @classmethod 212 def call_and_scale(cls, scale, mod, x): 213 x = mod(x) 214 return x * scale 215 216 def forward(self, x): 217 x1 = self.call_and_scale(self.scale, self.layer1, x) 218 x2 = self.call_and_scale(self.scale, self.layer2, x) 219 return x1 + x2 220 221 222class ModuleProperty(torch.nn.Module): 223 def __init__(self) -> None: 224 super().__init__() 225 self.scale = torch.randn(1, 10) 226 227 @property 228 def scale_alias(self): 229 return self.scale 230 231 def forward(self, x): 232 return x * self.scale_alias 233 234 235class NestedModuleList(torch.nn.Module): 236 def __init__(self) -> None: 237 super().__init__() 238 self.layers = torch.nn.ModuleList([]) 239 for _ in range(3): 240 self.layers.append( 241 torch.nn.ModuleList( 242 [ 243 torch.nn.Linear(10, 10), 244 torch.nn.ReLU(), 245 ] 246 ) 247 ) 248 249 def forward(self, x): 250 for layer, act in self.layers: 251 x = act(layer(x)) 252 return x 253 254 255class ConstLoop(torch.nn.Module): 256 def __init__(self) -> None: 257 super().__init__() 258 self.linear1 = torch.nn.Linear(10, 10) 259 self.count = 3 260 261 def forward(self, x): 262 for i in range(self.count): 263 x = torch.sigmoid(self.linear1(x)) 264 return x 265 266 267class ViaModuleCall(torch.nn.Module): 268 def __init__(self) -> None: 269 super().__init__() 270 self.linear1 = torch.nn.Linear(10, 10) 271 272 def forward(self, x): 273 return test_functions.constant3(torch.sigmoid(self.linear1(x)), x) 274 275 276class IsNoneLayer(torch.nn.Module): 277 def __init__(self) -> None: 278 super().__init__() 279 self.layer1 = torch.nn.Linear(10, 10) 280 self.layer2 = None 281 self.train(True) 282 283 def forward(self, x): 284 if self.layer1 is not None: 285 x = self.layer1(x) 286 if self.layer2 is not None: 287 x = self.layer2(x) 288 return x 289 290 291class LayerList(torch.nn.Module): 292 def __init__(self) -> None: 293 super().__init__() 294 self.layers = [ 295 torch.nn.Linear(10, 10), 296 torch.nn.ReLU(), 297 torch.nn.Linear(10, 10), 298 ] 299 300 def forward(self, x): 301 for layer in self.layers: 302 x = layer(x) 303 return x 304 305 306class ModuleList(torch.nn.Module): 307 def __init__(self) -> None: 308 super().__init__() 309 self.layers = torch.nn.ModuleList( 310 [ 311 torch.nn.Linear(10, 10), 312 torch.nn.ReLU(), 313 torch.nn.Linear(10, 10), 314 torch.nn.ReLU(), 315 ] 316 ) 317 318 def forward(self, x): 319 for i in range(len(self.layers)): 320 x = self.layers[i](x) 321 322 for layer in self.layers: 323 x = layer(x) 324 325 for layer, val in zip(self.layers, (x, x, x, x)): 326 x = layer(x) + val 327 328 for layer, val in zip(self.layers, (1, 2, 3, 4)): 329 x = layer(x) + val 330 331 for idx, layer in enumerate(self.layers): 332 x = layer(x) * idx 333 334 for idx, layer in enumerate(self.layers[::-1]): 335 x = layer(x) * idx 336 337 return x 338 339 340class CustomGetItemModuleList(torch.nn.Module): 341 def __init__(self) -> None: 342 super().__init__() 343 self.layers = torch.nn.ModuleList( 344 [ 345 torch.nn.Linear(10, 10), 346 torch.nn.ReLU(), 347 torch.nn.Linear(10, 10), 348 torch.nn.ReLU(), 349 ] 350 ) 351 352 def __getitem__(self, idx: int): 353 return self.layers[idx] 354 355 def __len__(self) -> int: 356 return len(self.layers) 357 358 def forward(self, x): 359 for i in range(len(self)): 360 x = self[i](x) 361 362 return x 363 364 365class ModuleDict(torch.nn.Module): 366 def __init__(self) -> None: 367 super().__init__() 368 self.layers = torch.nn.ModuleDict( 369 { 370 "0": torch.nn.Linear(10, 10), 371 } 372 ) 373 374 def forward(self, x): 375 # TODO(future PR): handle more logic 376 x = self.layers["0"](x) 377 return x 378 379 380class ParameterDict(torch.nn.Module): 381 def __init__(self) -> None: 382 super().__init__() 383 self.layers = torch.nn.ParameterDict( 384 { 385 "0": torch.nn.Parameter(torch.randn(10, 10)), 386 } 387 ) 388 389 def forward(self, x): 390 x = self.layers["0"].mm(x) 391 return x 392 393 394class CustomGetItemParameterDict(torch.nn.Module): 395 def __init__(self) -> None: 396 super().__init__() 397 self.layers = torch.nn.ParameterDict( 398 { 399 "0": torch.nn.Parameter(torch.randn(10, 10)), 400 } 401 ) 402 403 def __getitem__(self, key: str) -> torch.nn.Module: 404 return self.layers[key] 405 406 def forward(self, x): 407 x = self["0"].mm(x) 408 return x 409 410 411class CustomGetItemModuleDict(torch.nn.Module): 412 def __init__(self) -> None: 413 super().__init__() 414 self.layers = torch.nn.ModuleDict( 415 { 416 "0": torch.nn.Linear(10, 10), 417 } 418 ) 419 420 def __getitem__(self, key: str) -> torch.nn.Module: 421 return self.layers[key] 422 423 def forward(self, x): 424 x = self["0"](x) 425 return x 426 427 428class TensorList(torch.nn.Module): 429 def __init__(self) -> None: 430 super().__init__() 431 self.layers = ( 432 torch.randn((1, 10)), 433 torch.randn((10, 1)), 434 torch.randn((1, 10)), 435 torch.randn((10, 1)), 436 ) 437 438 def forward(self, x): 439 for layer in self.layers: 440 x = x * layer 441 return x 442 443 444class Children(torch.nn.Module): 445 def __init__(self) -> None: 446 super().__init__() 447 self.l1 = torch.nn.Linear(10, 10) 448 self.l2 = torch.nn.ReLU() 449 self.l3 = torch.nn.Linear(10, 10) 450 self.l4 = torch.nn.ReLU() 451 452 def forward(self, x): 453 for block in self.children(): 454 x = block(x) 455 return x 456 457 458class NamedChildren(torch.nn.Module): 459 def __init__(self) -> None: 460 super().__init__() 461 self.l1 = torch.nn.Linear(10, 10) 462 self.l2 = torch.nn.ReLU() 463 self.l3 = torch.nn.Linear(10, 10) 464 self.l4 = torch.nn.ReLU() 465 466 def forward(self, x): 467 for _, block in self.named_children(): 468 x = block(x) 469 return x 470 471 472class IntArg(torch.nn.Module): 473 def __init__(self) -> None: 474 super().__init__() 475 self.layer1 = torch.nn.Linear(10, 10) 476 477 def forward(self, x, offset=1): 478 x = F.relu(self.layer1(x)) + offset 479 return x 480 481 482class Seq(torch.nn.Module): 483 def __init__(self) -> None: 484 super().__init__() 485 self.layers = torch.nn.Sequential( 486 torch.nn.Linear(10, 10), 487 torch.nn.ReLU(), 488 torch.nn.Linear(10, 10), 489 torch.nn.ReLU(), 490 ) 491 492 def forward(self, x): 493 return self.layers(x) 494 495 496class Cfg: 497 def __init__(self) -> None: 498 self.val = 0.5 499 self.count = 3 500 501 502class CfgModule(torch.nn.Module): 503 def __init__(self) -> None: 504 super().__init__() 505 self.cfg = Cfg() 506 self.layer = torch.nn.Linear(10, 10) 507 508 def forward(self, x): 509 for i in range(self.cfg.count): 510 x = self.layer(x + self.cfg.val) 511 return x 512 513 514class StringMember(torch.nn.Module): 515 def __init__(self) -> None: 516 super().__init__() 517 self.linear1 = torch.nn.Linear(10, 10) 518 self.mode = "some_string" 519 520 def forward(self, x): 521 if self.mode == "some_string": 522 return F.relu(self.linear1(x)) 523 524 525class _Block(torch.nn.Module): 526 def forward(self, x): 527 return 1.5 * torch.cat(x, 1) 528 529 530class _DenseBlock(torch.nn.ModuleDict): 531 _version = 2 532 533 def __init__( 534 self, 535 num_layers: int = 3, 536 ) -> None: 537 super().__init__() 538 for i in range(num_layers): 539 self.add_module("denselayer%d" % (i + 1), _Block()) 540 541 def forward(self, init_features): 542 features = [init_features] 543 for layer in self.values(): 544 new_features = layer(features) 545 features.append(new_features) 546 return torch.cat(features, 1) 547 548 549class DenseNetBlocks(torch.nn.Module): 550 def __init__(self) -> None: 551 super().__init__() 552 self.layers = _DenseBlock() 553 554 def forward(self, x): 555 return self.layers(x) 556 557 558class MaterializedModule(torch.nn.Module): 559 """Once the below lazy module is initialized with its first input, 560 it is transformed into this module.""" 561 562 param: Parameter 563 564 def __init__(self) -> None: 565 super().__init__() 566 self.register_parameter("param", None) 567 568 def forward(self, x): 569 return x 570 571 572class LazyModule(LazyModuleMixin, MaterializedModule): 573 param: UninitializedParameter 574 cls_to_become = MaterializedModule 575 576 def __init__(self) -> None: 577 super().__init__() 578 self.param = UninitializedParameter() 579 580 def initialize_parameters(self, x): 581 # force graph break to ensure this was not inlined 582 torch._dynamo.graph_break() 583 self.param.materialize(x.shape) 584 585 586class LazyMLP(torch.nn.Module): 587 def __init__(self) -> None: 588 super().__init__() 589 self.fc1 = torch.nn.LazyLinear(10) 590 self.relu1 = torch.nn.ReLU() 591 self.fc2 = torch.nn.LazyLinear(1) 592 self.relu2 = torch.nn.ReLU() 593 594 def forward(self, input): 595 x = self.relu1(self.fc1(input)) 596 y = self.relu2(self.fc2(x)) 597 return y 598 599 600class MyInput(NamedTuple): 601 x: Dict[str, Dict[str, torch.Tensor]] 602 y: torch.Tensor 603 604 605class LazyLayerWithNamedTupleInput(LazyModuleMixin, torch.nn.Module): 606 def __init__(self) -> None: 607 super().__init__() 608 609 def initialize_parameters(self, input): 610 with torch.no_grad(): 611 self._param = torch.nn.Parameter( 612 torch.empty(input.x["a"][0].shape).fill_(0.5) 613 ) 614 615 def forward(self, input): 616 input = input.x["a"] 617 x = 0 618 for i in range(len(input)): 619 x = x + input[i] 620 return x 621 622 623class LazyModuleWithNamedTupleInput(torch.nn.Module): 624 def __init__(self) -> None: 625 super().__init__() 626 self.layer = LazyLayerWithNamedTupleInput() 627 628 def forward(self, input): 629 return self.layer(input) 630 631 632class LazyLayerWithListInput(LazyModuleMixin, torch.nn.Module): 633 def __init__(self) -> None: 634 super().__init__() 635 636 def initialize_parameters(self, input): 637 with torch.no_grad(): 638 self._param = torch.nn.Parameter(torch.empty(input[0].shape).fill_(0.5)) 639 640 def forward(self, input): 641 x = 0 642 for i in range(len(input)): 643 x = x + input[i] 644 return x 645 646 647class LazyModuleWithListInput(torch.nn.Module): 648 def __init__(self) -> None: 649 super().__init__() 650 self.layer = LazyLayerWithListInput() 651 652 def forward(self, input): 653 return self.layer(input[:-1]) 654 655 656class LazyModuleWithLazySubmodule(LazyModuleMixin, torch.nn.Module): 657 def __init__(self) -> None: 658 super().__init__() 659 660 def initialize_parameters(self, input): 661 with torch.no_grad(): 662 self.layer = LazyLayerWithListInput() 663 664 def forward(self, x): 665 return self.layer(x) 666 667 668class LazyLayerWithInputs(LazyModuleMixin, torch.nn.Module): 669 def __init__(self) -> None: 670 super().__init__() 671 672 def initialize_parameters(self, x, y): 673 with torch.no_grad(): 674 self._param_x = torch.nn.Parameter(torch.empty(x[0].shape).fill_(0.5)) 675 self._param_y = torch.nn.Parameter(torch.empty(y[0].shape).fill_(0.5)) 676 677 def forward(self, x, y): 678 res_x = 0 679 for i in range(len(x)): 680 res_x = res_x + x[i] 681 res_y = 0 682 for i in range(len(y)): 683 res_y = res_y + y[i] 684 return res_x + res_y 685 686 687class LazyModuleKwArgs(LazyModuleMixin, torch.nn.Module): 688 def __init__(self) -> None: 689 super().__init__() 690 691 def initialize_parameters(self, *args, **kwargs): 692 with torch.no_grad(): 693 self.layer = LazyLayerWithInputs() 694 695 def forward(self, x, y): 696 return self.layer(x, y=y) 697 698 699class LazyParentModule(LazyModuleMixin, torch.nn.Module): 700 def __init__(self) -> None: 701 super().__init__() 702 703 def impl(self, x): 704 return x.cos() + self._val 705 706 707class LazyChildModuleNoClsToBecome(LazyParentModule): 708 def __init__(self) -> None: 709 super().__init__() 710 711 def forward(self, x): 712 return super().impl(x.sin()) 713 714 def initialize_parameters(self, input): 715 self._val = torch.nn.Parameter(torch.ones(2, 2)) 716 717 718def requires_grad1(module: torch.nn.Module, recurse: bool = False) -> bool: 719 requires_grad = any(p.requires_grad for p in module.parameters(recurse)) 720 return requires_grad 721 722 723def requires_grad2(module: torch.nn.Module, recurse: bool = False) -> bool: 724 requires_grad = any(p.requires_grad for p in module.parameters(recurse)) 725 return requires_grad 726 727 728class ParametersModule1(torch.nn.Module): 729 def __init__(self) -> None: 730 super().__init__() 731 self.linear1 = torch.nn.Linear(10, 10) 732 self.scale = torch.nn.Parameter(torch.randn(1, 10)) 733 734 def forward(self, x): 735 if not requires_grad1(self): 736 return F.relu(self.linear1(x)) * self.scale 737 else: 738 return x + 1 739 740 741class ParametersModule2(ParametersModule1): 742 def forward(self, x): 743 if not requires_grad2(self): 744 return F.relu(self.linear1(x)) * self.scale 745 else: 746 return x + 1 747 748 749class ParametersModule3(ParametersModule1): 750 def forward(self, x): 751 ones = torch.ones(10, dtype=next(self.parameters()).dtype) 752 return F.relu(self.linear1(x)) * self.scale + ones 753 754 755class ParametersModule4(ParametersModule1): 756 def forward(self, x): 757 ones = torch.ones(10, dtype=next(self.parameters(recurse=False)).dtype) 758 return F.relu(self.linear1(x)) * self.scale + ones 759 760 761class ParametersModule5(torch.nn.Module): 762 def __init__(self) -> None: 763 super().__init__() 764 self.linear1 = torch.nn.Linear(10, 10) 765 self.scale = torch.nn.Parameter(torch.randn(10, 10)) 766 self.scale_dup = self.scale 767 768 def forward(self, x): 769 counter = 0 770 for param in self.parameters(): 771 counter += 1 772 773 return x * self.scale * counter 774 775 776class SuperModule(BasicModule): 777 def forward(self, x): 778 x = super().forward(x) 779 return x + 10.0 780 781 782class SuperModule2(BasicModule): 783 def forward(self, x): 784 return BasicModule.forward(self, x) 785 786 787class ComplicatedSuperParent(torch.nn.Module): 788 @classmethod 789 def custom_add(cls, x): 790 x = x + x 791 return x 792 793 794class SuperChildCallsClassMethod(ComplicatedSuperParent): 795 @classmethod 796 def child_func(cls, x): 797 x = super().custom_add(x) 798 return x 799 800 def forward(self, x): 801 x = self.child_func(x) 802 return x 803 804 805class HasAttrModule(torch.nn.Module): 806 def __init__(self) -> None: 807 super().__init__() 808 self.scale = torch.nn.Parameter(torch.randn(1, 10)) 809 810 def forward(self, x): 811 x = F.relu(x) 812 if hasattr(self, "scale"): 813 x *= self.scale 814 if hasattr(self, "scale2"): 815 x *= self.scale2 816 return x 817 818 819class EnumValues(torch.nn.ModuleDict): 820 def __init__( 821 self, 822 num_layers: int = 3, 823 ) -> None: 824 super().__init__() 825 for i in range(num_layers): 826 self.add_module("denselayer%d" % (i + 1), _Block()) 827 828 def forward(self, init_features): 829 features = [init_features] 830 for idx, layer in enumerate(self.values()): 831 new_features = layer(features) 832 features.append(new_features) 833 return torch.cat(features, 1) 834 835 836class AccessByKeys(torch.nn.ModuleDict): 837 def __init__( 838 self, 839 num_layers: int = 3, 840 ) -> None: 841 super().__init__() 842 for i in range(num_layers): 843 self.add_module("denselayer%d" % (i + 1), _Block()) 844 845 def forward(self, init_features): 846 features = [init_features] 847 for k in self.keys(): 848 new_features = self[k](features) 849 features.append(new_features) 850 return torch.cat(features, 1) 851 852 853class CallForwardDirectly(torch.nn.Module): 854 def __init__(self) -> None: 855 super().__init__() 856 self.layer1 = BasicModule() 857 self.layer2 = torch.nn.Linear(10, 10) 858 859 def forward(self, x): 860 x = self.layer1.forward(x) 861 x = self.layer2.forward(x) 862 return x 863 864 865class ConvCallForwardDirectly(torch.nn.Module): 866 def __init__(self) -> None: 867 super().__init__() 868 self.layer = torch.nn.Conv2d(3, 64, 3, 1, 1, bias=False) 869 870 def forward(self, x): 871 return self.layer.forward(x) 872 873 874class ConvTransposeCallForwardDirectly(torch.nn.Module): 875 def __init__(self) -> None: 876 super().__init__() 877 self.layer = torch.nn.ConvTranspose2d(4, 4, 4) 878 879 def forward(self, x): 880 return self.layer.forward(x) 881 882 883class ConvCallSuperForwardDirectly(torch.nn.Conv1d): 884 def __init__(self, in_channels, out_channels, kernel_size, **kwargs): 885 super().__init__( 886 in_channels=in_channels, 887 out_channels=out_channels, 888 kernel_size=kernel_size, 889 **kwargs, 890 ) 891 892 def forward(self, inputs, mask=None): 893 outputs = super().forward(inputs) 894 return outputs 895 896 897class ConvTransposeCallSuperForwardDirectly(torch.nn.ConvTranspose2d): 898 def __init__(self, in_channels, out_channels, kernel_size, **kwargs): 899 super().__init__( 900 in_channels=in_channels, 901 out_channels=out_channels, 902 kernel_size=kernel_size, 903 **kwargs, 904 ) 905 906 def forward(self, x): 907 if x.numel() > 0: 908 return super().forward(x) 909 output_shape = [ 910 ((i - 1) * d - 2 * p + (di * (k - 1) + 1) + op) 911 for i, p, di, k, d, op in zip( 912 x.shape[-2:], 913 self.padding, 914 self.dilation, 915 self.kernel_size, 916 self.stride, 917 self.output_padding, 918 ) 919 ] 920 output_shape = [x.shape[0], self.bias.shape[0]] + output_shape 921 return _NewEmptyTensorOp.apply(x, output_shape) # noqa: F821 922 923 924class ModuleNameString(torch.nn.Module): 925 def __init__(self) -> None: 926 super().__init__() 927 self.linear1 = torch.nn.Linear(10, 10) 928 929 def forward(self, x): 930 if self.__class__.__name__ == "ABC": 931 return 10 932 if self.linear1.__class__.__name__ == "Linear": 933 return F.relu(self.linear1(x) + 10) 934 return 11 935 936 937class SelfMutatingModule(torch.nn.Module): 938 def __init__(self, layer): 939 super().__init__() 940 self.layer = layer 941 self.counter = 0 942 943 def forward(self, x): 944 result = self.layer(x) + self.counter 945 self.counter += 1 946 return F.relu(result) 947 948 949class ModuleAttributePrecedenceBase(torch.nn.Module): 950 def linear(self, x, flag=None): 951 if flag: 952 return x * 2.0 953 return x * 3.0 954 955 956class ModuleAttributePrecedence(ModuleAttributePrecedenceBase): 957 def __init__(self) -> None: 958 super().__init__() 959 self.activation = torch.nn.ReLU() 960 self.linear = torch.nn.Linear(10, 10) 961 self.initializer = torch.ones([10, 10]) 962 self.scale = 0.5 963 964 def activation(self, x): 965 return x * 1.2 966 967 def initializer(self): 968 return torch.zeros([10, 10]) 969 970 def scale(self): 971 return 2.0 972 973 def forward(self, x): 974 # object attribute takes precedence unless it's a nn.Module 975 return self.activation(self.linear(self.initializer + x)) * self.scale 976 977 978class ModuleForwardHasGraphBreak(torch.nn.Module): 979 def __init__(self) -> None: 980 super().__init__() 981 self.layer1 = BasicModule() 982 self.layer2 = BasicModule() 983 self.layer3 = torch.nn.Sequential(BasicModule(), BasicModule()) 984 self.layer4 = torch.nn.ModuleList( 985 [ 986 torch.nn.Linear(10, 10), 987 torch.nn.ReLU(), 988 torch.nn.Linear(10, 10), 989 torch.nn.ReLU(), 990 ] 991 ) 992 self.layer5 = torch.nn.ModuleDict( 993 { 994 "0": torch.nn.Linear(10, 10), 995 } 996 ) 997 self.scale = torch.randn(1, 10) 998 999 def forward(self, x): 1000 """ 1001 This is used to test if the results of functions like `named_parameters` 1002 can be reconstructed correctly after graph break. 1003 1004 https://github.com/pytorch/torchdynamo/issues/1931 1005 """ 1006 x = self.layer1(x) 1007 params1 = dict(self.named_parameters()) 1008 params2 = list(self.parameters()) 1009 buffers1 = dict(self.named_buffers()) 1010 buffers2 = list(self.buffers()) 1011 modules1 = dict(self.named_modules()) 1012 modules2 = list(self.modules()) 1013 torch._dynamo.graph_break() 1014 y = modules2 1015 y = modules1 1016 y = buffers2 1017 y = buffers1 1018 y = params2 1019 y = params1 1020 x = ( 1021 self.layer2(x) 1022 + y["layer3.1.linear1.weight"] 1023 + y["layer4.2.weight"] 1024 + y["layer5.0.weight"] 1025 ) 1026 return x * self.scale 1027 1028 1029class ModuleGuardNameIsValid(torch.nn.ModuleDict): 1030 # Guard names should be valid python identifier as we use eval() to get 1031 # corresponding guard value. Some guard names come from source(module path) 1032 # where special symbols are valid. But they are not valid python identifier, 1033 # we should identify these pattern and rewrite them with getattr. 1034 def __init__(self) -> None: 1035 super().__init__() 1036 for i in range(2): 1037 self.add_module("l@yer-%d" % (i + 1), BasicModule()) 1038 1039 def forward(self, x): 1040 for layer in self.values(): 1041 x = layer(x) 1042 return x 1043 1044 1045class SequentialWithDuplicatedModule(torch.nn.Module): 1046 # Sequential module(self.layer) contains three duplicated ReLU module. 1047 def __init__(self) -> None: 1048 super().__init__() 1049 self.relu = torch.nn.ReLU() 1050 self.layer = torch.nn.Sequential( 1051 torch.nn.Linear(10, 20), 1052 self.relu, 1053 torch.nn.Linear(20, 20), 1054 self.relu, 1055 torch.nn.Linear(20, 10), 1056 self.relu, 1057 ) 1058 1059 def forward(self, x): 1060 return self.layer(x) 1061 1062 1063class SequentialWithDuplicatedModule2(torch.nn.Module): 1064 def __init__(self) -> None: 1065 super().__init__() 1066 self.relu = torch.nn.ReLU() 1067 self.layer = torch.nn.Sequential( 1068 collections.OrderedDict( 1069 [ 1070 ("linear1", torch.nn.Linear(10, 20)), 1071 ("relu1", self.relu), 1072 ("linear2", torch.nn.Linear(20, 20)), 1073 ("relu2", self.relu), 1074 ("linear3", torch.nn.Linear(20, 10)), 1075 ("relu3", self.relu), 1076 ] 1077 ) 1078 ) 1079 1080 def forward(self, x): 1081 return self.layer(x) 1082 1083 1084class ModuleComparison(torch.nn.Module): 1085 def __init__(self) -> None: 1086 super().__init__() 1087 self.layer0 = torch.nn.Linear(10, 10) 1088 self.layer1 = torch.nn.Linear(10, 10) 1089 self.layer2 = torch.nn.Linear(10, 10) 1090 1091 @property 1092 def encoder_layers(self): 1093 return [self.layer0, self.layer1, self.layer2] 1094 1095 def forward(self, x): 1096 for layer in self.encoder_layers: 1097 output = layer(x) 1098 if layer is None or layer == self.layer0: 1099 output = F.relu6(output) 1100 else: 1101 output = F.relu(output) 1102 return output 1103 1104 1105class ModulePatch1(torch.nn.Module): 1106 pass 1107 1108 1109class ModulePatch2(torch.nn.Module): 1110 def forward(self, x): 1111 return x - 1 1112 1113 1114class UnspecNonInlinableModule(torch.nn.Module): 1115 torchdynamo_force_dynamic = True # forced to be a UnspecializedNNModule 1116 1117 def forward(self, x): 1118 if x.sum() > 0: 1119 return x + 1 1120 else: 1121 return x - 1 1122 1123 1124class UnspecNonInlinableToplevelModule(torch.nn.Module): 1125 def __init__(self) -> None: 1126 super().__init__() 1127 self.m = UnspecNonInlinableModule() 1128 1129 def forward(self, x): 1130 return self.m(x) 1131 1132 1133def make_test(fn, expected_ops=None): 1134 def test_fn(self): 1135 return torch._dynamo.testing.standard_test( 1136 self, fn=fn, nargs=1, expected_ops=expected_ops 1137 ) 1138 1139 fn.eval() 1140 return test_fn 1141 1142 1143@contextlib.contextmanager 1144def temporary_tensor_subclass(torch_function=None): 1145 class TensorProxy(torch.Tensor): 1146 @classmethod 1147 def __torch_function__(cls, func, types, args=(), kwargs=None): 1148 if torch_function is not None: 1149 torch_function() 1150 return super().__torch_function__(func, types, args, kwargs) 1151 1152 torch._dynamo.config.traceable_tensor_subclasses.add(TensorProxy) 1153 try: 1154 yield TensorProxy 1155 finally: 1156 torch._dynamo.config.traceable_tensor_subclasses.remove(TensorProxy) 1157 1158 1159class NNModuleTests(torch._dynamo.test_case.TestCase): 1160 test_seq = make_test(Seq()) 1161 test_basicmodule1 = make_test(BasicModule()) 1162 test_basicmodule2 = make_test(BasicModule()) 1163 test_submodules1 = make_test(SubmoduleExample()) 1164 test_submodules2 = make_test(SubmoduleExample()) 1165 test_modulemethod1 = make_test(ModuleMethodCall()) 1166 test_modulemethod2 = make_test(ModuleMethodCall()) 1167 test_module_call_module_with_static_forward = make_test( 1168 ModuleCallModuleWithStaticForward() 1169 ) 1170 test_module_static_method = make_test(ModuleStaticMethodCall()) 1171 test_fnmember = make_test(FnMember()) 1172 test_fnmembercmp1 = make_test(FnMemberCmp(F.relu)) 1173 test_fnmembercmp2 = make_test(FnMemberCmp(None)) 1174 test_constloop = make_test(ConstLoop()) 1175 test_istraining1 = make_test(IsTrainingCheck()) 1176 test_istraining2 = make_test(IsTrainingCheck()) 1177 test_iseval1 = make_test(IsEvalCheck()) 1178 test_iseval2 = make_test(IsEvalCheck()) 1179 test_viamodulecall = make_test(ViaModuleCall()) 1180 test_isnonelayer = make_test(IsNoneLayer()) 1181 test_layerlist = make_test(LayerList()) 1182 test_tensorlist = make_test(TensorList()) 1183 test_intarg = make_test(IntArg()) 1184 test_cfgmod = make_test(CfgModule()) 1185 test_stringmember = make_test(StringMember()) 1186 test_modulelist = make_test(ModuleList()) 1187 test_modulelist_nested = make_test(NestedModuleList()) 1188 test_modulelist_custom = make_test(CustomGetItemModuleList()) 1189 test_moduledict = make_test(ModuleDict()) 1190 test_moduledict_custom = make_test(CustomGetItemModuleDict()) 1191 test_parameterdict = make_test(ParameterDict()) 1192 test_parameterdict_custom = make_test(CustomGetItemParameterDict()) 1193 test_super1 = make_test(SuperModule()) 1194 test_super2 = make_test(SuperModule2()) 1195 test_super_class_method = make_test(SuperChildCallsClassMethod()) 1196 test_children = make_test(Children()) 1197 test_named_children = make_test(NamedChildren()) 1198 test_densenet = make_test(DenseNetBlocks()) 1199 test_parameters1 = make_test(ParametersModule1()) 1200 test_parameters2 = make_test(ParametersModule2()) 1201 test_parameters3 = make_test(ParametersModule3(), expected_ops=5) 1202 test_parameters4 = make_test(ParametersModule4()) 1203 test_parameters5 = make_test(ParametersModule5()) 1204 test_hasattr = make_test(HasAttrModule()) 1205 test_enumvalues = make_test(EnumValues()) 1206 test_access_by_keys = make_test(AccessByKeys()) 1207 test_module_class_method = make_test(ModuleClassMethodCall()) 1208 test_module_property = make_test(ModuleProperty()) 1209 test_forward_directly = make_test(CallForwardDirectly()) 1210 test_module_name_string = make_test(ModuleNameString()) 1211 test_module_attribute_precedence = make_test(ModuleAttributePrecedence()) 1212 test_module_guard_name_is_valid = make_test(ModuleGuardNameIsValid()) 1213 test_sequential_with_duplicated_module = make_test(SequentialWithDuplicatedModule()) 1214 test_sequential_with_duplicated_module2 = make_test( 1215 SequentialWithDuplicatedModule2() 1216 ) 1217 test_module_comparison = make_test(ModuleComparison()) 1218 1219 def test_module_forward_has_graph_break(self): 1220 m = ModuleForwardHasGraphBreak() 1221 x = torch.rand([10, 10]) 1222 ref = m(x) 1223 opt_m = torch._dynamo.optimize("eager")(m) 1224 res = opt_m(x) 1225 self.assertTrue(torch.allclose(ref, res)) 1226 1227 def test_unsupportedmethod(self): 1228 m = UnsupportedMethodCall() 1229 i = torch.randn(10) 1230 cnt = torch._dynamo.testing.CompileCounter() 1231 opt_m = torch._dynamo.optimize(cnt)(m) 1232 r = opt_m(i) 1233 self.assertTrue(torch._dynamo.testing.same(r, m(i))) 1234 self.assertEqual(cnt.op_count, 5) 1235 1236 def test_unsupportedmodule(self): 1237 m = UnsupportedModuleCall() 1238 i = torch.randn(10) 1239 cnt = torch._dynamo.testing.CompileCounter() 1240 opt_m = torch._dynamo.optimize(cnt)(m) 1241 r = opt_m(i) 1242 self.assertTrue(torch._dynamo.testing.same(r, m(i))) 1243 self.assertEqual(cnt.op_count, 6) 1244 1245 def test_self_mutating1(self): 1246 m1 = torch.nn.Linear(10, 10) 1247 m2 = SelfMutatingModule(m1) 1248 m3 = SelfMutatingModule(m1) 1249 m4 = SelfMutatingModule(m1) 1250 i = torch.randn(10) 1251 out2 = [m2(i), m2(i), m2(i)] 1252 cnt = torch._dynamo.testing.CompileCounter() 1253 opt_m3 = torch._dynamo.optimize_assert(cnt)(m3) 1254 opt_m4 = torch._dynamo.optimize_assert(cnt)(m4) 1255 out3 = [opt_m3(i), opt_m3(i), opt_m3(i)] 1256 out4 = [opt_m4(i), opt_m4(i), opt_m4(i)] 1257 self.assertTrue(torch._dynamo.testing.same(out2, out3)) 1258 self.assertTrue(torch._dynamo.testing.same(out2, out4)) 1259 if torch._dynamo.config.assume_static_by_default: 1260 self.assertExpectedInline(cnt.frame_count, """2""") 1261 else: 1262 self.assertExpectedInline(cnt.frame_count, """1""") 1263 1264 @patch.object(torch._dynamo.config, "raise_on_ctx_manager_usage", False) 1265 def test_generation_tag(self): 1266 cnt = torch._dynamo.testing.CompileCounter() 1267 1268 # guarantee that we have installed 1269 # the generation tagging function 1270 with torch._dynamo.optimize_assert(cnt): 1271 pass 1272 1273 m1 = torch.nn.Linear(10, 10) 1274 prev_generation = GenerationTracker.get_generation_value(m1) 1275 cur_generation = prev_generation + 1 1276 1277 with torch._dynamo.optimize_assert(cnt): 1278 m2 = torch.nn.Linear(10, 10) 1279 1280 self.assertEqual(GenerationTracker.get_generation_value(m1), prev_generation) 1281 self.assertEqual(GenerationTracker.get_generation_value(m2), cur_generation) 1282 # check that newly constructed instances 1283 # also have the same generation (even if copied from an old instance) 1284 m3 = deepcopy(m1) 1285 self.assertEqual(GenerationTracker.get_generation_value(m3), cur_generation) 1286 1287 def test_simple_torch_function(self): 1288 def foo(x): 1289 # function call, twice to test wrapping 1290 x = F.sigmoid(x) 1291 x = F.sigmoid(x) 1292 # method call, twice to test wrapping 1293 x = x.sigmoid() 1294 x = x.sigmoid() 1295 return x 1296 1297 with temporary_tensor_subclass() as TensorProxy: 1298 x = torch.randn(1).as_subclass(TensorProxy) 1299 cnt = torch._dynamo.testing.CompileCounter() 1300 out1 = foo(x) 1301 opt_foo = torch._dynamo.optimize(cnt, nopython=True)(foo) 1302 out2 = opt_foo(x) 1303 1304 self.assertEqual(cnt.op_count, 4) 1305 self.assertTrue(torch._dynamo.testing.same(out1, out2)) 1306 1307 def test_torch_function_with_closure(self): 1308 def run(): 1309 def foo(x): 1310 # function call, twice to test wrapping 1311 x = F.sigmoid(x) 1312 x = F.sigmoid(x) 1313 # method call, twice to test wrapping 1314 x = x.sigmoid() 1315 x = x.sigmoid() 1316 return x 1317 1318 counter = 0 1319 1320 def function(): 1321 nonlocal counter 1322 # for now, only support reads from closure cells 1323 # TODO(future PR): support writes as well 1324 counter + 1 1325 1326 with temporary_tensor_subclass(function) as TensorProxy: 1327 x = torch.randn(1).as_subclass(TensorProxy) 1328 x = torch.randn(1) 1329 cnt = torch._dynamo.testing.CompileCounter() 1330 out1 = foo(x) 1331 opt_foo = torch._dynamo.optimize(cnt, nopython=True)(foo) 1332 out2 = opt_foo(x) 1333 1334 self.assertEqual(cnt.op_count, 4) 1335 self.assertTrue(torch._dynamo.testing.same(out1, out2)) 1336 1337 run() 1338 1339 def test_torch_mangled_class_name(self): 1340 original = TensorWithTFOverrideVariable.global_mangled_class_name 1341 results = [] 1342 1343 def instrumented(self, tx): 1344 result = original(self, tx) 1345 results.append(result) 1346 return result 1347 1348 TensorWithTFOverrideVariable.global_mangled_class_name = instrumented 1349 1350 def one_break(x): 1351 x = F.sigmoid(x) 1352 print() # force break 1353 x = x.sigmoid() 1354 return x 1355 1356 try: 1357 with temporary_tensor_subclass() as TensorProxy: 1358 x = torch.randn(1).as_subclass(TensorProxy) 1359 x1 = one_break(x) 1360 1361 cnt = torch._dynamo.testing.CompileCounter() 1362 opt_one_break = torch._dynamo.optimize(cnt)(one_break) 1363 x2 = opt_one_break(x) 1364 1365 self.assertTrue(torch._dynamo.testing.same(x1, x2)) 1366 self.assertEqual(cnt.frame_count, 2) 1367 self.assertEqual(cnt.op_count, 2) 1368 1369 compile_ids = set() 1370 for r in results: 1371 # A mangled classname looks like __subclass_TensorProxy_94524181138240_c0 1372 # where the last segment contains the compile_id. 1373 prefix = "__subclass_TensorProxy_" 1374 before, sep, after = r.partition(prefix) 1375 self.assertEqual(before, "") 1376 self.assertEqual(sep, prefix) 1377 1378 class_type_id, compile_id = after.split("_") 1379 self.assertTrue(class_type_id.isnumeric()) 1380 self.assertTrue(compile_id.startswith("c")) 1381 1382 cid = compile_id[1:] 1383 self.assertTrue(cid.isnumeric()) 1384 compile_ids.add(cid) 1385 1386 self.assertEqual(len(compile_ids), 3) 1387 1388 finally: 1389 TensorWithTFOverrideVariable.global_mangled_class_name = original 1390 1391 @patch.object(torch._dynamo.config, "raise_on_ctx_manager_usage", False) 1392 def test_nn_moduledict_contains(self): 1393 class M(torch.nn.Module): 1394 def __init__(self, module_dict): 1395 super().__init__() 1396 self.module_dict = module_dict 1397 1398 def forward(self, x): 1399 if "foo" in self.module_dict: 1400 x = torch.mul(x, 1.0) 1401 x = torch.add(x, 1.0) 1402 return x 1403 1404 module_dict = torch.nn.ModuleDict({"foo": torch.nn.Conv2d(1, 1, 1)}) 1405 m = M(module_dict) 1406 data = torch.randn(1) 1407 out1 = m(data) 1408 cnt = torch._dynamo.testing.CompileCounter() 1409 opt_m = torch._dynamo.optimize(cnt, nopython=True)(m) 1410 out2 = opt_m(data) 1411 self.assertEqual(cnt.op_count, 2) 1412 self.assertTrue(torch._dynamo.testing.same(out1, out2)) 1413 1414 module_dict = torch.nn.ModuleDict({"bar": torch.nn.Conv2d(1, 1, 1)}) 1415 m = M(module_dict) 1416 data = torch.randn(1) 1417 out1 = m(data) 1418 cnt = torch._dynamo.testing.CompileCounter() 1419 torch._dynamo.reset() 1420 opt_m = torch._dynamo.optimize(cnt, nopython=True)(m) 1421 out2 = opt_m(data) 1422 1423 self.assertEqual(cnt.op_count, 1) 1424 self.assertTrue(torch._dynamo.testing.same(out1, out2)) 1425 1426 module_dict = torch.nn.ModuleDict({"cat": torch.nn.Conv2d(1, 1, 1)}) 1427 pre = m(data) 1428 cnt.clear() 1429 1430 with torch._dynamo.optimize(cnt, nopython=False): 1431 opt_pre = m(data) 1432 m = M(module_dict) 1433 data = torch.randn(1) 1434 out1 = m(data) 1435 1436 out_post = m(data) 1437 self.assertEqual(cnt.frame_count, 1) 1438 self.assertEqual(cnt.op_count, 1) 1439 self.assertTrue(torch._dynamo.testing.same(pre, opt_pre)) 1440 self.assertTrue(torch._dynamo.testing.same(out1, out_post)) 1441 1442 # RuntimeError: SymIntArrayRef expected to contain only concrete integers 1443 @expectedFailureDynamic 1444 def test_lazy_module1(self): 1445 input_shape = (16, 3, 6, 7, 8) 1446 1447 cnt = torch._dynamo.testing.CompileCounter() 1448 module = LazyModule() 1449 1450 def test_static_module(): 1451 input = torch.ones(*input_shape) 1452 module(input) 1453 1454 # test no graph break 1455 opt_test_static_module = torch._dynamo.optimize(cnt, nopython=True)( 1456 test_static_module 1457 ) 1458 opt_test_static_module() 1459 1460 self.assertTrue( 1461 isinstance(module, MaterializedModule), 1462 "Module should be transformed to an instance of MaterializedModule.", 1463 ) 1464 self.assertEqual(module.param.shape, input_shape) 1465 1466 # test when mapped to UnspecializedNNModule 1467 module = LazyModule() 1468 1469 def test_unspecialized(): 1470 nonlocal module 1471 module = LazyModule() 1472 input = torch.ones(*input_shape) 1473 module(input) 1474 1475 opt_test_unspecialized = torch._dynamo.optimize(cnt)(test_unspecialized) 1476 opt_test_unspecialized() 1477 1478 self.assertTrue( 1479 isinstance(module, MaterializedModule), 1480 "Module should be transformed to an instance of MaterializedModule.", 1481 ) 1482 self.assertEqual(module.param.shape, input_shape) 1483 1484 # test with a static module in torch.* 1485 module = torch.nn.modules.LazyBatchNorm3d( 1486 affine=False, track_running_stats=False 1487 ) 1488 1489 cnt = torch._dynamo.testing.CompileCounter() 1490 1491 torch._dynamo.reset() 1492 1493 def test_torch_static(): 1494 input = torch.ones(*input_shape) 1495 return module(input) # fully materialized 1496 1497 # test no graph break 1498 opt_test_torch_static = torch._dynamo.optimize(cnt, nopython=True)( 1499 test_torch_static 1500 ) 1501 opt_test_torch_static() 1502 out = opt_test_torch_static() 1503 1504 self.assertTrue(same(out, module(torch.ones(*input_shape)))) 1505 1506 self.assertTrue( 1507 isinstance(module, torch.nn.modules.batchnorm.BatchNorm3d), 1508 "Module should be transformed to an instance of BatchNorm3d.", 1509 ) 1510 self.assertEqual(cnt.frame_count, 1, "No guards should have triggered.") 1511 1512 # RuntimeError: SymIntArrayRef expected to contain only concrete integers 1513 @expectedFailureDynamic 1514 def test_lazy_module2(self): 1515 # Test FX graph 'call_module' works well if argument is lazy module 1516 m = LazyMLP() 1517 x = torch.rand([10, 10]) 1518 opt_m = torch._dynamo.optimize("eager", nopython=True)(m) 1519 # We should run compile mode firstly, otherwise the module 1520 # would be initialized when running eager mode. 1521 res = opt_m(x) 1522 ref = m(x) 1523 self.assertTrue(torch.allclose(ref, res)) 1524 1525 # RuntimeError: SymIntArrayRef expected to contain only concrete integers 1526 @expectedFailureDynamic 1527 @unittest.skipIf(not torch.cuda.is_available(), "requires cuda") 1528 def test_lazy_module3(self): 1529 m = LazyMLP() 1530 x = torch.rand([10, 10]) 1531 cnt = torch._dynamo.testing.CompileCounter() 1532 opt_m = torch._dynamo.optimize(cnt, nopython=True)(m) 1533 # first iteration 1534 res = opt_m(x) 1535 ref = m(x) 1536 self.assertTrue(torch.allclose(ref, res)) 1537 # move to cuda and second iteration 1538 m = m.to("cuda") 1539 x = x.to("cuda") 1540 res = opt_m(x) 1541 ref = m(x) 1542 self.assertTrue(torch.allclose(ref, res)) 1543 self.assertEqual(cnt.frame_count, 2) 1544 1545 # RuntimeError: SymIntArrayRef expected to contain only concrete integers 1546 @expectedFailureDynamic 1547 def test_lazy_module4(self): 1548 m = LazyMLP() 1549 x = torch.rand([10, 10]) 1550 cnt = torch._dynamo.testing.CompileCounter() 1551 opt_m = torch._dynamo.optimize(cnt, nopython=True)(m) 1552 # first iteration 1553 res = opt_m(x) 1554 ref = m(x) 1555 self.assertTrue(torch.allclose(ref, res)) 1556 # input shape changed and second iteration 1557 x = torch.rand([20, 20]) 1558 try: 1559 opt_m(x) 1560 except RuntimeError: 1561 self.assertIn("must have same reduction dim", traceback.format_exc()) 1562 1563 # RuntimeError: SymIntArrayRef expected to contain only concrete integers 1564 @expectedFailureDynamic 1565 def test_lazy_module5(self): 1566 # Test lazy module works well with list/tuple input 1567 m = LazyModuleWithListInput() 1568 x = [torch.rand([5, 5])] * 3 + [None] 1569 opt_m = torch._dynamo.optimize("eager", nopython=True)(m) 1570 res = opt_m(x) 1571 ref = m(x) 1572 self.assertTrue(torch.allclose(ref, res)) 1573 1574 # RuntimeError: SymIntArrayRef expected to contain only concrete integers 1575 @expectedFailureDynamic 1576 def test_lazy_module6(self): 1577 # Test new lazy submodule in lazy module's initialize_parameters 1578 m = LazyModuleWithLazySubmodule() 1579 x = [torch.rand([5, 5])] * 3 1580 opt_m = torch._dynamo.optimize("eager", nopython=True)(m) 1581 res = opt_m(x) 1582 ref = m(x) 1583 self.assertTrue(torch.allclose(ref, res)) 1584 1585 # RuntimeError: SymIntArrayRef expected to contain only concrete integers 1586 @expectedFailureDynamic 1587 def test_lazy_module7(self): 1588 # Test lazy module works well with namedtuple/dict input 1589 m = LazyModuleWithNamedTupleInput() 1590 x = MyInput( 1591 x={"a": [torch.rand([5, 5])] * 3, "b": torch.rand([5, 5])}, 1592 y=torch.rand([5, 5]), 1593 ) 1594 opt_m = torch.compile(backend="eager", fullgraph=True)(m) 1595 res = opt_m(x) 1596 ref = m(x) 1597 self.assertTrue(torch.allclose(ref, res)) 1598 1599 def test_lazy_module_no_cls_to_become(self): 1600 # make sure super() works in the case where cls_to_become is None 1601 m = LazyChildModuleNoClsToBecome() 1602 x = torch.rand(2, 2) 1603 opt_m = torch._dynamo.optimize("eager", nopython=True)(m) 1604 res = opt_m(x) 1605 ref = m(x) 1606 self.assertTrue(torch.allclose(ref, res)) 1607 1608 def test_lazy_module_kwargs(self): 1609 m = LazyModuleKwArgs() 1610 x = [torch.rand([5, 5])] * 3 1611 y = [torch.rand([5, 5])] * 2 1612 opt_m = torch.compile(backend="eager", fullgraph=True)(m) 1613 exp_res = m(x, y) 1614 self.assertTrue(torch.allclose(exp_res, opt_m(x, y))) 1615 1616 def test_call_fn_with_non_const_inputs_safe(self): 1617 class ModuleSpecialFwd(torch.nn.Module): 1618 def __init__(self) -> None: 1619 super().__init__() 1620 self.conv = torch.nn.Conv2d( 1621 in_channels=3, out_channels=20, kernel_size=(5, 5) 1622 ) 1623 1624 def _conv_forward(self, x): 1625 return self.conv._conv_forward(x, self.conv.weight, self.conv.bias) 1626 1627 def forward(self, x): 1628 return self._conv_forward(x) 1629 1630 mod = ModuleSpecialFwd() 1631 rx = torch.randn([3, 10, 10]) 1632 real = mod(rx) 1633 graph, _ = torch._dynamo.export(mod)(rx) 1634 self.assertTrue(torch._dynamo.testing.same(real, graph(rx))) 1635 1636 def test_conv_call_forward_directly(self): 1637 m = ConvCallForwardDirectly() 1638 x = torch.rand([4, 3, 9, 9]) 1639 ref = m(x) 1640 opt_m = torch.compile(backend="eager", fullgraph=True)(m) 1641 res = opt_m(x) 1642 self.assertTrue(torch.allclose(ref, res)) 1643 1644 def test_conv_transpose_call_forward_directly(self): 1645 m = ConvTransposeCallForwardDirectly() 1646 x = torch.rand([4, 4, 4, 4]) 1647 ref = m(x) 1648 opt_m = torch.compile(backend="eager", fullgraph=True)(m) 1649 res = opt_m(x) 1650 self.assertTrue(torch.allclose(ref, res)) 1651 1652 def test_conv_call_super_forward_directly(self): 1653 x = torch.randn(4, 4) 1654 m = ConvCallSuperForwardDirectly(4, 4, 4) 1655 ref = m(x) 1656 opt_m = torch.compile(backend="eager", fullgraph=True)(m) 1657 res = opt_m(x) 1658 self.assertTrue(torch.allclose(ref, res)) 1659 1660 def test_conv_transpose_call_super_forward_directly(self): 1661 x = torch.randn(4, 4, 4) 1662 m = ConvTransposeCallSuperForwardDirectly(4, 4, 4) 1663 ref = m(x) 1664 opt_m = torch.compile(backend="eager", fullgraph=True)(m) 1665 res = opt_m(x) 1666 self.assertTrue(torch.allclose(ref, res)) 1667 1668 1669class MockModule(torch.nn.Module): 1670 def __init__(self) -> None: 1671 super().__init__() 1672 self.relu = torch.nn.ReLU() 1673 self.linear = torch.nn.Linear(10, 10) 1674 self.buf0 = torch.nn.Buffer(torch.randn(10, 10)) 1675 1676 def forward(self, x): 1677 return self.relu(self.linear(x) + self.buf0) 1678 1679 1680class OptimizedModuleTest(torch._dynamo.test_case.TestCase): 1681 def test_nn_module(self): 1682 mod = MockModule() 1683 cnt = torch._dynamo.testing.CompileCounter() 1684 opt_mod = torch._dynamo.optimize(cnt)(mod) 1685 self.assertIsInstance(opt_mod, torch._dynamo.OptimizedModule) 1686 1687 x = torch.randn(10, 10) 1688 self.assertTrue(torch._dynamo.testing.same(mod(x), opt_mod(x))) 1689 self.assertEqual(cnt.frame_count, 1) 1690 1691 @torch._dynamo.config.patch(guard_nn_modules=True) 1692 def test_attr_precedence(self): 1693 class Mod(torch.nn.Module): 1694 def __init__(self) -> None: 1695 super().__init__() 1696 self.a = 3 1697 1698 def forward(self, x, c=4): 1699 return x * c 1700 1701 def linear(self, x): 1702 return x 1703 1704 def b(self, x): 1705 raise RuntimeError("Should not be called") 1706 1707 class MyMod(Mod): 1708 def __init__(self) -> None: 1709 super().__init__() 1710 self.linear = torch.nn.Linear(11, 11) 1711 self.a = 2 1712 self.b = 2 1713 self.scale = 1 1714 1715 def scale(self, x): 1716 # Should not be called because it is shadowed by the instance 1717 # attribute 1718 raise RuntimeError("Should not be called") 1719 1720 def forward(self, x, c=None): 1721 return self.linear(x) * self.a * self.b * self.scale 1722 1723 mod = MyMod() 1724 x = torch.ones(3, 3) 1725 ref = mod(x) 1726 1727 cnts = torch._dynamo.testing.CompileCounter() 1728 opt_mod = torch.compile(mod, backend=cnts) 1729 opt_mod(torch.ones(3, 3)) 1730 res = opt_mod(torch.ones(3, 3)) 1731 1732 self.assertEqual(cnts.frame_count, 1) 1733 self.assertEqual(ref, res) 1734 1735 def test_to(self): 1736 mod = MockModule() 1737 cnt = torch._dynamo.testing.CompileCounter() 1738 opt_mod = torch._dynamo.optimize(cnt)(mod) 1739 x = torch.randn(10, 10) 1740 self.assertTrue(torch._dynamo.testing.same(mod(x), opt_mod(x))) 1741 self.assertEqual(cnt.frame_count, 1) 1742 1743 # Ensure that there is no recompilation 1744 opt_mod(x) 1745 self.assertEqual(cnt.frame_count, 1) 1746 1747 opt_mod = opt_mod.to(device="cpu").to(dtype=torch.float64) 1748 self.assertIsInstance(opt_mod, torch._dynamo.OptimizedModule) 1749 x = torch.randn(10, 10).to(dtype=torch.float64) 1750 opt_mod(x) 1751 # Ensure that there is a recompilation 1752 self.assertEqual(cnt.frame_count, 2) 1753 1754 # Ensure that there is no recompilation 1755 opt_mod(x) 1756 self.assertEqual(cnt.frame_count, 2) 1757 1758 torch._dynamo.reset() 1759 opt_mod(x) 1760 self.assertEqual(cnt.frame_count, 3) 1761 1762 @torch._dynamo.config.patch(guard_nn_modules=True) 1763 def test_param_order(self): 1764 class MyModule(torch.nn.Module): 1765 def __init__(self) -> None: 1766 super().__init__() 1767 self.param1 = torch.nn.Parameter(torch.ones([1])) 1768 self.param2 = torch.nn.Parameter(torch.ones([2])) 1769 1770 def forward(self, x): 1771 return x 1772 1773 mod = MyModule() 1774 coeffs = [2, 3] 1775 1776 def fn(x): 1777 for idx, p in enumerate(mod.parameters()): 1778 x += p.sum() * coeffs[idx] 1779 1780 for idx, p in enumerate(mod.named_parameters()): 1781 x += p[1].sum() * coeffs[idx] 1782 1783 return x 1784 1785 ref = fn(torch.ones(1)) 1786 cnts = torch._dynamo.testing.CompileCounter() 1787 opt_fn = torch._dynamo.optimize(cnts)(fn) 1788 res = opt_fn(torch.ones(1)) 1789 1790 self.assertEqual(ref, res) 1791 self.assertEqual(cnts.frame_count, 1) 1792 1793 mod._parameters["param1"] = mod._parameters.pop("param1") 1794 ref = fn(torch.ones(1)) 1795 res = opt_fn(torch.ones(1)) 1796 1797 self.assertEqual(ref, res) 1798 self.assertEqual(cnts.frame_count, 2) 1799 1800 @torch._dynamo.config.patch(guard_nn_modules=True) 1801 def test_buffer_order(self): 1802 class MyModule(torch.nn.Module): 1803 def __init__(self) -> None: 1804 super().__init__() 1805 self.b1 = torch.nn.Buffer(torch.ones([1])) 1806 self.b2 = torch.nn.Buffer(torch.ones([2])) 1807 1808 def forward(self, x): 1809 return x 1810 1811 mod = MyModule() 1812 coeffs = [2, 3] 1813 1814 def fn(x): 1815 for idx, p in enumerate(mod.buffers()): 1816 x += p.sum() * coeffs[idx] 1817 1818 for idx, p in enumerate(mod.named_buffers()): 1819 x += p[1].sum() * coeffs[idx] 1820 1821 return x 1822 1823 ref = fn(torch.ones(1)) 1824 cnts = torch._dynamo.testing.CompileCounter() 1825 opt_fn = torch._dynamo.optimize(cnts)(fn) 1826 res = opt_fn(torch.ones(1)) 1827 1828 self.assertEqual(ref, res) 1829 self.assertEqual(cnts.frame_count, 1) 1830 1831 mod._buffers["b1"] = mod._buffers.pop("b1") 1832 ref = fn(torch.ones(1)) 1833 res = opt_fn(torch.ones(1)) 1834 1835 self.assertEqual(ref, res) 1836 self.assertEqual(cnts.frame_count, 2) 1837 1838 @torch._dynamo.config.patch(guard_nn_modules=True) 1839 def test_module_order(self): 1840 class MyModule(torch.nn.Module): 1841 def __init__(self) -> None: 1842 super().__init__() 1843 self.linear1 = torch.nn.Linear(3, 3) 1844 self.linear2 = torch.nn.Linear(10, 10) 1845 1846 def forward(self, x): 1847 return x 1848 1849 mod = MyModule() 1850 coeffs = [2, 3, 4] 1851 1852 coeffs_for_mod = {mod: 10, mod.linear1: 20, mod.linear2: 30} 1853 1854 # Check order of _modules 1855 def fn(x): 1856 for idx, p in enumerate(mod.modules()): 1857 # Something silly to force depedency on the order 1858 x += coeffs_for_mod[p] * coeffs[idx] 1859 for idx, p in enumerate(mod.named_modules()): 1860 x += coeffs_for_mod[p[1]] * coeffs[idx] 1861 for idx, p in enumerate(mod.children()): 1862 x += coeffs_for_mod[p] * coeffs[idx] 1863 for idx, p in enumerate(mod.named_children()): 1864 x += coeffs_for_mod[p[1]] * coeffs[idx] 1865 return x 1866 1867 ref = fn(torch.ones(1)) 1868 cnts = torch._dynamo.testing.CompileCounter() 1869 opt_fn = torch._dynamo.optimize(cnts)(fn) 1870 res = opt_fn(torch.ones(1)) 1871 1872 self.assertEqual(ref, res) 1873 self.assertEqual(cnts.frame_count, 1) 1874 1875 mod._modules["linear1"] = mod._modules.pop("linear1") 1876 ref = fn(torch.ones(1)) 1877 res = opt_fn(torch.ones(1)) 1878 1879 self.assertEqual(ref, res) 1880 self.assertEqual(cnts.frame_count, 2) 1881 1882 def test_attr(self): 1883 class MockModule(torch.nn.Module): 1884 def __init__(self) -> None: 1885 super().__init__() 1886 self.linear = torch.nn.Linear(10, 10) 1887 self.buf0 = torch.nn.Buffer(torch.randn(10, 10)) 1888 1889 def forward(self, x): 1890 return self.r(torch.sin(x)) + self.buf0 1891 1892 mod = MockModule() 1893 opt_mod = torch._dynamo.optimize("eager")(mod) 1894 1895 # Check parameters and buffers 1896 for p1, p2 in zip(mod.parameters(), opt_mod.parameters()): 1897 self.assertTrue(id(p1) == id(p2)) 1898 for b1, b2 in zip(mod.buffers(), opt_mod.buffers()): 1899 self.assertTrue(id(b1) == id(b2)) 1900 1901 def get_parameter_dtype(mod: torch.nn.Module): 1902 parameters_and_buffers = itertools.chain(mod.parameters(), mod.buffers()) 1903 return next(parameters_and_buffers).dtype 1904 1905 opt_mod = torch._dynamo.optimize("eager")(get_parameter_dtype) 1906 out_dtype = opt_mod(mod) 1907 self.assertEqual(out_dtype, torch.float32) 1908 1909 def test_dir(self): 1910 class MockModule(torch.nn.Module): 1911 def __init__(self) -> None: 1912 super().__init__() 1913 self.linear = torch.nn.Linear(10, 10) 1914 self.buf0 = torch.nn.Buffer(torch.nn.Buffer(torch.randn(10, 10))) 1915 self.register_parameter( 1916 name="param0", param=torch.nn.Parameter(torch.randn(10, 10)) 1917 ) 1918 1919 def forward(self, x): 1920 return self.r(torch.sin(x)) + self.buf0 1921 1922 mod = MockModule() 1923 mod_keys = dir(mod) 1924 opt_mod = torch._dynamo.optimize("eager")(mod) 1925 opt_mod_keys = dir(opt_mod) 1926 1927 # Check user-defined attributes, parameters and buffers 1928 self.assertIn("linear", opt_mod_keys) 1929 self.assertIn("buf0", opt_mod_keys) 1930 self.assertIn("param0", opt_mod_keys) 1931 1932 # Check all attributes, parameters and buffers 1933 self.assertTrue(len(set(mod_keys).difference(opt_mod_keys)) == 0) 1934 1935 def test_no_recompile_on_nn_guarded_modules(self): 1936 size = (10, 10) 1937 cache_size_limit = 1 1938 num_submodules = 4 1939 cnts = torch._dynamo.testing.CompileCounterWithBackend("eager") 1940 1941 class SubModule(torch.nn.Module): 1942 def __init__(self) -> None: 1943 super().__init__() 1944 self.linear = torch.nn.Linear(*size) 1945 1946 def forward(self, x): 1947 a = torch.sin(torch.cos(x)) 1948 return self.linear(a) 1949 1950 class MockModule(torch.nn.Module): 1951 def __init__(self) -> None: 1952 super().__init__() 1953 self.mods = [SubModule() for _ in range(num_submodules)] 1954 self.mods = [torch.compile(mod, backend=cnts) for mod in self.mods] 1955 1956 def forward(self, x): 1957 for mod in self.mods: 1958 x = mod(x) 1959 return x 1960 1961 mod = MockModule() 1962 # Each submod is compiled separately and has a different nn module 1963 # guard. Ensure that recompilation logic is handle correctly. 1964 with unittest.mock.patch( 1965 "torch._dynamo.config.error_on_recompile", True 1966 ), unittest.mock.patch( 1967 "torch._dynamo.config.cache_size_limit", 1968 cache_size_limit, 1969 ): 1970 x = torch.randn(*size, requires_grad=True) 1971 mod(x) 1972 if torch._dynamo.config.inline_inbuilt_nn_modules: 1973 self.assertEqual(cnts.frame_count, 1) 1974 else: 1975 self.assertEqual(cnts.frame_count, num_submodules) 1976 1977 @patch.object(torch._dynamo.config, "accumulated_cache_size_limit", 2) 1978 @patch.object(torch._dynamo.config, "inline_inbuilt_nn_modules", False) 1979 def test_recompile_limit_on_freed_module(self): 1980 class Mod(torch.nn.Module): 1981 def __init__(self) -> None: 1982 super().__init__() 1983 self.lin = torch.nn.Linear(5, 5) 1984 1985 def forward(self, x): 1986 return self.lin(x) 1987 1988 def fn(x, mod): 1989 return mod(x) 1990 1991 cnts = torch._dynamo.testing.CompileCounterWithBackend("eager") 1992 opt_mod = torch.compile(fn, backend=cnts) 1993 for i in range(8): 1994 mod = Mod() 1995 opt_mod(torch.randn(5, 5), mod) 1996 1997 # fn compiles twice 1998 self.assertEqual(cnts.frame_count, 2) 1999 2000 @patch.object(torch._dynamo.config, "inline_inbuilt_nn_modules", True) 2001 def test_inline_inbuilt_nn_modules(self): 2002 size = (10, 10) 2003 cache_size_limit = 1 2004 num_submodules = 4 2005 cnts = torch._dynamo.testing.CompileCounterWithBackend("eager") 2006 2007 class SubModule(torch.nn.Module): 2008 def __init__(self) -> None: 2009 super().__init__() 2010 self.linear = torch.nn.Linear(*size) 2011 2012 def forward(self, x): 2013 a = torch.sin(torch.cos(x)) 2014 return self.linear(a) 2015 2016 class MockModule(torch.nn.Module): 2017 def __init__(self) -> None: 2018 super().__init__() 2019 self.mods = [SubModule() for _ in range(num_submodules)] 2020 self.mods = [torch.compile(mod, backend=cnts) for mod in self.mods] 2021 2022 def forward(self, x): 2023 for mod in self.mods: 2024 x = mod(x) 2025 return x 2026 2027 mod = MockModule() 2028 # Each submod is compiled separately and has a different nn module 2029 # guard. Ensure that recompilation logic is handle correctly. 2030 with unittest.mock.patch( 2031 "torch._dynamo.config.error_on_recompile", True 2032 ), unittest.mock.patch( 2033 "torch._dynamo.config.cache_size_limit", 2034 cache_size_limit, 2035 ): 2036 x = torch.randn(*size, requires_grad=True) 2037 mod(x) 2038 self.assertEqual(cnts.frame_count, 1) 2039 2040 def test_cache_size_limit_on_guarded_nn_modules(self): 2041 cache_size_limit = 2 2042 num_submodules = 4 2043 cnts = torch._dynamo.testing.CompileCounterWithBackend("eager") 2044 2045 class SubModule(torch.nn.Module): 2046 def __init__(self) -> None: 2047 super().__init__() 2048 self.relu = torch.nn.ReLU() 2049 2050 def forward(self, x): 2051 a = torch.sin(torch.cos(x)) 2052 return self.relu(a) 2053 2054 class MockModule(torch.nn.Module): 2055 def __init__(self) -> None: 2056 super().__init__() 2057 self.mods = [SubModule() for _ in range(num_submodules)] 2058 self.mods = [torch.compile(mod, backend=cnts) for mod in self.mods] 2059 2060 def forward(self, x): 2061 for mod in self.mods: 2062 x = mod(x) 2063 return x 2064 2065 mod = MockModule() 2066 # For the third iteration, we would reach the cache size limit, and 2067 # therefore the total number of expected frame count is 2 * 2068 # num_submodules. 2069 with unittest.mock.patch( 2070 "torch._dynamo.config.cache_size_limit", 2071 cache_size_limit, 2072 ): 2073 for size in [ 2074 (4,), 2075 (4, 4), 2076 (4, 4, 4), 2077 ]: 2078 x = torch.randn(size) 2079 mod(x) 2080 if torch._dynamo.config.inline_inbuilt_nn_modules: 2081 self.assertEqual(cnts.frame_count, 2) 2082 else: 2083 self.assertEqual(cnts.frame_count, 2 * num_submodules) 2084 2085 def test_recursion(self): 2086 mod = MockModule() 2087 cnt = torch._dynamo.testing.CompileCounter() 2088 opt_mod = torch._dynamo.optimize(cnt)(mod) 2089 2090 for _ in range(5): 2091 opt_mod = torch._dynamo.optimize(cnt)(opt_mod) 2092 opt_mod(torch.randn(10, 10)) 2093 self.assertEqual(cnt.frame_count, 1) 2094 2095 def test_composition(self): 2096 class InnerModule(torch.nn.Module): 2097 def __init__(self) -> None: 2098 super().__init__() 2099 self.relu = torch.nn.ReLU() 2100 2101 def forward(self, x): 2102 return self.relu(torch.sin(x)) 2103 2104 opt_inner_mod = InnerModule() 2105 2106 class OuterModule(torch.nn.Module): 2107 def __init__(self) -> None: 2108 super().__init__() 2109 self.mod = opt_inner_mod 2110 2111 def forward(self, x): 2112 return self.mod(torch.cos(x)) 2113 2114 outer_mod = OuterModule() 2115 cnt = torch._dynamo.testing.CompileCounter() 2116 opt_outer_mod = torch._dynamo.optimize(cnt)(outer_mod) 2117 2118 x = torch.randn(4) 2119 self.assertIsInstance(opt_outer_mod, torch._dynamo.OptimizedModule) 2120 self.assertTrue(torch._dynamo.testing.same(outer_mod(x), opt_outer_mod(x))) 2121 self.assertEqual(cnt.frame_count, 1) 2122 2123 def test_composition_with_opt_mod(self): 2124 class InnerModule(torch.nn.Module): 2125 def __init__(self) -> None: 2126 super().__init__() 2127 self.relu = torch.nn.ReLU() 2128 2129 def forward(self, x): 2130 return self.relu(torch.sin(x)) 2131 2132 inner_mod = InnerModule() 2133 cnt = torch._dynamo.testing.CompileCounter() 2134 opt_inner_mod = torch._dynamo.optimize(cnt)(inner_mod) 2135 2136 class OuterModule(torch.nn.Module): 2137 def __init__(self) -> None: 2138 super().__init__() 2139 self.mod = opt_inner_mod 2140 2141 def forward(self, x): 2142 return self.mod(torch.cos(x)) 2143 2144 outer_mod = OuterModule() 2145 opt_outer_mod = torch._dynamo.optimize(cnt)(outer_mod) 2146 2147 x = torch.randn(4) 2148 self.assertIsInstance(opt_outer_mod, torch._dynamo.OptimizedModule) 2149 self.assertTrue(torch._dynamo.testing.same(outer_mod(x), opt_outer_mod(x))) 2150 # There will be a graph break for the inner mod being OptimizedModule 2151 self.assertEqual(cnt.frame_count, 2) 2152 2153 def test_module_patch(self): 2154 mod = ModulePatch1() 2155 mod.forward = types.MethodType(ModulePatch2.forward, mod) 2156 2157 def fn(x): 2158 return mod(x) 2159 2160 self.assertTrue( 2161 torch.allclose( 2162 torch._dynamo.optimize("eager", nopython=True)(fn)(torch.ones(10)), 2163 torch.zeros(1), 2164 ) 2165 ) 2166 2167 @patch.object(torch._dynamo.config, "skip_nnmodule_hook_guards", False) 2168 def test_hooks_outer(self): 2169 class TestModule(torch.nn.Module): 2170 def forward(self, x: torch.Tensor) -> torch.Tensor: 2171 return 2 * x + 1 2172 2173 m = TestModule() 2174 2175 def forward_hook( 2176 module: torch.nn.Module, inputs: Tuple[torch.Tensor], output: torch.Tensor 2177 ) -> torch.Tensor: 2178 return 2 * output + 1 2179 2180 handle = m.register_forward_hook(forward_hook) 2181 inp = torch.tensor(1.0, requires_grad=True) 2182 2183 failure_reason = None 2184 2185 def guard_fail_fn(failure): 2186 nonlocal failure_reason 2187 failure_reason = failure[0] 2188 2189 compiled_m = torch._dynamo.optimize( 2190 guard_fail_fn=guard_fail_fn, backend="eager" 2191 )(m) 2192 2193 self.assertEqual(compiled_m(inp), m(inp)) 2194 self.assertEqual(compiled_m(inp).item(), 7) 2195 self.assertTrue(failure_reason is None) 2196 2197 # what if we remove our hook? we should recompile? 2198 handle.remove() 2199 self.assertEqual(compiled_m(inp), m(inp)) 2200 self.assertEqual(compiled_m(inp).item(), 3) 2201 # self.assertTrue(failure_reason == "hook") 2202 2203 """ 2204 Summary: 2205 - removing a hook doesn't fail a guard, because we weren't compiling the hook 2206 (at least into the same graph) as forward in the first place! We do correctly 2207 omit calling the removed hook, but since this hook is a post forward hook, 2208 the 'RETURN' from forward is breaking the graph. 2209 2210 Why is 'forward' the entrypoint to an InstructionTranslator, after I changed 2211 the eval_frame entrypoint to Module.__call__? 2212 """ 2213 2214 @patch.object(torch._dynamo.config, "skip_nnmodule_hook_guards", False) 2215 def test_hooks_inner(self): 2216 class TestModule(torch.nn.Module): 2217 def forward(self, x: torch.Tensor) -> torch.Tensor: 2218 return 2 * x + 1 2219 2220 m = TestModule() 2221 2222 def forward_hook( 2223 module: torch.nn.Module, inputs: Tuple[torch.Tensor], output: torch.Tensor 2224 ) -> torch.Tensor: 2225 return 2 * output + 1 2226 2227 handle = m.register_forward_hook(forward_hook) 2228 2229 def outer_func(tensor): 2230 x = tensor * 2 + 1 2231 y = m(x) 2232 return y 2233 2234 inp = torch.tensor(1.0, requires_grad=True) 2235 2236 failure_reason = None 2237 2238 def guard_fail_fn(failure): 2239 nonlocal failure_reason 2240 failure_reason = failure[0] 2241 2242 cc = torch._dynamo.testing.CompileCounterWithBackend("aot_eager") 2243 compiled_func = torch._dynamo.optimize( 2244 guard_fail_fn=guard_fail_fn, 2245 backend=cc, 2246 )(outer_func) 2247 2248 self.assertEqual(compiled_func(inp), outer_func(inp)) 2249 self.assertEqual(compiled_func(inp).item(), 15) 2250 2251 # We are compiling 1 big graph for all 3 functions including the hook. 2252 self.assertEqual(cc.frame_count, 1) 2253 self.assertEqual(cc.op_count, 6) 2254 2255 # If we remove the hook, we should recompile 2256 handle.remove() 2257 self.assertEqual(compiled_func(inp), outer_func(inp)) 2258 self.assertEqual(compiled_func(inp).item(), 7) 2259 self.assertTrue("forward_hooks" in failure_reason) 2260 self.assertEqual(cc.frame_count, 1 + 1) 2261 self.assertEqual(cc.op_count, 6 + 4) 2262 2263 # what if instead of removing, we alter our hook? 2264 torch._dynamo.reset() 2265 m = TestModule() 2266 handle = m.register_forward_hook(forward_hook) 2267 failure_reason = None 2268 self.assertEqual(compiled_func(inp), outer_func(inp)) 2269 self.assertEqual(compiled_func(inp).item(), 15) 2270 2271 def new_forward_hook( 2272 module: torch.nn.Module, inputs: Tuple[torch.Tensor], output: torch.Tensor 2273 ) -> torch.Tensor: 2274 return 2 * output + 2 2275 2276 m._forward_hooks[handle.id] = new_forward_hook 2277 self.assertEqual(compiled_func(inp), outer_func(inp)) 2278 self.assertEqual(compiled_func(inp).item(), 16) 2279 self.assertRegex(failure_reason, r"___check_obj_id\(L\['m'\]._forward_hooks") 2280 2281 @patch.object(torch._dynamo.config, "guard_nn_modules", False) 2282 @patch.object(torch._dynamo.config, "skip_nnmodule_hook_guards", True) 2283 @patch.object(torch._dynamo.config, "inline_inbuilt_nn_modules", False) 2284 def test_hooks_skip_guards(self): 2285 class TestModule(torch.nn.Module): 2286 def forward(self, x: torch.Tensor) -> torch.Tensor: 2287 return 2 * x + 1 2288 2289 m = TestModule() 2290 2291 def forward_hook( 2292 module: torch.nn.Module, inputs: Tuple[torch.Tensor], output: torch.Tensor 2293 ) -> torch.Tensor: 2294 return 2 * output + 1 2295 2296 handle = m.register_forward_hook(forward_hook) 2297 2298 def outer_func(tensor): 2299 x = tensor * 2 + 1 2300 y = m(x) 2301 return y 2302 2303 inp = torch.tensor(1.0, requires_grad=True) 2304 2305 failure_reason = None 2306 2307 def guard_fail_fn(failure): 2308 nonlocal failure_reason 2309 failure_reason = failure[0] 2310 2311 cc = torch._dynamo.testing.CompileCounterWithBackend("aot_eager") 2312 compiled_func = torch._dynamo.optimize( 2313 guard_fail_fn=guard_fail_fn, 2314 backend=cc, 2315 )(outer_func) 2316 2317 m = TestModule() 2318 handle = m.register_forward_hook(forward_hook) 2319 failure_reason = None 2320 self.assertEqual(compiled_func(inp), outer_func(inp)) 2321 self.assertEqual(compiled_func(inp).item(), 15) 2322 self.assertEqual(cc.frame_count, 1) 2323 self.assertEqual(cc.op_count, 6) 2324 2325 # if we remove the hook, dynamo shouldn't notice 2326 handle.remove() 2327 self.assertNotEqual(compiled_func(inp), outer_func(inp)) 2328 self.assertEqual(compiled_func(inp).item(), 15) 2329 self.assertEqual(cc.frame_count, 1) 2330 2331 def _forward_hook_test_helper(self, model): 2332 forward_handles = {} 2333 compiled_activations = {} 2334 eager_activations = {} 2335 activations = None 2336 2337 def save_activations(name, mod, inp, out): 2338 activations[name] = inp 2339 2340 for name, module in model.named_modules(): 2341 forward_handles[name] = module.register_forward_hook( 2342 partial(save_activations, name) 2343 ) 2344 2345 compiled_model = torch.compile(model, backend="aot_eager") 2346 2347 activations = compiled_activations 2348 for i in range(2): 2349 # second iteration is key, hooks would have fired during aot trace 2350 # on first iter 2351 compiled_activations.clear() 2352 x = torch.randn((20, 10)) 2353 pred = compiled_model(x) 2354 loss = pred.sum() 2355 loss.backward() 2356 2357 activations = eager_activations 2358 for i in range(2): 2359 # second iteration is key, hooks would have fired during aot trace 2360 # on first iter 2361 eager_activations.clear() 2362 x = torch.randn((20, 10)) 2363 pred = model(x) 2364 loss = pred.sum() 2365 loss.backward() 2366 2367 print(f"Recorded Layers: {compiled_activations.keys()}\n\n") 2368 print(f"Expected Layers: {eager_activations.keys()}") 2369 2370 self.assertTrue(compiled_activations.keys() == eager_activations.keys()) 2371 self.assertTrue(activations.keys() == forward_handles.keys()) 2372 2373 def test_hooks_allowed_modules(self): 2374 # this test shouldn't care whether hook guards are enabled or not 2375 class ToyModel(torch.nn.Module): 2376 def __init__(self) -> None: 2377 super().__init__() 2378 self.net = torch.nn.Sequential( 2379 *[torch.nn.Linear(10, 10000), torch.nn.ReLU()] 2380 + [torch.nn.Linear(10000, 5), torch.nn.ReLU()] 2381 ) 2382 2383 def forward(self, x): 2384 return self.net(x) 2385 2386 model = ToyModel() 2387 self._forward_hook_test_helper(model) 2388 2389 def test_hooks_allowed_modules_compiles(self): 2390 class ToyModel(torch.nn.Module): 2391 def __init__(self) -> None: 2392 super().__init__() 2393 self.net = torch.nn.Sequential( 2394 *[torch.nn.Linear(10, 10000), torch.nn.ReLU()] 2395 + [torch.nn.Linear(10000, 5), torch.nn.ReLU()] 2396 ) 2397 2398 def forward(self, x): 2399 return self.net(x) 2400 2401 model = ToyModel() 2402 activations = [] 2403 2404 def save_activations(mod, inp, out): 2405 activations.append(inp) 2406 2407 for name, module in model.named_modules(): 2408 module.register_forward_hook(save_activations) 2409 2410 cnt = torch._dynamo.testing.CompileCounter() 2411 model = torch._dynamo.optimize(cnt, nopython=True)(model) 2412 for i in range(2): 2413 # second iteration is key, hooks would have fired during aot trace 2414 # on first iter 2415 activations.clear() 2416 x = torch.randn((20, 10)) 2417 pred = model(x) 2418 loss = pred.sum() 2419 loss.backward() 2420 self.assertEqual(len(activations), 6) 2421 self.assertEqual(cnt.frame_count, 1) 2422 2423 def test_hooks_allowed_modules_compiles_self_contained(self): 2424 class ToyModel(torch.nn.Module): 2425 def __init__(self) -> None: 2426 super().__init__() 2427 self.net = torch.nn.Sequential( 2428 *[torch.nn.Linear(10, 10000), torch.nn.ReLU()] 2429 + [torch.nn.Linear(10000, 5), torch.nn.ReLU()] 2430 ) 2431 2432 def forward(self, x): 2433 return self.net(x) * self.net(x) 2434 2435 model = ToyModel() 2436 forward_handles = {} 2437 2438 def output_modifying_hook(mod, inp, out): 2439 return 2 * out + 1 2440 2441 for name, module in model.named_modules(): 2442 forward_handles[name] = module.register_forward_hook(output_modifying_hook) 2443 2444 cnt = torch._dynamo.testing.CompileCounter() 2445 2446 x = torch.randn((20, 10)) 2447 pred_eager = model(x) 2448 loss_eager = pred_eager.sum() 2449 eager_loss_bwd = loss_eager.backward() 2450 2451 model = torch._dynamo.optimize(cnt, nopython=True)(model) 2452 pred = model(x) 2453 2454 loss = pred.sum() 2455 loss_bwd = loss.backward() 2456 2457 self.assertEqual(eager_loss_bwd, loss_bwd) 2458 self.assertEqual(cnt.frame_count, 2) 2459 2460 # Ndim change, recompile 2461 pred = model(torch.randn([10, 10, 10])) 2462 self.assertEqual(cnt.frame_count, 4) 2463 2464 # Stable 2465 pred = model(torch.randn([10, 10, 10])) 2466 self.assertEqual(cnt.frame_count, 4) 2467 2468 def test_dunder_call_explicitly(self): 2469 # hooks should be triggered if explicit calling `__call__` 2470 class ToyModel(torch.nn.Module): 2471 def __init__(self) -> None: 2472 super().__init__() 2473 self.linear = torch.nn.Linear(10, 10000) 2474 2475 def forward(self, x): 2476 return self.linear.__call__(x) 2477 2478 model = ToyModel() 2479 self._forward_hook_test_helper(model) 2480 2481 def test_backward_hooks(self): 2482 # this test shouldn't care whether hook guards are enabled or not 2483 2484 class CustomLinear(torch.nn.Module): 2485 # not an 'allowed module', so should not graph-break 2486 def __init__(self, a, b): 2487 super().__init__() 2488 self.weight = torch.nn.Parameter(torch.randn(a, b)) 2489 2490 def forward(self, x): 2491 return torch.mm(x, self.weight) 2492 2493 class ToyModel(torch.nn.Module): 2494 def __init__(self) -> None: 2495 super().__init__() 2496 self.net = torch.nn.Sequential( 2497 *[CustomLinear(10, 10)] 2498 + [CustomLinear(10, 10000)] 2499 + [CustomLinear(10000, 5)] 2500 ) 2501 2502 def forward(self, x): 2503 return self.net(x) 2504 2505 model = ToyModel() 2506 backward_hook_handles = {} 2507 pre_backward_hook_handles = {} 2508 2509 grad_sizes = {} 2510 2511 def backward_hook(name, mod, grad_inp, grad_out): 2512 grad_sizes[name] = ( 2513 (gi.shape for gi in grad_inp), 2514 (go.shape for go in grad_out), 2515 ) 2516 return None 2517 2518 pre_grad_sizes = {} 2519 2520 def backward_pre_hook(name, mod, grad_out): 2521 pre_grad_sizes[name] = (go.shape for go in grad_out) 2522 return None 2523 2524 for name, module in model.named_modules(): 2525 backward_hook_handles[name] = module.register_full_backward_hook( 2526 partial(backward_hook, name) 2527 ) 2528 2529 pre_backward_hook_handles[name] = module.register_full_backward_pre_hook( 2530 partial(backward_pre_hook, name) 2531 ) 2532 2533 model = torch.compile(model, backend="aot_eager") 2534 2535 for i in range(2): 2536 # second iteration is key, hooks would have fired during aot trace 2537 # on first iter 2538 x = torch.randn((20, 10)) 2539 pred = model(x) 2540 loss = pred.sum() 2541 loss.backward() 2542 2543 self.assertTrue(grad_sizes.keys() == backward_hook_handles.keys()) 2544 self.assertTrue(pre_grad_sizes.keys() == pre_backward_hook_handles.keys()) 2545 2546 def test_udo_instance_method_as_hook(self): 2547 class CustomClass: 2548 def __init__(self, module): 2549 self.module = module 2550 self.handle = self.module.register_forward_pre_hook( 2551 self.func1, prepend=True, with_kwargs=True 2552 ) 2553 2554 def func1(self, module, args, kwargs): 2555 return (args[0] + 1,), kwargs 2556 2557 def __call__(self, x): 2558 return self.module(x) 2559 2560 class ToyModel(torch.nn.Module): 2561 def __init__(self) -> None: 2562 super().__init__() 2563 2564 def forward(self, x): 2565 return x * x 2566 2567 model = ToyModel() 2568 x = torch.zeros((3, 4)) 2569 obj = CustomClass(model) 2570 out = torch.compile(obj, fullgraph=True)(x) 2571 self.assertEqual(out, (x + 1) * (x + 1)) 2572 2573 def test_module_dict_iter_name(self): 2574 class MyModule(torch.nn.Module): 2575 def __init__(self) -> None: 2576 super().__init__() 2577 self.activations = torch.nn.ModuleDict( 2578 [["lrelu", torch.nn.LeakyReLU()], ["prelu", torch.nn.PReLU()]] 2579 ) 2580 2581 def forward(self, x): 2582 for activation_name in self.activations: 2583 x = self.activations[activation_name](x) 2584 return x 2585 2586 cnt = torch._dynamo.testing.CompileCounter() 2587 # Eager 2588 eager_res = MyModule()(torch.ones(10, 10)) 2589 2590 # Compile 2591 optim_res = torch._dynamo.optimize(cnt)(MyModule())(torch.ones(10, 10)) 2592 self.assertEqual(eager_res, optim_res) 2593 self.assertEqual(cnt.frame_count, 1) 2594 2595 def test_module_dict_iter_keys(self): 2596 class MyModule(torch.nn.Module): 2597 def __init__(self) -> None: 2598 super().__init__() 2599 self.activations = torch.nn.ModuleDict( 2600 [["lrelu", torch.nn.LeakyReLU()], ["prelu", torch.nn.PReLU()]] 2601 ) 2602 2603 def forward(self, x): 2604 for activation_name in self.activations.keys(): 2605 x = self.activations[activation_name](x) 2606 return x 2607 2608 cnt = torch._dynamo.testing.CompileCounter() 2609 # Eager 2610 eager_res = MyModule()(torch.ones(10, 10)) 2611 2612 # Compile 2613 optim_res = torch._dynamo.optimize(cnt)(MyModule())(torch.ones(10, 10)) 2614 self.assertEqual(eager_res, optim_res) 2615 self.assertEqual(cnt.frame_count, 1) 2616 2617 def test_module_setattr(self): 2618 models = torch.nn.Sequential(torch.nn.Linear(3, 3)) 2619 models[0].abc = False 2620 2621 def run(): 2622 models[0].abc = True 2623 x = torch.randn(1, 3) 2624 return models(x) 2625 2626 run = torch.compile(run, fullgraph=True) 2627 run() 2628 self.assertTrue(models[0].abc) 2629 2630 def test_assign_does_not_exist(self): 2631 class MyModule(torch.nn.Module): 2632 def forward(self, x): 2633 self.text_encoding = x + 1 2634 return self.text_encoding 2635 2636 mod = MyModule() 2637 out = torch.compile(mod, fullgraph=True)(torch.randn(10)) 2638 assert mod.text_encoding is out 2639 2640 def test_module_dict_iter_values(self): 2641 class MyModule(torch.nn.Module): 2642 def __init__(self) -> None: 2643 super().__init__() 2644 self.activations = torch.nn.ModuleDict( 2645 [["lrelu", torch.nn.LeakyReLU()], ["prelu", torch.nn.PReLU()]] 2646 ) 2647 2648 def forward(self, x): 2649 for activation in self.activations.values(): 2650 x = activation(x) 2651 return x 2652 2653 cnt = torch._dynamo.testing.CompileCounter() 2654 # Eager 2655 eager_res = MyModule()(torch.ones(10, 10)) 2656 2657 # Compile 2658 optim_res = torch._dynamo.optimize(cnt)(MyModule())(torch.ones(10, 10)) 2659 self.assertEqual(eager_res, optim_res) 2660 self.assertEqual(cnt.frame_count, 1) 2661 2662 def test_unspecialized_seq(self): 2663 models = torch.nn.Sequential(torch.nn.Linear(3, 3)) 2664 2665 def fn(x): 2666 models[0].training = False 2667 return models(x) 2668 2669 opt_fn = torch._dynamo.optimize("eager")(fn) 2670 x = torch.randn(1, 3) 2671 ref = fn(x) 2672 res = opt_fn(x) 2673 self.assertEqual(ref, res) 2674 2675 def test_no_op_assignment(self): 2676 class Mod(torch.nn.Module): 2677 def __init__(self) -> None: 2678 super().__init__() 2679 self.buffer = torch.rand([4]) 2680 2681 def forward(self, x): 2682 # should be a no-op, but causes dynamo to lose the static input 2683 x = x + 1 2684 self.buffer = self.buffer.to(x) 2685 return self.buffer + x 2686 2687 compiles_without_buffers = 0 2688 2689 def debug_compile(gm, *args, **kwargs): 2690 nonlocal compiles_without_buffers 2691 compiles_without_buffers += len(list(gm.buffers())) == 0 2692 return gm 2693 2694 @torch.compile(backend=debug_compile) 2695 def foo(mod, x): 2696 return mod(x) 2697 2698 mod = Mod() 2699 foo(mod, torch.rand([4])) 2700 if torch._dynamo.config.inline_inbuilt_nn_modules: 2701 self.assertEqual(compiles_without_buffers, 1) 2702 else: 2703 self.assertEqual(compiles_without_buffers, 0) 2704 2705 foo(mod, torch.rand([4], dtype=torch.half)) 2706 if torch._dynamo.config.inline_inbuilt_nn_modules: 2707 self.assertEqual(compiles_without_buffers, 2) 2708 else: 2709 self.assertEqual(compiles_without_buffers, 1) 2710 2711 class Mod2(Mod): 2712 def __setattr__(self, name, value): 2713 return super().__setattr__(name, value) 2714 2715 foo(Mod2(), torch.rand([4])) 2716 # causes two compilations, bc unimplemented custom setattr 2717 self.assertTrue(compiles_without_buffers >= 2) 2718 2719 def test_unspec_non_inlinable_module(self): 2720 mod = UnspecNonInlinableModule() 2721 opt_fn = torch._dynamo.optimize("eager")(mod) 2722 x = torch.randn(100) 2723 actual = opt_fn(x) 2724 expected = mod(x) 2725 self.assertEqual(actual, expected) 2726 2727 @torch._dynamo.config.patch("inline_inbuilt_nn_modules", True) 2728 def test_mark_static_previously_seen_tensor(self): 2729 # This test verifies that dynamo will mark 2730 # the buffers/params of a module as static 2731 # even if this param was previously seen 2732 # (ex. as a different input) 2733 num_compiles = 0 2734 2735 def debug_compiler(gm, _): 2736 nonlocal num_compiles 2737 num_compiles += 1 2738 2739 input_nodes = [ 2740 n for n in gm.graph.nodes if n.op == "placeholder" and n.name == "l_b_" 2741 ] 2742 2743 self.assertGreater(len(input_nodes), 0) 2744 for input_node in input_nodes: 2745 self.assertEqual( 2746 input_node.meta["tensor_dict"]["_dynamo_static_input_type"], 2747 "unguarded", 2748 ) 2749 2750 return gm 2751 2752 class TestModule(torch.nn.Module): 2753 def __init__(self, buf) -> None: 2754 super().__init__() 2755 # Changing this one to nn.Buffer fails because `nn.Buffer` does a .detach() 2756 # so the value in self.tx.output.side_effects will no longer evaluate to True 2757 self.register_buffer("buf", buf) 2758 2759 def forward(self, x): 2760 return self.buf * x 2761 2762 @torch._dynamo.optimize(backend=debug_compiler) 2763 def fn(x, b, mod): 2764 z = b + 1 2765 return z * mod(x) 2766 2767 buf = torch.ones(2, 2) 2768 inp = torch.ones(2) 2769 mod = TestModule(buf) 2770 fn(inp, buf, mod) 2771 self.assertEqual(num_compiles, 1) 2772 2773 @torch._dynamo.config.patch("inline_inbuilt_nn_modules", True) 2774 def test_mark_static_nn_module_tensor(self): 2775 # This test verifies that dynamo will mark 2776 # the nn module tensor attributes as static 2777 num_compiles = 0 2778 2779 def debug_compiler(gm, _): 2780 nonlocal num_compiles 2781 num_compiles += 1 2782 2783 input_nodes = [ 2784 n 2785 for n in gm.graph.nodes 2786 if n.op == "placeholder" and n.name == "l_mod_buf" 2787 ] 2788 2789 self.assertGreater(len(input_nodes), 0) 2790 for input_node in input_nodes: 2791 self.assertEqual( 2792 input_node.meta["tensor_dict"]["_dynamo_static_input_type"], 2793 "unguarded", 2794 ) 2795 2796 return gm 2797 2798 class TestModule(torch.nn.Module): 2799 def __init__(self) -> None: 2800 super().__init__() 2801 self.buf = torch.ones(2, 2) 2802 2803 def forward(self, x): 2804 return self.buf * x 2805 2806 mod = TestModule() 2807 2808 @torch._dynamo.optimize(backend=debug_compiler) 2809 def fn(x): 2810 return x * mod(x) 2811 2812 inp = torch.ones(2) 2813 fn(inp) 2814 self.assertEqual(num_compiles, 1) 2815 2816 @torch._dynamo.config.patch("inline_inbuilt_nn_modules", True) 2817 @torch._inductor.config.patch("freezing", True) 2818 @torch.no_grad() 2819 def test_mark_static_with_freezing(self): 2820 # This test verifies that dynamo will 2821 # add buffers/params as attributes of the 2822 # graph w/ guards if freezing is enabled 2823 num_compiles = 0 2824 2825 def debug_compiler(gm, _): 2826 nonlocal num_compiles 2827 num_compiles += 1 2828 2829 input_nodes = [ 2830 n for n in gm.graph.nodes if n.op == "placeholder" and n.name == "l_b_" 2831 ] 2832 self.assertEqual(len(input_nodes), 0) 2833 self.assertEqual(len(list(gm.buffers())), 1) 2834 return gm 2835 2836 class TestModule(torch.nn.Module): 2837 def __init__(self, buf) -> None: 2838 super().__init__() 2839 self.buf = torch.nn.Buffer(buf) 2840 2841 def forward(self, x): 2842 return self.buf * x 2843 2844 @torch._dynamo.optimize(backend=debug_compiler) 2845 def fn(x, mod): 2846 return mod(x) 2847 2848 buf = torch.ones(2, 2) 2849 inp = torch.ones(2) 2850 mod = TestModule(buf) 2851 fn(inp, mod) 2852 self.assertEqual(num_compiles, 1) 2853 mod.buf = torch.rand_like(buf) 2854 fn(inp, mod) 2855 self.assertEqual(num_compiles, 2) 2856 2857 @patch.object(torch._dynamo.config, "guard_nn_modules", True) 2858 def test_guard_on_torch_nn_modules(self): 2859 # https://github.com/pytorch/pytorch/issues/110048 2860 2861 class MockModule(torch.nn.Module): 2862 def __init__(self) -> None: 2863 super().__init__() 2864 self.linear = torch.nn.Linear(10, 10) 2865 self.multiplier = 10 2866 2867 def forward(self, x): 2868 return self.linear(x) * self.multiplier 2869 2870 mod = MockModule() 2871 2872 cnt = torch._dynamo.testing.CompileCounter() 2873 2874 @torch.compile(backend=cnt) 2875 def generate(x, c): 2876 return mod(x) + c 2877 2878 for _ in range(0, 10): 2879 generate(torch.randn(10, 10), 0) 2880 generate(torch.randn(10, 10), 1) 2881 self.assertEqual(cnt.frame_count, 2) 2882 2883 # Ensure that modification in user module causes recompile 2884 mod.multiplier = 11 2885 generate(torch.randn(10, 10), 0) 2886 self.assertEqual(cnt.frame_count, 3) 2887 2888 def test_setattr_on_compiled_module(self): 2889 # https://github.com/pytorch/pytorch/issues/114844 2890 2891 class ReplayMutation(torch.nn.Module): 2892 def __init__(self, inp_size, out_size, inner_size): 2893 super().__init__() 2894 self.Linear1 = torch.nn.Linear(inp_size, inner_size) 2895 self.Linear2 = torch.nn.Linear(inner_size, out_size) 2896 self.x = None 2897 2898 def forward(self, inp): 2899 res = self.Linear1(inp) 2900 self.x = res 2901 return self.Linear2(res) 2902 2903 N, D_in, H, D_out, inner = 2, 2, 2, 2, 4 2904 model = ReplayMutation(D_in, H, inner) 2905 model2 = copy.deepcopy(model) 2906 input = torch.ones(N, D_in) 2907 2908 # Keep some intermediate value in model.x 2909 model.x = torch.tensor([[100, 100, 100, 100], [200, 200, 200, 200]]) 2910 model(input) 2911 2912 compiled_model = torch.compile(model2, backend="eager") 2913 compiled_model.x = torch.tensor([[100, 100, 100, 100], [200, 200, 200, 200]]) 2914 compiled_model(input) 2915 2916 self.assertEqual(model.x, compiled_model.x) 2917 2918 def test_globals_change_in_other_file(self): 2919 @torch.compile(backend="eager", fullgraph=True) 2920 def fn(x): 2921 update_global() 2922 a = test_functions.update_global(x) 2923 # Ensure that the updated global values are read 2924 return x * a * (_variable + _variable1 + test_functions._variable) 2925 2926 res = fn(torch.ones(10)) 2927 self.assertEqual(_variable, 1) 2928 self.assertEqual(_variable1, 1) 2929 # Ensure that the reconstructed bytecode updates the global value in the 2930 # other file. 2931 self.assertEqual(test_functions._variable, 1) 2932 self.assertEqual(res, 3 * torch.ones(10)) 2933 2934 @unittest.skipIf( 2935 "inductor" not in torch._dynamo.list_backends(), 2936 "inductor backend is not available", 2937 ) 2938 def test_save_and_load_inductor(self): 2939 mod = MockModule() 2940 opt_mod = torch.compile(mod, backend="inductor") 2941 inp = torch.randn(10, 10) 2942 opt_mod(inp) 2943 2944 with tempfile.TemporaryDirectory() as tmpdirname: 2945 torch.save(opt_mod, os.path.join(tmpdirname, "model.pt")) 2946 loaded_model = torch.load(os.path.join(tmpdirname, "model.pt")) 2947 loaded_model(inp) 2948 self.assertTrue(same_two_models(loaded_model, mod, [inp])) 2949 self.assertTrue(same_two_models(loaded_model, opt_mod, [inp])) 2950 2951 torch._dynamo.reset() # force recompiles 2952 torch._inductor.metrics.generated_kernel_count = 0 2953 loaded_model(inp) 2954 self.assertGreater(torch._inductor.metrics.generated_kernel_count, 0) 2955 2956 def test_save_and_load_all_backends(self): 2957 mod = MockModule() 2958 inp = torch.randn(10, 10) 2959 for backend in torch._dynamo.list_backends(): 2960 try: 2961 opt_mod = torch.compile(mod, backend=backend) 2962 with tempfile.TemporaryDirectory() as tmpdirname: 2963 torch.save(opt_mod, os.path.join(tmpdirname, "model.pt")) 2964 loaded_model = torch.load(os.path.join(tmpdirname, "model.pt")) 2965 torch._dynamo.reset() # force recompiles 2966 torch._inductor.metrics.generated_kernel_count = 0 2967 opt_mod(inp) 2968 opt_success = torch._inductor.metrics.generated_kernel_count == 0 2969 torch._dynamo.reset() # force recompiles 2970 torch._inductor.metrics.generated_kernel_count = 0 2971 loaded_model(inp) 2972 loaded_success = torch._inductor.metrics.generated_kernel_count == 0 2973 self.assertEqual(opt_success, loaded_success) 2974 except torch._dynamo.exc.BackendCompilerFailed: 2975 pass 2976 2977 def test_monkeypatching_forward(self): 2978 class FakeModule(torch.nn.Module): 2979 def forward(self, x): 2980 return torch.sin(x) 2981 2982 class MyModule(torch.nn.Module): 2983 def __init__(self, x): 2984 super().__init__() 2985 2986 def forward(self, x): 2987 return torch.cos(x) 2988 2989 def helper(): 2990 torch._dynamo.reset() 2991 mod = MyModule(3) 2992 2993 def fn(x): 2994 return mod(x) 2995 2996 cnt = torch._dynamo.testing.CompileCounter() 2997 opt_fn = torch._dynamo.optimize(cnt)(fn) 2998 x = torch.randn(10) 2999 3000 opt_fn(x) 3001 opt_fn(x) 3002 self.assertEqual(cnt.frame_count, 1) 3003 3004 # Monkeypatch forward 3005 mod.forward = types.MethodType(FakeModule.forward, mod) 3006 ref = fn(x) 3007 res = opt_fn(x) 3008 self.assertEqual(ref, res) 3009 self.assertEqual(cnt.frame_count, 2) 3010 3011 helper() 3012 with torch._dynamo.config.patch(inline_inbuilt_nn_modules=True): 3013 helper() 3014 3015 def test_user_defined_nn_module_dynamic(self): 3016 class Conv2d(torch.nn.Conv2d): 3017 def __init__(self, *args, **kwargs): 3018 super().__init__(*args, **kwargs) 3019 3020 def forward(self, x): 3021 x = torch.nn.functional.conv2d( 3022 x, 3023 self.weight, 3024 self.bias, 3025 self.stride, 3026 self.padding, 3027 self.dilation, 3028 self.groups, 3029 ) 3030 return x 3031 3032 cnts = torch._dynamo.testing.CompileCounter() 3033 mod1 = Conv2d(64, 64, kernel_size=(2, 2), stride=(1, 1)) 3034 mod2 = Conv2d(64, 64, kernel_size=(2, 2), stride=(2, 2)) 3035 mod3 = Conv2d(64, 64, kernel_size=(2, 2), stride=(3, 3)) 3036 3037 opt_mod1 = torch.compile(mod1, backend=cnts, fullgraph=True) 3038 opt_mod2 = torch.compile(mod2, backend=cnts, fullgraph=True) 3039 opt_mod3 = torch.compile(mod3, backend=cnts, fullgraph=True) 3040 3041 x = torch.randn(1, 64, 64, 64) 3042 opt_mod1(x) 3043 opt_mod2(x) 3044 opt_mod3(x) 3045 3046 # Must be 3 compilations. If not marked static there would be 2, because strides would be converted to symints. 3047 self.assertEqual(cnts.frame_count, 3) 3048 3049 3050if __name__ == "__main__": 3051 from torch._dynamo.test_case import run_tests 3052 3053 run_tests() 3054