xref: /aosp_15_r20/external/pytorch/test/jit/test_misc.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: jit"]
2
3import os
4import sys
5import unittest
6from typing import Any, Dict, List, Optional, Tuple
7
8import torch
9import torch.nn as nn
10import torch.testing._internal.jit_utils
11from jit.test_module_interface import TestModuleInterface  # noqa: F401
12from torch import jit
13from torch.testing import FileCheck
14from torch.testing._internal.common_utils import freeze_rng_state
15from torch.testing._internal.jit_utils import JitTestCase, make_global, RUN_CUDA_HALF
16
17
18# Make the helper files in test/ importable
19pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
20sys.path.append(pytorch_test_dir)
21
22if __name__ == "__main__":
23    raise RuntimeError(
24        "This test file is not meant to be run directly, use:\n\n"
25        "\tpython test/test_jit.py TESTNAME\n\n"
26        "instead."
27    )
28
29
30class TestMisc(JitTestCase):
31    def test_joined_str(self):
32        def func(x):
33            hello, test = "Hello", "test"
34            print(f"{hello + ' ' + test}, I'm a {test}")
35            print("format blank")
36            hi = "hi"
37            print(f"stuff before {hi}")
38            print(f"{hi} stuff after")
39            return x + 1
40
41        x = torch.arange(4.0, requires_grad=True)
42        # TODO: Add support for f-strings in string parser frontend
43        # self.checkScript(func, [x], optimize=True, capture_output=True)
44
45        with self.capture_stdout() as captured:
46            out = func(x)
47
48        scripted = torch.jit.script(func)
49        with self.capture_stdout() as captured_script:
50            out_script = func(x)
51
52        self.assertEqual(out, out_script)
53        self.assertEqual(captured, captured_script)
54
55    def test_kwarg_support(self):
56        with self.assertRaisesRegex(
57            torch.jit.frontend.NotSupportedError, "variable number of arguments"
58        ):
59
60            class M(torch.nn.Module):
61                def forward(self, *, n_tokens: int, device_name: str = 2):
62                    pass
63
64            torch.jit.script(M())
65
66        class M(torch.nn.Module):
67            def forward(self, *, n_tokens: int, device_name: str):
68                return n_tokens, device_name
69
70        sm = torch.jit.script(M())
71
72        with self.assertRaisesRegex(
73            RuntimeError, "missing value for argument 'n_tokens'"
74        ):
75            sm()
76
77        with self.assertRaisesRegex(RuntimeError, "positional arg"):
78            sm(3, "hello")
79
80        self.assertEqual(sm(n_tokens=3, device_name="hello"), (3, "hello"))
81
82    def test_tuple_subscripted_assign(self):
83        with self.assertRaisesRegex(RuntimeError, "subscripted assignment"):
84
85            @torch.jit.script
86            def foo(a: Tuple[int, int]) -> None:
87                a[0] = a[1]
88
89        with self.assertRaisesRegex(RuntimeError, "augmented assignment"):
90
91            @torch.jit.script
92            def bar(a: Tuple[int, int]) -> None:
93                a[0] += a[1]
94
95    def test_subexpression_List_Future(self):
96        @torch.jit.script
97        def fn(x: List[torch.jit.Future[int]]) -> torch.jit.Future[int]:
98            return x[0]
99
100        FileCheck().check("Future[int]").check("Future[int]").run(fn.graph)
101
102    def test_subexpression_Future_annotate(self):
103        @torch.jit.script
104        def fn() -> torch.jit.Future[int]:
105            x: List[torch.jit.Future[int]] = []
106            return x[0]
107
108        FileCheck().check("Future[int][]").run(fn.graph)
109
110    def test_future_isinstance(self):
111        @torch.jit.script
112        def fn(x: Any) -> torch.jit.Future[int]:
113            assert isinstance(x, jit.Future[int])
114            return x
115
116        FileCheck().check("Future[int]").run(fn.graph)
117
118    def test_str_refine_any(self):
119        def forward(x: Any) -> str:
120            if isinstance(x, str):
121                return x
122            return "foo"
123
124        forward = torch.jit.script(forward)
125        self.assertEqual(forward(1), "foo")
126        self.assertEqual(forward("bar"), "bar")
127
128    def test_subexpression_Tuple_int_int_Future(self):
129        @torch.jit.script
130        def fn(
131            x: Tuple[int, int, torch.jit.Future[int]]
132        ) -> Tuple[int, torch.jit.Future[int]]:
133            return x[0], x[2]
134
135        FileCheck().check("(int, int, Future[int])").check("(int, Future[int])").run(
136            fn.graph
137        )
138
139    def test_subexpression_Dict_int_Future(self):
140        @torch.jit.script
141        def fn(x: Dict[int, torch.jit.Future[int]], y: int) -> torch.jit.Future[int]:
142            return x[y]
143
144        FileCheck().check("Dict(int, Future(int))").check("Future[int]").run(fn.graph)
145
146    def test_subexpression_Optional(self):
147        @torch.jit.script
148        def fn(
149            x: Optional[Dict[int, torch.jit.Future[int]]]
150        ) -> Optional[torch.jit.Future[int]]:
151            if x is not None:
152                return x[0]
153            else:
154                return None
155
156        FileCheck().check("Dict(int, Future(int))?").run(fn.graph)
157
158    def test_if_returning_any(self):
159        """
160        Check that an if statement can return different
161        types early from each branch when the return
162        type of the function is Any.
163        """
164
165        def if_function(inp: torch.Tensor) -> Any:
166            if inp.shape[0] == 1:
167                return inp * inp
168            else:
169                return "str"
170
171        self.checkScript(if_function, (torch.randn(5),))
172
173    def test_hacked_twin(self):
174        def gen_data():
175            with freeze_rng_state():
176                return torch.randn(10), torch.randint(10, (20,)), torch.randn(20)
177
178        (
179            input,
180            index,
181            value,
182        ) = gen_data()
183        (
184            input1,
185            index1,
186            value1,
187        ) = gen_data()
188        out1 = torch.ops.aten.index_put.hacked_twin(
189            input, [index], value, accumulate=False
190        )
191        out2 = torch.index_put(input1, [index1], value1, accumulate=False)
192        self.assertEqual(out1, out2)
193
194        torch.ops.aten.index_put_.hacked_twin(input, [index], value, accumulate=False)
195        torch.index_put_(input1, [index1], value1, accumulate=False)
196        self.assertEqual(input, input1)
197
198    def test_unsafe_hacked_twin(self):
199        def gen_data():
200            with freeze_rng_state():
201                return torch.randn(10), torch.randint(10, (20,)), torch.randn(20)
202
203        (
204            input,
205            index,
206            value,
207        ) = gen_data()
208        (
209            input1,
210            index1,
211            value1,
212        ) = gen_data()
213        out1 = torch.ops.aten._unsafe_index_put.hacked_twin(
214            input, [index], value, accumulate=False
215        )
216        out2 = torch.index_put(input1, [index1], value1, accumulate=False)
217        self.assertEqual(out1, out2)
218
219        torch.ops.aten._unsafe_index.Tensor_hacked_twin(input, [index])
220        torch.index_put(input1, [index1], value1, accumulate=False)
221        self.assertEqual(input, input1)
222
223        def index_put_fn(input, index, value):
224            return torch.ops.aten._unsafe_index_put(
225                input, [index], value, accumulate=False
226            )
227
228        input2, index2, value2 = gen_data()
229        script_index_put_fn = torch.jit.script(index_put_fn)
230        expect = index_put_fn(input2.clone(), index2, value2)
231        actual = script_index_put_fn(input2.clone(), index2, value2)
232        self.assertEqual(expect, actual)
233
234        def index_fn(input, index, value):
235            return torch.ops.aten._unsafe_index_put(
236                input, [index], value, accumulate=False
237            )
238
239        script_index_fn = torch.jit.script(index_fn)
240        expect = index_fn(input2.clone(), index2, value2)
241        actual = script_index_fn(input2.clone(), index2, value2)
242        self.assertEqual(expect, actual)
243
244    def test_export_opnames_interface(self):
245        @torch.jit.interface
246        class OneTwoModule(nn.Module):
247            def one(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
248                pass
249
250            def two(self, x: torch.Tensor) -> torch.Tensor:
251                pass
252
253            def forward(self, x: torch.Tensor) -> torch.Tensor:
254                pass
255
256        class FooMod(nn.Module):
257            def one(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
258                return x + y
259
260            def two(self, x: torch.Tensor) -> torch.Tensor:
261                return 2 * x
262
263            def forward(self, x: torch.Tensor) -> torch.Tensor:
264                return self.one(self.two(x), x)
265
266        class BarMod(nn.Module):
267            def one(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
268                return x * y
269
270            def two(self, x: torch.Tensor) -> torch.Tensor:
271                return 2 / x
272
273            def forward(self, x: torch.Tensor) -> torch.Tensor:
274                return self.two(self.one(x, x))
275
276        make_global(OneTwoModule)
277
278        class M(nn.Module):
279            sub: OneTwoModule
280
281            def __init__(self) -> None:
282                super().__init__()
283                self.sub = BarMod()
284
285            def forward(self, x: torch.Tensor) -> torch.Tensor:
286                return self.sub.forward(x)
287
288        def use_module_interface(mod_list: List[OneTwoModule], x: torch.Tensor):
289            return mod_list[0].forward(x) + mod_list[1].forward(x)
290
291        torch._C._enable_mobile_interface_call_export()
292        scripted_M_mod = torch.jit.script(M())
293        self.assertTrue(
294            {"aten::mul.Scalar", "aten::mul.Tensor", "aten::reciprocal"}.issubset(
295                set(torch.jit.export_opnames(scripted_M_mod))
296            )
297        )
298
299        scripted_M_mod.sub = torch.jit.script(FooMod())
300        self.assertTrue(
301            {"aten::add.Tensor", "aten::mul.Scalar"}.issubset(
302                set(torch.jit.export_opnames(scripted_M_mod))
303            )
304        )
305
306    def test_math_inf(self):
307        from math import inf
308
309        def foo():
310            return inf
311
312        self.checkScript(foo, ())
313
314    def test_list_literal_infer(self):
315        def expects_intlist(x: List[int]):
316            x.append(3)
317            return x
318
319        def foo():
320            return expects_intlist([])
321
322        self.checkScript(foo, ())
323
324        def annotated_list_fail():
325            return expects_intlist(torch.jit.annotate([], List[Tensor]))  # noqa: F821
326
327        with self.assertRaises(RuntimeError):
328            torch.jit.script(annotated_list_fail)
329
330        def non_temporary_fail():
331            a = []
332            return expects_intlist(a)
333
334        with self.assertRaises(RuntimeError):
335            torch.jit.script(non_temporary_fail)
336
337        @torch.jit.script
338        def test_return():
339            return []
340
341        FileCheck().check("Tensor[] = prim::ListConstruct").run(test_return.graph)
342
343    def test_legacy_tensor_constructor(self):
344        # testing PyObject overload
345        def test_all_dtypes():
346            return (
347                torch.BoolTensor([2]),
348                torch.LongTensor([3]),
349                torch.ByteTensor([4]),
350                torch.CharTensor([5]),
351                torch.DoubleTensor([6]),
352                torch.FloatTensor([7]),
353                torch.IntTensor([8]),
354                torch.ShortTensor([1]),
355                torch.HalfTensor([1]),
356            )
357
358        self.checkScript(test_all_dtypes, ())
359
360        # now test empty overload
361        def empty_overload():
362            return torch.LongTensor(2, 3, 4)
363
364        eager = empty_overload()
365        jit = torch.jit.script(empty_overload)()
366        eager[:] = 1
367        jit[:] = 1
368        self.assertEqual(eager, jit)
369
370        def no_inputs():
371            return torch.DoubleTensor()
372
373        self.checkScript(no_inputs, ())
374
375        # bad schema
376        def multiple_args():
377            return torch.LongTensor(1, [2])
378
379        with self.assertRaisesRegex(
380            RuntimeError, "multiple positional arguments that were not all integers"
381        ):
382            torch.jit.script(multiple_args)
383
384        # kwarg bad schema
385        def bad_kwarg():
386            return torch.LongTensor(hello="1")
387
388        with self.assertRaisesRegex(RuntimeError, "hello"):
389            torch.jit.script(bad_kwarg)
390
391    def test_broadcasting_list(self):
392        """
393        Test BroadcastingList and torch.nn._size_N_t alias
394        """
395        from torch._jit_internal import BroadcastingList2
396        from torch.nn.common_types import _size_2_t
397
398        def sum_i(x: _size_2_t) -> int:
399            return x[0] + x[1]
400
401        def sum_f(x: BroadcastingList2[float]) -> float:
402            return x[0] + x[1]
403
404        self.assertTrue(torch.jit.script(sum_i)(4) == 8)
405        self.assertTrue(torch.jit.script(sum_f)(4.5) == 9.0)
406
407    def test_parse_ir_annotate(self):
408        ir = """
409        graph():
410          %3 : int[] = prim::Constant[value=annotate(List[int], [])]()
411          return (%3)
412        """
413        graph = torch._C.parse_ir(ir, True)
414        func = torch._C._create_function_from_graph("forward", graph)
415        ret = func()
416        self.assertTrue(ret == [])
417
418    def test_parse_ir_single_element_tensor_positive(self):
419        ir = """
420        graph():
421          %7 : Long(1, strides=[1], requires_grad=0, device=cpu) = prim::Constant[value={0}]()
422          return (%7)
423        """
424        graph = torch._C.parse_ir(ir, True)
425        func = torch._C._create_function_from_graph("forward", graph)
426        ret = func()
427        self.assertTrue(ret.numel() == 1)
428        self.assertTrue(len(ret.size()) == 1)
429
430    def test_parse_ir_single_element_tensor_negative(self):
431        ir = """
432        graph():
433          %7 : Long(1, strides=[1], requires_grad=0, device=cpu) = prim::Constant[value={-17}]()
434          return (%7)
435        """
436        graph = torch._C.parse_ir(ir, True)
437        func = torch._C._create_function_from_graph("forward", graph)
438        ret = func()
439        self.assertTrue(ret.numel() == 1)
440        self.assertTrue(len(ret.size()) == 1)
441
442    def test_script_many_decorators(self):
443        def no_op_decorator(f):
444            return f
445
446        @no_op_decorator
447        @no_op_decorator
448        @no_op_decorator
449        @no_op_decorator
450        @no_op_decorator
451        def foo(x, dim: int):
452            return x.unsqueeze(dim)
453
454        x = torch.randn(
455            1,
456        )
457        expected = foo(x, 0)
458        scripted = torch.jit.script(foo)
459        actual = scripted(x, 0)
460        torch.testing.assert_close(expected, actual)
461
462    @unittest.skipIf(not RUN_CUDA_HALF, "need CUDA half support")
463    def test_pow_multiple_dtype(self):
464        # https://github.com/pytorch/pytorch/issues/75476
465        def fn(p: torch.Tensor, gamma: float = 2.0) -> torch.Tensor:
466            p = torch.sigmoid(p)
467            result = p**gamma
468            return result
469
470        x = torch.rand((2, 2), dtype=torch.half, device="cuda")
471
472        ref = fn(x)
473
474        script_fn = torch.jit.script(fn)
475        for i in range(4):
476            res = script_fn(x)
477
478        self.assertEqual(ref, res)
479
480    def test_jit_get_operation_order(self):
481        # See https://github.com/pytorch/pytorch/pull/107138.
482        # Depending on order of operator registration, you can get different
483        # order of overloads in the JIT operator registry.
484        # This is to verify that the order of operators returned by
485        # _jit_get_operation always puts aten ops first (i.e. by sorting
486        # to put them first)
487
488        # Make sure that this chooses a "scalar" overload not a "complex" overload
489        ret = torch.ops.aten.add(4, 3.3)
490        self.assertFalse("complex" in str(ret.dtype))
491
492        # "Scalar" overload is a normal aten op; "complex" is added by torchscript.
493        # We want "Scalar" to come before "complex".
494        op, override_names = torch._C._jit_get_operation("aten::add")
495        print(override_names)
496        complex_indices = [
497            i for i, name in enumerate(override_names) if name == "complex"
498        ]
499        Scalar_indices = [
500            i for i, name in enumerate(override_names) if name == "Scalar"
501        ]
502
503        self.assertTrue(len(complex_indices) > 0)
504        self.assertTrue(len(Scalar_indices) > 0)
505        self.assertTrue(complex_indices[0] > Scalar_indices[0])
506