1# Owner(s): ["oncall: jit"] 2 3import os 4import random 5import sys 6import tempfile 7from textwrap import dedent 8 9import torch 10from torch.testing._internal.jit_utils import execWrapper, JitTestCase 11 12 13# Make the helper files in test/ importable 14pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) 15sys.path.append(pytorch_test_dir) 16 17if __name__ == "__main__": 18 raise RuntimeError( 19 "This test file is not meant to be run directly, use:\n\n" 20 "\tpython test/test_jit.py TESTNAME\n\n" 21 "instead." 22 ) 23 24 25def get_fn(file_name, script_path): 26 import importlib.util 27 28 spec = importlib.util.spec_from_file_location(file_name, script_path) 29 module = importlib.util.module_from_spec(spec) 30 spec.loader.exec_module(module) 31 fn = module.fn 32 return fn 33 34 35class TestPythonBuiltinOP(JitTestCase): 36 def test_add(self): 37 def func(a, b): 38 c = a + b 39 c += a 40 return c 41 42 a = torch.rand(1, requires_grad=True) 43 b = torch.rand(1, requires_grad=True) 44 self.checkScript(func, (a, b), optimize=True) 45 46 def test_mul(self): 47 def func(a, b): 48 return a * b 49 50 a = torch.rand(1, requires_grad=True) 51 b = torch.rand(1, requires_grad=True) 52 self.checkScript(func, (a, b), optimize=True) 53 54 def test_matmul_py3(self): 55 code = dedent( 56 """ 57 def fn(a, b): 58 return a @ b 59 """ 60 ) 61 62 with tempfile.TemporaryDirectory() as tmp_dir: 63 script_path = os.path.join(tmp_dir, "script.py") 64 with open(script_path, "w") as f: 65 f.write(code) 66 fn = get_fn("test_matmul_py3", script_path) 67 68 a = torch.rand(4, 3, requires_grad=True) 69 b = torch.rand(3, 2, requires_grad=True) 70 self.checkScript(fn, (a, b), optimize=True) 71 72 def test_pow(self): 73 def func(a, b): 74 return a**b 75 76 def func2(a, b, c, d): 77 return c + a**b**d 78 79 def func3(a, b): 80 # type: (int, float) -> float 81 return a**b 82 83 def func4(): 84 # type: () -> float 85 return 2**-2 86 87 def func5(x, y): 88 return x.item() ** y.item() 89 90 a = torch.rand(1, requires_grad=True) 91 b = torch.rand(1, requires_grad=True) 92 c = torch.rand(1, requires_grad=True) 93 d = torch.rand(1, requires_grad=True) 94 self.checkScript(func, (a, b), optimize=True) 95 self.checkScript(func2, (a, b, c, d), optimize=True) 96 self.checkScript(func3, (4, -0.5), optimize=True) 97 self.checkScript(func4, ()) 98 99 inputs = [ 100 torch.tensor(2), 101 torch.tensor(-2), 102 torch.tensor(0.5), 103 torch.tensor(0.2), 104 ] 105 for x in inputs: 106 for y in inputs: 107 if x < 0: 108 continue 109 else: 110 self.checkScript(func5, (x, y)) 111 112 def test_triple(self): 113 def func(x): 114 return 3.0 * x 115 116 x = torch.rand(1, dtype=torch.float, requires_grad=True) 117 self.checkScript(func, [x], optimize=True) 118 119 def test_slice(self): 120 def func(x): 121 return x[:5] 122 123 x = torch.rand(10, dtype=torch.float, requires_grad=True) 124 self.checkScript(func, [x], optimize=True) 125 126 def func2(x): 127 return x[5:] 128 129 self.checkScript(func2, [x], optimize=True) 130 131 def func3(x): 132 return x[:8:2] 133 134 self.checkScript(func3, [x], optimize=True) 135 136 def func4(x): 137 return x[1::4] 138 139 self.checkScript(func4, [x], optimize=True) 140 141 def test_gather(self): 142 def func(x): 143 return x[0] 144 145 x = torch.rand(10, dtype=torch.float, requires_grad=True) 146 self.checkScript(func, [x], optimize=True) 147 148 def test_random(self): 149 @torch.jit.script 150 def f(mean, std): 151 return torch.normal(mean, std) 152 153 mean, std = torch.zeros(5, 5), torch.ones(5, 5) 154 with torch.random.fork_rng(devices=[]): 155 output = torch.normal(mean, std) 156 with torch.random.fork_rng(devices=[]): 157 script_output = f(mean, std) 158 self.assertEqual(output, script_output) 159 160 def _check_code(self, code_str, fn_name, inputs): 161 scope = {} 162 exec(code_str, globals(), scope) 163 cu = torch.jit.CompilationUnit(code_str) 164 self.assertEqual(cu.func(*inputs), scope[fn_name](*inputs)) 165 166 def test_stepped_tuple_slicing(self): 167 def check_slicing_tuple(slicing, tuple_type, tuple): 168 template = dedent( 169 """ 170 def func(x): 171 # type: ({}) -> Any 172 return x{} 173 """ 174 ) 175 self._check_code(template.format(tuple_type, slicing), "func", [tuple]) 176 177 check_slicing_tuple("[-3:3:2]", "Tuple[int, int, int]", (0, 1, 2)) 178 check_slicing_tuple("[::55]", "Tuple[int, int, int, int, int]", (0, 1, 2, 3, 4)) 179 check_slicing_tuple("[:4:4]", "Tuple[int, int, int, int, int]", (0, 1, 2, 3, 4)) 180 check_slicing_tuple( 181 "[::-1]", "Tuple[int, int, int, int, int, int, int]", (0, 1, 2, 3, 4, 5, 6) 182 ) 183 check_slicing_tuple( 184 "[7:5:2]", "Tuple[int, int, int, int, int, int, int]", (0, 1, 2, 3, 4, 5, 6) 185 ) 186 check_slicing_tuple( 187 "[5:7:-2]", 188 "Tuple[int, int, int, int, int, int, int]", 189 (0, 1, 2, 3, 4, 5, 6), 190 ) 191 check_slicing_tuple("[::-2]", "Tuple[int, int, int, int, int]", (0, 1, 2, 3, 4)) 192 check_slicing_tuple( 193 "[:4:-3]", "Tuple[int, int, int, int, int, int]", (0, 1, 2, 3, 4, 5) 194 ) 195 check_slicing_tuple( 196 "[3::-2]", "Tuple[int, int, int, int, int]", (0, 1, 2, 3, 4) 197 ) 198 199 def test_index(self): 200 def consec(size, start=0): 201 numel = torch.tensor(size).prod().item() 202 return torch.arange(numel).view(size) 203 204 def check_indexing(indexing, tensor): 205 template = dedent( 206 """ 207 def func(x): 208 return x{} 209 """ 210 ) 211 212 self._check_code(template.format(indexing), "func", [tensor]) 213 214 def check_dynamic_indexing(indexing, tensor, value1, value2): 215 value1 = torch.tensor(value1) 216 value2 = torch.tensor(value2) 217 218 template = dedent( 219 """ 220 def func(x, value1, value2): 221 i = int(value1) 222 j = int(value2) 223 return x{} 224 """ 225 ) 226 227 self._check_code( 228 template.format(indexing), "func", [tensor, value1, value2] 229 ) 230 231 # basic slices 232 check_indexing("[0]", consec((3, 3))) 233 check_indexing("[1]", consec((3, 3), 10)) 234 check_indexing("[2]", consec((3, 3), 19)) 235 check_indexing("[2]", consec((3,))) 236 check_indexing("[-1]", consec((3, 3), 19)) 237 check_indexing("[0:2]", consec((3, 3, 3))) 238 check_indexing("[1:-1]", consec((3, 3, 3))) 239 check_indexing("[-3:-1]", consec((6, 3))) 240 check_indexing("[1:]", consec((3, 3))) 241 check_indexing("[:1]", consec((3, 3))) 242 check_indexing("[:]", consec((3, 2))) 243 244 # multi-dim: indexes 245 check_indexing("[0, 1]", consec((3, 3))) 246 check_indexing("[0, 1]", consec((3, 3, 2))) 247 check_indexing("[1, 0, 2]", consec((3, 3, 3))) 248 check_indexing("[2, -1]", consec((3, 3))) 249 250 # multi-dim: mixed slicing and indexing 251 check_indexing("[0, 1:2]", consec((3, 3))) 252 check_indexing("[0, :1]", consec((3, 3, 2))) 253 check_indexing("[1, 2:]", consec((3, 3, 3))) 254 check_indexing("[-1, 1:, 0]", consec((3, 3, 3, 3))) 255 check_indexing("[1:, -1, 0]", consec((3, 3, 3, 3))) 256 check_indexing("[-1, 2:, 1:2]", consec((3, 3, 3, 3))) 257 check_indexing("[-1, 1:, 0]", consec((3, 3, 3, 3))) 258 check_indexing("[-1, :, 0, 2]", consec((3, 3, 3, 3))) 259 260 # zero-sized slices 261 check_indexing("[0:0]", consec((2, 2))) 262 check_indexing("[0:0, 1]", consec((3, 3))) 263 264 # trivial expression usage 265 check_indexing("[1+1]", consec((3, 3))) 266 check_indexing("[1:(0 + 2)]", consec((3, 3, 3))) 267 268 # None for new dimensions 269 check_indexing("[None, 0]", consec((3, 3))) 270 check_indexing("[1, None]", consec((3, 3), 10)) 271 check_indexing("[None, None, 2]", consec((3, 3), 19)) 272 check_indexing("[None, 2, None]", consec((3,))) 273 check_indexing("[0:2, None]", consec((3, 3, 3))) 274 check_indexing("[None, 1:-1]", consec((3, 3, 3))) 275 check_indexing("[None, -3:-1, None]", consec((6, 3))) 276 check_indexing("[-1, None, 2:, None, 1:2]", consec((3, 3, 3, 3))) 277 check_indexing("[None, -1, None, 2:, None, 1:2, None]", consec((3, 3, 3, 3))) 278 279 # dynamic expression usage 280 check_dynamic_indexing("[i + j]", consec((3, 3)), 0, 1) 281 check_dynamic_indexing("[i:j, i]", consec((3, 3, 2)), 0, 2) 282 283 def test_advancedindex(self): 284 def consec(size, start=0): 285 numel = torch.tensor(size).prod().item() 286 return torch.arange(numel).view(size) 287 288 def check_indexing(indexing, tensor, **kwargs): 289 indices_dict = kwargs 290 291 template = dedent( 292 """ 293 def func(x{formals}): 294 return x{expr} 295 """ 296 ) 297 298 formals = [] 299 values = [] 300 for formal, value in indices_dict.items(): 301 formals.append(formal) 302 values.append(value) 303 304 formals = "".join(map(", {}".format, formals)) 305 inputs = [tensor] + values 306 self._check_code( 307 template.format(formals=formals, expr=indexing), "func", inputs 308 ) 309 310 # Indexing with tensor (basic) 311 check_indexing("[i]", consec((3, 3)), i=torch.tensor([0])) 312 check_indexing("[i]", consec((3, 3)), i=torch.tensor(1)) 313 check_indexing("[i]", consec((3, 3)), i=torch.tensor([-2])) 314 check_indexing("[i]", consec((3, 3), 2), i=torch.tensor([0, 0])) 315 check_indexing("[i]", consec((3, 3, 2, 2)), i=torch.tensor([0, -2, 1])) 316 317 # NB: indexing with tensors and indexing with sequences can be implemented 318 # in a very similar way (sequences are converted to tensors), so only one 319 # case needs to be tested extensively. 320 # XXX: When we can index with sequences, replace these cases with 321 # sequence indexing expressions; those are much easier to read. 322 323 # Misc sequence advanced indexing 324 inp = consec((4, 8, 5)) 325 to_check = [ 326 # [[0, 1, 3]] 327 ["[i]", {"i": [0, 1, 3]}], 328 # [[0, 2], [1, 3]] 329 ["[i, j]", {"i": [0, 2], "j": [1, 3]}], 330 # [[[0, 1], [0, 1]], [[0, 1], [0, 1]]] 331 ["[i, j]", {"i": [[0, 1], [0, 1]], "j": [[0, 1], [0, 1]]}], 332 # [[0, 2], [1, 3], [1, 1]] 333 ["[i, j, k]", {"i": [0, 2], "j": [1, 3], "k": [1, 1]}], 334 # [[0, 2], 1, [1, 1]] 335 ["[i, j, k]", {"i": [0, 2], "j": 1, "k": [1, 1]}], 336 # [:, :, [0, 3, 4]] 337 ["[:, :, i]", {"i": [0, 3, 4]}], 338 # [:, [2, 4, 5, 7], 2:4] 339 ["[:, i, 2:4]", {"i": [0, 2, 3]}], 340 # [[2, 3], :, :] 341 ["[i, :, :]", {"i": [2, 3]}], 342 # [:, [0, 2, 3], [1, 3, 4]] 343 ["[:, i, j]", {"i": [0, 2, 3], "j": [1, 3, 4]}], 344 # [:, [0], [1, 2, 4]] 345 ["[:, i, j]", {"i": [0], "j": [1, 2, 4]}], 346 # [:, [0, 1, 3], [4]] 347 ["[:, i, j]", {"i": [0, 1, 3], "j": [4]}], 348 # [:, [[0, 1], [1, 0]], [[2, 3]]] 349 ["[:, i, j]", {"i": [[0, 1], [1, 0]], "j": [[2, 3]]}], 350 # [:, [[0, 1], [2, 3]], [[0]]] 351 ["[:, i, j]", {"i": [[0, 1], [2, 3]], "j": [[0]]}], 352 # [:, [[5, 6]], [[0, 3], [4, 4]]] 353 ["[:, i, j]", {"i": [[5, 6]], "j": [[0, 3], [4, 4]]}], 354 # [[0, 2, 3], [1, 3, 4], :] 355 ["[i, j, :]", {"i": [0, 2, 3], "j": [1, 3, 4]}], 356 # [0, [1, 2, 4], :] 357 ["[i, j, :]", {"i": 0, "j": [1, 2, 4]}], 358 # [[0, 1, 3], 4, :] 359 ["[i, j, :]", {"i": [0, 1, 3], "j": 4}], 360 # [[[0, 1], [1, 0]], [[2, 1], [3, 5]], :] 361 ["[i, j, :]", {"i": [[0, 1], [1, 0]], "j": [[2, 1], [3, 5]]}], 362 # [[[0, 1], [1, 0]], [[2, 3]], :] 363 ["[i, j, :]", {"i": [[0, 1], [1, 0]], "j": [[2, 3]]}], 364 # [[[0, 1], [2, 3]], [[0]], :] 365 ["[i, j, :]", {"i": [[0, 1], [2, 3]], "j": [[0]]}], 366 # [[[2, 1]], [[0, 3], [4, 4]], :] 367 ["[i, j, :]", {"i": [[2, 1]], "j": [[0, 3], [4, 4]]}], 368 # [[[2]], [[0, 3], [4, 1]], 0:2] 369 ["[i, j, 0:2]", {"i": [[2]], "j": [[0, 3], [4, 1]]}], 370 ] 371 372 for expr, argdict in to_check: 373 tensordict = {k: torch.tensor(v) for (k, v) in argdict.items()} 374 check_indexing(expr, inp, **tensordict) 375 376 def test_adv_indexing_list(self): 377 # indexing with list is equivalent to indexing with tensor 378 def func1(x): 379 return x[[0, 1, 5]] 380 381 def func2(x): 382 return x[[0, 1], [0, 1]] 383 384 def func3(x): 385 return x[[[0, 1], [0, 1]], [[0, 1], [0, 1]]] 386 387 def func4(x): 388 ls = [0] 389 ls.append(1) 390 ls.append(2) 391 return x[ls] 392 393 def func5(x): 394 ls = [0.1, 1.2, 2.3] 395 return x[ls] 396 397 input = torch.rand((6, 2)) 398 self.checkScript(func1, (input,)) 399 self.checkScript(func2, (input,)) 400 self.checkScript(func3, (input,)) 401 self.checkScript(func4, (input,)) 402 self.checkScript(func5, (input,)) 403 404 def test_index_ellipses(self): 405 vals = [":", 1, None] 406 for _ in range(100): 407 indices = [random.choice(vals) for _ in range(4)] 408 indices[random.randint(0, len(indices) - 1)] = "..." 409 test_str = dedent( 410 """ 411 def f(): 412 x = torch.ones(10, 9, 8, 7, 6) 413 return x{indices}.shape 414 """.format( 415 indices=indices 416 ) 417 ) 418 test_str = test_str.replace(r"'", r"") 419 scope = {} 420 execWrapper(test_str, globals(), scope) 421 cu = torch.jit.CompilationUnit(test_str) 422 res1 = cu.f() 423 res2 = scope["f"]() 424 self.assertEqual(res1, res2) 425 426 def test_inf(self): 427 @torch.jit.script 428 def foo(a): 429 return a < float("inf") 430 431 s = torch.rand(1) 432 self.assertTrue(foo(s)) 433 434 @torch.jit.script 435 def bar(a): 436 return a > float("-inf") 437 438 s = torch.rand(1) 439 self.assertTrue(foo(s)) 440 441 # test re-assignment on imported source 442 str = """ 443 def foo(x): 444 # type: (bool) 445 a = float("-inf") 446 if not x: 447 a = float(torch.tensor([5])) 448 return a < 4 449 """ 450 cu = torch.jit.CompilationUnit(str) 451 self.assertTrue(cu.foo(True)) 452 self.assertFalse(cu.foo(False)) 453 454 def test_str_to_float(self): 455 @torch.jit.script 456 def foo(a): 457 return 0.5 == float("0.5 hello") 458 459 s = torch.rand(1) 460 with self.assertRaisesRegex(RuntimeError, "could not convert string to float"): 461 self.assertTrue(foo(s)) 462 463 @torch.jit.script 464 def foo(a): 465 return 0.5 == float("0.5") 466 467 s = torch.rand(1) 468 self.assertTrue(foo(s)) 469 470 @torch.jit.script 471 def foo(a): 472 return 0.0 == float("0") 473 474 s = torch.rand(1) 475 self.assertTrue(foo(s)) 476