1# Owner(s): ["oncall: jit"] 2 3import os 4import sys 5import unittest 6from typing import Any, Dict, List, Optional, Tuple 7 8import torch 9import torch.nn as nn 10import torch.testing._internal.jit_utils 11from jit.test_module_interface import TestModuleInterface # noqa: F401 12from torch import jit 13from torch.testing import FileCheck 14from torch.testing._internal.common_utils import freeze_rng_state 15from torch.testing._internal.jit_utils import JitTestCase, make_global, RUN_CUDA_HALF 16 17 18# Make the helper files in test/ importable 19pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) 20sys.path.append(pytorch_test_dir) 21 22if __name__ == "__main__": 23 raise RuntimeError( 24 "This test file is not meant to be run directly, use:\n\n" 25 "\tpython test/test_jit.py TESTNAME\n\n" 26 "instead." 27 ) 28 29 30class TestMisc(JitTestCase): 31 def test_joined_str(self): 32 def func(x): 33 hello, test = "Hello", "test" 34 print(f"{hello + ' ' + test}, I'm a {test}") 35 print("format blank") 36 hi = "hi" 37 print(f"stuff before {hi}") 38 print(f"{hi} stuff after") 39 return x + 1 40 41 x = torch.arange(4.0, requires_grad=True) 42 # TODO: Add support for f-strings in string parser frontend 43 # self.checkScript(func, [x], optimize=True, capture_output=True) 44 45 with self.capture_stdout() as captured: 46 out = func(x) 47 48 scripted = torch.jit.script(func) 49 with self.capture_stdout() as captured_script: 50 out_script = func(x) 51 52 self.assertEqual(out, out_script) 53 self.assertEqual(captured, captured_script) 54 55 def test_kwarg_support(self): 56 with self.assertRaisesRegex( 57 torch.jit.frontend.NotSupportedError, "variable number of arguments" 58 ): 59 60 class M(torch.nn.Module): 61 def forward(self, *, n_tokens: int, device_name: str = 2): 62 pass 63 64 torch.jit.script(M()) 65 66 class M(torch.nn.Module): 67 def forward(self, *, n_tokens: int, device_name: str): 68 return n_tokens, device_name 69 70 sm = torch.jit.script(M()) 71 72 with self.assertRaisesRegex( 73 RuntimeError, "missing value for argument 'n_tokens'" 74 ): 75 sm() 76 77 with self.assertRaisesRegex(RuntimeError, "positional arg"): 78 sm(3, "hello") 79 80 self.assertEqual(sm(n_tokens=3, device_name="hello"), (3, "hello")) 81 82 def test_tuple_subscripted_assign(self): 83 with self.assertRaisesRegex(RuntimeError, "subscripted assignment"): 84 85 @torch.jit.script 86 def foo(a: Tuple[int, int]) -> None: 87 a[0] = a[1] 88 89 with self.assertRaisesRegex(RuntimeError, "augmented assignment"): 90 91 @torch.jit.script 92 def bar(a: Tuple[int, int]) -> None: 93 a[0] += a[1] 94 95 def test_subexpression_List_Future(self): 96 @torch.jit.script 97 def fn(x: List[torch.jit.Future[int]]) -> torch.jit.Future[int]: 98 return x[0] 99 100 FileCheck().check("Future[int]").check("Future[int]").run(fn.graph) 101 102 def test_subexpression_Future_annotate(self): 103 @torch.jit.script 104 def fn() -> torch.jit.Future[int]: 105 x: List[torch.jit.Future[int]] = [] 106 return x[0] 107 108 FileCheck().check("Future[int][]").run(fn.graph) 109 110 def test_future_isinstance(self): 111 @torch.jit.script 112 def fn(x: Any) -> torch.jit.Future[int]: 113 assert isinstance(x, jit.Future[int]) 114 return x 115 116 FileCheck().check("Future[int]").run(fn.graph) 117 118 def test_str_refine_any(self): 119 def forward(x: Any) -> str: 120 if isinstance(x, str): 121 return x 122 return "foo" 123 124 forward = torch.jit.script(forward) 125 self.assertEqual(forward(1), "foo") 126 self.assertEqual(forward("bar"), "bar") 127 128 def test_subexpression_Tuple_int_int_Future(self): 129 @torch.jit.script 130 def fn( 131 x: Tuple[int, int, torch.jit.Future[int]] 132 ) -> Tuple[int, torch.jit.Future[int]]: 133 return x[0], x[2] 134 135 FileCheck().check("(int, int, Future[int])").check("(int, Future[int])").run( 136 fn.graph 137 ) 138 139 def test_subexpression_Dict_int_Future(self): 140 @torch.jit.script 141 def fn(x: Dict[int, torch.jit.Future[int]], y: int) -> torch.jit.Future[int]: 142 return x[y] 143 144 FileCheck().check("Dict(int, Future(int))").check("Future[int]").run(fn.graph) 145 146 def test_subexpression_Optional(self): 147 @torch.jit.script 148 def fn( 149 x: Optional[Dict[int, torch.jit.Future[int]]] 150 ) -> Optional[torch.jit.Future[int]]: 151 if x is not None: 152 return x[0] 153 else: 154 return None 155 156 FileCheck().check("Dict(int, Future(int))?").run(fn.graph) 157 158 def test_if_returning_any(self): 159 """ 160 Check that an if statement can return different 161 types early from each branch when the return 162 type of the function is Any. 163 """ 164 165 def if_function(inp: torch.Tensor) -> Any: 166 if inp.shape[0] == 1: 167 return inp * inp 168 else: 169 return "str" 170 171 self.checkScript(if_function, (torch.randn(5),)) 172 173 def test_hacked_twin(self): 174 def gen_data(): 175 with freeze_rng_state(): 176 return torch.randn(10), torch.randint(10, (20,)), torch.randn(20) 177 178 ( 179 input, 180 index, 181 value, 182 ) = gen_data() 183 ( 184 input1, 185 index1, 186 value1, 187 ) = gen_data() 188 out1 = torch.ops.aten.index_put.hacked_twin( 189 input, [index], value, accumulate=False 190 ) 191 out2 = torch.index_put(input1, [index1], value1, accumulate=False) 192 self.assertEqual(out1, out2) 193 194 torch.ops.aten.index_put_.hacked_twin(input, [index], value, accumulate=False) 195 torch.index_put_(input1, [index1], value1, accumulate=False) 196 self.assertEqual(input, input1) 197 198 def test_unsafe_hacked_twin(self): 199 def gen_data(): 200 with freeze_rng_state(): 201 return torch.randn(10), torch.randint(10, (20,)), torch.randn(20) 202 203 ( 204 input, 205 index, 206 value, 207 ) = gen_data() 208 ( 209 input1, 210 index1, 211 value1, 212 ) = gen_data() 213 out1 = torch.ops.aten._unsafe_index_put.hacked_twin( 214 input, [index], value, accumulate=False 215 ) 216 out2 = torch.index_put(input1, [index1], value1, accumulate=False) 217 self.assertEqual(out1, out2) 218 219 torch.ops.aten._unsafe_index.Tensor_hacked_twin(input, [index]) 220 torch.index_put(input1, [index1], value1, accumulate=False) 221 self.assertEqual(input, input1) 222 223 def index_put_fn(input, index, value): 224 return torch.ops.aten._unsafe_index_put( 225 input, [index], value, accumulate=False 226 ) 227 228 input2, index2, value2 = gen_data() 229 script_index_put_fn = torch.jit.script(index_put_fn) 230 expect = index_put_fn(input2.clone(), index2, value2) 231 actual = script_index_put_fn(input2.clone(), index2, value2) 232 self.assertEqual(expect, actual) 233 234 def index_fn(input, index, value): 235 return torch.ops.aten._unsafe_index_put( 236 input, [index], value, accumulate=False 237 ) 238 239 script_index_fn = torch.jit.script(index_fn) 240 expect = index_fn(input2.clone(), index2, value2) 241 actual = script_index_fn(input2.clone(), index2, value2) 242 self.assertEqual(expect, actual) 243 244 def test_export_opnames_interface(self): 245 @torch.jit.interface 246 class OneTwoModule(nn.Module): 247 def one(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 248 pass 249 250 def two(self, x: torch.Tensor) -> torch.Tensor: 251 pass 252 253 def forward(self, x: torch.Tensor) -> torch.Tensor: 254 pass 255 256 class FooMod(nn.Module): 257 def one(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 258 return x + y 259 260 def two(self, x: torch.Tensor) -> torch.Tensor: 261 return 2 * x 262 263 def forward(self, x: torch.Tensor) -> torch.Tensor: 264 return self.one(self.two(x), x) 265 266 class BarMod(nn.Module): 267 def one(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 268 return x * y 269 270 def two(self, x: torch.Tensor) -> torch.Tensor: 271 return 2 / x 272 273 def forward(self, x: torch.Tensor) -> torch.Tensor: 274 return self.two(self.one(x, x)) 275 276 make_global(OneTwoModule) 277 278 class M(nn.Module): 279 sub: OneTwoModule 280 281 def __init__(self) -> None: 282 super().__init__() 283 self.sub = BarMod() 284 285 def forward(self, x: torch.Tensor) -> torch.Tensor: 286 return self.sub.forward(x) 287 288 def use_module_interface(mod_list: List[OneTwoModule], x: torch.Tensor): 289 return mod_list[0].forward(x) + mod_list[1].forward(x) 290 291 torch._C._enable_mobile_interface_call_export() 292 scripted_M_mod = torch.jit.script(M()) 293 self.assertTrue( 294 {"aten::mul.Scalar", "aten::mul.Tensor", "aten::reciprocal"}.issubset( 295 set(torch.jit.export_opnames(scripted_M_mod)) 296 ) 297 ) 298 299 scripted_M_mod.sub = torch.jit.script(FooMod()) 300 self.assertTrue( 301 {"aten::add.Tensor", "aten::mul.Scalar"}.issubset( 302 set(torch.jit.export_opnames(scripted_M_mod)) 303 ) 304 ) 305 306 def test_math_inf(self): 307 from math import inf 308 309 def foo(): 310 return inf 311 312 self.checkScript(foo, ()) 313 314 def test_list_literal_infer(self): 315 def expects_intlist(x: List[int]): 316 x.append(3) 317 return x 318 319 def foo(): 320 return expects_intlist([]) 321 322 self.checkScript(foo, ()) 323 324 def annotated_list_fail(): 325 return expects_intlist(torch.jit.annotate([], List[Tensor])) # noqa: F821 326 327 with self.assertRaises(RuntimeError): 328 torch.jit.script(annotated_list_fail) 329 330 def non_temporary_fail(): 331 a = [] 332 return expects_intlist(a) 333 334 with self.assertRaises(RuntimeError): 335 torch.jit.script(non_temporary_fail) 336 337 @torch.jit.script 338 def test_return(): 339 return [] 340 341 FileCheck().check("Tensor[] = prim::ListConstruct").run(test_return.graph) 342 343 def test_legacy_tensor_constructor(self): 344 # testing PyObject overload 345 def test_all_dtypes(): 346 return ( 347 torch.BoolTensor([2]), 348 torch.LongTensor([3]), 349 torch.ByteTensor([4]), 350 torch.CharTensor([5]), 351 torch.DoubleTensor([6]), 352 torch.FloatTensor([7]), 353 torch.IntTensor([8]), 354 torch.ShortTensor([1]), 355 torch.HalfTensor([1]), 356 ) 357 358 self.checkScript(test_all_dtypes, ()) 359 360 # now test empty overload 361 def empty_overload(): 362 return torch.LongTensor(2, 3, 4) 363 364 eager = empty_overload() 365 jit = torch.jit.script(empty_overload)() 366 eager[:] = 1 367 jit[:] = 1 368 self.assertEqual(eager, jit) 369 370 def no_inputs(): 371 return torch.DoubleTensor() 372 373 self.checkScript(no_inputs, ()) 374 375 # bad schema 376 def multiple_args(): 377 return torch.LongTensor(1, [2]) 378 379 with self.assertRaisesRegex( 380 RuntimeError, "multiple positional arguments that were not all integers" 381 ): 382 torch.jit.script(multiple_args) 383 384 # kwarg bad schema 385 def bad_kwarg(): 386 return torch.LongTensor(hello="1") 387 388 with self.assertRaisesRegex(RuntimeError, "hello"): 389 torch.jit.script(bad_kwarg) 390 391 def test_broadcasting_list(self): 392 """ 393 Test BroadcastingList and torch.nn._size_N_t alias 394 """ 395 from torch._jit_internal import BroadcastingList2 396 from torch.nn.common_types import _size_2_t 397 398 def sum_i(x: _size_2_t) -> int: 399 return x[0] + x[1] 400 401 def sum_f(x: BroadcastingList2[float]) -> float: 402 return x[0] + x[1] 403 404 self.assertTrue(torch.jit.script(sum_i)(4) == 8) 405 self.assertTrue(torch.jit.script(sum_f)(4.5) == 9.0) 406 407 def test_parse_ir_annotate(self): 408 ir = """ 409 graph(): 410 %3 : int[] = prim::Constant[value=annotate(List[int], [])]() 411 return (%3) 412 """ 413 graph = torch._C.parse_ir(ir, True) 414 func = torch._C._create_function_from_graph("forward", graph) 415 ret = func() 416 self.assertTrue(ret == []) 417 418 def test_parse_ir_single_element_tensor_positive(self): 419 ir = """ 420 graph(): 421 %7 : Long(1, strides=[1], requires_grad=0, device=cpu) = prim::Constant[value={0}]() 422 return (%7) 423 """ 424 graph = torch._C.parse_ir(ir, True) 425 func = torch._C._create_function_from_graph("forward", graph) 426 ret = func() 427 self.assertTrue(ret.numel() == 1) 428 self.assertTrue(len(ret.size()) == 1) 429 430 def test_parse_ir_single_element_tensor_negative(self): 431 ir = """ 432 graph(): 433 %7 : Long(1, strides=[1], requires_grad=0, device=cpu) = prim::Constant[value={-17}]() 434 return (%7) 435 """ 436 graph = torch._C.parse_ir(ir, True) 437 func = torch._C._create_function_from_graph("forward", graph) 438 ret = func() 439 self.assertTrue(ret.numel() == 1) 440 self.assertTrue(len(ret.size()) == 1) 441 442 def test_script_many_decorators(self): 443 def no_op_decorator(f): 444 return f 445 446 @no_op_decorator 447 @no_op_decorator 448 @no_op_decorator 449 @no_op_decorator 450 @no_op_decorator 451 def foo(x, dim: int): 452 return x.unsqueeze(dim) 453 454 x = torch.randn( 455 1, 456 ) 457 expected = foo(x, 0) 458 scripted = torch.jit.script(foo) 459 actual = scripted(x, 0) 460 torch.testing.assert_close(expected, actual) 461 462 @unittest.skipIf(not RUN_CUDA_HALF, "need CUDA half support") 463 def test_pow_multiple_dtype(self): 464 # https://github.com/pytorch/pytorch/issues/75476 465 def fn(p: torch.Tensor, gamma: float = 2.0) -> torch.Tensor: 466 p = torch.sigmoid(p) 467 result = p**gamma 468 return result 469 470 x = torch.rand((2, 2), dtype=torch.half, device="cuda") 471 472 ref = fn(x) 473 474 script_fn = torch.jit.script(fn) 475 for i in range(4): 476 res = script_fn(x) 477 478 self.assertEqual(ref, res) 479 480 def test_jit_get_operation_order(self): 481 # See https://github.com/pytorch/pytorch/pull/107138. 482 # Depending on order of operator registration, you can get different 483 # order of overloads in the JIT operator registry. 484 # This is to verify that the order of operators returned by 485 # _jit_get_operation always puts aten ops first (i.e. by sorting 486 # to put them first) 487 488 # Make sure that this chooses a "scalar" overload not a "complex" overload 489 ret = torch.ops.aten.add(4, 3.3) 490 self.assertFalse("complex" in str(ret.dtype)) 491 492 # "Scalar" overload is a normal aten op; "complex" is added by torchscript. 493 # We want "Scalar" to come before "complex". 494 op, override_names = torch._C._jit_get_operation("aten::add") 495 print(override_names) 496 complex_indices = [ 497 i for i, name in enumerate(override_names) if name == "complex" 498 ] 499 Scalar_indices = [ 500 i for i, name in enumerate(override_names) if name == "Scalar" 501 ] 502 503 self.assertTrue(len(complex_indices) > 0) 504 self.assertTrue(len(Scalar_indices) > 0) 505 self.assertTrue(complex_indices[0] > Scalar_indices[0]) 506