xref: /aosp_15_r20/external/pytorch/test/jit/test_python_builtins.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: jit"]
2
3import os
4import random
5import sys
6import tempfile
7from textwrap import dedent
8
9import torch
10from torch.testing._internal.jit_utils import execWrapper, JitTestCase
11
12
13# Make the helper files in test/ importable
14pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
15sys.path.append(pytorch_test_dir)
16
17if __name__ == "__main__":
18    raise RuntimeError(
19        "This test file is not meant to be run directly, use:\n\n"
20        "\tpython test/test_jit.py TESTNAME\n\n"
21        "instead."
22    )
23
24
25def get_fn(file_name, script_path):
26    import importlib.util
27
28    spec = importlib.util.spec_from_file_location(file_name, script_path)
29    module = importlib.util.module_from_spec(spec)
30    spec.loader.exec_module(module)
31    fn = module.fn
32    return fn
33
34
35class TestPythonBuiltinOP(JitTestCase):
36    def test_add(self):
37        def func(a, b):
38            c = a + b
39            c += a
40            return c
41
42        a = torch.rand(1, requires_grad=True)
43        b = torch.rand(1, requires_grad=True)
44        self.checkScript(func, (a, b), optimize=True)
45
46    def test_mul(self):
47        def func(a, b):
48            return a * b
49
50        a = torch.rand(1, requires_grad=True)
51        b = torch.rand(1, requires_grad=True)
52        self.checkScript(func, (a, b), optimize=True)
53
54    def test_matmul_py3(self):
55        code = dedent(
56            """
57        def fn(a, b):
58            return a @ b
59        """
60        )
61
62        with tempfile.TemporaryDirectory() as tmp_dir:
63            script_path = os.path.join(tmp_dir, "script.py")
64            with open(script_path, "w") as f:
65                f.write(code)
66            fn = get_fn("test_matmul_py3", script_path)
67
68            a = torch.rand(4, 3, requires_grad=True)
69            b = torch.rand(3, 2, requires_grad=True)
70            self.checkScript(fn, (a, b), optimize=True)
71
72    def test_pow(self):
73        def func(a, b):
74            return a**b
75
76        def func2(a, b, c, d):
77            return c + a**b**d
78
79        def func3(a, b):
80            # type: (int, float) -> float
81            return a**b
82
83        def func4():
84            # type: () -> float
85            return 2**-2
86
87        def func5(x, y):
88            return x.item() ** y.item()
89
90        a = torch.rand(1, requires_grad=True)
91        b = torch.rand(1, requires_grad=True)
92        c = torch.rand(1, requires_grad=True)
93        d = torch.rand(1, requires_grad=True)
94        self.checkScript(func, (a, b), optimize=True)
95        self.checkScript(func2, (a, b, c, d), optimize=True)
96        self.checkScript(func3, (4, -0.5), optimize=True)
97        self.checkScript(func4, ())
98
99        inputs = [
100            torch.tensor(2),
101            torch.tensor(-2),
102            torch.tensor(0.5),
103            torch.tensor(0.2),
104        ]
105        for x in inputs:
106            for y in inputs:
107                if x < 0:
108                    continue
109                else:
110                    self.checkScript(func5, (x, y))
111
112    def test_triple(self):
113        def func(x):
114            return 3.0 * x
115
116        x = torch.rand(1, dtype=torch.float, requires_grad=True)
117        self.checkScript(func, [x], optimize=True)
118
119    def test_slice(self):
120        def func(x):
121            return x[:5]
122
123        x = torch.rand(10, dtype=torch.float, requires_grad=True)
124        self.checkScript(func, [x], optimize=True)
125
126        def func2(x):
127            return x[5:]
128
129        self.checkScript(func2, [x], optimize=True)
130
131        def func3(x):
132            return x[:8:2]
133
134        self.checkScript(func3, [x], optimize=True)
135
136        def func4(x):
137            return x[1::4]
138
139        self.checkScript(func4, [x], optimize=True)
140
141    def test_gather(self):
142        def func(x):
143            return x[0]
144
145        x = torch.rand(10, dtype=torch.float, requires_grad=True)
146        self.checkScript(func, [x], optimize=True)
147
148    def test_random(self):
149        @torch.jit.script
150        def f(mean, std):
151            return torch.normal(mean, std)
152
153        mean, std = torch.zeros(5, 5), torch.ones(5, 5)
154        with torch.random.fork_rng(devices=[]):
155            output = torch.normal(mean, std)
156        with torch.random.fork_rng(devices=[]):
157            script_output = f(mean, std)
158        self.assertEqual(output, script_output)
159
160    def _check_code(self, code_str, fn_name, inputs):
161        scope = {}
162        exec(code_str, globals(), scope)
163        cu = torch.jit.CompilationUnit(code_str)
164        self.assertEqual(cu.func(*inputs), scope[fn_name](*inputs))
165
166    def test_stepped_tuple_slicing(self):
167        def check_slicing_tuple(slicing, tuple_type, tuple):
168            template = dedent(
169                """
170            def func(x):
171                # type: ({}) -> Any
172                return x{}
173            """
174            )
175            self._check_code(template.format(tuple_type, slicing), "func", [tuple])
176
177        check_slicing_tuple("[-3:3:2]", "Tuple[int, int, int]", (0, 1, 2))
178        check_slicing_tuple("[::55]", "Tuple[int, int, int, int, int]", (0, 1, 2, 3, 4))
179        check_slicing_tuple("[:4:4]", "Tuple[int, int, int, int, int]", (0, 1, 2, 3, 4))
180        check_slicing_tuple(
181            "[::-1]", "Tuple[int, int, int, int, int, int, int]", (0, 1, 2, 3, 4, 5, 6)
182        )
183        check_slicing_tuple(
184            "[7:5:2]", "Tuple[int, int, int, int, int, int, int]", (0, 1, 2, 3, 4, 5, 6)
185        )
186        check_slicing_tuple(
187            "[5:7:-2]",
188            "Tuple[int, int, int, int, int, int, int]",
189            (0, 1, 2, 3, 4, 5, 6),
190        )
191        check_slicing_tuple("[::-2]", "Tuple[int, int, int, int, int]", (0, 1, 2, 3, 4))
192        check_slicing_tuple(
193            "[:4:-3]", "Tuple[int, int, int, int, int, int]", (0, 1, 2, 3, 4, 5)
194        )
195        check_slicing_tuple(
196            "[3::-2]", "Tuple[int, int, int, int, int]", (0, 1, 2, 3, 4)
197        )
198
199    def test_index(self):
200        def consec(size, start=0):
201            numel = torch.tensor(size).prod().item()
202            return torch.arange(numel).view(size)
203
204        def check_indexing(indexing, tensor):
205            template = dedent(
206                """
207            def func(x):
208                return x{}
209            """
210            )
211
212            self._check_code(template.format(indexing), "func", [tensor])
213
214        def check_dynamic_indexing(indexing, tensor, value1, value2):
215            value1 = torch.tensor(value1)
216            value2 = torch.tensor(value2)
217
218            template = dedent(
219                """
220            def func(x, value1, value2):
221                i = int(value1)
222                j = int(value2)
223                return x{}
224            """
225            )
226
227            self._check_code(
228                template.format(indexing), "func", [tensor, value1, value2]
229            )
230
231        # basic slices
232        check_indexing("[0]", consec((3, 3)))
233        check_indexing("[1]", consec((3, 3), 10))
234        check_indexing("[2]", consec((3, 3), 19))
235        check_indexing("[2]", consec((3,)))
236        check_indexing("[-1]", consec((3, 3), 19))
237        check_indexing("[0:2]", consec((3, 3, 3)))
238        check_indexing("[1:-1]", consec((3, 3, 3)))
239        check_indexing("[-3:-1]", consec((6, 3)))
240        check_indexing("[1:]", consec((3, 3)))
241        check_indexing("[:1]", consec((3, 3)))
242        check_indexing("[:]", consec((3, 2)))
243
244        # multi-dim: indexes
245        check_indexing("[0, 1]", consec((3, 3)))
246        check_indexing("[0, 1]", consec((3, 3, 2)))
247        check_indexing("[1, 0, 2]", consec((3, 3, 3)))
248        check_indexing("[2, -1]", consec((3, 3)))
249
250        # multi-dim: mixed slicing and indexing
251        check_indexing("[0, 1:2]", consec((3, 3)))
252        check_indexing("[0, :1]", consec((3, 3, 2)))
253        check_indexing("[1, 2:]", consec((3, 3, 3)))
254        check_indexing("[-1, 1:, 0]", consec((3, 3, 3, 3)))
255        check_indexing("[1:, -1, 0]", consec((3, 3, 3, 3)))
256        check_indexing("[-1, 2:, 1:2]", consec((3, 3, 3, 3)))
257        check_indexing("[-1, 1:, 0]", consec((3, 3, 3, 3)))
258        check_indexing("[-1, :, 0, 2]", consec((3, 3, 3, 3)))
259
260        # zero-sized slices
261        check_indexing("[0:0]", consec((2, 2)))
262        check_indexing("[0:0, 1]", consec((3, 3)))
263
264        # trivial expression usage
265        check_indexing("[1+1]", consec((3, 3)))
266        check_indexing("[1:(0 + 2)]", consec((3, 3, 3)))
267
268        # None for new dimensions
269        check_indexing("[None, 0]", consec((3, 3)))
270        check_indexing("[1, None]", consec((3, 3), 10))
271        check_indexing("[None, None, 2]", consec((3, 3), 19))
272        check_indexing("[None, 2, None]", consec((3,)))
273        check_indexing("[0:2, None]", consec((3, 3, 3)))
274        check_indexing("[None, 1:-1]", consec((3, 3, 3)))
275        check_indexing("[None, -3:-1, None]", consec((6, 3)))
276        check_indexing("[-1, None, 2:, None, 1:2]", consec((3, 3, 3, 3)))
277        check_indexing("[None, -1, None, 2:, None, 1:2, None]", consec((3, 3, 3, 3)))
278
279        # dynamic expression usage
280        check_dynamic_indexing("[i + j]", consec((3, 3)), 0, 1)
281        check_dynamic_indexing("[i:j, i]", consec((3, 3, 2)), 0, 2)
282
283    def test_advancedindex(self):
284        def consec(size, start=0):
285            numel = torch.tensor(size).prod().item()
286            return torch.arange(numel).view(size)
287
288        def check_indexing(indexing, tensor, **kwargs):
289            indices_dict = kwargs
290
291            template = dedent(
292                """
293            def func(x{formals}):
294                return x{expr}
295            """
296            )
297
298            formals = []
299            values = []
300            for formal, value in indices_dict.items():
301                formals.append(formal)
302                values.append(value)
303
304            formals = "".join(map(", {}".format, formals))
305            inputs = [tensor] + values
306            self._check_code(
307                template.format(formals=formals, expr=indexing), "func", inputs
308            )
309
310        # Indexing with tensor (basic)
311        check_indexing("[i]", consec((3, 3)), i=torch.tensor([0]))
312        check_indexing("[i]", consec((3, 3)), i=torch.tensor(1))
313        check_indexing("[i]", consec((3, 3)), i=torch.tensor([-2]))
314        check_indexing("[i]", consec((3, 3), 2), i=torch.tensor([0, 0]))
315        check_indexing("[i]", consec((3, 3, 2, 2)), i=torch.tensor([0, -2, 1]))
316
317        # NB: indexing with tensors and indexing with sequences can be implemented
318        # in a very similar way (sequences are converted to tensors), so only one
319        # case needs to be tested extensively.
320        # XXX: When we can index with sequences, replace these cases with
321        # sequence indexing expressions; those are much easier to read.
322
323        # Misc sequence advanced indexing
324        inp = consec((4, 8, 5))
325        to_check = [
326            # [[0, 1, 3]]
327            ["[i]", {"i": [0, 1, 3]}],
328            # [[0, 2], [1, 3]]
329            ["[i, j]", {"i": [0, 2], "j": [1, 3]}],
330            # [[[0, 1], [0, 1]], [[0, 1], [0, 1]]]
331            ["[i, j]", {"i": [[0, 1], [0, 1]], "j": [[0, 1], [0, 1]]}],
332            # [[0, 2], [1, 3], [1, 1]]
333            ["[i, j, k]", {"i": [0, 2], "j": [1, 3], "k": [1, 1]}],
334            # [[0, 2], 1, [1, 1]]
335            ["[i, j, k]", {"i": [0, 2], "j": 1, "k": [1, 1]}],
336            # [:, :, [0, 3, 4]]
337            ["[:, :, i]", {"i": [0, 3, 4]}],
338            # [:, [2, 4, 5, 7], 2:4]
339            ["[:, i, 2:4]", {"i": [0, 2, 3]}],
340            # [[2, 3], :, :]
341            ["[i, :, :]", {"i": [2, 3]}],
342            # [:, [0, 2, 3], [1, 3, 4]]
343            ["[:, i, j]", {"i": [0, 2, 3], "j": [1, 3, 4]}],
344            # [:, [0], [1, 2, 4]]
345            ["[:, i, j]", {"i": [0], "j": [1, 2, 4]}],
346            # [:, [0, 1, 3], [4]]
347            ["[:, i, j]", {"i": [0, 1, 3], "j": [4]}],
348            # [:, [[0, 1], [1, 0]], [[2, 3]]]
349            ["[:, i, j]", {"i": [[0, 1], [1, 0]], "j": [[2, 3]]}],
350            # [:, [[0, 1], [2, 3]], [[0]]]
351            ["[:, i, j]", {"i": [[0, 1], [2, 3]], "j": [[0]]}],
352            # [:, [[5, 6]], [[0, 3], [4, 4]]]
353            ["[:, i, j]", {"i": [[5, 6]], "j": [[0, 3], [4, 4]]}],
354            # [[0, 2, 3], [1, 3, 4], :]
355            ["[i, j, :]", {"i": [0, 2, 3], "j": [1, 3, 4]}],
356            # [0, [1, 2, 4], :]
357            ["[i, j, :]", {"i": 0, "j": [1, 2, 4]}],
358            # [[0, 1, 3], 4, :]
359            ["[i, j, :]", {"i": [0, 1, 3], "j": 4}],
360            # [[[0, 1], [1, 0]], [[2, 1], [3, 5]], :]
361            ["[i, j, :]", {"i": [[0, 1], [1, 0]], "j": [[2, 1], [3, 5]]}],
362            # [[[0, 1], [1, 0]], [[2, 3]], :]
363            ["[i, j, :]", {"i": [[0, 1], [1, 0]], "j": [[2, 3]]}],
364            # [[[0, 1], [2, 3]], [[0]], :]
365            ["[i, j, :]", {"i": [[0, 1], [2, 3]], "j": [[0]]}],
366            # [[[2, 1]], [[0, 3], [4, 4]], :]
367            ["[i, j, :]", {"i": [[2, 1]], "j": [[0, 3], [4, 4]]}],
368            # [[[2]], [[0, 3], [4, 1]], 0:2]
369            ["[i, j, 0:2]", {"i": [[2]], "j": [[0, 3], [4, 1]]}],
370        ]
371
372        for expr, argdict in to_check:
373            tensordict = {k: torch.tensor(v) for (k, v) in argdict.items()}
374            check_indexing(expr, inp, **tensordict)
375
376    def test_adv_indexing_list(self):
377        # indexing with list is equivalent to indexing with tensor
378        def func1(x):
379            return x[[0, 1, 5]]
380
381        def func2(x):
382            return x[[0, 1], [0, 1]]
383
384        def func3(x):
385            return x[[[0, 1], [0, 1]], [[0, 1], [0, 1]]]
386
387        def func4(x):
388            ls = [0]
389            ls.append(1)
390            ls.append(2)
391            return x[ls]
392
393        def func5(x):
394            ls = [0.1, 1.2, 2.3]
395            return x[ls]
396
397        input = torch.rand((6, 2))
398        self.checkScript(func1, (input,))
399        self.checkScript(func2, (input,))
400        self.checkScript(func3, (input,))
401        self.checkScript(func4, (input,))
402        self.checkScript(func5, (input,))
403
404    def test_index_ellipses(self):
405        vals = [":", 1, None]
406        for _ in range(100):
407            indices = [random.choice(vals) for _ in range(4)]
408            indices[random.randint(0, len(indices) - 1)] = "..."
409            test_str = dedent(
410                """
411            def f():
412                x = torch.ones(10, 9, 8, 7, 6)
413                return x{indices}.shape
414            """.format(
415                    indices=indices
416                )
417            )
418            test_str = test_str.replace(r"'", r"")
419            scope = {}
420            execWrapper(test_str, globals(), scope)
421            cu = torch.jit.CompilationUnit(test_str)
422            res1 = cu.f()
423            res2 = scope["f"]()
424            self.assertEqual(res1, res2)
425
426    def test_inf(self):
427        @torch.jit.script
428        def foo(a):
429            return a < float("inf")
430
431        s = torch.rand(1)
432        self.assertTrue(foo(s))
433
434        @torch.jit.script
435        def bar(a):
436            return a > float("-inf")
437
438        s = torch.rand(1)
439        self.assertTrue(foo(s))
440
441        # test re-assignment on imported source
442        str = """
443        def foo(x):
444            # type: (bool)
445            a = float("-inf")
446            if not x:
447                a = float(torch.tensor([5]))
448            return a < 4
449        """
450        cu = torch.jit.CompilationUnit(str)
451        self.assertTrue(cu.foo(True))
452        self.assertFalse(cu.foo(False))
453
454    def test_str_to_float(self):
455        @torch.jit.script
456        def foo(a):
457            return 0.5 == float("0.5 hello")
458
459        s = torch.rand(1)
460        with self.assertRaisesRegex(RuntimeError, "could not convert string to float"):
461            self.assertTrue(foo(s))
462
463        @torch.jit.script
464        def foo(a):
465            return 0.5 == float("0.5")
466
467        s = torch.rand(1)
468        self.assertTrue(foo(s))
469
470        @torch.jit.script
471        def foo(a):
472            return 0.0 == float("0")
473
474        s = torch.rand(1)
475        self.assertTrue(foo(s))
476