xref: /aosp_15_r20/external/pytorch/test/jit/test_misc.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["oncall: jit"]
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerimport os
4*da0073e9SAndroid Build Coastguard Workerimport sys
5*da0073e9SAndroid Build Coastguard Workerimport unittest
6*da0073e9SAndroid Build Coastguard Workerfrom typing import Any, Dict, List, Optional, Tuple
7*da0073e9SAndroid Build Coastguard Worker
8*da0073e9SAndroid Build Coastguard Workerimport torch
9*da0073e9SAndroid Build Coastguard Workerimport torch.nn as nn
10*da0073e9SAndroid Build Coastguard Workerimport torch.testing._internal.jit_utils
11*da0073e9SAndroid Build Coastguard Workerfrom jit.test_module_interface import TestModuleInterface  # noqa: F401
12*da0073e9SAndroid Build Coastguard Workerfrom torch import jit
13*da0073e9SAndroid Build Coastguard Workerfrom torch.testing import FileCheck
14*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import freeze_rng_state
15*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.jit_utils import JitTestCase, make_global, RUN_CUDA_HALF
16*da0073e9SAndroid Build Coastguard Worker
17*da0073e9SAndroid Build Coastguard Worker
18*da0073e9SAndroid Build Coastguard Worker# Make the helper files in test/ importable
19*da0073e9SAndroid Build Coastguard Workerpytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
20*da0073e9SAndroid Build Coastguard Workersys.path.append(pytorch_test_dir)
21*da0073e9SAndroid Build Coastguard Worker
22*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
23*da0073e9SAndroid Build Coastguard Worker    raise RuntimeError(
24*da0073e9SAndroid Build Coastguard Worker        "This test file is not meant to be run directly, use:\n\n"
25*da0073e9SAndroid Build Coastguard Worker        "\tpython test/test_jit.py TESTNAME\n\n"
26*da0073e9SAndroid Build Coastguard Worker        "instead."
27*da0073e9SAndroid Build Coastguard Worker    )
28*da0073e9SAndroid Build Coastguard Worker
29*da0073e9SAndroid Build Coastguard Worker
30*da0073e9SAndroid Build Coastguard Workerclass TestMisc(JitTestCase):
31*da0073e9SAndroid Build Coastguard Worker    def test_joined_str(self):
32*da0073e9SAndroid Build Coastguard Worker        def func(x):
33*da0073e9SAndroid Build Coastguard Worker            hello, test = "Hello", "test"
34*da0073e9SAndroid Build Coastguard Worker            print(f"{hello + ' ' + test}, I'm a {test}")
35*da0073e9SAndroid Build Coastguard Worker            print("format blank")
36*da0073e9SAndroid Build Coastguard Worker            hi = "hi"
37*da0073e9SAndroid Build Coastguard Worker            print(f"stuff before {hi}")
38*da0073e9SAndroid Build Coastguard Worker            print(f"{hi} stuff after")
39*da0073e9SAndroid Build Coastguard Worker            return x + 1
40*da0073e9SAndroid Build Coastguard Worker
41*da0073e9SAndroid Build Coastguard Worker        x = torch.arange(4.0, requires_grad=True)
42*da0073e9SAndroid Build Coastguard Worker        # TODO: Add support for f-strings in string parser frontend
43*da0073e9SAndroid Build Coastguard Worker        # self.checkScript(func, [x], optimize=True, capture_output=True)
44*da0073e9SAndroid Build Coastguard Worker
45*da0073e9SAndroid Build Coastguard Worker        with self.capture_stdout() as captured:
46*da0073e9SAndroid Build Coastguard Worker            out = func(x)
47*da0073e9SAndroid Build Coastguard Worker
48*da0073e9SAndroid Build Coastguard Worker        scripted = torch.jit.script(func)
49*da0073e9SAndroid Build Coastguard Worker        with self.capture_stdout() as captured_script:
50*da0073e9SAndroid Build Coastguard Worker            out_script = func(x)
51*da0073e9SAndroid Build Coastguard Worker
52*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out, out_script)
53*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(captured, captured_script)
54*da0073e9SAndroid Build Coastguard Worker
55*da0073e9SAndroid Build Coastguard Worker    def test_kwarg_support(self):
56*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
57*da0073e9SAndroid Build Coastguard Worker            torch.jit.frontend.NotSupportedError, "variable number of arguments"
58*da0073e9SAndroid Build Coastguard Worker        ):
59*da0073e9SAndroid Build Coastguard Worker
60*da0073e9SAndroid Build Coastguard Worker            class M(torch.nn.Module):
61*da0073e9SAndroid Build Coastguard Worker                def forward(self, *, n_tokens: int, device_name: str = 2):
62*da0073e9SAndroid Build Coastguard Worker                    pass
63*da0073e9SAndroid Build Coastguard Worker
64*da0073e9SAndroid Build Coastguard Worker            torch.jit.script(M())
65*da0073e9SAndroid Build Coastguard Worker
66*da0073e9SAndroid Build Coastguard Worker        class M(torch.nn.Module):
67*da0073e9SAndroid Build Coastguard Worker            def forward(self, *, n_tokens: int, device_name: str):
68*da0073e9SAndroid Build Coastguard Worker                return n_tokens, device_name
69*da0073e9SAndroid Build Coastguard Worker
70*da0073e9SAndroid Build Coastguard Worker        sm = torch.jit.script(M())
71*da0073e9SAndroid Build Coastguard Worker
72*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
73*da0073e9SAndroid Build Coastguard Worker            RuntimeError, "missing value for argument 'n_tokens'"
74*da0073e9SAndroid Build Coastguard Worker        ):
75*da0073e9SAndroid Build Coastguard Worker            sm()
76*da0073e9SAndroid Build Coastguard Worker
77*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "positional arg"):
78*da0073e9SAndroid Build Coastguard Worker            sm(3, "hello")
79*da0073e9SAndroid Build Coastguard Worker
80*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(sm(n_tokens=3, device_name="hello"), (3, "hello"))
81*da0073e9SAndroid Build Coastguard Worker
82*da0073e9SAndroid Build Coastguard Worker    def test_tuple_subscripted_assign(self):
83*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "subscripted assignment"):
84*da0073e9SAndroid Build Coastguard Worker
85*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script
86*da0073e9SAndroid Build Coastguard Worker            def foo(a: Tuple[int, int]) -> None:
87*da0073e9SAndroid Build Coastguard Worker                a[0] = a[1]
88*da0073e9SAndroid Build Coastguard Worker
89*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "augmented assignment"):
90*da0073e9SAndroid Build Coastguard Worker
91*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script
92*da0073e9SAndroid Build Coastguard Worker            def bar(a: Tuple[int, int]) -> None:
93*da0073e9SAndroid Build Coastguard Worker                a[0] += a[1]
94*da0073e9SAndroid Build Coastguard Worker
95*da0073e9SAndroid Build Coastguard Worker    def test_subexpression_List_Future(self):
96*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
97*da0073e9SAndroid Build Coastguard Worker        def fn(x: List[torch.jit.Future[int]]) -> torch.jit.Future[int]:
98*da0073e9SAndroid Build Coastguard Worker            return x[0]
99*da0073e9SAndroid Build Coastguard Worker
100*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("Future[int]").check("Future[int]").run(fn.graph)
101*da0073e9SAndroid Build Coastguard Worker
102*da0073e9SAndroid Build Coastguard Worker    def test_subexpression_Future_annotate(self):
103*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
104*da0073e9SAndroid Build Coastguard Worker        def fn() -> torch.jit.Future[int]:
105*da0073e9SAndroid Build Coastguard Worker            x: List[torch.jit.Future[int]] = []
106*da0073e9SAndroid Build Coastguard Worker            return x[0]
107*da0073e9SAndroid Build Coastguard Worker
108*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("Future[int][]").run(fn.graph)
109*da0073e9SAndroid Build Coastguard Worker
110*da0073e9SAndroid Build Coastguard Worker    def test_future_isinstance(self):
111*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
112*da0073e9SAndroid Build Coastguard Worker        def fn(x: Any) -> torch.jit.Future[int]:
113*da0073e9SAndroid Build Coastguard Worker            assert isinstance(x, jit.Future[int])
114*da0073e9SAndroid Build Coastguard Worker            return x
115*da0073e9SAndroid Build Coastguard Worker
116*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("Future[int]").run(fn.graph)
117*da0073e9SAndroid Build Coastguard Worker
118*da0073e9SAndroid Build Coastguard Worker    def test_str_refine_any(self):
119*da0073e9SAndroid Build Coastguard Worker        def forward(x: Any) -> str:
120*da0073e9SAndroid Build Coastguard Worker            if isinstance(x, str):
121*da0073e9SAndroid Build Coastguard Worker                return x
122*da0073e9SAndroid Build Coastguard Worker            return "foo"
123*da0073e9SAndroid Build Coastguard Worker
124*da0073e9SAndroid Build Coastguard Worker        forward = torch.jit.script(forward)
125*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(forward(1), "foo")
126*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(forward("bar"), "bar")
127*da0073e9SAndroid Build Coastguard Worker
128*da0073e9SAndroid Build Coastguard Worker    def test_subexpression_Tuple_int_int_Future(self):
129*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
130*da0073e9SAndroid Build Coastguard Worker        def fn(
131*da0073e9SAndroid Build Coastguard Worker            x: Tuple[int, int, torch.jit.Future[int]]
132*da0073e9SAndroid Build Coastguard Worker        ) -> Tuple[int, torch.jit.Future[int]]:
133*da0073e9SAndroid Build Coastguard Worker            return x[0], x[2]
134*da0073e9SAndroid Build Coastguard Worker
135*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("(int, int, Future[int])").check("(int, Future[int])").run(
136*da0073e9SAndroid Build Coastguard Worker            fn.graph
137*da0073e9SAndroid Build Coastguard Worker        )
138*da0073e9SAndroid Build Coastguard Worker
139*da0073e9SAndroid Build Coastguard Worker    def test_subexpression_Dict_int_Future(self):
140*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
141*da0073e9SAndroid Build Coastguard Worker        def fn(x: Dict[int, torch.jit.Future[int]], y: int) -> torch.jit.Future[int]:
142*da0073e9SAndroid Build Coastguard Worker            return x[y]
143*da0073e9SAndroid Build Coastguard Worker
144*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("Dict(int, Future(int))").check("Future[int]").run(fn.graph)
145*da0073e9SAndroid Build Coastguard Worker
146*da0073e9SAndroid Build Coastguard Worker    def test_subexpression_Optional(self):
147*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
148*da0073e9SAndroid Build Coastguard Worker        def fn(
149*da0073e9SAndroid Build Coastguard Worker            x: Optional[Dict[int, torch.jit.Future[int]]]
150*da0073e9SAndroid Build Coastguard Worker        ) -> Optional[torch.jit.Future[int]]:
151*da0073e9SAndroid Build Coastguard Worker            if x is not None:
152*da0073e9SAndroid Build Coastguard Worker                return x[0]
153*da0073e9SAndroid Build Coastguard Worker            else:
154*da0073e9SAndroid Build Coastguard Worker                return None
155*da0073e9SAndroid Build Coastguard Worker
156*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("Dict(int, Future(int))?").run(fn.graph)
157*da0073e9SAndroid Build Coastguard Worker
158*da0073e9SAndroid Build Coastguard Worker    def test_if_returning_any(self):
159*da0073e9SAndroid Build Coastguard Worker        """
160*da0073e9SAndroid Build Coastguard Worker        Check that an if statement can return different
161*da0073e9SAndroid Build Coastguard Worker        types early from each branch when the return
162*da0073e9SAndroid Build Coastguard Worker        type of the function is Any.
163*da0073e9SAndroid Build Coastguard Worker        """
164*da0073e9SAndroid Build Coastguard Worker
165*da0073e9SAndroid Build Coastguard Worker        def if_function(inp: torch.Tensor) -> Any:
166*da0073e9SAndroid Build Coastguard Worker            if inp.shape[0] == 1:
167*da0073e9SAndroid Build Coastguard Worker                return inp * inp
168*da0073e9SAndroid Build Coastguard Worker            else:
169*da0073e9SAndroid Build Coastguard Worker                return "str"
170*da0073e9SAndroid Build Coastguard Worker
171*da0073e9SAndroid Build Coastguard Worker        self.checkScript(if_function, (torch.randn(5),))
172*da0073e9SAndroid Build Coastguard Worker
173*da0073e9SAndroid Build Coastguard Worker    def test_hacked_twin(self):
174*da0073e9SAndroid Build Coastguard Worker        def gen_data():
175*da0073e9SAndroid Build Coastguard Worker            with freeze_rng_state():
176*da0073e9SAndroid Build Coastguard Worker                return torch.randn(10), torch.randint(10, (20,)), torch.randn(20)
177*da0073e9SAndroid Build Coastguard Worker
178*da0073e9SAndroid Build Coastguard Worker        (
179*da0073e9SAndroid Build Coastguard Worker            input,
180*da0073e9SAndroid Build Coastguard Worker            index,
181*da0073e9SAndroid Build Coastguard Worker            value,
182*da0073e9SAndroid Build Coastguard Worker        ) = gen_data()
183*da0073e9SAndroid Build Coastguard Worker        (
184*da0073e9SAndroid Build Coastguard Worker            input1,
185*da0073e9SAndroid Build Coastguard Worker            index1,
186*da0073e9SAndroid Build Coastguard Worker            value1,
187*da0073e9SAndroid Build Coastguard Worker        ) = gen_data()
188*da0073e9SAndroid Build Coastguard Worker        out1 = torch.ops.aten.index_put.hacked_twin(
189*da0073e9SAndroid Build Coastguard Worker            input, [index], value, accumulate=False
190*da0073e9SAndroid Build Coastguard Worker        )
191*da0073e9SAndroid Build Coastguard Worker        out2 = torch.index_put(input1, [index1], value1, accumulate=False)
192*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out1, out2)
193*da0073e9SAndroid Build Coastguard Worker
194*da0073e9SAndroid Build Coastguard Worker        torch.ops.aten.index_put_.hacked_twin(input, [index], value, accumulate=False)
195*da0073e9SAndroid Build Coastguard Worker        torch.index_put_(input1, [index1], value1, accumulate=False)
196*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(input, input1)
197*da0073e9SAndroid Build Coastguard Worker
198*da0073e9SAndroid Build Coastguard Worker    def test_unsafe_hacked_twin(self):
199*da0073e9SAndroid Build Coastguard Worker        def gen_data():
200*da0073e9SAndroid Build Coastguard Worker            with freeze_rng_state():
201*da0073e9SAndroid Build Coastguard Worker                return torch.randn(10), torch.randint(10, (20,)), torch.randn(20)
202*da0073e9SAndroid Build Coastguard Worker
203*da0073e9SAndroid Build Coastguard Worker        (
204*da0073e9SAndroid Build Coastguard Worker            input,
205*da0073e9SAndroid Build Coastguard Worker            index,
206*da0073e9SAndroid Build Coastguard Worker            value,
207*da0073e9SAndroid Build Coastguard Worker        ) = gen_data()
208*da0073e9SAndroid Build Coastguard Worker        (
209*da0073e9SAndroid Build Coastguard Worker            input1,
210*da0073e9SAndroid Build Coastguard Worker            index1,
211*da0073e9SAndroid Build Coastguard Worker            value1,
212*da0073e9SAndroid Build Coastguard Worker        ) = gen_data()
213*da0073e9SAndroid Build Coastguard Worker        out1 = torch.ops.aten._unsafe_index_put.hacked_twin(
214*da0073e9SAndroid Build Coastguard Worker            input, [index], value, accumulate=False
215*da0073e9SAndroid Build Coastguard Worker        )
216*da0073e9SAndroid Build Coastguard Worker        out2 = torch.index_put(input1, [index1], value1, accumulate=False)
217*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out1, out2)
218*da0073e9SAndroid Build Coastguard Worker
219*da0073e9SAndroid Build Coastguard Worker        torch.ops.aten._unsafe_index.Tensor_hacked_twin(input, [index])
220*da0073e9SAndroid Build Coastguard Worker        torch.index_put(input1, [index1], value1, accumulate=False)
221*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(input, input1)
222*da0073e9SAndroid Build Coastguard Worker
223*da0073e9SAndroid Build Coastguard Worker        def index_put_fn(input, index, value):
224*da0073e9SAndroid Build Coastguard Worker            return torch.ops.aten._unsafe_index_put(
225*da0073e9SAndroid Build Coastguard Worker                input, [index], value, accumulate=False
226*da0073e9SAndroid Build Coastguard Worker            )
227*da0073e9SAndroid Build Coastguard Worker
228*da0073e9SAndroid Build Coastguard Worker        input2, index2, value2 = gen_data()
229*da0073e9SAndroid Build Coastguard Worker        script_index_put_fn = torch.jit.script(index_put_fn)
230*da0073e9SAndroid Build Coastguard Worker        expect = index_put_fn(input2.clone(), index2, value2)
231*da0073e9SAndroid Build Coastguard Worker        actual = script_index_put_fn(input2.clone(), index2, value2)
232*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expect, actual)
233*da0073e9SAndroid Build Coastguard Worker
234*da0073e9SAndroid Build Coastguard Worker        def index_fn(input, index, value):
235*da0073e9SAndroid Build Coastguard Worker            return torch.ops.aten._unsafe_index_put(
236*da0073e9SAndroid Build Coastguard Worker                input, [index], value, accumulate=False
237*da0073e9SAndroid Build Coastguard Worker            )
238*da0073e9SAndroid Build Coastguard Worker
239*da0073e9SAndroid Build Coastguard Worker        script_index_fn = torch.jit.script(index_fn)
240*da0073e9SAndroid Build Coastguard Worker        expect = index_fn(input2.clone(), index2, value2)
241*da0073e9SAndroid Build Coastguard Worker        actual = script_index_fn(input2.clone(), index2, value2)
242*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expect, actual)
243*da0073e9SAndroid Build Coastguard Worker
244*da0073e9SAndroid Build Coastguard Worker    def test_export_opnames_interface(self):
245*da0073e9SAndroid Build Coastguard Worker        @torch.jit.interface
246*da0073e9SAndroid Build Coastguard Worker        class OneTwoModule(nn.Module):
247*da0073e9SAndroid Build Coastguard Worker            def one(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
248*da0073e9SAndroid Build Coastguard Worker                pass
249*da0073e9SAndroid Build Coastguard Worker
250*da0073e9SAndroid Build Coastguard Worker            def two(self, x: torch.Tensor) -> torch.Tensor:
251*da0073e9SAndroid Build Coastguard Worker                pass
252*da0073e9SAndroid Build Coastguard Worker
253*da0073e9SAndroid Build Coastguard Worker            def forward(self, x: torch.Tensor) -> torch.Tensor:
254*da0073e9SAndroid Build Coastguard Worker                pass
255*da0073e9SAndroid Build Coastguard Worker
256*da0073e9SAndroid Build Coastguard Worker        class FooMod(nn.Module):
257*da0073e9SAndroid Build Coastguard Worker            def one(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
258*da0073e9SAndroid Build Coastguard Worker                return x + y
259*da0073e9SAndroid Build Coastguard Worker
260*da0073e9SAndroid Build Coastguard Worker            def two(self, x: torch.Tensor) -> torch.Tensor:
261*da0073e9SAndroid Build Coastguard Worker                return 2 * x
262*da0073e9SAndroid Build Coastguard Worker
263*da0073e9SAndroid Build Coastguard Worker            def forward(self, x: torch.Tensor) -> torch.Tensor:
264*da0073e9SAndroid Build Coastguard Worker                return self.one(self.two(x), x)
265*da0073e9SAndroid Build Coastguard Worker
266*da0073e9SAndroid Build Coastguard Worker        class BarMod(nn.Module):
267*da0073e9SAndroid Build Coastguard Worker            def one(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
268*da0073e9SAndroid Build Coastguard Worker                return x * y
269*da0073e9SAndroid Build Coastguard Worker
270*da0073e9SAndroid Build Coastguard Worker            def two(self, x: torch.Tensor) -> torch.Tensor:
271*da0073e9SAndroid Build Coastguard Worker                return 2 / x
272*da0073e9SAndroid Build Coastguard Worker
273*da0073e9SAndroid Build Coastguard Worker            def forward(self, x: torch.Tensor) -> torch.Tensor:
274*da0073e9SAndroid Build Coastguard Worker                return self.two(self.one(x, x))
275*da0073e9SAndroid Build Coastguard Worker
276*da0073e9SAndroid Build Coastguard Worker        make_global(OneTwoModule)
277*da0073e9SAndroid Build Coastguard Worker
278*da0073e9SAndroid Build Coastguard Worker        class M(nn.Module):
279*da0073e9SAndroid Build Coastguard Worker            sub: OneTwoModule
280*da0073e9SAndroid Build Coastguard Worker
281*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
282*da0073e9SAndroid Build Coastguard Worker                super().__init__()
283*da0073e9SAndroid Build Coastguard Worker                self.sub = BarMod()
284*da0073e9SAndroid Build Coastguard Worker
285*da0073e9SAndroid Build Coastguard Worker            def forward(self, x: torch.Tensor) -> torch.Tensor:
286*da0073e9SAndroid Build Coastguard Worker                return self.sub.forward(x)
287*da0073e9SAndroid Build Coastguard Worker
288*da0073e9SAndroid Build Coastguard Worker        def use_module_interface(mod_list: List[OneTwoModule], x: torch.Tensor):
289*da0073e9SAndroid Build Coastguard Worker            return mod_list[0].forward(x) + mod_list[1].forward(x)
290*da0073e9SAndroid Build Coastguard Worker
291*da0073e9SAndroid Build Coastguard Worker        torch._C._enable_mobile_interface_call_export()
292*da0073e9SAndroid Build Coastguard Worker        scripted_M_mod = torch.jit.script(M())
293*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(
294*da0073e9SAndroid Build Coastguard Worker            {"aten::mul.Scalar", "aten::mul.Tensor", "aten::reciprocal"}.issubset(
295*da0073e9SAndroid Build Coastguard Worker                set(torch.jit.export_opnames(scripted_M_mod))
296*da0073e9SAndroid Build Coastguard Worker            )
297*da0073e9SAndroid Build Coastguard Worker        )
298*da0073e9SAndroid Build Coastguard Worker
299*da0073e9SAndroid Build Coastguard Worker        scripted_M_mod.sub = torch.jit.script(FooMod())
300*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(
301*da0073e9SAndroid Build Coastguard Worker            {"aten::add.Tensor", "aten::mul.Scalar"}.issubset(
302*da0073e9SAndroid Build Coastguard Worker                set(torch.jit.export_opnames(scripted_M_mod))
303*da0073e9SAndroid Build Coastguard Worker            )
304*da0073e9SAndroid Build Coastguard Worker        )
305*da0073e9SAndroid Build Coastguard Worker
306*da0073e9SAndroid Build Coastguard Worker    def test_math_inf(self):
307*da0073e9SAndroid Build Coastguard Worker        from math import inf
308*da0073e9SAndroid Build Coastguard Worker
309*da0073e9SAndroid Build Coastguard Worker        def foo():
310*da0073e9SAndroid Build Coastguard Worker            return inf
311*da0073e9SAndroid Build Coastguard Worker
312*da0073e9SAndroid Build Coastguard Worker        self.checkScript(foo, ())
313*da0073e9SAndroid Build Coastguard Worker
314*da0073e9SAndroid Build Coastguard Worker    def test_list_literal_infer(self):
315*da0073e9SAndroid Build Coastguard Worker        def expects_intlist(x: List[int]):
316*da0073e9SAndroid Build Coastguard Worker            x.append(3)
317*da0073e9SAndroid Build Coastguard Worker            return x
318*da0073e9SAndroid Build Coastguard Worker
319*da0073e9SAndroid Build Coastguard Worker        def foo():
320*da0073e9SAndroid Build Coastguard Worker            return expects_intlist([])
321*da0073e9SAndroid Build Coastguard Worker
322*da0073e9SAndroid Build Coastguard Worker        self.checkScript(foo, ())
323*da0073e9SAndroid Build Coastguard Worker
324*da0073e9SAndroid Build Coastguard Worker        def annotated_list_fail():
325*da0073e9SAndroid Build Coastguard Worker            return expects_intlist(torch.jit.annotate([], List[Tensor]))  # noqa: F821
326*da0073e9SAndroid Build Coastguard Worker
327*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(RuntimeError):
328*da0073e9SAndroid Build Coastguard Worker            torch.jit.script(annotated_list_fail)
329*da0073e9SAndroid Build Coastguard Worker
330*da0073e9SAndroid Build Coastguard Worker        def non_temporary_fail():
331*da0073e9SAndroid Build Coastguard Worker            a = []
332*da0073e9SAndroid Build Coastguard Worker            return expects_intlist(a)
333*da0073e9SAndroid Build Coastguard Worker
334*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(RuntimeError):
335*da0073e9SAndroid Build Coastguard Worker            torch.jit.script(non_temporary_fail)
336*da0073e9SAndroid Build Coastguard Worker
337*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
338*da0073e9SAndroid Build Coastguard Worker        def test_return():
339*da0073e9SAndroid Build Coastguard Worker            return []
340*da0073e9SAndroid Build Coastguard Worker
341*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("Tensor[] = prim::ListConstruct").run(test_return.graph)
342*da0073e9SAndroid Build Coastguard Worker
343*da0073e9SAndroid Build Coastguard Worker    def test_legacy_tensor_constructor(self):
344*da0073e9SAndroid Build Coastguard Worker        # testing PyObject overload
345*da0073e9SAndroid Build Coastguard Worker        def test_all_dtypes():
346*da0073e9SAndroid Build Coastguard Worker            return (
347*da0073e9SAndroid Build Coastguard Worker                torch.BoolTensor([2]),
348*da0073e9SAndroid Build Coastguard Worker                torch.LongTensor([3]),
349*da0073e9SAndroid Build Coastguard Worker                torch.ByteTensor([4]),
350*da0073e9SAndroid Build Coastguard Worker                torch.CharTensor([5]),
351*da0073e9SAndroid Build Coastguard Worker                torch.DoubleTensor([6]),
352*da0073e9SAndroid Build Coastguard Worker                torch.FloatTensor([7]),
353*da0073e9SAndroid Build Coastguard Worker                torch.IntTensor([8]),
354*da0073e9SAndroid Build Coastguard Worker                torch.ShortTensor([1]),
355*da0073e9SAndroid Build Coastguard Worker                torch.HalfTensor([1]),
356*da0073e9SAndroid Build Coastguard Worker            )
357*da0073e9SAndroid Build Coastguard Worker
358*da0073e9SAndroid Build Coastguard Worker        self.checkScript(test_all_dtypes, ())
359*da0073e9SAndroid Build Coastguard Worker
360*da0073e9SAndroid Build Coastguard Worker        # now test empty overload
361*da0073e9SAndroid Build Coastguard Worker        def empty_overload():
362*da0073e9SAndroid Build Coastguard Worker            return torch.LongTensor(2, 3, 4)
363*da0073e9SAndroid Build Coastguard Worker
364*da0073e9SAndroid Build Coastguard Worker        eager = empty_overload()
365*da0073e9SAndroid Build Coastguard Worker        jit = torch.jit.script(empty_overload)()
366*da0073e9SAndroid Build Coastguard Worker        eager[:] = 1
367*da0073e9SAndroid Build Coastguard Worker        jit[:] = 1
368*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(eager, jit)
369*da0073e9SAndroid Build Coastguard Worker
370*da0073e9SAndroid Build Coastguard Worker        def no_inputs():
371*da0073e9SAndroid Build Coastguard Worker            return torch.DoubleTensor()
372*da0073e9SAndroid Build Coastguard Worker
373*da0073e9SAndroid Build Coastguard Worker        self.checkScript(no_inputs, ())
374*da0073e9SAndroid Build Coastguard Worker
375*da0073e9SAndroid Build Coastguard Worker        # bad schema
376*da0073e9SAndroid Build Coastguard Worker        def multiple_args():
377*da0073e9SAndroid Build Coastguard Worker            return torch.LongTensor(1, [2])
378*da0073e9SAndroid Build Coastguard Worker
379*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
380*da0073e9SAndroid Build Coastguard Worker            RuntimeError, "multiple positional arguments that were not all integers"
381*da0073e9SAndroid Build Coastguard Worker        ):
382*da0073e9SAndroid Build Coastguard Worker            torch.jit.script(multiple_args)
383*da0073e9SAndroid Build Coastguard Worker
384*da0073e9SAndroid Build Coastguard Worker        # kwarg bad schema
385*da0073e9SAndroid Build Coastguard Worker        def bad_kwarg():
386*da0073e9SAndroid Build Coastguard Worker            return torch.LongTensor(hello="1")
387*da0073e9SAndroid Build Coastguard Worker
388*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "hello"):
389*da0073e9SAndroid Build Coastguard Worker            torch.jit.script(bad_kwarg)
390*da0073e9SAndroid Build Coastguard Worker
391*da0073e9SAndroid Build Coastguard Worker    def test_broadcasting_list(self):
392*da0073e9SAndroid Build Coastguard Worker        """
393*da0073e9SAndroid Build Coastguard Worker        Test BroadcastingList and torch.nn._size_N_t alias
394*da0073e9SAndroid Build Coastguard Worker        """
395*da0073e9SAndroid Build Coastguard Worker        from torch._jit_internal import BroadcastingList2
396*da0073e9SAndroid Build Coastguard Worker        from torch.nn.common_types import _size_2_t
397*da0073e9SAndroid Build Coastguard Worker
398*da0073e9SAndroid Build Coastguard Worker        def sum_i(x: _size_2_t) -> int:
399*da0073e9SAndroid Build Coastguard Worker            return x[0] + x[1]
400*da0073e9SAndroid Build Coastguard Worker
401*da0073e9SAndroid Build Coastguard Worker        def sum_f(x: BroadcastingList2[float]) -> float:
402*da0073e9SAndroid Build Coastguard Worker            return x[0] + x[1]
403*da0073e9SAndroid Build Coastguard Worker
404*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch.jit.script(sum_i)(4) == 8)
405*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch.jit.script(sum_f)(4.5) == 9.0)
406*da0073e9SAndroid Build Coastguard Worker
407*da0073e9SAndroid Build Coastguard Worker    def test_parse_ir_annotate(self):
408*da0073e9SAndroid Build Coastguard Worker        ir = """
409*da0073e9SAndroid Build Coastguard Worker        graph():
410*da0073e9SAndroid Build Coastguard Worker          %3 : int[] = prim::Constant[value=annotate(List[int], [])]()
411*da0073e9SAndroid Build Coastguard Worker          return (%3)
412*da0073e9SAndroid Build Coastguard Worker        """
413*da0073e9SAndroid Build Coastguard Worker        graph = torch._C.parse_ir(ir, True)
414*da0073e9SAndroid Build Coastguard Worker        func = torch._C._create_function_from_graph("forward", graph)
415*da0073e9SAndroid Build Coastguard Worker        ret = func()
416*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(ret == [])
417*da0073e9SAndroid Build Coastguard Worker
418*da0073e9SAndroid Build Coastguard Worker    def test_parse_ir_single_element_tensor_positive(self):
419*da0073e9SAndroid Build Coastguard Worker        ir = """
420*da0073e9SAndroid Build Coastguard Worker        graph():
421*da0073e9SAndroid Build Coastguard Worker          %7 : Long(1, strides=[1], requires_grad=0, device=cpu) = prim::Constant[value={0}]()
422*da0073e9SAndroid Build Coastguard Worker          return (%7)
423*da0073e9SAndroid Build Coastguard Worker        """
424*da0073e9SAndroid Build Coastguard Worker        graph = torch._C.parse_ir(ir, True)
425*da0073e9SAndroid Build Coastguard Worker        func = torch._C._create_function_from_graph("forward", graph)
426*da0073e9SAndroid Build Coastguard Worker        ret = func()
427*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(ret.numel() == 1)
428*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(len(ret.size()) == 1)
429*da0073e9SAndroid Build Coastguard Worker
430*da0073e9SAndroid Build Coastguard Worker    def test_parse_ir_single_element_tensor_negative(self):
431*da0073e9SAndroid Build Coastguard Worker        ir = """
432*da0073e9SAndroid Build Coastguard Worker        graph():
433*da0073e9SAndroid Build Coastguard Worker          %7 : Long(1, strides=[1], requires_grad=0, device=cpu) = prim::Constant[value={-17}]()
434*da0073e9SAndroid Build Coastguard Worker          return (%7)
435*da0073e9SAndroid Build Coastguard Worker        """
436*da0073e9SAndroid Build Coastguard Worker        graph = torch._C.parse_ir(ir, True)
437*da0073e9SAndroid Build Coastguard Worker        func = torch._C._create_function_from_graph("forward", graph)
438*da0073e9SAndroid Build Coastguard Worker        ret = func()
439*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(ret.numel() == 1)
440*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(len(ret.size()) == 1)
441*da0073e9SAndroid Build Coastguard Worker
442*da0073e9SAndroid Build Coastguard Worker    def test_script_many_decorators(self):
443*da0073e9SAndroid Build Coastguard Worker        def no_op_decorator(f):
444*da0073e9SAndroid Build Coastguard Worker            return f
445*da0073e9SAndroid Build Coastguard Worker
446*da0073e9SAndroid Build Coastguard Worker        @no_op_decorator
447*da0073e9SAndroid Build Coastguard Worker        @no_op_decorator
448*da0073e9SAndroid Build Coastguard Worker        @no_op_decorator
449*da0073e9SAndroid Build Coastguard Worker        @no_op_decorator
450*da0073e9SAndroid Build Coastguard Worker        @no_op_decorator
451*da0073e9SAndroid Build Coastguard Worker        def foo(x, dim: int):
452*da0073e9SAndroid Build Coastguard Worker            return x.unsqueeze(dim)
453*da0073e9SAndroid Build Coastguard Worker
454*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(
455*da0073e9SAndroid Build Coastguard Worker            1,
456*da0073e9SAndroid Build Coastguard Worker        )
457*da0073e9SAndroid Build Coastguard Worker        expected = foo(x, 0)
458*da0073e9SAndroid Build Coastguard Worker        scripted = torch.jit.script(foo)
459*da0073e9SAndroid Build Coastguard Worker        actual = scripted(x, 0)
460*da0073e9SAndroid Build Coastguard Worker        torch.testing.assert_close(expected, actual)
461*da0073e9SAndroid Build Coastguard Worker
462*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not RUN_CUDA_HALF, "need CUDA half support")
463*da0073e9SAndroid Build Coastguard Worker    def test_pow_multiple_dtype(self):
464*da0073e9SAndroid Build Coastguard Worker        # https://github.com/pytorch/pytorch/issues/75476
465*da0073e9SAndroid Build Coastguard Worker        def fn(p: torch.Tensor, gamma: float = 2.0) -> torch.Tensor:
466*da0073e9SAndroid Build Coastguard Worker            p = torch.sigmoid(p)
467*da0073e9SAndroid Build Coastguard Worker            result = p**gamma
468*da0073e9SAndroid Build Coastguard Worker            return result
469*da0073e9SAndroid Build Coastguard Worker
470*da0073e9SAndroid Build Coastguard Worker        x = torch.rand((2, 2), dtype=torch.half, device="cuda")
471*da0073e9SAndroid Build Coastguard Worker
472*da0073e9SAndroid Build Coastguard Worker        ref = fn(x)
473*da0073e9SAndroid Build Coastguard Worker
474*da0073e9SAndroid Build Coastguard Worker        script_fn = torch.jit.script(fn)
475*da0073e9SAndroid Build Coastguard Worker        for i in range(4):
476*da0073e9SAndroid Build Coastguard Worker            res = script_fn(x)
477*da0073e9SAndroid Build Coastguard Worker
478*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(ref, res)
479*da0073e9SAndroid Build Coastguard Worker
480*da0073e9SAndroid Build Coastguard Worker    def test_jit_get_operation_order(self):
481*da0073e9SAndroid Build Coastguard Worker        # See https://github.com/pytorch/pytorch/pull/107138.
482*da0073e9SAndroid Build Coastguard Worker        # Depending on order of operator registration, you can get different
483*da0073e9SAndroid Build Coastguard Worker        # order of overloads in the JIT operator registry.
484*da0073e9SAndroid Build Coastguard Worker        # This is to verify that the order of operators returned by
485*da0073e9SAndroid Build Coastguard Worker        # _jit_get_operation always puts aten ops first (i.e. by sorting
486*da0073e9SAndroid Build Coastguard Worker        # to put them first)
487*da0073e9SAndroid Build Coastguard Worker
488*da0073e9SAndroid Build Coastguard Worker        # Make sure that this chooses a "scalar" overload not a "complex" overload
489*da0073e9SAndroid Build Coastguard Worker        ret = torch.ops.aten.add(4, 3.3)
490*da0073e9SAndroid Build Coastguard Worker        self.assertFalse("complex" in str(ret.dtype))
491*da0073e9SAndroid Build Coastguard Worker
492*da0073e9SAndroid Build Coastguard Worker        # "Scalar" overload is a normal aten op; "complex" is added by torchscript.
493*da0073e9SAndroid Build Coastguard Worker        # We want "Scalar" to come before "complex".
494*da0073e9SAndroid Build Coastguard Worker        op, override_names = torch._C._jit_get_operation("aten::add")
495*da0073e9SAndroid Build Coastguard Worker        print(override_names)
496*da0073e9SAndroid Build Coastguard Worker        complex_indices = [
497*da0073e9SAndroid Build Coastguard Worker            i for i, name in enumerate(override_names) if name == "complex"
498*da0073e9SAndroid Build Coastguard Worker        ]
499*da0073e9SAndroid Build Coastguard Worker        Scalar_indices = [
500*da0073e9SAndroid Build Coastguard Worker            i for i, name in enumerate(override_names) if name == "Scalar"
501*da0073e9SAndroid Build Coastguard Worker        ]
502*da0073e9SAndroid Build Coastguard Worker
503*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(len(complex_indices) > 0)
504*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(len(Scalar_indices) > 0)
505*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(complex_indices[0] > Scalar_indices[0])
506