xref: /aosp_15_r20/external/pytorch/test/test_typing.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: typing"]
2# based on NumPy numpy/typing/tests/test_typing.py
3
4import itertools
5import os
6import re
7import shutil
8import unittest
9from collections import defaultdict
10from threading import Lock
11from typing import Dict, IO, List, Optional
12
13from torch.testing._internal.common_utils import (
14    instantiate_parametrized_tests,
15    parametrize,
16    run_tests,
17    TestCase,
18)
19
20
21try:
22    from mypy import api
23except ImportError:
24    NO_MYPY = True
25else:
26    NO_MYPY = False
27
28
29DATA_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "typing"))
30REVEAL_DIR = os.path.join(DATA_DIR, "reveal")
31PASS_DIR = os.path.join(DATA_DIR, "pass")
32FAIL_DIR = os.path.join(DATA_DIR, "fail")
33MYPY_INI = os.path.join(DATA_DIR, os.pardir, os.pardir, "mypy.ini")
34CACHE_DIR = os.path.join(DATA_DIR, ".mypy_cache")
35
36
37def _key_func(key: str) -> str:
38    """Split at the first occurance of the ``:`` character.
39
40    Windows drive-letters (*e.g.* ``C:``) are ignored herein.
41    """
42    drive, tail = os.path.splitdrive(key)
43    return os.path.join(drive, tail.split(":", 1)[0])
44
45
46def _strip_filename(msg: str) -> str:
47    """Strip the filename from a mypy message."""
48    _, tail = os.path.splitdrive(msg)
49    return tail.split(":", 1)[-1]
50
51
52def _run_mypy() -> Dict[str, List[str]]:
53    """Clears the cache and run mypy before running any of the typing tests."""
54    if os.path.isdir(CACHE_DIR):
55        shutil.rmtree(CACHE_DIR)
56
57    rc: Dict[str, List[str]] = {}
58    for directory in (REVEAL_DIR, PASS_DIR, FAIL_DIR):
59        # Run mypy
60        stdout, stderr, _ = api.run(
61            [
62                "--show-absolute-path",
63                "--config-file",
64                MYPY_INI,
65                "--cache-dir",
66                CACHE_DIR,
67                directory,
68            ]
69        )
70        assert not stderr, stderr
71        stdout = stdout.replace("*", "")
72
73        # Parse the output
74        iterator = itertools.groupby(stdout.split("\n"), key=_key_func)
75        rc.update((k, list(v)) for k, v in iterator if k)
76    return rc
77
78
79def get_test_cases(directory):
80    for root, _, files in os.walk(directory):
81        for fname in files:
82            if fname.startswith("disabled_"):
83                continue
84            if os.path.splitext(fname)[-1] == ".py":
85                fullpath = os.path.join(root, fname)
86                yield fullpath
87
88
89_FAIL_MSG1 = """Extra error at line {}
90Extra error: {!r}
91"""
92
93_FAIL_MSG2 = """Error mismatch at line {}
94Expected error: {!r}
95Observed error: {!r}
96"""
97
98
99def _test_fail(
100    path: str, error: str, expected_error: Optional[str], lineno: int
101) -> None:
102    if expected_error is None:
103        raise AssertionError(_FAIL_MSG1.format(lineno, error))
104    elif error not in expected_error:
105        raise AssertionError(_FAIL_MSG2.format(lineno, expected_error, error))
106
107
108def _construct_format_dict():
109    dct = {
110        "ModuleList": "torch.nn.modules.container.ModuleList",
111        "AdaptiveAvgPool2d": "torch.nn.modules.pooling.AdaptiveAvgPool2d",
112        "AdaptiveMaxPool2d": "torch.nn.modules.pooling.AdaptiveMaxPool2d",
113        "Tensor": "torch._tensor.Tensor",
114        "Adagrad": "torch.optim.adagrad.Adagrad",
115        "Adam": "torch.optim.adam.Adam",
116    }
117    return dct
118
119
120#: A dictionary with all supported format keys (as keys)
121#: and matching values
122FORMAT_DICT: Dict[str, str] = _construct_format_dict()
123
124
125def _parse_reveals(file: IO[str]) -> List[str]:
126    """Extract and parse all ``"  # E: "`` comments from the passed file-like object.
127
128    All format keys will be substituted for their respective value from `FORMAT_DICT`,
129    *e.g.* ``"{Tensor}"`` becomes ``"torch.tensor.Tensor"``.
130    """
131    string = file.read().replace("*", "")
132
133    # Grab all `# E:`-based comments
134    comments_array = [str.partition("  # E: ")[2] for str in string.split("\n")]
135    comments = "/n".join(comments_array)
136
137    # Only search for the `{*}` pattern within comments,
138    # otherwise there is the risk of accidently grabbing dictionaries and sets
139    key_set = set(re.findall(r"\{(.*?)\}", comments))
140    kwargs = {
141        k: FORMAT_DICT.get(k, f"<UNRECOGNIZED FORMAT KEY {k!r}>") for k in key_set
142    }
143    fmt_str = comments.format(**kwargs)
144
145    return fmt_str.split("/n")
146
147
148_REVEAL_MSG = """Reveal mismatch at line {}
149
150Expected reveal: {!r}
151Observed reveal: {!r}
152"""
153
154
155def _test_reveal(path: str, reveal: str, expected_reveal: str, lineno: int) -> None:
156    if reveal not in expected_reveal:
157        raise AssertionError(_REVEAL_MSG.format(lineno, expected_reveal, reveal))
158
159
160@unittest.skipIf(NO_MYPY, reason="Mypy is not installed")
161class TestTyping(TestCase):
162    _lock = Lock()
163    _cached_output: Optional[Dict[str, List[str]]] = None
164
165    @classmethod
166    def get_mypy_output(cls) -> Dict[str, List[str]]:
167        with cls._lock:
168            if cls._cached_output is None:
169                cls._cached_output = _run_mypy()
170            return cls._cached_output
171
172    @parametrize(
173        "path",
174        get_test_cases(PASS_DIR),
175        name_fn=lambda b: os.path.relpath(b, start=PASS_DIR),
176    )
177    def test_success(self, path) -> None:
178        output_mypy = self.get_mypy_output()
179        if path in output_mypy:
180            msg = "Unexpected mypy output\n\n"
181            msg += "\n".join(_strip_filename(v) for v in output_mypy[path])
182            raise AssertionError(msg)
183
184    @parametrize(
185        "path",
186        get_test_cases(FAIL_DIR),
187        name_fn=lambda b: os.path.relpath(b, start=FAIL_DIR),
188    )
189    def test_fail(self, path):
190        __tracebackhide__ = True
191
192        with open(path) as fin:
193            lines = fin.readlines()
194
195        errors = defaultdict(lambda: "")
196
197        output_mypy = self.get_mypy_output()
198        self.assertIn(path, output_mypy)
199        for error_line in output_mypy[path]:
200            error_line = _strip_filename(error_line)
201            match = re.match(
202                r"(?P<lineno>\d+):(?P<colno>\d+): (error|note): .+$",
203                error_line,
204            )
205            if match is None:
206                raise ValueError(f"Unexpected error line format: {error_line}")
207            lineno = int(match.group("lineno"))
208            errors[lineno] += f"{error_line}\n"
209
210        for i, line in enumerate(lines):
211            lineno = i + 1
212            if line.startswith("#") or (" E:" not in line and lineno not in errors):
213                continue
214
215            target_line = lines[lineno - 1]
216            self.assertIn(
217                "# E:", target_line, f"Unexpected mypy output\n\n{errors[lineno]}"
218            )
219            marker = target_line.split("# E:")[-1].strip()
220            expected_error = errors.get(lineno)
221            _test_fail(path, marker, expected_error, lineno)
222
223    @parametrize(
224        "path",
225        get_test_cases(REVEAL_DIR),
226        name_fn=lambda b: os.path.relpath(b, start=REVEAL_DIR),
227    )
228    def test_reveal(self, path):
229        __tracebackhide__ = True
230
231        with open(path) as fin:
232            lines = _parse_reveals(fin)
233
234        output_mypy = self.get_mypy_output()
235        assert path in output_mypy
236        for error_line in output_mypy[path]:
237            match = re.match(
238                r"^.+\.py:(?P<lineno>\d+):(?P<colno>\d+): note: .+$",
239                error_line,
240            )
241            if match is None:
242                raise ValueError(f"Unexpected reveal line format: {error_line}")
243            lineno = int(match.group("lineno")) - 1
244            assert "Revealed type is" in error_line
245
246            marker = lines[lineno]
247            _test_reveal(path, marker, error_line, 1 + lineno)
248
249
250instantiate_parametrized_tests(TestTyping)
251
252if __name__ == "__main__":
253    run_tests()
254