xref: /aosp_15_r20/external/pytorch/test/jit/test_scriptmod_ann.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: jit"]
2
3import os
4import sys
5import unittest
6import warnings
7from typing import Dict, List, Optional
8
9import torch
10
11
12# Make the helper files in test/ importable
13pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
14sys.path.append(pytorch_test_dir)
15from torch.testing._internal.jit_utils import JitTestCase
16
17
18if __name__ == "__main__":
19    raise RuntimeError(
20        "This test file is not meant to be run directly, use:\n\n"
21        "\tpython test/test_jit.py TESTNAME\n\n"
22        "instead."
23    )
24
25
26class TestScriptModuleInstanceAttributeTypeAnnotation(JitTestCase):
27    # NB: There are no tests for `Tuple` or `NamedTuple` here. In fact,
28    # reassigning a non-empty Tuple to an attribute previously typed
29    # as containing an empty Tuple SHOULD fail. See note in `_check.py`
30
31    def test_annotated_falsy_base_type(self):
32        class M(torch.nn.Module):
33            def __init__(self) -> None:
34                super().__init__()
35                self.x: int = 0
36
37            def forward(self, x: int):
38                self.x = x
39                return 1
40
41        with warnings.catch_warnings(record=True) as w:
42            self.checkModule(M(), (1,))
43        assert len(w) == 0
44
45    def test_annotated_nonempty_container(self):
46        class M(torch.nn.Module):
47            def __init__(self) -> None:
48                super().__init__()
49                self.x: List[int] = [1, 2, 3]
50
51            def forward(self, x: List[int]):
52                self.x = x
53                return 1
54
55        with warnings.catch_warnings(record=True) as w:
56            self.checkModule(M(), ([1, 2, 3],))
57        assert len(w) == 0
58
59    def test_annotated_empty_tensor(self):
60        class M(torch.nn.Module):
61            def __init__(self) -> None:
62                super().__init__()
63                self.x: torch.Tensor = torch.empty(0)
64
65            def forward(self, x: torch.Tensor):
66                self.x = x
67                return self.x
68
69        with warnings.catch_warnings(record=True) as w:
70            self.checkModule(M(), (torch.rand(2, 3),))
71        assert len(w) == 0
72
73    def test_annotated_with_jit_attribute(self):
74        class M(torch.nn.Module):
75            def __init__(self) -> None:
76                super().__init__()
77                self.x = torch.jit.Attribute([], List[int])
78
79            def forward(self, x: List[int]):
80                self.x = x
81                return self.x
82
83        with warnings.catch_warnings(record=True) as w:
84            self.checkModule(M(), ([1, 2, 3],))
85        assert len(w) == 0
86
87    def test_annotated_class_level_annotation_only(self):
88        class M(torch.nn.Module):
89            x: List[int]
90
91            def __init__(self) -> None:
92                super().__init__()
93                self.x = []
94
95            def forward(self, y: List[int]):
96                self.x = y
97                return self.x
98
99        with warnings.catch_warnings(record=True) as w:
100            self.checkModule(M(), ([1, 2, 3],))
101        assert len(w) == 0
102
103    def test_annotated_class_level_annotation_and_init_annotation(self):
104        class M(torch.nn.Module):
105            x: List[int]
106
107            def __init__(self) -> None:
108                super().__init__()
109                self.x: List[int] = []
110
111            def forward(self, y: List[int]):
112                self.x = y
113                return self.x
114
115        with warnings.catch_warnings(record=True) as w:
116            self.checkModule(M(), ([1, 2, 3],))
117        assert len(w) == 0
118
119    def test_annotated_class_level_jit_annotation(self):
120        class M(torch.nn.Module):
121            x: List[int]
122
123            def __init__(self) -> None:
124                super().__init__()
125                self.x: List[int] = torch.jit.annotate(List[int], [])
126
127            def forward(self, y: List[int]):
128                self.x = y
129                return self.x
130
131        with warnings.catch_warnings(record=True) as w:
132            self.checkModule(M(), ([1, 2, 3],))
133        assert len(w) == 0
134
135    def test_annotated_empty_list(self):
136        class M(torch.nn.Module):
137            def __init__(self) -> None:
138                super().__init__()
139                self.x: List[int] = []
140
141            def forward(self, x: List[int]):
142                self.x = x
143                return 1
144
145        with self.assertRaisesRegexWithHighlight(
146            RuntimeError, "Tried to set nonexistent attribute", "self.x = x"
147        ):
148            with self.assertWarnsRegex(
149                UserWarning,
150                "doesn't support "
151                "instance-level annotations on "
152                "empty non-base types",
153            ):
154                torch.jit.script(M())
155
156    @unittest.skipIf(
157        sys.version_info[:2] < (3, 9), "Requires lowercase static typing (Python 3.9+)"
158    )
159    def test_annotated_empty_list_lowercase(self):
160        class M(torch.nn.Module):
161            def __init__(self) -> None:
162                super().__init__()
163                self.x: list[int] = []
164
165            def forward(self, x: list[int]):
166                self.x = x
167                return 1
168
169        with self.assertRaisesRegexWithHighlight(
170            RuntimeError, "Tried to set nonexistent attribute", "self.x = x"
171        ):
172            with self.assertWarnsRegex(
173                UserWarning,
174                "doesn't support "
175                "instance-level annotations on "
176                "empty non-base types",
177            ):
178                torch.jit.script(M())
179
180    def test_annotated_empty_dict(self):
181        class M(torch.nn.Module):
182            def __init__(self) -> None:
183                super().__init__()
184                self.x: Dict[str, int] = {}
185
186            def forward(self, x: Dict[str, int]):
187                self.x = x
188                return 1
189
190        with self.assertRaisesRegexWithHighlight(
191            RuntimeError, "Tried to set nonexistent attribute", "self.x = x"
192        ):
193            with self.assertWarnsRegex(
194                UserWarning,
195                "doesn't support "
196                "instance-level annotations on "
197                "empty non-base types",
198            ):
199                torch.jit.script(M())
200
201    @unittest.skipIf(
202        sys.version_info[:2] < (3, 9), "Requires lowercase static typing (Python 3.9+)"
203    )
204    def test_annotated_empty_dict_lowercase(self):
205        class M(torch.nn.Module):
206            def __init__(self) -> None:
207                super().__init__()
208                self.x: dict[str, int] = {}
209
210            def forward(self, x: dict[str, int]):
211                self.x = x
212                return 1
213
214        with self.assertRaisesRegexWithHighlight(
215            RuntimeError, "Tried to set nonexistent attribute", "self.x = x"
216        ):
217            with self.assertWarnsRegex(
218                UserWarning,
219                "doesn't support "
220                "instance-level annotations on "
221                "empty non-base types",
222            ):
223                torch.jit.script(M())
224
225    def test_annotated_empty_optional(self):
226        class M(torch.nn.Module):
227            def __init__(self) -> None:
228                super().__init__()
229                self.x: Optional[str] = None
230
231            def forward(self, x: Optional[str]):
232                self.x = x
233                return 1
234
235        with self.assertRaisesRegexWithHighlight(
236            RuntimeError, "Wrong type for attribute assignment", "self.x = x"
237        ):
238            with self.assertWarnsRegex(
239                UserWarning,
240                "doesn't support "
241                "instance-level annotations on "
242                "empty non-base types",
243            ):
244                torch.jit.script(M())
245
246    def test_annotated_with_jit_empty_list(self):
247        class M(torch.nn.Module):
248            def __init__(self) -> None:
249                super().__init__()
250                self.x = torch.jit.annotate(List[int], [])
251
252            def forward(self, x: List[int]):
253                self.x = x
254                return 1
255
256        with self.assertRaisesRegexWithHighlight(
257            RuntimeError, "Tried to set nonexistent attribute", "self.x = x"
258        ):
259            with self.assertWarnsRegex(
260                UserWarning,
261                "doesn't support "
262                "instance-level annotations on "
263                "empty non-base types",
264            ):
265                torch.jit.script(M())
266
267    @unittest.skipIf(
268        sys.version_info[:2] < (3, 9), "Requires lowercase static typing (Python 3.9+)"
269    )
270    def test_annotated_with_jit_empty_list_lowercase(self):
271        class M(torch.nn.Module):
272            def __init__(self) -> None:
273                super().__init__()
274                self.x = torch.jit.annotate(list[int], [])
275
276            def forward(self, x: list[int]):
277                self.x = x
278                return 1
279
280        with self.assertRaisesRegexWithHighlight(
281            RuntimeError, "Tried to set nonexistent attribute", "self.x = x"
282        ):
283            with self.assertWarnsRegex(
284                UserWarning,
285                "doesn't support "
286                "instance-level annotations on "
287                "empty non-base types",
288            ):
289                torch.jit.script(M())
290
291    def test_annotated_with_jit_empty_dict(self):
292        class M(torch.nn.Module):
293            def __init__(self) -> None:
294                super().__init__()
295                self.x = torch.jit.annotate(Dict[str, int], {})
296
297            def forward(self, x: Dict[str, int]):
298                self.x = x
299                return 1
300
301        with self.assertRaisesRegexWithHighlight(
302            RuntimeError, "Tried to set nonexistent attribute", "self.x = x"
303        ):
304            with self.assertWarnsRegex(
305                UserWarning,
306                "doesn't support "
307                "instance-level annotations on "
308                "empty non-base types",
309            ):
310                torch.jit.script(M())
311
312    @unittest.skipIf(
313        sys.version_info[:2] < (3, 9), "Requires lowercase static typing (Python 3.9+)"
314    )
315    def test_annotated_with_jit_empty_dict_lowercase(self):
316        class M(torch.nn.Module):
317            def __init__(self) -> None:
318                super().__init__()
319                self.x = torch.jit.annotate(dict[str, int], {})
320
321            def forward(self, x: dict[str, int]):
322                self.x = x
323                return 1
324
325        with self.assertRaisesRegexWithHighlight(
326            RuntimeError, "Tried to set nonexistent attribute", "self.x = x"
327        ):
328            with self.assertWarnsRegex(
329                UserWarning,
330                "doesn't support "
331                "instance-level annotations on "
332                "empty non-base types",
333            ):
334                torch.jit.script(M())
335
336    def test_annotated_with_jit_empty_optional(self):
337        class M(torch.nn.Module):
338            def __init__(self) -> None:
339                super().__init__()
340                self.x = torch.jit.annotate(Optional[str], None)
341
342            def forward(self, x: Optional[str]):
343                self.x = x
344                return 1
345
346        with self.assertRaisesRegexWithHighlight(
347            RuntimeError, "Wrong type for attribute assignment", "self.x = x"
348        ):
349            with self.assertWarnsRegex(
350                UserWarning,
351                "doesn't support "
352                "instance-level annotations on "
353                "empty non-base types",
354            ):
355                torch.jit.script(M())
356
357    def test_annotated_with_torch_jit_import(self):
358        from torch import jit
359
360        class M(torch.nn.Module):
361            def __init__(self) -> None:
362                super().__init__()
363                self.x = jit.annotate(Optional[str], None)
364
365            def forward(self, x: Optional[str]):
366                self.x = x
367                return 1
368
369        with self.assertRaisesRegexWithHighlight(
370            RuntimeError, "Wrong type for attribute assignment", "self.x = x"
371        ):
372            with self.assertWarnsRegex(
373                UserWarning,
374                "doesn't support "
375                "instance-level annotations on "
376                "empty non-base types",
377            ):
378                torch.jit.script(M())
379