xref: /aosp_15_r20/external/pytorch/test/dynamo/test_modules.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: dynamo"]
2
3import collections
4import contextlib
5import copy
6import itertools
7import os
8import tempfile
9import traceback
10import types
11import unittest
12from copy import deepcopy
13from functools import partial
14from typing import Dict, NamedTuple, Tuple
15from unittest.mock import patch
16
17import torch
18import torch._dynamo.test_case
19import torch._dynamo.testing
20import torch.nn.functional as F
21from torch._dynamo.debug_utils import same_two_models
22from torch._dynamo.eval_frame import unsupported
23from torch._dynamo.mutation_guard import GenerationTracker
24from torch._dynamo.testing import expectedFailureDynamic, same
25from torch._dynamo.variables.torch_function import TensorWithTFOverrideVariable
26from torch.nn.modules.lazy import LazyModuleMixin
27from torch.nn.parameter import Parameter, UninitializedParameter
28
29
30try:
31    from . import test_functions
32except ImportError:
33    import test_functions
34
35
36_variable = 0
37_variable1 = 0
38
39
40def update_global():
41    global _variable, _variable1
42    _variable += 1
43    _variable1 += 1
44
45
46class BasicModule(torch.nn.Module):
47    def __init__(self) -> None:
48        super().__init__()
49        self.linear1 = torch.nn.Linear(10, 10)
50        self.scale = torch.randn(1, 10)
51
52    def forward(self, x):
53        return F.relu(self.linear1(x)) * self.scale
54
55
56class FnMember(torch.nn.Module):
57    def __init__(self) -> None:
58        super().__init__()
59        self.linear1 = torch.nn.Linear(10, 10)
60        self.activation = F.relu
61
62    def forward(self, x):
63        x = self.linear1(x)
64        if self.activation:
65            x = self.activation(x)
66        return x
67
68
69class FnMemberCmp(torch.nn.Module):
70    def __init__(self, activation):
71        super().__init__()
72        self.linear1 = torch.nn.Linear(10, 10)
73        self.activation = activation
74
75    def forward(self, x):
76        x = self.linear1(x)
77        if self.activation is not None:
78            x = self.activation(x)
79        if self.activation is None:
80            x = torch.sigmoid(x)
81        return x
82
83
84class SubmoduleExample(torch.nn.Module):
85    def __init__(self) -> None:
86        super().__init__()
87        self.layer1 = BasicModule()
88        self.layer2 = BasicModule()
89        self.scale = torch.randn(1, 10)
90
91    def forward(self, x):
92        x = self.layer1(x)
93        x = self.layer2(x)
94        return x * self.scale
95
96
97class IsTrainingCheck(torch.nn.Module):
98    def __init__(self) -> None:
99        super().__init__()
100        self.linear1 = torch.nn.Linear(10, 10)
101        self.linear2 = torch.nn.Linear(10, 10)
102        self.train(True)
103
104    def forward(self, x):
105        if self.training:
106            mod = self.linear1
107        else:
108            mod = self.linear2
109        return F.relu(mod(x))
110
111
112class IsEvalCheck(IsTrainingCheck):
113    def __init__(self) -> None:
114        super().__init__()
115        self.train(False)
116
117
118class ModuleMethodCall(torch.nn.Module):
119    def __init__(self) -> None:
120        super().__init__()
121        self.layer1 = BasicModule()
122        self.layer2 = BasicModule()
123        self.scale = torch.randn(1, 10)
124
125    def call_and_scale(self, mod, x):
126        x = mod(x)
127        return x * self.scale
128
129    def forward(self, x):
130        x1 = self.call_and_scale(self.layer1, x)
131        x2 = self.call_and_scale(self.layer2, x)
132        return x1 + x2
133
134
135class UnsupportedMethodCall(torch.nn.Module):
136    def __init__(self) -> None:
137        super().__init__()
138        self.layer1 = BasicModule()
139        self.scale = torch.randn(1, 10)
140
141    def call_and_scale(self, mod, x):
142        x = mod(x)
143        x = x * self.scale
144        return unsupported(x, x)
145
146    def forward(self, x):
147        x1 = self.call_and_scale(self.layer1, x)
148        return x + x1
149
150
151class UnsupportedModule(torch.nn.Module):
152    def __init__(self) -> None:
153        super().__init__()
154        self.layer1 = BasicModule()
155        self.scale = torch.randn(1, 10)
156
157    def forward(self, x):
158        x = self.layer1(x) * self.scale
159        return unsupported(x, x)
160
161
162class UnsupportedModuleCall(torch.nn.Module):
163    def __init__(self) -> None:
164        super().__init__()
165        self.mod = UnsupportedModule()
166
167    def forward(self, x):
168        return 1 + self.mod(x * 1.5)
169
170
171class ModuleWithStaticForward(torch.nn.Module):
172    @staticmethod
173    def forward(x):
174        return x * torch.sigmoid(x)
175
176
177class ModuleCallModuleWithStaticForward(torch.nn.Module):
178    def __init__(self) -> None:
179        super().__init__()
180        self.mod = ModuleWithStaticForward()
181
182    def forward(self, x):
183        return self.mod(x)
184
185
186class ModuleStaticMethodCall(torch.nn.Module):
187    def __init__(self) -> None:
188        super().__init__()
189        self.layer1 = BasicModule()
190        self.layer2 = BasicModule()
191        self.scale = torch.randn(1, 10)
192
193    @staticmethod
194    def call_and_scale(scale, mod, x):
195        x = mod(x)
196        return x * scale
197
198    def forward(self, x):
199        x1 = self.call_and_scale(self.scale, self.layer1, x)
200        x2 = self.call_and_scale(self.scale, self.layer2, x)
201        return x1 + x2
202
203
204class ModuleClassMethodCall(torch.nn.Module):
205    def __init__(self) -> None:
206        super().__init__()
207        self.layer1 = BasicModule()
208        self.layer2 = BasicModule()
209        self.scale = torch.randn(1, 10)
210
211    @classmethod
212    def call_and_scale(cls, scale, mod, x):
213        x = mod(x)
214        return x * scale
215
216    def forward(self, x):
217        x1 = self.call_and_scale(self.scale, self.layer1, x)
218        x2 = self.call_and_scale(self.scale, self.layer2, x)
219        return x1 + x2
220
221
222class ModuleProperty(torch.nn.Module):
223    def __init__(self) -> None:
224        super().__init__()
225        self.scale = torch.randn(1, 10)
226
227    @property
228    def scale_alias(self):
229        return self.scale
230
231    def forward(self, x):
232        return x * self.scale_alias
233
234
235class NestedModuleList(torch.nn.Module):
236    def __init__(self) -> None:
237        super().__init__()
238        self.layers = torch.nn.ModuleList([])
239        for _ in range(3):
240            self.layers.append(
241                torch.nn.ModuleList(
242                    [
243                        torch.nn.Linear(10, 10),
244                        torch.nn.ReLU(),
245                    ]
246                )
247            )
248
249    def forward(self, x):
250        for layer, act in self.layers:
251            x = act(layer(x))
252        return x
253
254
255class ConstLoop(torch.nn.Module):
256    def __init__(self) -> None:
257        super().__init__()
258        self.linear1 = torch.nn.Linear(10, 10)
259        self.count = 3
260
261    def forward(self, x):
262        for i in range(self.count):
263            x = torch.sigmoid(self.linear1(x))
264        return x
265
266
267class ViaModuleCall(torch.nn.Module):
268    def __init__(self) -> None:
269        super().__init__()
270        self.linear1 = torch.nn.Linear(10, 10)
271
272    def forward(self, x):
273        return test_functions.constant3(torch.sigmoid(self.linear1(x)), x)
274
275
276class IsNoneLayer(torch.nn.Module):
277    def __init__(self) -> None:
278        super().__init__()
279        self.layer1 = torch.nn.Linear(10, 10)
280        self.layer2 = None
281        self.train(True)
282
283    def forward(self, x):
284        if self.layer1 is not None:
285            x = self.layer1(x)
286        if self.layer2 is not None:
287            x = self.layer2(x)
288        return x
289
290
291class LayerList(torch.nn.Module):
292    def __init__(self) -> None:
293        super().__init__()
294        self.layers = [
295            torch.nn.Linear(10, 10),
296            torch.nn.ReLU(),
297            torch.nn.Linear(10, 10),
298        ]
299
300    def forward(self, x):
301        for layer in self.layers:
302            x = layer(x)
303        return x
304
305
306class ModuleList(torch.nn.Module):
307    def __init__(self) -> None:
308        super().__init__()
309        self.layers = torch.nn.ModuleList(
310            [
311                torch.nn.Linear(10, 10),
312                torch.nn.ReLU(),
313                torch.nn.Linear(10, 10),
314                torch.nn.ReLU(),
315            ]
316        )
317
318    def forward(self, x):
319        for i in range(len(self.layers)):
320            x = self.layers[i](x)
321
322        for layer in self.layers:
323            x = layer(x)
324
325        for layer, val in zip(self.layers, (x, x, x, x)):
326            x = layer(x) + val
327
328        for layer, val in zip(self.layers, (1, 2, 3, 4)):
329            x = layer(x) + val
330
331        for idx, layer in enumerate(self.layers):
332            x = layer(x) * idx
333
334        for idx, layer in enumerate(self.layers[::-1]):
335            x = layer(x) * idx
336
337        return x
338
339
340class CustomGetItemModuleList(torch.nn.Module):
341    def __init__(self) -> None:
342        super().__init__()
343        self.layers = torch.nn.ModuleList(
344            [
345                torch.nn.Linear(10, 10),
346                torch.nn.ReLU(),
347                torch.nn.Linear(10, 10),
348                torch.nn.ReLU(),
349            ]
350        )
351
352    def __getitem__(self, idx: int):
353        return self.layers[idx]
354
355    def __len__(self) -> int:
356        return len(self.layers)
357
358    def forward(self, x):
359        for i in range(len(self)):
360            x = self[i](x)
361
362        return x
363
364
365class ModuleDict(torch.nn.Module):
366    def __init__(self) -> None:
367        super().__init__()
368        self.layers = torch.nn.ModuleDict(
369            {
370                "0": torch.nn.Linear(10, 10),
371            }
372        )
373
374    def forward(self, x):
375        # TODO(future PR): handle more logic
376        x = self.layers["0"](x)
377        return x
378
379
380class ParameterDict(torch.nn.Module):
381    def __init__(self) -> None:
382        super().__init__()
383        self.layers = torch.nn.ParameterDict(
384            {
385                "0": torch.nn.Parameter(torch.randn(10, 10)),
386            }
387        )
388
389    def forward(self, x):
390        x = self.layers["0"].mm(x)
391        return x
392
393
394class CustomGetItemParameterDict(torch.nn.Module):
395    def __init__(self) -> None:
396        super().__init__()
397        self.layers = torch.nn.ParameterDict(
398            {
399                "0": torch.nn.Parameter(torch.randn(10, 10)),
400            }
401        )
402
403    def __getitem__(self, key: str) -> torch.nn.Module:
404        return self.layers[key]
405
406    def forward(self, x):
407        x = self["0"].mm(x)
408        return x
409
410
411class CustomGetItemModuleDict(torch.nn.Module):
412    def __init__(self) -> None:
413        super().__init__()
414        self.layers = torch.nn.ModuleDict(
415            {
416                "0": torch.nn.Linear(10, 10),
417            }
418        )
419
420    def __getitem__(self, key: str) -> torch.nn.Module:
421        return self.layers[key]
422
423    def forward(self, x):
424        x = self["0"](x)
425        return x
426
427
428class TensorList(torch.nn.Module):
429    def __init__(self) -> None:
430        super().__init__()
431        self.layers = (
432            torch.randn((1, 10)),
433            torch.randn((10, 1)),
434            torch.randn((1, 10)),
435            torch.randn((10, 1)),
436        )
437
438    def forward(self, x):
439        for layer in self.layers:
440            x = x * layer
441        return x
442
443
444class Children(torch.nn.Module):
445    def __init__(self) -> None:
446        super().__init__()
447        self.l1 = torch.nn.Linear(10, 10)
448        self.l2 = torch.nn.ReLU()
449        self.l3 = torch.nn.Linear(10, 10)
450        self.l4 = torch.nn.ReLU()
451
452    def forward(self, x):
453        for block in self.children():
454            x = block(x)
455        return x
456
457
458class NamedChildren(torch.nn.Module):
459    def __init__(self) -> None:
460        super().__init__()
461        self.l1 = torch.nn.Linear(10, 10)
462        self.l2 = torch.nn.ReLU()
463        self.l3 = torch.nn.Linear(10, 10)
464        self.l4 = torch.nn.ReLU()
465
466    def forward(self, x):
467        for _, block in self.named_children():
468            x = block(x)
469        return x
470
471
472class IntArg(torch.nn.Module):
473    def __init__(self) -> None:
474        super().__init__()
475        self.layer1 = torch.nn.Linear(10, 10)
476
477    def forward(self, x, offset=1):
478        x = F.relu(self.layer1(x)) + offset
479        return x
480
481
482class Seq(torch.nn.Module):
483    def __init__(self) -> None:
484        super().__init__()
485        self.layers = torch.nn.Sequential(
486            torch.nn.Linear(10, 10),
487            torch.nn.ReLU(),
488            torch.nn.Linear(10, 10),
489            torch.nn.ReLU(),
490        )
491
492    def forward(self, x):
493        return self.layers(x)
494
495
496class Cfg:
497    def __init__(self) -> None:
498        self.val = 0.5
499        self.count = 3
500
501
502class CfgModule(torch.nn.Module):
503    def __init__(self) -> None:
504        super().__init__()
505        self.cfg = Cfg()
506        self.layer = torch.nn.Linear(10, 10)
507
508    def forward(self, x):
509        for i in range(self.cfg.count):
510            x = self.layer(x + self.cfg.val)
511        return x
512
513
514class StringMember(torch.nn.Module):
515    def __init__(self) -> None:
516        super().__init__()
517        self.linear1 = torch.nn.Linear(10, 10)
518        self.mode = "some_string"
519
520    def forward(self, x):
521        if self.mode == "some_string":
522            return F.relu(self.linear1(x))
523
524
525class _Block(torch.nn.Module):
526    def forward(self, x):
527        return 1.5 * torch.cat(x, 1)
528
529
530class _DenseBlock(torch.nn.ModuleDict):
531    _version = 2
532
533    def __init__(
534        self,
535        num_layers: int = 3,
536    ) -> None:
537        super().__init__()
538        for i in range(num_layers):
539            self.add_module("denselayer%d" % (i + 1), _Block())
540
541    def forward(self, init_features):
542        features = [init_features]
543        for layer in self.values():
544            new_features = layer(features)
545            features.append(new_features)
546        return torch.cat(features, 1)
547
548
549class DenseNetBlocks(torch.nn.Module):
550    def __init__(self) -> None:
551        super().__init__()
552        self.layers = _DenseBlock()
553
554    def forward(self, x):
555        return self.layers(x)
556
557
558class MaterializedModule(torch.nn.Module):
559    """Once the below lazy module is initialized with its first input,
560    it is transformed into this module."""
561
562    param: Parameter
563
564    def __init__(self) -> None:
565        super().__init__()
566        self.register_parameter("param", None)
567
568    def forward(self, x):
569        return x
570
571
572class LazyModule(LazyModuleMixin, MaterializedModule):
573    param: UninitializedParameter
574    cls_to_become = MaterializedModule
575
576    def __init__(self) -> None:
577        super().__init__()
578        self.param = UninitializedParameter()
579
580    def initialize_parameters(self, x):
581        # force graph break to ensure this was not inlined
582        torch._dynamo.graph_break()
583        self.param.materialize(x.shape)
584
585
586class LazyMLP(torch.nn.Module):
587    def __init__(self) -> None:
588        super().__init__()
589        self.fc1 = torch.nn.LazyLinear(10)
590        self.relu1 = torch.nn.ReLU()
591        self.fc2 = torch.nn.LazyLinear(1)
592        self.relu2 = torch.nn.ReLU()
593
594    def forward(self, input):
595        x = self.relu1(self.fc1(input))
596        y = self.relu2(self.fc2(x))
597        return y
598
599
600class MyInput(NamedTuple):
601    x: Dict[str, Dict[str, torch.Tensor]]
602    y: torch.Tensor
603
604
605class LazyLayerWithNamedTupleInput(LazyModuleMixin, torch.nn.Module):
606    def __init__(self) -> None:
607        super().__init__()
608
609    def initialize_parameters(self, input):
610        with torch.no_grad():
611            self._param = torch.nn.Parameter(
612                torch.empty(input.x["a"][0].shape).fill_(0.5)
613            )
614
615    def forward(self, input):
616        input = input.x["a"]
617        x = 0
618        for i in range(len(input)):
619            x = x + input[i]
620        return x
621
622
623class LazyModuleWithNamedTupleInput(torch.nn.Module):
624    def __init__(self) -> None:
625        super().__init__()
626        self.layer = LazyLayerWithNamedTupleInput()
627
628    def forward(self, input):
629        return self.layer(input)
630
631
632class LazyLayerWithListInput(LazyModuleMixin, torch.nn.Module):
633    def __init__(self) -> None:
634        super().__init__()
635
636    def initialize_parameters(self, input):
637        with torch.no_grad():
638            self._param = torch.nn.Parameter(torch.empty(input[0].shape).fill_(0.5))
639
640    def forward(self, input):
641        x = 0
642        for i in range(len(input)):
643            x = x + input[i]
644        return x
645
646
647class LazyModuleWithListInput(torch.nn.Module):
648    def __init__(self) -> None:
649        super().__init__()
650        self.layer = LazyLayerWithListInput()
651
652    def forward(self, input):
653        return self.layer(input[:-1])
654
655
656class LazyModuleWithLazySubmodule(LazyModuleMixin, torch.nn.Module):
657    def __init__(self) -> None:
658        super().__init__()
659
660    def initialize_parameters(self, input):
661        with torch.no_grad():
662            self.layer = LazyLayerWithListInput()
663
664    def forward(self, x):
665        return self.layer(x)
666
667
668class LazyLayerWithInputs(LazyModuleMixin, torch.nn.Module):
669    def __init__(self) -> None:
670        super().__init__()
671
672    def initialize_parameters(self, x, y):
673        with torch.no_grad():
674            self._param_x = torch.nn.Parameter(torch.empty(x[0].shape).fill_(0.5))
675            self._param_y = torch.nn.Parameter(torch.empty(y[0].shape).fill_(0.5))
676
677    def forward(self, x, y):
678        res_x = 0
679        for i in range(len(x)):
680            res_x = res_x + x[i]
681        res_y = 0
682        for i in range(len(y)):
683            res_y = res_y + y[i]
684        return res_x + res_y
685
686
687class LazyModuleKwArgs(LazyModuleMixin, torch.nn.Module):
688    def __init__(self) -> None:
689        super().__init__()
690
691    def initialize_parameters(self, *args, **kwargs):
692        with torch.no_grad():
693            self.layer = LazyLayerWithInputs()
694
695    def forward(self, x, y):
696        return self.layer(x, y=y)
697
698
699class LazyParentModule(LazyModuleMixin, torch.nn.Module):
700    def __init__(self) -> None:
701        super().__init__()
702
703    def impl(self, x):
704        return x.cos() + self._val
705
706
707class LazyChildModuleNoClsToBecome(LazyParentModule):
708    def __init__(self) -> None:
709        super().__init__()
710
711    def forward(self, x):
712        return super().impl(x.sin())
713
714    def initialize_parameters(self, input):
715        self._val = torch.nn.Parameter(torch.ones(2, 2))
716
717
718def requires_grad1(module: torch.nn.Module, recurse: bool = False) -> bool:
719    requires_grad = any(p.requires_grad for p in module.parameters(recurse))
720    return requires_grad
721
722
723def requires_grad2(module: torch.nn.Module, recurse: bool = False) -> bool:
724    requires_grad = any(p.requires_grad for p in module.parameters(recurse))
725    return requires_grad
726
727
728class ParametersModule1(torch.nn.Module):
729    def __init__(self) -> None:
730        super().__init__()
731        self.linear1 = torch.nn.Linear(10, 10)
732        self.scale = torch.nn.Parameter(torch.randn(1, 10))
733
734    def forward(self, x):
735        if not requires_grad1(self):
736            return F.relu(self.linear1(x)) * self.scale
737        else:
738            return x + 1
739
740
741class ParametersModule2(ParametersModule1):
742    def forward(self, x):
743        if not requires_grad2(self):
744            return F.relu(self.linear1(x)) * self.scale
745        else:
746            return x + 1
747
748
749class ParametersModule3(ParametersModule1):
750    def forward(self, x):
751        ones = torch.ones(10, dtype=next(self.parameters()).dtype)
752        return F.relu(self.linear1(x)) * self.scale + ones
753
754
755class ParametersModule4(ParametersModule1):
756    def forward(self, x):
757        ones = torch.ones(10, dtype=next(self.parameters(recurse=False)).dtype)
758        return F.relu(self.linear1(x)) * self.scale + ones
759
760
761class ParametersModule5(torch.nn.Module):
762    def __init__(self) -> None:
763        super().__init__()
764        self.linear1 = torch.nn.Linear(10, 10)
765        self.scale = torch.nn.Parameter(torch.randn(10, 10))
766        self.scale_dup = self.scale
767
768    def forward(self, x):
769        counter = 0
770        for param in self.parameters():
771            counter += 1
772
773        return x * self.scale * counter
774
775
776class SuperModule(BasicModule):
777    def forward(self, x):
778        x = super().forward(x)
779        return x + 10.0
780
781
782class SuperModule2(BasicModule):
783    def forward(self, x):
784        return BasicModule.forward(self, x)
785
786
787class ComplicatedSuperParent(torch.nn.Module):
788    @classmethod
789    def custom_add(cls, x):
790        x = x + x
791        return x
792
793
794class SuperChildCallsClassMethod(ComplicatedSuperParent):
795    @classmethod
796    def child_func(cls, x):
797        x = super().custom_add(x)
798        return x
799
800    def forward(self, x):
801        x = self.child_func(x)
802        return x
803
804
805class HasAttrModule(torch.nn.Module):
806    def __init__(self) -> None:
807        super().__init__()
808        self.scale = torch.nn.Parameter(torch.randn(1, 10))
809
810    def forward(self, x):
811        x = F.relu(x)
812        if hasattr(self, "scale"):
813            x *= self.scale
814        if hasattr(self, "scale2"):
815            x *= self.scale2
816        return x
817
818
819class EnumValues(torch.nn.ModuleDict):
820    def __init__(
821        self,
822        num_layers: int = 3,
823    ) -> None:
824        super().__init__()
825        for i in range(num_layers):
826            self.add_module("denselayer%d" % (i + 1), _Block())
827
828    def forward(self, init_features):
829        features = [init_features]
830        for idx, layer in enumerate(self.values()):
831            new_features = layer(features)
832            features.append(new_features)
833        return torch.cat(features, 1)
834
835
836class AccessByKeys(torch.nn.ModuleDict):
837    def __init__(
838        self,
839        num_layers: int = 3,
840    ) -> None:
841        super().__init__()
842        for i in range(num_layers):
843            self.add_module("denselayer%d" % (i + 1), _Block())
844
845    def forward(self, init_features):
846        features = [init_features]
847        for k in self.keys():
848            new_features = self[k](features)
849            features.append(new_features)
850        return torch.cat(features, 1)
851
852
853class CallForwardDirectly(torch.nn.Module):
854    def __init__(self) -> None:
855        super().__init__()
856        self.layer1 = BasicModule()
857        self.layer2 = torch.nn.Linear(10, 10)
858
859    def forward(self, x):
860        x = self.layer1.forward(x)
861        x = self.layer2.forward(x)
862        return x
863
864
865class ConvCallForwardDirectly(torch.nn.Module):
866    def __init__(self) -> None:
867        super().__init__()
868        self.layer = torch.nn.Conv2d(3, 64, 3, 1, 1, bias=False)
869
870    def forward(self, x):
871        return self.layer.forward(x)
872
873
874class ConvTransposeCallForwardDirectly(torch.nn.Module):
875    def __init__(self) -> None:
876        super().__init__()
877        self.layer = torch.nn.ConvTranspose2d(4, 4, 4)
878
879    def forward(self, x):
880        return self.layer.forward(x)
881
882
883class ConvCallSuperForwardDirectly(torch.nn.Conv1d):
884    def __init__(self, in_channels, out_channels, kernel_size, **kwargs):
885        super().__init__(
886            in_channels=in_channels,
887            out_channels=out_channels,
888            kernel_size=kernel_size,
889            **kwargs,
890        )
891
892    def forward(self, inputs, mask=None):
893        outputs = super().forward(inputs)
894        return outputs
895
896
897class ConvTransposeCallSuperForwardDirectly(torch.nn.ConvTranspose2d):
898    def __init__(self, in_channels, out_channels, kernel_size, **kwargs):
899        super().__init__(
900            in_channels=in_channels,
901            out_channels=out_channels,
902            kernel_size=kernel_size,
903            **kwargs,
904        )
905
906    def forward(self, x):
907        if x.numel() > 0:
908            return super().forward(x)
909        output_shape = [
910            ((i - 1) * d - 2 * p + (di * (k - 1) + 1) + op)
911            for i, p, di, k, d, op in zip(
912                x.shape[-2:],
913                self.padding,
914                self.dilation,
915                self.kernel_size,
916                self.stride,
917                self.output_padding,
918            )
919        ]
920        output_shape = [x.shape[0], self.bias.shape[0]] + output_shape
921        return _NewEmptyTensorOp.apply(x, output_shape)  # noqa: F821
922
923
924class ModuleNameString(torch.nn.Module):
925    def __init__(self) -> None:
926        super().__init__()
927        self.linear1 = torch.nn.Linear(10, 10)
928
929    def forward(self, x):
930        if self.__class__.__name__ == "ABC":
931            return 10
932        if self.linear1.__class__.__name__ == "Linear":
933            return F.relu(self.linear1(x) + 10)
934        return 11
935
936
937class SelfMutatingModule(torch.nn.Module):
938    def __init__(self, layer):
939        super().__init__()
940        self.layer = layer
941        self.counter = 0
942
943    def forward(self, x):
944        result = self.layer(x) + self.counter
945        self.counter += 1
946        return F.relu(result)
947
948
949class ModuleAttributePrecedenceBase(torch.nn.Module):
950    def linear(self, x, flag=None):
951        if flag:
952            return x * 2.0
953        return x * 3.0
954
955
956class ModuleAttributePrecedence(ModuleAttributePrecedenceBase):
957    def __init__(self) -> None:
958        super().__init__()
959        self.activation = torch.nn.ReLU()
960        self.linear = torch.nn.Linear(10, 10)
961        self.initializer = torch.ones([10, 10])
962        self.scale = 0.5
963
964    def activation(self, x):
965        return x * 1.2
966
967    def initializer(self):
968        return torch.zeros([10, 10])
969
970    def scale(self):
971        return 2.0
972
973    def forward(self, x):
974        # object attribute takes precedence unless it's a nn.Module
975        return self.activation(self.linear(self.initializer + x)) * self.scale
976
977
978class ModuleForwardHasGraphBreak(torch.nn.Module):
979    def __init__(self) -> None:
980        super().__init__()
981        self.layer1 = BasicModule()
982        self.layer2 = BasicModule()
983        self.layer3 = torch.nn.Sequential(BasicModule(), BasicModule())
984        self.layer4 = torch.nn.ModuleList(
985            [
986                torch.nn.Linear(10, 10),
987                torch.nn.ReLU(),
988                torch.nn.Linear(10, 10),
989                torch.nn.ReLU(),
990            ]
991        )
992        self.layer5 = torch.nn.ModuleDict(
993            {
994                "0": torch.nn.Linear(10, 10),
995            }
996        )
997        self.scale = torch.randn(1, 10)
998
999    def forward(self, x):
1000        """
1001        This is used to test if the results of functions like `named_parameters`
1002        can be reconstructed correctly after graph break.
1003
1004        https://github.com/pytorch/torchdynamo/issues/1931
1005        """
1006        x = self.layer1(x)
1007        params1 = dict(self.named_parameters())
1008        params2 = list(self.parameters())
1009        buffers1 = dict(self.named_buffers())
1010        buffers2 = list(self.buffers())
1011        modules1 = dict(self.named_modules())
1012        modules2 = list(self.modules())
1013        torch._dynamo.graph_break()
1014        y = modules2
1015        y = modules1
1016        y = buffers2
1017        y = buffers1
1018        y = params2
1019        y = params1
1020        x = (
1021            self.layer2(x)
1022            + y["layer3.1.linear1.weight"]
1023            + y["layer4.2.weight"]
1024            + y["layer5.0.weight"]
1025        )
1026        return x * self.scale
1027
1028
1029class ModuleGuardNameIsValid(torch.nn.ModuleDict):
1030    # Guard names should be valid python identifier as we use eval() to get
1031    # corresponding guard value. Some guard names come from source(module path)
1032    # where special symbols are valid. But they are not valid python identifier,
1033    # we should identify these pattern and rewrite them with getattr.
1034    def __init__(self) -> None:
1035        super().__init__()
1036        for i in range(2):
1037            self.add_module("l@yer-%d" % (i + 1), BasicModule())
1038
1039    def forward(self, x):
1040        for layer in self.values():
1041            x = layer(x)
1042        return x
1043
1044
1045class SequentialWithDuplicatedModule(torch.nn.Module):
1046    # Sequential module(self.layer) contains three duplicated ReLU module.
1047    def __init__(self) -> None:
1048        super().__init__()
1049        self.relu = torch.nn.ReLU()
1050        self.layer = torch.nn.Sequential(
1051            torch.nn.Linear(10, 20),
1052            self.relu,
1053            torch.nn.Linear(20, 20),
1054            self.relu,
1055            torch.nn.Linear(20, 10),
1056            self.relu,
1057        )
1058
1059    def forward(self, x):
1060        return self.layer(x)
1061
1062
1063class SequentialWithDuplicatedModule2(torch.nn.Module):
1064    def __init__(self) -> None:
1065        super().__init__()
1066        self.relu = torch.nn.ReLU()
1067        self.layer = torch.nn.Sequential(
1068            collections.OrderedDict(
1069                [
1070                    ("linear1", torch.nn.Linear(10, 20)),
1071                    ("relu1", self.relu),
1072                    ("linear2", torch.nn.Linear(20, 20)),
1073                    ("relu2", self.relu),
1074                    ("linear3", torch.nn.Linear(20, 10)),
1075                    ("relu3", self.relu),
1076                ]
1077            )
1078        )
1079
1080    def forward(self, x):
1081        return self.layer(x)
1082
1083
1084class ModuleComparison(torch.nn.Module):
1085    def __init__(self) -> None:
1086        super().__init__()
1087        self.layer0 = torch.nn.Linear(10, 10)
1088        self.layer1 = torch.nn.Linear(10, 10)
1089        self.layer2 = torch.nn.Linear(10, 10)
1090
1091    @property
1092    def encoder_layers(self):
1093        return [self.layer0, self.layer1, self.layer2]
1094
1095    def forward(self, x):
1096        for layer in self.encoder_layers:
1097            output = layer(x)
1098            if layer is None or layer == self.layer0:
1099                output = F.relu6(output)
1100            else:
1101                output = F.relu(output)
1102        return output
1103
1104
1105class ModulePatch1(torch.nn.Module):
1106    pass
1107
1108
1109class ModulePatch2(torch.nn.Module):
1110    def forward(self, x):
1111        return x - 1
1112
1113
1114class UnspecNonInlinableModule(torch.nn.Module):
1115    torchdynamo_force_dynamic = True  # forced to be a UnspecializedNNModule
1116
1117    def forward(self, x):
1118        if x.sum() > 0:
1119            return x + 1
1120        else:
1121            return x - 1
1122
1123
1124class UnspecNonInlinableToplevelModule(torch.nn.Module):
1125    def __init__(self) -> None:
1126        super().__init__()
1127        self.m = UnspecNonInlinableModule()
1128
1129    def forward(self, x):
1130        return self.m(x)
1131
1132
1133def make_test(fn, expected_ops=None):
1134    def test_fn(self):
1135        return torch._dynamo.testing.standard_test(
1136            self, fn=fn, nargs=1, expected_ops=expected_ops
1137        )
1138
1139    fn.eval()
1140    return test_fn
1141
1142
1143@contextlib.contextmanager
1144def temporary_tensor_subclass(torch_function=None):
1145    class TensorProxy(torch.Tensor):
1146        @classmethod
1147        def __torch_function__(cls, func, types, args=(), kwargs=None):
1148            if torch_function is not None:
1149                torch_function()
1150            return super().__torch_function__(func, types, args, kwargs)
1151
1152    torch._dynamo.config.traceable_tensor_subclasses.add(TensorProxy)
1153    try:
1154        yield TensorProxy
1155    finally:
1156        torch._dynamo.config.traceable_tensor_subclasses.remove(TensorProxy)
1157
1158
1159class NNModuleTests(torch._dynamo.test_case.TestCase):
1160    test_seq = make_test(Seq())
1161    test_basicmodule1 = make_test(BasicModule())
1162    test_basicmodule2 = make_test(BasicModule())
1163    test_submodules1 = make_test(SubmoduleExample())
1164    test_submodules2 = make_test(SubmoduleExample())
1165    test_modulemethod1 = make_test(ModuleMethodCall())
1166    test_modulemethod2 = make_test(ModuleMethodCall())
1167    test_module_call_module_with_static_forward = make_test(
1168        ModuleCallModuleWithStaticForward()
1169    )
1170    test_module_static_method = make_test(ModuleStaticMethodCall())
1171    test_fnmember = make_test(FnMember())
1172    test_fnmembercmp1 = make_test(FnMemberCmp(F.relu))
1173    test_fnmembercmp2 = make_test(FnMemberCmp(None))
1174    test_constloop = make_test(ConstLoop())
1175    test_istraining1 = make_test(IsTrainingCheck())
1176    test_istraining2 = make_test(IsTrainingCheck())
1177    test_iseval1 = make_test(IsEvalCheck())
1178    test_iseval2 = make_test(IsEvalCheck())
1179    test_viamodulecall = make_test(ViaModuleCall())
1180    test_isnonelayer = make_test(IsNoneLayer())
1181    test_layerlist = make_test(LayerList())
1182    test_tensorlist = make_test(TensorList())
1183    test_intarg = make_test(IntArg())
1184    test_cfgmod = make_test(CfgModule())
1185    test_stringmember = make_test(StringMember())
1186    test_modulelist = make_test(ModuleList())
1187    test_modulelist_nested = make_test(NestedModuleList())
1188    test_modulelist_custom = make_test(CustomGetItemModuleList())
1189    test_moduledict = make_test(ModuleDict())
1190    test_moduledict_custom = make_test(CustomGetItemModuleDict())
1191    test_parameterdict = make_test(ParameterDict())
1192    test_parameterdict_custom = make_test(CustomGetItemParameterDict())
1193    test_super1 = make_test(SuperModule())
1194    test_super2 = make_test(SuperModule2())
1195    test_super_class_method = make_test(SuperChildCallsClassMethod())
1196    test_children = make_test(Children())
1197    test_named_children = make_test(NamedChildren())
1198    test_densenet = make_test(DenseNetBlocks())
1199    test_parameters1 = make_test(ParametersModule1())
1200    test_parameters2 = make_test(ParametersModule2())
1201    test_parameters3 = make_test(ParametersModule3(), expected_ops=5)
1202    test_parameters4 = make_test(ParametersModule4())
1203    test_parameters5 = make_test(ParametersModule5())
1204    test_hasattr = make_test(HasAttrModule())
1205    test_enumvalues = make_test(EnumValues())
1206    test_access_by_keys = make_test(AccessByKeys())
1207    test_module_class_method = make_test(ModuleClassMethodCall())
1208    test_module_property = make_test(ModuleProperty())
1209    test_forward_directly = make_test(CallForwardDirectly())
1210    test_module_name_string = make_test(ModuleNameString())
1211    test_module_attribute_precedence = make_test(ModuleAttributePrecedence())
1212    test_module_guard_name_is_valid = make_test(ModuleGuardNameIsValid())
1213    test_sequential_with_duplicated_module = make_test(SequentialWithDuplicatedModule())
1214    test_sequential_with_duplicated_module2 = make_test(
1215        SequentialWithDuplicatedModule2()
1216    )
1217    test_module_comparison = make_test(ModuleComparison())
1218
1219    def test_module_forward_has_graph_break(self):
1220        m = ModuleForwardHasGraphBreak()
1221        x = torch.rand([10, 10])
1222        ref = m(x)
1223        opt_m = torch._dynamo.optimize("eager")(m)
1224        res = opt_m(x)
1225        self.assertTrue(torch.allclose(ref, res))
1226
1227    def test_unsupportedmethod(self):
1228        m = UnsupportedMethodCall()
1229        i = torch.randn(10)
1230        cnt = torch._dynamo.testing.CompileCounter()
1231        opt_m = torch._dynamo.optimize(cnt)(m)
1232        r = opt_m(i)
1233        self.assertTrue(torch._dynamo.testing.same(r, m(i)))
1234        self.assertEqual(cnt.op_count, 5)
1235
1236    def test_unsupportedmodule(self):
1237        m = UnsupportedModuleCall()
1238        i = torch.randn(10)
1239        cnt = torch._dynamo.testing.CompileCounter()
1240        opt_m = torch._dynamo.optimize(cnt)(m)
1241        r = opt_m(i)
1242        self.assertTrue(torch._dynamo.testing.same(r, m(i)))
1243        self.assertEqual(cnt.op_count, 6)
1244
1245    def test_self_mutating1(self):
1246        m1 = torch.nn.Linear(10, 10)
1247        m2 = SelfMutatingModule(m1)
1248        m3 = SelfMutatingModule(m1)
1249        m4 = SelfMutatingModule(m1)
1250        i = torch.randn(10)
1251        out2 = [m2(i), m2(i), m2(i)]
1252        cnt = torch._dynamo.testing.CompileCounter()
1253        opt_m3 = torch._dynamo.optimize_assert(cnt)(m3)
1254        opt_m4 = torch._dynamo.optimize_assert(cnt)(m4)
1255        out3 = [opt_m3(i), opt_m3(i), opt_m3(i)]
1256        out4 = [opt_m4(i), opt_m4(i), opt_m4(i)]
1257        self.assertTrue(torch._dynamo.testing.same(out2, out3))
1258        self.assertTrue(torch._dynamo.testing.same(out2, out4))
1259        if torch._dynamo.config.assume_static_by_default:
1260            self.assertExpectedInline(cnt.frame_count, """2""")
1261        else:
1262            self.assertExpectedInline(cnt.frame_count, """1""")
1263
1264    @patch.object(torch._dynamo.config, "raise_on_ctx_manager_usage", False)
1265    def test_generation_tag(self):
1266        cnt = torch._dynamo.testing.CompileCounter()
1267
1268        # guarantee that we have installed
1269        # the generation tagging function
1270        with torch._dynamo.optimize_assert(cnt):
1271            pass
1272
1273        m1 = torch.nn.Linear(10, 10)
1274        prev_generation = GenerationTracker.get_generation_value(m1)
1275        cur_generation = prev_generation + 1
1276
1277        with torch._dynamo.optimize_assert(cnt):
1278            m2 = torch.nn.Linear(10, 10)
1279
1280        self.assertEqual(GenerationTracker.get_generation_value(m1), prev_generation)
1281        self.assertEqual(GenerationTracker.get_generation_value(m2), cur_generation)
1282        # check that newly constructed instances
1283        # also have the same generation (even if copied from an old instance)
1284        m3 = deepcopy(m1)
1285        self.assertEqual(GenerationTracker.get_generation_value(m3), cur_generation)
1286
1287    def test_simple_torch_function(self):
1288        def foo(x):
1289            # function call, twice to test wrapping
1290            x = F.sigmoid(x)
1291            x = F.sigmoid(x)
1292            # method call, twice to test wrapping
1293            x = x.sigmoid()
1294            x = x.sigmoid()
1295            return x
1296
1297        with temporary_tensor_subclass() as TensorProxy:
1298            x = torch.randn(1).as_subclass(TensorProxy)
1299            cnt = torch._dynamo.testing.CompileCounter()
1300            out1 = foo(x)
1301            opt_foo = torch._dynamo.optimize(cnt, nopython=True)(foo)
1302            out2 = opt_foo(x)
1303
1304            self.assertEqual(cnt.op_count, 4)
1305            self.assertTrue(torch._dynamo.testing.same(out1, out2))
1306
1307    def test_torch_function_with_closure(self):
1308        def run():
1309            def foo(x):
1310                # function call, twice to test wrapping
1311                x = F.sigmoid(x)
1312                x = F.sigmoid(x)
1313                # method call, twice to test wrapping
1314                x = x.sigmoid()
1315                x = x.sigmoid()
1316                return x
1317
1318            counter = 0
1319
1320            def function():
1321                nonlocal counter
1322                # for now, only support reads from closure cells
1323                # TODO(future PR): support writes as well
1324                counter + 1
1325
1326            with temporary_tensor_subclass(function) as TensorProxy:
1327                x = torch.randn(1).as_subclass(TensorProxy)
1328                x = torch.randn(1)
1329                cnt = torch._dynamo.testing.CompileCounter()
1330                out1 = foo(x)
1331                opt_foo = torch._dynamo.optimize(cnt, nopython=True)(foo)
1332                out2 = opt_foo(x)
1333
1334                self.assertEqual(cnt.op_count, 4)
1335                self.assertTrue(torch._dynamo.testing.same(out1, out2))
1336
1337        run()
1338
1339    def test_torch_mangled_class_name(self):
1340        original = TensorWithTFOverrideVariable.global_mangled_class_name
1341        results = []
1342
1343        def instrumented(self, tx):
1344            result = original(self, tx)
1345            results.append(result)
1346            return result
1347
1348        TensorWithTFOverrideVariable.global_mangled_class_name = instrumented
1349
1350        def one_break(x):
1351            x = F.sigmoid(x)
1352            print()  # force break
1353            x = x.sigmoid()
1354            return x
1355
1356        try:
1357            with temporary_tensor_subclass() as TensorProxy:
1358                x = torch.randn(1).as_subclass(TensorProxy)
1359                x1 = one_break(x)
1360
1361                cnt = torch._dynamo.testing.CompileCounter()
1362                opt_one_break = torch._dynamo.optimize(cnt)(one_break)
1363                x2 = opt_one_break(x)
1364
1365                self.assertTrue(torch._dynamo.testing.same(x1, x2))
1366                self.assertEqual(cnt.frame_count, 2)
1367                self.assertEqual(cnt.op_count, 2)
1368
1369                compile_ids = set()
1370                for r in results:
1371                    # A mangled classname looks like __subclass_TensorProxy_94524181138240_c0
1372                    # where the last segment contains the compile_id.
1373                    prefix = "__subclass_TensorProxy_"
1374                    before, sep, after = r.partition(prefix)
1375                    self.assertEqual(before, "")
1376                    self.assertEqual(sep, prefix)
1377
1378                    class_type_id, compile_id = after.split("_")
1379                    self.assertTrue(class_type_id.isnumeric())
1380                    self.assertTrue(compile_id.startswith("c"))
1381
1382                    cid = compile_id[1:]
1383                    self.assertTrue(cid.isnumeric())
1384                    compile_ids.add(cid)
1385
1386                self.assertEqual(len(compile_ids), 3)
1387
1388        finally:
1389            TensorWithTFOverrideVariable.global_mangled_class_name = original
1390
1391    @patch.object(torch._dynamo.config, "raise_on_ctx_manager_usage", False)
1392    def test_nn_moduledict_contains(self):
1393        class M(torch.nn.Module):
1394            def __init__(self, module_dict):
1395                super().__init__()
1396                self.module_dict = module_dict
1397
1398            def forward(self, x):
1399                if "foo" in self.module_dict:
1400                    x = torch.mul(x, 1.0)
1401                x = torch.add(x, 1.0)
1402                return x
1403
1404        module_dict = torch.nn.ModuleDict({"foo": torch.nn.Conv2d(1, 1, 1)})
1405        m = M(module_dict)
1406        data = torch.randn(1)
1407        out1 = m(data)
1408        cnt = torch._dynamo.testing.CompileCounter()
1409        opt_m = torch._dynamo.optimize(cnt, nopython=True)(m)
1410        out2 = opt_m(data)
1411        self.assertEqual(cnt.op_count, 2)
1412        self.assertTrue(torch._dynamo.testing.same(out1, out2))
1413
1414        module_dict = torch.nn.ModuleDict({"bar": torch.nn.Conv2d(1, 1, 1)})
1415        m = M(module_dict)
1416        data = torch.randn(1)
1417        out1 = m(data)
1418        cnt = torch._dynamo.testing.CompileCounter()
1419        torch._dynamo.reset()
1420        opt_m = torch._dynamo.optimize(cnt, nopython=True)(m)
1421        out2 = opt_m(data)
1422
1423        self.assertEqual(cnt.op_count, 1)
1424        self.assertTrue(torch._dynamo.testing.same(out1, out2))
1425
1426        module_dict = torch.nn.ModuleDict({"cat": torch.nn.Conv2d(1, 1, 1)})
1427        pre = m(data)
1428        cnt.clear()
1429
1430        with torch._dynamo.optimize(cnt, nopython=False):
1431            opt_pre = m(data)
1432            m = M(module_dict)
1433            data = torch.randn(1)
1434            out1 = m(data)
1435
1436        out_post = m(data)
1437        self.assertEqual(cnt.frame_count, 1)
1438        self.assertEqual(cnt.op_count, 1)
1439        self.assertTrue(torch._dynamo.testing.same(pre, opt_pre))
1440        self.assertTrue(torch._dynamo.testing.same(out1, out_post))
1441
1442    # RuntimeError: SymIntArrayRef expected to contain only concrete integers
1443    @expectedFailureDynamic
1444    def test_lazy_module1(self):
1445        input_shape = (16, 3, 6, 7, 8)
1446
1447        cnt = torch._dynamo.testing.CompileCounter()
1448        module = LazyModule()
1449
1450        def test_static_module():
1451            input = torch.ones(*input_shape)
1452            module(input)
1453
1454        # test no graph break
1455        opt_test_static_module = torch._dynamo.optimize(cnt, nopython=True)(
1456            test_static_module
1457        )
1458        opt_test_static_module()
1459
1460        self.assertTrue(
1461            isinstance(module, MaterializedModule),
1462            "Module should be transformed to an instance of MaterializedModule.",
1463        )
1464        self.assertEqual(module.param.shape, input_shape)
1465
1466        # test when mapped to UnspecializedNNModule
1467        module = LazyModule()
1468
1469        def test_unspecialized():
1470            nonlocal module
1471            module = LazyModule()
1472            input = torch.ones(*input_shape)
1473            module(input)
1474
1475        opt_test_unspecialized = torch._dynamo.optimize(cnt)(test_unspecialized)
1476        opt_test_unspecialized()
1477
1478        self.assertTrue(
1479            isinstance(module, MaterializedModule),
1480            "Module should be transformed to an instance of MaterializedModule.",
1481        )
1482        self.assertEqual(module.param.shape, input_shape)
1483
1484        # test with a static module in torch.*
1485        module = torch.nn.modules.LazyBatchNorm3d(
1486            affine=False, track_running_stats=False
1487        )
1488
1489        cnt = torch._dynamo.testing.CompileCounter()
1490
1491        torch._dynamo.reset()
1492
1493        def test_torch_static():
1494            input = torch.ones(*input_shape)
1495            return module(input)  # fully materialized
1496
1497        # test no graph break
1498        opt_test_torch_static = torch._dynamo.optimize(cnt, nopython=True)(
1499            test_torch_static
1500        )
1501        opt_test_torch_static()
1502        out = opt_test_torch_static()
1503
1504        self.assertTrue(same(out, module(torch.ones(*input_shape))))
1505
1506        self.assertTrue(
1507            isinstance(module, torch.nn.modules.batchnorm.BatchNorm3d),
1508            "Module should be transformed to an instance of BatchNorm3d.",
1509        )
1510        self.assertEqual(cnt.frame_count, 1, "No guards should have triggered.")
1511
1512    # RuntimeError: SymIntArrayRef expected to contain only concrete integers
1513    @expectedFailureDynamic
1514    def test_lazy_module2(self):
1515        # Test FX graph 'call_module' works well if argument is lazy module
1516        m = LazyMLP()
1517        x = torch.rand([10, 10])
1518        opt_m = torch._dynamo.optimize("eager", nopython=True)(m)
1519        # We should run compile mode firstly, otherwise the module
1520        # would be initialized when running eager mode.
1521        res = opt_m(x)
1522        ref = m(x)
1523        self.assertTrue(torch.allclose(ref, res))
1524
1525    # RuntimeError: SymIntArrayRef expected to contain only concrete integers
1526    @expectedFailureDynamic
1527    @unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
1528    def test_lazy_module3(self):
1529        m = LazyMLP()
1530        x = torch.rand([10, 10])
1531        cnt = torch._dynamo.testing.CompileCounter()
1532        opt_m = torch._dynamo.optimize(cnt, nopython=True)(m)
1533        # first iteration
1534        res = opt_m(x)
1535        ref = m(x)
1536        self.assertTrue(torch.allclose(ref, res))
1537        # move to cuda and second iteration
1538        m = m.to("cuda")
1539        x = x.to("cuda")
1540        res = opt_m(x)
1541        ref = m(x)
1542        self.assertTrue(torch.allclose(ref, res))
1543        self.assertEqual(cnt.frame_count, 2)
1544
1545    # RuntimeError: SymIntArrayRef expected to contain only concrete integers
1546    @expectedFailureDynamic
1547    def test_lazy_module4(self):
1548        m = LazyMLP()
1549        x = torch.rand([10, 10])
1550        cnt = torch._dynamo.testing.CompileCounter()
1551        opt_m = torch._dynamo.optimize(cnt, nopython=True)(m)
1552        # first iteration
1553        res = opt_m(x)
1554        ref = m(x)
1555        self.assertTrue(torch.allclose(ref, res))
1556        # input shape changed and second iteration
1557        x = torch.rand([20, 20])
1558        try:
1559            opt_m(x)
1560        except RuntimeError:
1561            self.assertIn("must have same reduction dim", traceback.format_exc())
1562
1563    # RuntimeError: SymIntArrayRef expected to contain only concrete integers
1564    @expectedFailureDynamic
1565    def test_lazy_module5(self):
1566        # Test lazy module works well with list/tuple input
1567        m = LazyModuleWithListInput()
1568        x = [torch.rand([5, 5])] * 3 + [None]
1569        opt_m = torch._dynamo.optimize("eager", nopython=True)(m)
1570        res = opt_m(x)
1571        ref = m(x)
1572        self.assertTrue(torch.allclose(ref, res))
1573
1574    # RuntimeError: SymIntArrayRef expected to contain only concrete integers
1575    @expectedFailureDynamic
1576    def test_lazy_module6(self):
1577        # Test new lazy submodule in lazy module's initialize_parameters
1578        m = LazyModuleWithLazySubmodule()
1579        x = [torch.rand([5, 5])] * 3
1580        opt_m = torch._dynamo.optimize("eager", nopython=True)(m)
1581        res = opt_m(x)
1582        ref = m(x)
1583        self.assertTrue(torch.allclose(ref, res))
1584
1585    # RuntimeError: SymIntArrayRef expected to contain only concrete integers
1586    @expectedFailureDynamic
1587    def test_lazy_module7(self):
1588        # Test lazy module works well with namedtuple/dict input
1589        m = LazyModuleWithNamedTupleInput()
1590        x = MyInput(
1591            x={"a": [torch.rand([5, 5])] * 3, "b": torch.rand([5, 5])},
1592            y=torch.rand([5, 5]),
1593        )
1594        opt_m = torch.compile(backend="eager", fullgraph=True)(m)
1595        res = opt_m(x)
1596        ref = m(x)
1597        self.assertTrue(torch.allclose(ref, res))
1598
1599    def test_lazy_module_no_cls_to_become(self):
1600        # make sure super() works in the case where cls_to_become is None
1601        m = LazyChildModuleNoClsToBecome()
1602        x = torch.rand(2, 2)
1603        opt_m = torch._dynamo.optimize("eager", nopython=True)(m)
1604        res = opt_m(x)
1605        ref = m(x)
1606        self.assertTrue(torch.allclose(ref, res))
1607
1608    def test_lazy_module_kwargs(self):
1609        m = LazyModuleKwArgs()
1610        x = [torch.rand([5, 5])] * 3
1611        y = [torch.rand([5, 5])] * 2
1612        opt_m = torch.compile(backend="eager", fullgraph=True)(m)
1613        exp_res = m(x, y)
1614        self.assertTrue(torch.allclose(exp_res, opt_m(x, y)))
1615
1616    def test_call_fn_with_non_const_inputs_safe(self):
1617        class ModuleSpecialFwd(torch.nn.Module):
1618            def __init__(self) -> None:
1619                super().__init__()
1620                self.conv = torch.nn.Conv2d(
1621                    in_channels=3, out_channels=20, kernel_size=(5, 5)
1622                )
1623
1624            def _conv_forward(self, x):
1625                return self.conv._conv_forward(x, self.conv.weight, self.conv.bias)
1626
1627            def forward(self, x):
1628                return self._conv_forward(x)
1629
1630        mod = ModuleSpecialFwd()
1631        rx = torch.randn([3, 10, 10])
1632        real = mod(rx)
1633        graph, _ = torch._dynamo.export(mod)(rx)
1634        self.assertTrue(torch._dynamo.testing.same(real, graph(rx)))
1635
1636    def test_conv_call_forward_directly(self):
1637        m = ConvCallForwardDirectly()
1638        x = torch.rand([4, 3, 9, 9])
1639        ref = m(x)
1640        opt_m = torch.compile(backend="eager", fullgraph=True)(m)
1641        res = opt_m(x)
1642        self.assertTrue(torch.allclose(ref, res))
1643
1644    def test_conv_transpose_call_forward_directly(self):
1645        m = ConvTransposeCallForwardDirectly()
1646        x = torch.rand([4, 4, 4, 4])
1647        ref = m(x)
1648        opt_m = torch.compile(backend="eager", fullgraph=True)(m)
1649        res = opt_m(x)
1650        self.assertTrue(torch.allclose(ref, res))
1651
1652    def test_conv_call_super_forward_directly(self):
1653        x = torch.randn(4, 4)
1654        m = ConvCallSuperForwardDirectly(4, 4, 4)
1655        ref = m(x)
1656        opt_m = torch.compile(backend="eager", fullgraph=True)(m)
1657        res = opt_m(x)
1658        self.assertTrue(torch.allclose(ref, res))
1659
1660    def test_conv_transpose_call_super_forward_directly(self):
1661        x = torch.randn(4, 4, 4)
1662        m = ConvTransposeCallSuperForwardDirectly(4, 4, 4)
1663        ref = m(x)
1664        opt_m = torch.compile(backend="eager", fullgraph=True)(m)
1665        res = opt_m(x)
1666        self.assertTrue(torch.allclose(ref, res))
1667
1668
1669class MockModule(torch.nn.Module):
1670    def __init__(self) -> None:
1671        super().__init__()
1672        self.relu = torch.nn.ReLU()
1673        self.linear = torch.nn.Linear(10, 10)
1674        self.buf0 = torch.nn.Buffer(torch.randn(10, 10))
1675
1676    def forward(self, x):
1677        return self.relu(self.linear(x) + self.buf0)
1678
1679
1680class OptimizedModuleTest(torch._dynamo.test_case.TestCase):
1681    def test_nn_module(self):
1682        mod = MockModule()
1683        cnt = torch._dynamo.testing.CompileCounter()
1684        opt_mod = torch._dynamo.optimize(cnt)(mod)
1685        self.assertIsInstance(opt_mod, torch._dynamo.OptimizedModule)
1686
1687        x = torch.randn(10, 10)
1688        self.assertTrue(torch._dynamo.testing.same(mod(x), opt_mod(x)))
1689        self.assertEqual(cnt.frame_count, 1)
1690
1691    @torch._dynamo.config.patch(guard_nn_modules=True)
1692    def test_attr_precedence(self):
1693        class Mod(torch.nn.Module):
1694            def __init__(self) -> None:
1695                super().__init__()
1696                self.a = 3
1697
1698            def forward(self, x, c=4):
1699                return x * c
1700
1701            def linear(self, x):
1702                return x
1703
1704            def b(self, x):
1705                raise RuntimeError("Should not be called")
1706
1707        class MyMod(Mod):
1708            def __init__(self) -> None:
1709                super().__init__()
1710                self.linear = torch.nn.Linear(11, 11)
1711                self.a = 2
1712                self.b = 2
1713                self.scale = 1
1714
1715            def scale(self, x):
1716                # Should not be called because it is shadowed by the instance
1717                # attribute
1718                raise RuntimeError("Should not be called")
1719
1720            def forward(self, x, c=None):
1721                return self.linear(x) * self.a * self.b * self.scale
1722
1723        mod = MyMod()
1724        x = torch.ones(3, 3)
1725        ref = mod(x)
1726
1727        cnts = torch._dynamo.testing.CompileCounter()
1728        opt_mod = torch.compile(mod, backend=cnts)
1729        opt_mod(torch.ones(3, 3))
1730        res = opt_mod(torch.ones(3, 3))
1731
1732        self.assertEqual(cnts.frame_count, 1)
1733        self.assertEqual(ref, res)
1734
1735    def test_to(self):
1736        mod = MockModule()
1737        cnt = torch._dynamo.testing.CompileCounter()
1738        opt_mod = torch._dynamo.optimize(cnt)(mod)
1739        x = torch.randn(10, 10)
1740        self.assertTrue(torch._dynamo.testing.same(mod(x), opt_mod(x)))
1741        self.assertEqual(cnt.frame_count, 1)
1742
1743        # Ensure that there is no recompilation
1744        opt_mod(x)
1745        self.assertEqual(cnt.frame_count, 1)
1746
1747        opt_mod = opt_mod.to(device="cpu").to(dtype=torch.float64)
1748        self.assertIsInstance(opt_mod, torch._dynamo.OptimizedModule)
1749        x = torch.randn(10, 10).to(dtype=torch.float64)
1750        opt_mod(x)
1751        # Ensure that there is a recompilation
1752        self.assertEqual(cnt.frame_count, 2)
1753
1754        # Ensure that there is no recompilation
1755        opt_mod(x)
1756        self.assertEqual(cnt.frame_count, 2)
1757
1758        torch._dynamo.reset()
1759        opt_mod(x)
1760        self.assertEqual(cnt.frame_count, 3)
1761
1762    @torch._dynamo.config.patch(guard_nn_modules=True)
1763    def test_param_order(self):
1764        class MyModule(torch.nn.Module):
1765            def __init__(self) -> None:
1766                super().__init__()
1767                self.param1 = torch.nn.Parameter(torch.ones([1]))
1768                self.param2 = torch.nn.Parameter(torch.ones([2]))
1769
1770            def forward(self, x):
1771                return x
1772
1773        mod = MyModule()
1774        coeffs = [2, 3]
1775
1776        def fn(x):
1777            for idx, p in enumerate(mod.parameters()):
1778                x += p.sum() * coeffs[idx]
1779
1780            for idx, p in enumerate(mod.named_parameters()):
1781                x += p[1].sum() * coeffs[idx]
1782
1783            return x
1784
1785        ref = fn(torch.ones(1))
1786        cnts = torch._dynamo.testing.CompileCounter()
1787        opt_fn = torch._dynamo.optimize(cnts)(fn)
1788        res = opt_fn(torch.ones(1))
1789
1790        self.assertEqual(ref, res)
1791        self.assertEqual(cnts.frame_count, 1)
1792
1793        mod._parameters["param1"] = mod._parameters.pop("param1")
1794        ref = fn(torch.ones(1))
1795        res = opt_fn(torch.ones(1))
1796
1797        self.assertEqual(ref, res)
1798        self.assertEqual(cnts.frame_count, 2)
1799
1800    @torch._dynamo.config.patch(guard_nn_modules=True)
1801    def test_buffer_order(self):
1802        class MyModule(torch.nn.Module):
1803            def __init__(self) -> None:
1804                super().__init__()
1805                self.b1 = torch.nn.Buffer(torch.ones([1]))
1806                self.b2 = torch.nn.Buffer(torch.ones([2]))
1807
1808            def forward(self, x):
1809                return x
1810
1811        mod = MyModule()
1812        coeffs = [2, 3]
1813
1814        def fn(x):
1815            for idx, p in enumerate(mod.buffers()):
1816                x += p.sum() * coeffs[idx]
1817
1818            for idx, p in enumerate(mod.named_buffers()):
1819                x += p[1].sum() * coeffs[idx]
1820
1821            return x
1822
1823        ref = fn(torch.ones(1))
1824        cnts = torch._dynamo.testing.CompileCounter()
1825        opt_fn = torch._dynamo.optimize(cnts)(fn)
1826        res = opt_fn(torch.ones(1))
1827
1828        self.assertEqual(ref, res)
1829        self.assertEqual(cnts.frame_count, 1)
1830
1831        mod._buffers["b1"] = mod._buffers.pop("b1")
1832        ref = fn(torch.ones(1))
1833        res = opt_fn(torch.ones(1))
1834
1835        self.assertEqual(ref, res)
1836        self.assertEqual(cnts.frame_count, 2)
1837
1838    @torch._dynamo.config.patch(guard_nn_modules=True)
1839    def test_module_order(self):
1840        class MyModule(torch.nn.Module):
1841            def __init__(self) -> None:
1842                super().__init__()
1843                self.linear1 = torch.nn.Linear(3, 3)
1844                self.linear2 = torch.nn.Linear(10, 10)
1845
1846            def forward(self, x):
1847                return x
1848
1849        mod = MyModule()
1850        coeffs = [2, 3, 4]
1851
1852        coeffs_for_mod = {mod: 10, mod.linear1: 20, mod.linear2: 30}
1853
1854        # Check order of _modules
1855        def fn(x):
1856            for idx, p in enumerate(mod.modules()):
1857                # Something silly to force depedency on the order
1858                x += coeffs_for_mod[p] * coeffs[idx]
1859            for idx, p in enumerate(mod.named_modules()):
1860                x += coeffs_for_mod[p[1]] * coeffs[idx]
1861            for idx, p in enumerate(mod.children()):
1862                x += coeffs_for_mod[p] * coeffs[idx]
1863            for idx, p in enumerate(mod.named_children()):
1864                x += coeffs_for_mod[p[1]] * coeffs[idx]
1865            return x
1866
1867        ref = fn(torch.ones(1))
1868        cnts = torch._dynamo.testing.CompileCounter()
1869        opt_fn = torch._dynamo.optimize(cnts)(fn)
1870        res = opt_fn(torch.ones(1))
1871
1872        self.assertEqual(ref, res)
1873        self.assertEqual(cnts.frame_count, 1)
1874
1875        mod._modules["linear1"] = mod._modules.pop("linear1")
1876        ref = fn(torch.ones(1))
1877        res = opt_fn(torch.ones(1))
1878
1879        self.assertEqual(ref, res)
1880        self.assertEqual(cnts.frame_count, 2)
1881
1882    def test_attr(self):
1883        class MockModule(torch.nn.Module):
1884            def __init__(self) -> None:
1885                super().__init__()
1886                self.linear = torch.nn.Linear(10, 10)
1887                self.buf0 = torch.nn.Buffer(torch.randn(10, 10))
1888
1889            def forward(self, x):
1890                return self.r(torch.sin(x)) + self.buf0
1891
1892        mod = MockModule()
1893        opt_mod = torch._dynamo.optimize("eager")(mod)
1894
1895        # Check parameters and buffers
1896        for p1, p2 in zip(mod.parameters(), opt_mod.parameters()):
1897            self.assertTrue(id(p1) == id(p2))
1898        for b1, b2 in zip(mod.buffers(), opt_mod.buffers()):
1899            self.assertTrue(id(b1) == id(b2))
1900
1901        def get_parameter_dtype(mod: torch.nn.Module):
1902            parameters_and_buffers = itertools.chain(mod.parameters(), mod.buffers())
1903            return next(parameters_and_buffers).dtype
1904
1905        opt_mod = torch._dynamo.optimize("eager")(get_parameter_dtype)
1906        out_dtype = opt_mod(mod)
1907        self.assertEqual(out_dtype, torch.float32)
1908
1909    def test_dir(self):
1910        class MockModule(torch.nn.Module):
1911            def __init__(self) -> None:
1912                super().__init__()
1913                self.linear = torch.nn.Linear(10, 10)
1914                self.buf0 = torch.nn.Buffer(torch.nn.Buffer(torch.randn(10, 10)))
1915                self.register_parameter(
1916                    name="param0", param=torch.nn.Parameter(torch.randn(10, 10))
1917                )
1918
1919            def forward(self, x):
1920                return self.r(torch.sin(x)) + self.buf0
1921
1922        mod = MockModule()
1923        mod_keys = dir(mod)
1924        opt_mod = torch._dynamo.optimize("eager")(mod)
1925        opt_mod_keys = dir(opt_mod)
1926
1927        # Check user-defined attributes, parameters and buffers
1928        self.assertIn("linear", opt_mod_keys)
1929        self.assertIn("buf0", opt_mod_keys)
1930        self.assertIn("param0", opt_mod_keys)
1931
1932        # Check all attributes, parameters and buffers
1933        self.assertTrue(len(set(mod_keys).difference(opt_mod_keys)) == 0)
1934
1935    def test_no_recompile_on_nn_guarded_modules(self):
1936        size = (10, 10)
1937        cache_size_limit = 1
1938        num_submodules = 4
1939        cnts = torch._dynamo.testing.CompileCounterWithBackend("eager")
1940
1941        class SubModule(torch.nn.Module):
1942            def __init__(self) -> None:
1943                super().__init__()
1944                self.linear = torch.nn.Linear(*size)
1945
1946            def forward(self, x):
1947                a = torch.sin(torch.cos(x))
1948                return self.linear(a)
1949
1950        class MockModule(torch.nn.Module):
1951            def __init__(self) -> None:
1952                super().__init__()
1953                self.mods = [SubModule() for _ in range(num_submodules)]
1954                self.mods = [torch.compile(mod, backend=cnts) for mod in self.mods]
1955
1956            def forward(self, x):
1957                for mod in self.mods:
1958                    x = mod(x)
1959                return x
1960
1961        mod = MockModule()
1962        # Each submod is compiled separately and has a different nn module
1963        # guard. Ensure that recompilation logic is handle correctly.
1964        with unittest.mock.patch(
1965            "torch._dynamo.config.error_on_recompile", True
1966        ), unittest.mock.patch(
1967            "torch._dynamo.config.cache_size_limit",
1968            cache_size_limit,
1969        ):
1970            x = torch.randn(*size, requires_grad=True)
1971            mod(x)
1972            if torch._dynamo.config.inline_inbuilt_nn_modules:
1973                self.assertEqual(cnts.frame_count, 1)
1974            else:
1975                self.assertEqual(cnts.frame_count, num_submodules)
1976
1977    @patch.object(torch._dynamo.config, "accumulated_cache_size_limit", 2)
1978    @patch.object(torch._dynamo.config, "inline_inbuilt_nn_modules", False)
1979    def test_recompile_limit_on_freed_module(self):
1980        class Mod(torch.nn.Module):
1981            def __init__(self) -> None:
1982                super().__init__()
1983                self.lin = torch.nn.Linear(5, 5)
1984
1985            def forward(self, x):
1986                return self.lin(x)
1987
1988        def fn(x, mod):
1989            return mod(x)
1990
1991        cnts = torch._dynamo.testing.CompileCounterWithBackend("eager")
1992        opt_mod = torch.compile(fn, backend=cnts)
1993        for i in range(8):
1994            mod = Mod()
1995            opt_mod(torch.randn(5, 5), mod)
1996
1997        # fn compiles twice
1998        self.assertEqual(cnts.frame_count, 2)
1999
2000    @patch.object(torch._dynamo.config, "inline_inbuilt_nn_modules", True)
2001    def test_inline_inbuilt_nn_modules(self):
2002        size = (10, 10)
2003        cache_size_limit = 1
2004        num_submodules = 4
2005        cnts = torch._dynamo.testing.CompileCounterWithBackend("eager")
2006
2007        class SubModule(torch.nn.Module):
2008            def __init__(self) -> None:
2009                super().__init__()
2010                self.linear = torch.nn.Linear(*size)
2011
2012            def forward(self, x):
2013                a = torch.sin(torch.cos(x))
2014                return self.linear(a)
2015
2016        class MockModule(torch.nn.Module):
2017            def __init__(self) -> None:
2018                super().__init__()
2019                self.mods = [SubModule() for _ in range(num_submodules)]
2020                self.mods = [torch.compile(mod, backend=cnts) for mod in self.mods]
2021
2022            def forward(self, x):
2023                for mod in self.mods:
2024                    x = mod(x)
2025                return x
2026
2027        mod = MockModule()
2028        # Each submod is compiled separately and has a different nn module
2029        # guard. Ensure that recompilation logic is handle correctly.
2030        with unittest.mock.patch(
2031            "torch._dynamo.config.error_on_recompile", True
2032        ), unittest.mock.patch(
2033            "torch._dynamo.config.cache_size_limit",
2034            cache_size_limit,
2035        ):
2036            x = torch.randn(*size, requires_grad=True)
2037            mod(x)
2038            self.assertEqual(cnts.frame_count, 1)
2039
2040    def test_cache_size_limit_on_guarded_nn_modules(self):
2041        cache_size_limit = 2
2042        num_submodules = 4
2043        cnts = torch._dynamo.testing.CompileCounterWithBackend("eager")
2044
2045        class SubModule(torch.nn.Module):
2046            def __init__(self) -> None:
2047                super().__init__()
2048                self.relu = torch.nn.ReLU()
2049
2050            def forward(self, x):
2051                a = torch.sin(torch.cos(x))
2052                return self.relu(a)
2053
2054        class MockModule(torch.nn.Module):
2055            def __init__(self) -> None:
2056                super().__init__()
2057                self.mods = [SubModule() for _ in range(num_submodules)]
2058                self.mods = [torch.compile(mod, backend=cnts) for mod in self.mods]
2059
2060            def forward(self, x):
2061                for mod in self.mods:
2062                    x = mod(x)
2063                return x
2064
2065        mod = MockModule()
2066        # For the third iteration, we would reach the cache size limit, and
2067        # therefore the total number of expected frame count is 2 *
2068        # num_submodules.
2069        with unittest.mock.patch(
2070            "torch._dynamo.config.cache_size_limit",
2071            cache_size_limit,
2072        ):
2073            for size in [
2074                (4,),
2075                (4, 4),
2076                (4, 4, 4),
2077            ]:
2078                x = torch.randn(size)
2079                mod(x)
2080        if torch._dynamo.config.inline_inbuilt_nn_modules:
2081            self.assertEqual(cnts.frame_count, 2)
2082        else:
2083            self.assertEqual(cnts.frame_count, 2 * num_submodules)
2084
2085    def test_recursion(self):
2086        mod = MockModule()
2087        cnt = torch._dynamo.testing.CompileCounter()
2088        opt_mod = torch._dynamo.optimize(cnt)(mod)
2089
2090        for _ in range(5):
2091            opt_mod = torch._dynamo.optimize(cnt)(opt_mod)
2092        opt_mod(torch.randn(10, 10))
2093        self.assertEqual(cnt.frame_count, 1)
2094
2095    def test_composition(self):
2096        class InnerModule(torch.nn.Module):
2097            def __init__(self) -> None:
2098                super().__init__()
2099                self.relu = torch.nn.ReLU()
2100
2101            def forward(self, x):
2102                return self.relu(torch.sin(x))
2103
2104        opt_inner_mod = InnerModule()
2105
2106        class OuterModule(torch.nn.Module):
2107            def __init__(self) -> None:
2108                super().__init__()
2109                self.mod = opt_inner_mod
2110
2111            def forward(self, x):
2112                return self.mod(torch.cos(x))
2113
2114        outer_mod = OuterModule()
2115        cnt = torch._dynamo.testing.CompileCounter()
2116        opt_outer_mod = torch._dynamo.optimize(cnt)(outer_mod)
2117
2118        x = torch.randn(4)
2119        self.assertIsInstance(opt_outer_mod, torch._dynamo.OptimizedModule)
2120        self.assertTrue(torch._dynamo.testing.same(outer_mod(x), opt_outer_mod(x)))
2121        self.assertEqual(cnt.frame_count, 1)
2122
2123    def test_composition_with_opt_mod(self):
2124        class InnerModule(torch.nn.Module):
2125            def __init__(self) -> None:
2126                super().__init__()
2127                self.relu = torch.nn.ReLU()
2128
2129            def forward(self, x):
2130                return self.relu(torch.sin(x))
2131
2132        inner_mod = InnerModule()
2133        cnt = torch._dynamo.testing.CompileCounter()
2134        opt_inner_mod = torch._dynamo.optimize(cnt)(inner_mod)
2135
2136        class OuterModule(torch.nn.Module):
2137            def __init__(self) -> None:
2138                super().__init__()
2139                self.mod = opt_inner_mod
2140
2141            def forward(self, x):
2142                return self.mod(torch.cos(x))
2143
2144        outer_mod = OuterModule()
2145        opt_outer_mod = torch._dynamo.optimize(cnt)(outer_mod)
2146
2147        x = torch.randn(4)
2148        self.assertIsInstance(opt_outer_mod, torch._dynamo.OptimizedModule)
2149        self.assertTrue(torch._dynamo.testing.same(outer_mod(x), opt_outer_mod(x)))
2150        # There will be a graph break for the inner mod being OptimizedModule
2151        self.assertEqual(cnt.frame_count, 2)
2152
2153    def test_module_patch(self):
2154        mod = ModulePatch1()
2155        mod.forward = types.MethodType(ModulePatch2.forward, mod)
2156
2157        def fn(x):
2158            return mod(x)
2159
2160        self.assertTrue(
2161            torch.allclose(
2162                torch._dynamo.optimize("eager", nopython=True)(fn)(torch.ones(10)),
2163                torch.zeros(1),
2164            )
2165        )
2166
2167    @patch.object(torch._dynamo.config, "skip_nnmodule_hook_guards", False)
2168    def test_hooks_outer(self):
2169        class TestModule(torch.nn.Module):
2170            def forward(self, x: torch.Tensor) -> torch.Tensor:
2171                return 2 * x + 1
2172
2173        m = TestModule()
2174
2175        def forward_hook(
2176            module: torch.nn.Module, inputs: Tuple[torch.Tensor], output: torch.Tensor
2177        ) -> torch.Tensor:
2178            return 2 * output + 1
2179
2180        handle = m.register_forward_hook(forward_hook)
2181        inp = torch.tensor(1.0, requires_grad=True)
2182
2183        failure_reason = None
2184
2185        def guard_fail_fn(failure):
2186            nonlocal failure_reason
2187            failure_reason = failure[0]
2188
2189        compiled_m = torch._dynamo.optimize(
2190            guard_fail_fn=guard_fail_fn, backend="eager"
2191        )(m)
2192
2193        self.assertEqual(compiled_m(inp), m(inp))
2194        self.assertEqual(compiled_m(inp).item(), 7)
2195        self.assertTrue(failure_reason is None)
2196
2197        # what if we remove our hook? we should recompile?
2198        handle.remove()
2199        self.assertEqual(compiled_m(inp), m(inp))
2200        self.assertEqual(compiled_m(inp).item(), 3)
2201        # self.assertTrue(failure_reason == "hook")
2202
2203        """
2204        Summary:
2205          - removing a hook doesn't fail a guard, because we weren't compiling the hook
2206            (at least into the same graph) as forward in the first place! We do correctly
2207            omit calling the removed hook, but since this hook is a post forward hook,
2208            the 'RETURN' from forward is breaking the graph.
2209
2210            Why is 'forward' the entrypoint to an InstructionTranslator, after I changed
2211            the eval_frame entrypoint to Module.__call__?
2212        """
2213
2214    @patch.object(torch._dynamo.config, "skip_nnmodule_hook_guards", False)
2215    def test_hooks_inner(self):
2216        class TestModule(torch.nn.Module):
2217            def forward(self, x: torch.Tensor) -> torch.Tensor:
2218                return 2 * x + 1
2219
2220        m = TestModule()
2221
2222        def forward_hook(
2223            module: torch.nn.Module, inputs: Tuple[torch.Tensor], output: torch.Tensor
2224        ) -> torch.Tensor:
2225            return 2 * output + 1
2226
2227        handle = m.register_forward_hook(forward_hook)
2228
2229        def outer_func(tensor):
2230            x = tensor * 2 + 1
2231            y = m(x)
2232            return y
2233
2234        inp = torch.tensor(1.0, requires_grad=True)
2235
2236        failure_reason = None
2237
2238        def guard_fail_fn(failure):
2239            nonlocal failure_reason
2240            failure_reason = failure[0]
2241
2242        cc = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")
2243        compiled_func = torch._dynamo.optimize(
2244            guard_fail_fn=guard_fail_fn,
2245            backend=cc,
2246        )(outer_func)
2247
2248        self.assertEqual(compiled_func(inp), outer_func(inp))
2249        self.assertEqual(compiled_func(inp).item(), 15)
2250
2251        # We are compiling 1 big graph for all 3 functions including the hook.
2252        self.assertEqual(cc.frame_count, 1)
2253        self.assertEqual(cc.op_count, 6)
2254
2255        # If we remove the hook, we should recompile
2256        handle.remove()
2257        self.assertEqual(compiled_func(inp), outer_func(inp))
2258        self.assertEqual(compiled_func(inp).item(), 7)
2259        self.assertTrue("forward_hooks" in failure_reason)
2260        self.assertEqual(cc.frame_count, 1 + 1)
2261        self.assertEqual(cc.op_count, 6 + 4)
2262
2263        # what if instead of removing, we alter our hook?
2264        torch._dynamo.reset()
2265        m = TestModule()
2266        handle = m.register_forward_hook(forward_hook)
2267        failure_reason = None
2268        self.assertEqual(compiled_func(inp), outer_func(inp))
2269        self.assertEqual(compiled_func(inp).item(), 15)
2270
2271        def new_forward_hook(
2272            module: torch.nn.Module, inputs: Tuple[torch.Tensor], output: torch.Tensor
2273        ) -> torch.Tensor:
2274            return 2 * output + 2
2275
2276        m._forward_hooks[handle.id] = new_forward_hook
2277        self.assertEqual(compiled_func(inp), outer_func(inp))
2278        self.assertEqual(compiled_func(inp).item(), 16)
2279        self.assertRegex(failure_reason, r"___check_obj_id\(L\['m'\]._forward_hooks")
2280
2281    @patch.object(torch._dynamo.config, "guard_nn_modules", False)
2282    @patch.object(torch._dynamo.config, "skip_nnmodule_hook_guards", True)
2283    @patch.object(torch._dynamo.config, "inline_inbuilt_nn_modules", False)
2284    def test_hooks_skip_guards(self):
2285        class TestModule(torch.nn.Module):
2286            def forward(self, x: torch.Tensor) -> torch.Tensor:
2287                return 2 * x + 1
2288
2289        m = TestModule()
2290
2291        def forward_hook(
2292            module: torch.nn.Module, inputs: Tuple[torch.Tensor], output: torch.Tensor
2293        ) -> torch.Tensor:
2294            return 2 * output + 1
2295
2296        handle = m.register_forward_hook(forward_hook)
2297
2298        def outer_func(tensor):
2299            x = tensor * 2 + 1
2300            y = m(x)
2301            return y
2302
2303        inp = torch.tensor(1.0, requires_grad=True)
2304
2305        failure_reason = None
2306
2307        def guard_fail_fn(failure):
2308            nonlocal failure_reason
2309            failure_reason = failure[0]
2310
2311        cc = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")
2312        compiled_func = torch._dynamo.optimize(
2313            guard_fail_fn=guard_fail_fn,
2314            backend=cc,
2315        )(outer_func)
2316
2317        m = TestModule()
2318        handle = m.register_forward_hook(forward_hook)
2319        failure_reason = None
2320        self.assertEqual(compiled_func(inp), outer_func(inp))
2321        self.assertEqual(compiled_func(inp).item(), 15)
2322        self.assertEqual(cc.frame_count, 1)
2323        self.assertEqual(cc.op_count, 6)
2324
2325        # if we remove the hook, dynamo shouldn't notice
2326        handle.remove()
2327        self.assertNotEqual(compiled_func(inp), outer_func(inp))
2328        self.assertEqual(compiled_func(inp).item(), 15)
2329        self.assertEqual(cc.frame_count, 1)
2330
2331    def _forward_hook_test_helper(self, model):
2332        forward_handles = {}
2333        compiled_activations = {}
2334        eager_activations = {}
2335        activations = None
2336
2337        def save_activations(name, mod, inp, out):
2338            activations[name] = inp
2339
2340        for name, module in model.named_modules():
2341            forward_handles[name] = module.register_forward_hook(
2342                partial(save_activations, name)
2343            )
2344
2345        compiled_model = torch.compile(model, backend="aot_eager")
2346
2347        activations = compiled_activations
2348        for i in range(2):
2349            # second iteration is key, hooks would have fired during aot trace
2350            # on first iter
2351            compiled_activations.clear()
2352            x = torch.randn((20, 10))
2353            pred = compiled_model(x)
2354            loss = pred.sum()
2355            loss.backward()
2356
2357        activations = eager_activations
2358        for i in range(2):
2359            # second iteration is key, hooks would have fired during aot trace
2360            # on first iter
2361            eager_activations.clear()
2362            x = torch.randn((20, 10))
2363            pred = model(x)
2364            loss = pred.sum()
2365            loss.backward()
2366
2367        print(f"Recorded Layers: {compiled_activations.keys()}\n\n")
2368        print(f"Expected Layers: {eager_activations.keys()}")
2369
2370        self.assertTrue(compiled_activations.keys() == eager_activations.keys())
2371        self.assertTrue(activations.keys() == forward_handles.keys())
2372
2373    def test_hooks_allowed_modules(self):
2374        # this test shouldn't care whether hook guards are enabled or not
2375        class ToyModel(torch.nn.Module):
2376            def __init__(self) -> None:
2377                super().__init__()
2378                self.net = torch.nn.Sequential(
2379                    *[torch.nn.Linear(10, 10000), torch.nn.ReLU()]
2380                    + [torch.nn.Linear(10000, 5), torch.nn.ReLU()]
2381                )
2382
2383            def forward(self, x):
2384                return self.net(x)
2385
2386        model = ToyModel()
2387        self._forward_hook_test_helper(model)
2388
2389    def test_hooks_allowed_modules_compiles(self):
2390        class ToyModel(torch.nn.Module):
2391            def __init__(self) -> None:
2392                super().__init__()
2393                self.net = torch.nn.Sequential(
2394                    *[torch.nn.Linear(10, 10000), torch.nn.ReLU()]
2395                    + [torch.nn.Linear(10000, 5), torch.nn.ReLU()]
2396                )
2397
2398            def forward(self, x):
2399                return self.net(x)
2400
2401        model = ToyModel()
2402        activations = []
2403
2404        def save_activations(mod, inp, out):
2405            activations.append(inp)
2406
2407        for name, module in model.named_modules():
2408            module.register_forward_hook(save_activations)
2409
2410        cnt = torch._dynamo.testing.CompileCounter()
2411        model = torch._dynamo.optimize(cnt, nopython=True)(model)
2412        for i in range(2):
2413            # second iteration is key, hooks would have fired during aot trace
2414            # on first iter
2415            activations.clear()
2416            x = torch.randn((20, 10))
2417            pred = model(x)
2418            loss = pred.sum()
2419            loss.backward()
2420        self.assertEqual(len(activations), 6)
2421        self.assertEqual(cnt.frame_count, 1)
2422
2423    def test_hooks_allowed_modules_compiles_self_contained(self):
2424        class ToyModel(torch.nn.Module):
2425            def __init__(self) -> None:
2426                super().__init__()
2427                self.net = torch.nn.Sequential(
2428                    *[torch.nn.Linear(10, 10000), torch.nn.ReLU()]
2429                    + [torch.nn.Linear(10000, 5), torch.nn.ReLU()]
2430                )
2431
2432            def forward(self, x):
2433                return self.net(x) * self.net(x)
2434
2435        model = ToyModel()
2436        forward_handles = {}
2437
2438        def output_modifying_hook(mod, inp, out):
2439            return 2 * out + 1
2440
2441        for name, module in model.named_modules():
2442            forward_handles[name] = module.register_forward_hook(output_modifying_hook)
2443
2444        cnt = torch._dynamo.testing.CompileCounter()
2445
2446        x = torch.randn((20, 10))
2447        pred_eager = model(x)
2448        loss_eager = pred_eager.sum()
2449        eager_loss_bwd = loss_eager.backward()
2450
2451        model = torch._dynamo.optimize(cnt, nopython=True)(model)
2452        pred = model(x)
2453
2454        loss = pred.sum()
2455        loss_bwd = loss.backward()
2456
2457        self.assertEqual(eager_loss_bwd, loss_bwd)
2458        self.assertEqual(cnt.frame_count, 2)
2459
2460        # Ndim change, recompile
2461        pred = model(torch.randn([10, 10, 10]))
2462        self.assertEqual(cnt.frame_count, 4)
2463
2464        # Stable
2465        pred = model(torch.randn([10, 10, 10]))
2466        self.assertEqual(cnt.frame_count, 4)
2467
2468    def test_dunder_call_explicitly(self):
2469        # hooks should be triggered if explicit calling `__call__`
2470        class ToyModel(torch.nn.Module):
2471            def __init__(self) -> None:
2472                super().__init__()
2473                self.linear = torch.nn.Linear(10, 10000)
2474
2475            def forward(self, x):
2476                return self.linear.__call__(x)
2477
2478        model = ToyModel()
2479        self._forward_hook_test_helper(model)
2480
2481    def test_backward_hooks(self):
2482        # this test shouldn't care whether hook guards are enabled or not
2483
2484        class CustomLinear(torch.nn.Module):
2485            # not an 'allowed module', so should not graph-break
2486            def __init__(self, a, b):
2487                super().__init__()
2488                self.weight = torch.nn.Parameter(torch.randn(a, b))
2489
2490            def forward(self, x):
2491                return torch.mm(x, self.weight)
2492
2493        class ToyModel(torch.nn.Module):
2494            def __init__(self) -> None:
2495                super().__init__()
2496                self.net = torch.nn.Sequential(
2497                    *[CustomLinear(10, 10)]
2498                    + [CustomLinear(10, 10000)]
2499                    + [CustomLinear(10000, 5)]
2500                )
2501
2502            def forward(self, x):
2503                return self.net(x)
2504
2505        model = ToyModel()
2506        backward_hook_handles = {}
2507        pre_backward_hook_handles = {}
2508
2509        grad_sizes = {}
2510
2511        def backward_hook(name, mod, grad_inp, grad_out):
2512            grad_sizes[name] = (
2513                (gi.shape for gi in grad_inp),
2514                (go.shape for go in grad_out),
2515            )
2516            return None
2517
2518        pre_grad_sizes = {}
2519
2520        def backward_pre_hook(name, mod, grad_out):
2521            pre_grad_sizes[name] = (go.shape for go in grad_out)
2522            return None
2523
2524        for name, module in model.named_modules():
2525            backward_hook_handles[name] = module.register_full_backward_hook(
2526                partial(backward_hook, name)
2527            )
2528
2529            pre_backward_hook_handles[name] = module.register_full_backward_pre_hook(
2530                partial(backward_pre_hook, name)
2531            )
2532
2533        model = torch.compile(model, backend="aot_eager")
2534
2535        for i in range(2):
2536            # second iteration is key, hooks would have fired during aot trace
2537            # on first iter
2538            x = torch.randn((20, 10))
2539            pred = model(x)
2540            loss = pred.sum()
2541            loss.backward()
2542
2543        self.assertTrue(grad_sizes.keys() == backward_hook_handles.keys())
2544        self.assertTrue(pre_grad_sizes.keys() == pre_backward_hook_handles.keys())
2545
2546    def test_udo_instance_method_as_hook(self):
2547        class CustomClass:
2548            def __init__(self, module):
2549                self.module = module
2550                self.handle = self.module.register_forward_pre_hook(
2551                    self.func1, prepend=True, with_kwargs=True
2552                )
2553
2554            def func1(self, module, args, kwargs):
2555                return (args[0] + 1,), kwargs
2556
2557            def __call__(self, x):
2558                return self.module(x)
2559
2560        class ToyModel(torch.nn.Module):
2561            def __init__(self) -> None:
2562                super().__init__()
2563
2564            def forward(self, x):
2565                return x * x
2566
2567        model = ToyModel()
2568        x = torch.zeros((3, 4))
2569        obj = CustomClass(model)
2570        out = torch.compile(obj, fullgraph=True)(x)
2571        self.assertEqual(out, (x + 1) * (x + 1))
2572
2573    def test_module_dict_iter_name(self):
2574        class MyModule(torch.nn.Module):
2575            def __init__(self) -> None:
2576                super().__init__()
2577                self.activations = torch.nn.ModuleDict(
2578                    [["lrelu", torch.nn.LeakyReLU()], ["prelu", torch.nn.PReLU()]]
2579                )
2580
2581            def forward(self, x):
2582                for activation_name in self.activations:
2583                    x = self.activations[activation_name](x)
2584                return x
2585
2586        cnt = torch._dynamo.testing.CompileCounter()
2587        # Eager
2588        eager_res = MyModule()(torch.ones(10, 10))
2589
2590        # Compile
2591        optim_res = torch._dynamo.optimize(cnt)(MyModule())(torch.ones(10, 10))
2592        self.assertEqual(eager_res, optim_res)
2593        self.assertEqual(cnt.frame_count, 1)
2594
2595    def test_module_dict_iter_keys(self):
2596        class MyModule(torch.nn.Module):
2597            def __init__(self) -> None:
2598                super().__init__()
2599                self.activations = torch.nn.ModuleDict(
2600                    [["lrelu", torch.nn.LeakyReLU()], ["prelu", torch.nn.PReLU()]]
2601                )
2602
2603            def forward(self, x):
2604                for activation_name in self.activations.keys():
2605                    x = self.activations[activation_name](x)
2606                return x
2607
2608        cnt = torch._dynamo.testing.CompileCounter()
2609        # Eager
2610        eager_res = MyModule()(torch.ones(10, 10))
2611
2612        # Compile
2613        optim_res = torch._dynamo.optimize(cnt)(MyModule())(torch.ones(10, 10))
2614        self.assertEqual(eager_res, optim_res)
2615        self.assertEqual(cnt.frame_count, 1)
2616
2617    def test_module_setattr(self):
2618        models = torch.nn.Sequential(torch.nn.Linear(3, 3))
2619        models[0].abc = False
2620
2621        def run():
2622            models[0].abc = True
2623            x = torch.randn(1, 3)
2624            return models(x)
2625
2626        run = torch.compile(run, fullgraph=True)
2627        run()
2628        self.assertTrue(models[0].abc)
2629
2630    def test_assign_does_not_exist(self):
2631        class MyModule(torch.nn.Module):
2632            def forward(self, x):
2633                self.text_encoding = x + 1
2634                return self.text_encoding
2635
2636        mod = MyModule()
2637        out = torch.compile(mod, fullgraph=True)(torch.randn(10))
2638        assert mod.text_encoding is out
2639
2640    def test_module_dict_iter_values(self):
2641        class MyModule(torch.nn.Module):
2642            def __init__(self) -> None:
2643                super().__init__()
2644                self.activations = torch.nn.ModuleDict(
2645                    [["lrelu", torch.nn.LeakyReLU()], ["prelu", torch.nn.PReLU()]]
2646                )
2647
2648            def forward(self, x):
2649                for activation in self.activations.values():
2650                    x = activation(x)
2651                return x
2652
2653        cnt = torch._dynamo.testing.CompileCounter()
2654        # Eager
2655        eager_res = MyModule()(torch.ones(10, 10))
2656
2657        # Compile
2658        optim_res = torch._dynamo.optimize(cnt)(MyModule())(torch.ones(10, 10))
2659        self.assertEqual(eager_res, optim_res)
2660        self.assertEqual(cnt.frame_count, 1)
2661
2662    def test_unspecialized_seq(self):
2663        models = torch.nn.Sequential(torch.nn.Linear(3, 3))
2664
2665        def fn(x):
2666            models[0].training = False
2667            return models(x)
2668
2669        opt_fn = torch._dynamo.optimize("eager")(fn)
2670        x = torch.randn(1, 3)
2671        ref = fn(x)
2672        res = opt_fn(x)
2673        self.assertEqual(ref, res)
2674
2675    def test_no_op_assignment(self):
2676        class Mod(torch.nn.Module):
2677            def __init__(self) -> None:
2678                super().__init__()
2679                self.buffer = torch.rand([4])
2680
2681            def forward(self, x):
2682                # should be a no-op, but causes dynamo to lose the static input
2683                x = x + 1
2684                self.buffer = self.buffer.to(x)
2685                return self.buffer + x
2686
2687        compiles_without_buffers = 0
2688
2689        def debug_compile(gm, *args, **kwargs):
2690            nonlocal compiles_without_buffers
2691            compiles_without_buffers += len(list(gm.buffers())) == 0
2692            return gm
2693
2694        @torch.compile(backend=debug_compile)
2695        def foo(mod, x):
2696            return mod(x)
2697
2698        mod = Mod()
2699        foo(mod, torch.rand([4]))
2700        if torch._dynamo.config.inline_inbuilt_nn_modules:
2701            self.assertEqual(compiles_without_buffers, 1)
2702        else:
2703            self.assertEqual(compiles_without_buffers, 0)
2704
2705        foo(mod, torch.rand([4], dtype=torch.half))
2706        if torch._dynamo.config.inline_inbuilt_nn_modules:
2707            self.assertEqual(compiles_without_buffers, 2)
2708        else:
2709            self.assertEqual(compiles_without_buffers, 1)
2710
2711        class Mod2(Mod):
2712            def __setattr__(self, name, value):
2713                return super().__setattr__(name, value)
2714
2715        foo(Mod2(), torch.rand([4]))
2716        # causes two compilations, bc unimplemented custom setattr
2717        self.assertTrue(compiles_without_buffers >= 2)
2718
2719    def test_unspec_non_inlinable_module(self):
2720        mod = UnspecNonInlinableModule()
2721        opt_fn = torch._dynamo.optimize("eager")(mod)
2722        x = torch.randn(100)
2723        actual = opt_fn(x)
2724        expected = mod(x)
2725        self.assertEqual(actual, expected)
2726
2727    @torch._dynamo.config.patch("inline_inbuilt_nn_modules", True)
2728    def test_mark_static_previously_seen_tensor(self):
2729        # This test verifies that dynamo will mark
2730        # the buffers/params of a module as static
2731        # even if this param was previously seen
2732        # (ex. as a different input)
2733        num_compiles = 0
2734
2735        def debug_compiler(gm, _):
2736            nonlocal num_compiles
2737            num_compiles += 1
2738
2739            input_nodes = [
2740                n for n in gm.graph.nodes if n.op == "placeholder" and n.name == "l_b_"
2741            ]
2742
2743            self.assertGreater(len(input_nodes), 0)
2744            for input_node in input_nodes:
2745                self.assertEqual(
2746                    input_node.meta["tensor_dict"]["_dynamo_static_input_type"],
2747                    "unguarded",
2748                )
2749
2750            return gm
2751
2752        class TestModule(torch.nn.Module):
2753            def __init__(self, buf) -> None:
2754                super().__init__()
2755                # Changing this one to nn.Buffer fails because `nn.Buffer` does a .detach()
2756                # so the value in self.tx.output.side_effects will no longer evaluate to True
2757                self.register_buffer("buf", buf)
2758
2759            def forward(self, x):
2760                return self.buf * x
2761
2762        @torch._dynamo.optimize(backend=debug_compiler)
2763        def fn(x, b, mod):
2764            z = b + 1
2765            return z * mod(x)
2766
2767        buf = torch.ones(2, 2)
2768        inp = torch.ones(2)
2769        mod = TestModule(buf)
2770        fn(inp, buf, mod)
2771        self.assertEqual(num_compiles, 1)
2772
2773    @torch._dynamo.config.patch("inline_inbuilt_nn_modules", True)
2774    def test_mark_static_nn_module_tensor(self):
2775        # This test verifies that dynamo will mark
2776        # the nn module tensor attributes as static
2777        num_compiles = 0
2778
2779        def debug_compiler(gm, _):
2780            nonlocal num_compiles
2781            num_compiles += 1
2782
2783            input_nodes = [
2784                n
2785                for n in gm.graph.nodes
2786                if n.op == "placeholder" and n.name == "l_mod_buf"
2787            ]
2788
2789            self.assertGreater(len(input_nodes), 0)
2790            for input_node in input_nodes:
2791                self.assertEqual(
2792                    input_node.meta["tensor_dict"]["_dynamo_static_input_type"],
2793                    "unguarded",
2794                )
2795
2796            return gm
2797
2798        class TestModule(torch.nn.Module):
2799            def __init__(self) -> None:
2800                super().__init__()
2801                self.buf = torch.ones(2, 2)
2802
2803            def forward(self, x):
2804                return self.buf * x
2805
2806        mod = TestModule()
2807
2808        @torch._dynamo.optimize(backend=debug_compiler)
2809        def fn(x):
2810            return x * mod(x)
2811
2812        inp = torch.ones(2)
2813        fn(inp)
2814        self.assertEqual(num_compiles, 1)
2815
2816    @torch._dynamo.config.patch("inline_inbuilt_nn_modules", True)
2817    @torch._inductor.config.patch("freezing", True)
2818    @torch.no_grad()
2819    def test_mark_static_with_freezing(self):
2820        # This test verifies that dynamo will
2821        # add buffers/params as attributes of the
2822        # graph w/ guards if freezing is enabled
2823        num_compiles = 0
2824
2825        def debug_compiler(gm, _):
2826            nonlocal num_compiles
2827            num_compiles += 1
2828
2829            input_nodes = [
2830                n for n in gm.graph.nodes if n.op == "placeholder" and n.name == "l_b_"
2831            ]
2832            self.assertEqual(len(input_nodes), 0)
2833            self.assertEqual(len(list(gm.buffers())), 1)
2834            return gm
2835
2836        class TestModule(torch.nn.Module):
2837            def __init__(self, buf) -> None:
2838                super().__init__()
2839                self.buf = torch.nn.Buffer(buf)
2840
2841            def forward(self, x):
2842                return self.buf * x
2843
2844        @torch._dynamo.optimize(backend=debug_compiler)
2845        def fn(x, mod):
2846            return mod(x)
2847
2848        buf = torch.ones(2, 2)
2849        inp = torch.ones(2)
2850        mod = TestModule(buf)
2851        fn(inp, mod)
2852        self.assertEqual(num_compiles, 1)
2853        mod.buf = torch.rand_like(buf)
2854        fn(inp, mod)
2855        self.assertEqual(num_compiles, 2)
2856
2857    @patch.object(torch._dynamo.config, "guard_nn_modules", True)
2858    def test_guard_on_torch_nn_modules(self):
2859        # https://github.com/pytorch/pytorch/issues/110048
2860
2861        class MockModule(torch.nn.Module):
2862            def __init__(self) -> None:
2863                super().__init__()
2864                self.linear = torch.nn.Linear(10, 10)
2865                self.multiplier = 10
2866
2867            def forward(self, x):
2868                return self.linear(x) * self.multiplier
2869
2870        mod = MockModule()
2871
2872        cnt = torch._dynamo.testing.CompileCounter()
2873
2874        @torch.compile(backend=cnt)
2875        def generate(x, c):
2876            return mod(x) + c
2877
2878        for _ in range(0, 10):
2879            generate(torch.randn(10, 10), 0)
2880            generate(torch.randn(10, 10), 1)
2881        self.assertEqual(cnt.frame_count, 2)
2882
2883        # Ensure that modification in user module causes recompile
2884        mod.multiplier = 11
2885        generate(torch.randn(10, 10), 0)
2886        self.assertEqual(cnt.frame_count, 3)
2887
2888    def test_setattr_on_compiled_module(self):
2889        # https://github.com/pytorch/pytorch/issues/114844
2890
2891        class ReplayMutation(torch.nn.Module):
2892            def __init__(self, inp_size, out_size, inner_size):
2893                super().__init__()
2894                self.Linear1 = torch.nn.Linear(inp_size, inner_size)
2895                self.Linear2 = torch.nn.Linear(inner_size, out_size)
2896                self.x = None
2897
2898            def forward(self, inp):
2899                res = self.Linear1(inp)
2900                self.x = res
2901                return self.Linear2(res)
2902
2903        N, D_in, H, D_out, inner = 2, 2, 2, 2, 4
2904        model = ReplayMutation(D_in, H, inner)
2905        model2 = copy.deepcopy(model)
2906        input = torch.ones(N, D_in)
2907
2908        # Keep some intermediate value in model.x
2909        model.x = torch.tensor([[100, 100, 100, 100], [200, 200, 200, 200]])
2910        model(input)
2911
2912        compiled_model = torch.compile(model2, backend="eager")
2913        compiled_model.x = torch.tensor([[100, 100, 100, 100], [200, 200, 200, 200]])
2914        compiled_model(input)
2915
2916        self.assertEqual(model.x, compiled_model.x)
2917
2918    def test_globals_change_in_other_file(self):
2919        @torch.compile(backend="eager", fullgraph=True)
2920        def fn(x):
2921            update_global()
2922            a = test_functions.update_global(x)
2923            # Ensure that the updated global values are read
2924            return x * a * (_variable + _variable1 + test_functions._variable)
2925
2926        res = fn(torch.ones(10))
2927        self.assertEqual(_variable, 1)
2928        self.assertEqual(_variable1, 1)
2929        # Ensure that the reconstructed bytecode updates the global value in the
2930        # other file.
2931        self.assertEqual(test_functions._variable, 1)
2932        self.assertEqual(res, 3 * torch.ones(10))
2933
2934    @unittest.skipIf(
2935        "inductor" not in torch._dynamo.list_backends(),
2936        "inductor backend is not available",
2937    )
2938    def test_save_and_load_inductor(self):
2939        mod = MockModule()
2940        opt_mod = torch.compile(mod, backend="inductor")
2941        inp = torch.randn(10, 10)
2942        opt_mod(inp)
2943
2944        with tempfile.TemporaryDirectory() as tmpdirname:
2945            torch.save(opt_mod, os.path.join(tmpdirname, "model.pt"))
2946            loaded_model = torch.load(os.path.join(tmpdirname, "model.pt"))
2947        loaded_model(inp)
2948        self.assertTrue(same_two_models(loaded_model, mod, [inp]))
2949        self.assertTrue(same_two_models(loaded_model, opt_mod, [inp]))
2950
2951        torch._dynamo.reset()  # force recompiles
2952        torch._inductor.metrics.generated_kernel_count = 0
2953        loaded_model(inp)
2954        self.assertGreater(torch._inductor.metrics.generated_kernel_count, 0)
2955
2956    def test_save_and_load_all_backends(self):
2957        mod = MockModule()
2958        inp = torch.randn(10, 10)
2959        for backend in torch._dynamo.list_backends():
2960            try:
2961                opt_mod = torch.compile(mod, backend=backend)
2962                with tempfile.TemporaryDirectory() as tmpdirname:
2963                    torch.save(opt_mod, os.path.join(tmpdirname, "model.pt"))
2964                    loaded_model = torch.load(os.path.join(tmpdirname, "model.pt"))
2965                torch._dynamo.reset()  # force recompiles
2966                torch._inductor.metrics.generated_kernel_count = 0
2967                opt_mod(inp)
2968                opt_success = torch._inductor.metrics.generated_kernel_count == 0
2969                torch._dynamo.reset()  # force recompiles
2970                torch._inductor.metrics.generated_kernel_count = 0
2971                loaded_model(inp)
2972                loaded_success = torch._inductor.metrics.generated_kernel_count == 0
2973                self.assertEqual(opt_success, loaded_success)
2974            except torch._dynamo.exc.BackendCompilerFailed:
2975                pass
2976
2977    def test_monkeypatching_forward(self):
2978        class FakeModule(torch.nn.Module):
2979            def forward(self, x):
2980                return torch.sin(x)
2981
2982        class MyModule(torch.nn.Module):
2983            def __init__(self, x):
2984                super().__init__()
2985
2986            def forward(self, x):
2987                return torch.cos(x)
2988
2989        def helper():
2990            torch._dynamo.reset()
2991            mod = MyModule(3)
2992
2993            def fn(x):
2994                return mod(x)
2995
2996            cnt = torch._dynamo.testing.CompileCounter()
2997            opt_fn = torch._dynamo.optimize(cnt)(fn)
2998            x = torch.randn(10)
2999
3000            opt_fn(x)
3001            opt_fn(x)
3002            self.assertEqual(cnt.frame_count, 1)
3003
3004            # Monkeypatch forward
3005            mod.forward = types.MethodType(FakeModule.forward, mod)
3006            ref = fn(x)
3007            res = opt_fn(x)
3008            self.assertEqual(ref, res)
3009            self.assertEqual(cnt.frame_count, 2)
3010
3011        helper()
3012        with torch._dynamo.config.patch(inline_inbuilt_nn_modules=True):
3013            helper()
3014
3015    def test_user_defined_nn_module_dynamic(self):
3016        class Conv2d(torch.nn.Conv2d):
3017            def __init__(self, *args, **kwargs):
3018                super().__init__(*args, **kwargs)
3019
3020            def forward(self, x):
3021                x = torch.nn.functional.conv2d(
3022                    x,
3023                    self.weight,
3024                    self.bias,
3025                    self.stride,
3026                    self.padding,
3027                    self.dilation,
3028                    self.groups,
3029                )
3030                return x
3031
3032        cnts = torch._dynamo.testing.CompileCounter()
3033        mod1 = Conv2d(64, 64, kernel_size=(2, 2), stride=(1, 1))
3034        mod2 = Conv2d(64, 64, kernel_size=(2, 2), stride=(2, 2))
3035        mod3 = Conv2d(64, 64, kernel_size=(2, 2), stride=(3, 3))
3036
3037        opt_mod1 = torch.compile(mod1, backend=cnts, fullgraph=True)
3038        opt_mod2 = torch.compile(mod2, backend=cnts, fullgraph=True)
3039        opt_mod3 = torch.compile(mod3, backend=cnts, fullgraph=True)
3040
3041        x = torch.randn(1, 64, 64, 64)
3042        opt_mod1(x)
3043        opt_mod2(x)
3044        opt_mod3(x)
3045
3046        # Must be 3 compilations. If not marked static there would be 2, because strides would be converted to symints.
3047        self.assertEqual(cnts.frame_count, 3)
3048
3049
3050if __name__ == "__main__":
3051    from torch._dynamo.test_case import run_tests
3052
3053    run_tests()
3054