# Owner(s): ["oncall: jit"] import os import sys import unittest from typing import Any, Dict, List, Optional, Tuple import torch import torch.nn as nn import torch.testing._internal.jit_utils from jit.test_module_interface import TestModuleInterface # noqa: F401 from torch import jit from torch.testing import FileCheck from torch.testing._internal.common_utils import freeze_rng_state from torch.testing._internal.jit_utils import JitTestCase, make_global, RUN_CUDA_HALF # Make the helper files in test/ importable pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) sys.path.append(pytorch_test_dir) if __name__ == "__main__": raise RuntimeError( "This test file is not meant to be run directly, use:\n\n" "\tpython test/test_jit.py TESTNAME\n\n" "instead." ) class TestMisc(JitTestCase): def test_joined_str(self): def func(x): hello, test = "Hello", "test" print(f"{hello + ' ' + test}, I'm a {test}") print("format blank") hi = "hi" print(f"stuff before {hi}") print(f"{hi} stuff after") return x + 1 x = torch.arange(4.0, requires_grad=True) # TODO: Add support for f-strings in string parser frontend # self.checkScript(func, [x], optimize=True, capture_output=True) with self.capture_stdout() as captured: out = func(x) scripted = torch.jit.script(func) with self.capture_stdout() as captured_script: out_script = func(x) self.assertEqual(out, out_script) self.assertEqual(captured, captured_script) def test_kwarg_support(self): with self.assertRaisesRegex( torch.jit.frontend.NotSupportedError, "variable number of arguments" ): class M(torch.nn.Module): def forward(self, *, n_tokens: int, device_name: str = 2): pass torch.jit.script(M()) class M(torch.nn.Module): def forward(self, *, n_tokens: int, device_name: str): return n_tokens, device_name sm = torch.jit.script(M()) with self.assertRaisesRegex( RuntimeError, "missing value for argument 'n_tokens'" ): sm() with self.assertRaisesRegex(RuntimeError, "positional arg"): sm(3, "hello") self.assertEqual(sm(n_tokens=3, device_name="hello"), (3, "hello")) def test_tuple_subscripted_assign(self): with self.assertRaisesRegex(RuntimeError, "subscripted assignment"): @torch.jit.script def foo(a: Tuple[int, int]) -> None: a[0] = a[1] with self.assertRaisesRegex(RuntimeError, "augmented assignment"): @torch.jit.script def bar(a: Tuple[int, int]) -> None: a[0] += a[1] def test_subexpression_List_Future(self): @torch.jit.script def fn(x: List[torch.jit.Future[int]]) -> torch.jit.Future[int]: return x[0] FileCheck().check("Future[int]").check("Future[int]").run(fn.graph) def test_subexpression_Future_annotate(self): @torch.jit.script def fn() -> torch.jit.Future[int]: x: List[torch.jit.Future[int]] = [] return x[0] FileCheck().check("Future[int][]").run(fn.graph) def test_future_isinstance(self): @torch.jit.script def fn(x: Any) -> torch.jit.Future[int]: assert isinstance(x, jit.Future[int]) return x FileCheck().check("Future[int]").run(fn.graph) def test_str_refine_any(self): def forward(x: Any) -> str: if isinstance(x, str): return x return "foo" forward = torch.jit.script(forward) self.assertEqual(forward(1), "foo") self.assertEqual(forward("bar"), "bar") def test_subexpression_Tuple_int_int_Future(self): @torch.jit.script def fn( x: Tuple[int, int, torch.jit.Future[int]] ) -> Tuple[int, torch.jit.Future[int]]: return x[0], x[2] FileCheck().check("(int, int, Future[int])").check("(int, Future[int])").run( fn.graph ) def test_subexpression_Dict_int_Future(self): @torch.jit.script def fn(x: Dict[int, torch.jit.Future[int]], y: int) -> torch.jit.Future[int]: return x[y] FileCheck().check("Dict(int, Future(int))").check("Future[int]").run(fn.graph) def test_subexpression_Optional(self): @torch.jit.script def fn( x: Optional[Dict[int, torch.jit.Future[int]]] ) -> Optional[torch.jit.Future[int]]: if x is not None: return x[0] else: return None FileCheck().check("Dict(int, Future(int))?").run(fn.graph) def test_if_returning_any(self): """ Check that an if statement can return different types early from each branch when the return type of the function is Any. """ def if_function(inp: torch.Tensor) -> Any: if inp.shape[0] == 1: return inp * inp else: return "str" self.checkScript(if_function, (torch.randn(5),)) def test_hacked_twin(self): def gen_data(): with freeze_rng_state(): return torch.randn(10), torch.randint(10, (20,)), torch.randn(20) ( input, index, value, ) = gen_data() ( input1, index1, value1, ) = gen_data() out1 = torch.ops.aten.index_put.hacked_twin( input, [index], value, accumulate=False ) out2 = torch.index_put(input1, [index1], value1, accumulate=False) self.assertEqual(out1, out2) torch.ops.aten.index_put_.hacked_twin(input, [index], value, accumulate=False) torch.index_put_(input1, [index1], value1, accumulate=False) self.assertEqual(input, input1) def test_unsafe_hacked_twin(self): def gen_data(): with freeze_rng_state(): return torch.randn(10), torch.randint(10, (20,)), torch.randn(20) ( input, index, value, ) = gen_data() ( input1, index1, value1, ) = gen_data() out1 = torch.ops.aten._unsafe_index_put.hacked_twin( input, [index], value, accumulate=False ) out2 = torch.index_put(input1, [index1], value1, accumulate=False) self.assertEqual(out1, out2) torch.ops.aten._unsafe_index.Tensor_hacked_twin(input, [index]) torch.index_put(input1, [index1], value1, accumulate=False) self.assertEqual(input, input1) def index_put_fn(input, index, value): return torch.ops.aten._unsafe_index_put( input, [index], value, accumulate=False ) input2, index2, value2 = gen_data() script_index_put_fn = torch.jit.script(index_put_fn) expect = index_put_fn(input2.clone(), index2, value2) actual = script_index_put_fn(input2.clone(), index2, value2) self.assertEqual(expect, actual) def index_fn(input, index, value): return torch.ops.aten._unsafe_index_put( input, [index], value, accumulate=False ) script_index_fn = torch.jit.script(index_fn) expect = index_fn(input2.clone(), index2, value2) actual = script_index_fn(input2.clone(), index2, value2) self.assertEqual(expect, actual) def test_export_opnames_interface(self): @torch.jit.interface class OneTwoModule(nn.Module): def one(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: pass def two(self, x: torch.Tensor) -> torch.Tensor: pass def forward(self, x: torch.Tensor) -> torch.Tensor: pass class FooMod(nn.Module): def one(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: return x + y def two(self, x: torch.Tensor) -> torch.Tensor: return 2 * x def forward(self, x: torch.Tensor) -> torch.Tensor: return self.one(self.two(x), x) class BarMod(nn.Module): def one(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: return x * y def two(self, x: torch.Tensor) -> torch.Tensor: return 2 / x def forward(self, x: torch.Tensor) -> torch.Tensor: return self.two(self.one(x, x)) make_global(OneTwoModule) class M(nn.Module): sub: OneTwoModule def __init__(self) -> None: super().__init__() self.sub = BarMod() def forward(self, x: torch.Tensor) -> torch.Tensor: return self.sub.forward(x) def use_module_interface(mod_list: List[OneTwoModule], x: torch.Tensor): return mod_list[0].forward(x) + mod_list[1].forward(x) torch._C._enable_mobile_interface_call_export() scripted_M_mod = torch.jit.script(M()) self.assertTrue( {"aten::mul.Scalar", "aten::mul.Tensor", "aten::reciprocal"}.issubset( set(torch.jit.export_opnames(scripted_M_mod)) ) ) scripted_M_mod.sub = torch.jit.script(FooMod()) self.assertTrue( {"aten::add.Tensor", "aten::mul.Scalar"}.issubset( set(torch.jit.export_opnames(scripted_M_mod)) ) ) def test_math_inf(self): from math import inf def foo(): return inf self.checkScript(foo, ()) def test_list_literal_infer(self): def expects_intlist(x: List[int]): x.append(3) return x def foo(): return expects_intlist([]) self.checkScript(foo, ()) def annotated_list_fail(): return expects_intlist(torch.jit.annotate([], List[Tensor])) # noqa: F821 with self.assertRaises(RuntimeError): torch.jit.script(annotated_list_fail) def non_temporary_fail(): a = [] return expects_intlist(a) with self.assertRaises(RuntimeError): torch.jit.script(non_temporary_fail) @torch.jit.script def test_return(): return [] FileCheck().check("Tensor[] = prim::ListConstruct").run(test_return.graph) def test_legacy_tensor_constructor(self): # testing PyObject overload def test_all_dtypes(): return ( torch.BoolTensor([2]), torch.LongTensor([3]), torch.ByteTensor([4]), torch.CharTensor([5]), torch.DoubleTensor([6]), torch.FloatTensor([7]), torch.IntTensor([8]), torch.ShortTensor([1]), torch.HalfTensor([1]), ) self.checkScript(test_all_dtypes, ()) # now test empty overload def empty_overload(): return torch.LongTensor(2, 3, 4) eager = empty_overload() jit = torch.jit.script(empty_overload)() eager[:] = 1 jit[:] = 1 self.assertEqual(eager, jit) def no_inputs(): return torch.DoubleTensor() self.checkScript(no_inputs, ()) # bad schema def multiple_args(): return torch.LongTensor(1, [2]) with self.assertRaisesRegex( RuntimeError, "multiple positional arguments that were not all integers" ): torch.jit.script(multiple_args) # kwarg bad schema def bad_kwarg(): return torch.LongTensor(hello="1") with self.assertRaisesRegex(RuntimeError, "hello"): torch.jit.script(bad_kwarg) def test_broadcasting_list(self): """ Test BroadcastingList and torch.nn._size_N_t alias """ from torch._jit_internal import BroadcastingList2 from torch.nn.common_types import _size_2_t def sum_i(x: _size_2_t) -> int: return x[0] + x[1] def sum_f(x: BroadcastingList2[float]) -> float: return x[0] + x[1] self.assertTrue(torch.jit.script(sum_i)(4) == 8) self.assertTrue(torch.jit.script(sum_f)(4.5) == 9.0) def test_parse_ir_annotate(self): ir = """ graph(): %3 : int[] = prim::Constant[value=annotate(List[int], [])]() return (%3) """ graph = torch._C.parse_ir(ir, True) func = torch._C._create_function_from_graph("forward", graph) ret = func() self.assertTrue(ret == []) def test_parse_ir_single_element_tensor_positive(self): ir = """ graph(): %7 : Long(1, strides=[1], requires_grad=0, device=cpu) = prim::Constant[value={0}]() return (%7) """ graph = torch._C.parse_ir(ir, True) func = torch._C._create_function_from_graph("forward", graph) ret = func() self.assertTrue(ret.numel() == 1) self.assertTrue(len(ret.size()) == 1) def test_parse_ir_single_element_tensor_negative(self): ir = """ graph(): %7 : Long(1, strides=[1], requires_grad=0, device=cpu) = prim::Constant[value={-17}]() return (%7) """ graph = torch._C.parse_ir(ir, True) func = torch._C._create_function_from_graph("forward", graph) ret = func() self.assertTrue(ret.numel() == 1) self.assertTrue(len(ret.size()) == 1) def test_script_many_decorators(self): def no_op_decorator(f): return f @no_op_decorator @no_op_decorator @no_op_decorator @no_op_decorator @no_op_decorator def foo(x, dim: int): return x.unsqueeze(dim) x = torch.randn( 1, ) expected = foo(x, 0) scripted = torch.jit.script(foo) actual = scripted(x, 0) torch.testing.assert_close(expected, actual) @unittest.skipIf(not RUN_CUDA_HALF, "need CUDA half support") def test_pow_multiple_dtype(self): # https://github.com/pytorch/pytorch/issues/75476 def fn(p: torch.Tensor, gamma: float = 2.0) -> torch.Tensor: p = torch.sigmoid(p) result = p**gamma return result x = torch.rand((2, 2), dtype=torch.half, device="cuda") ref = fn(x) script_fn = torch.jit.script(fn) for i in range(4): res = script_fn(x) self.assertEqual(ref, res) def test_jit_get_operation_order(self): # See https://github.com/pytorch/pytorch/pull/107138. # Depending on order of operator registration, you can get different # order of overloads in the JIT operator registry. # This is to verify that the order of operators returned by # _jit_get_operation always puts aten ops first (i.e. by sorting # to put them first) # Make sure that this chooses a "scalar" overload not a "complex" overload ret = torch.ops.aten.add(4, 3.3) self.assertFalse("complex" in str(ret.dtype)) # "Scalar" overload is a normal aten op; "complex" is added by torchscript. # We want "Scalar" to come before "complex". op, override_names = torch._C._jit_get_operation("aten::add") print(override_names) complex_indices = [ i for i, name in enumerate(override_names) if name == "complex" ] Scalar_indices = [ i for i, name in enumerate(override_names) if name == "Scalar" ] self.assertTrue(len(complex_indices) > 0) self.assertTrue(len(Scalar_indices) > 0) self.assertTrue(complex_indices[0] > Scalar_indices[0])