xref: /aosp_15_r20/external/pytorch/test/jit/test_tensor_methods.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: jit"]
2
3import os
4import sys
5
6import torch
7
8
9# Make the helper files in test/ importable
10pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
11sys.path.append(pytorch_test_dir)
12from torch.testing import FileCheck
13from torch.testing._internal.jit_utils import JitTestCase
14
15
16if __name__ == "__main__":
17    raise RuntimeError(
18        "This test file is not meant to be run directly, use:\n\n"
19        "\tpython test/test_jit.py TESTNAME\n\n"
20        "instead."
21    )
22
23
24class TestTensorMethods(JitTestCase):
25    def test_getitem(self):
26        def tensor_getitem(inp: torch.Tensor):
27            indices = torch.tensor([0, 2], dtype=torch.long)
28            return inp.__getitem__(indices)
29
30        inp = torch.rand(3, 4)
31        self.checkScript(tensor_getitem, (inp,))
32
33        scripted = torch.jit.script(tensor_getitem)
34        FileCheck().check("aten::index").run(scripted.graph)
35
36    def test_getitem_invalid(self):
37        def tensor_getitem_invalid(inp: torch.Tensor):
38            return inp.__getitem__()
39
40        with self.assertRaisesRegexWithHighlight(
41            RuntimeError, "expected exactly 1 argument", "inp.__getitem__"
42        ):
43            torch.jit.script(tensor_getitem_invalid)
44