1# Owner(s): ["oncall: jit"] 2 3import os 4import sys 5 6import torch 7 8 9# Make the helper files in test/ importable 10pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) 11sys.path.append(pytorch_test_dir) 12from torch.testing import FileCheck 13from torch.testing._internal.jit_utils import JitTestCase 14 15 16if __name__ == "__main__": 17 raise RuntimeError( 18 "This test file is not meant to be run directly, use:\n\n" 19 "\tpython test/test_jit.py TESTNAME\n\n" 20 "instead." 21 ) 22 23 24class TestTensorMethods(JitTestCase): 25 def test_getitem(self): 26 def tensor_getitem(inp: torch.Tensor): 27 indices = torch.tensor([0, 2], dtype=torch.long) 28 return inp.__getitem__(indices) 29 30 inp = torch.rand(3, 4) 31 self.checkScript(tensor_getitem, (inp,)) 32 33 scripted = torch.jit.script(tensor_getitem) 34 FileCheck().check("aten::index").run(scripted.graph) 35 36 def test_getitem_invalid(self): 37 def tensor_getitem_invalid(inp: torch.Tensor): 38 return inp.__getitem__() 39 40 with self.assertRaisesRegexWithHighlight( 41 RuntimeError, "expected exactly 1 argument", "inp.__getitem__" 42 ): 43 torch.jit.script(tensor_getitem_invalid) 44