xref: /aosp_15_r20/external/pytorch/test/jit/test_exception.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: jit"]
2import torch
3from torch import nn
4from torch.testing._internal.common_utils import TestCase
5
6
7r"""
8Test TorchScript exception handling.
9"""
10
11
12class TestException(TestCase):
13    def test_pyop_exception_message(self):
14        class Foo(torch.jit.ScriptModule):
15            def __init__(self) -> None:
16                super().__init__()
17                self.conv = nn.Conv2d(1, 10, kernel_size=5)
18
19            @torch.jit.script_method
20            def forward(self, x):
21                return self.conv(x)
22
23        foo = Foo()
24        # testing that the correct error message propagates
25        with self.assertRaisesRegex(
26            RuntimeError, r"Expected 3D \(unbatched\) or 4D \(batched\) input to conv2d"
27        ):
28            foo(torch.ones([123]))  # wrong size
29
30    def test_builtin_error_messsage(self):
31        with self.assertRaisesRegex(RuntimeError, "Arguments for call are not valid"):
32
33            @torch.jit.script
34            def close_match(x):
35                return x.masked_fill(True)
36
37        with self.assertRaisesRegex(
38            RuntimeError,
39            "This op may not exist or may not be currently " "supported in TorchScript",
40        ):
41
42            @torch.jit.script
43            def unknown_op(x):
44                torch.set_anomaly_enabled(True)
45                return x
46
47    def test_exceptions(self):
48        cu = torch.jit.CompilationUnit(
49            """
50            def foo(cond):
51                if bool(cond):
52                    raise ValueError(3)
53                return 1
54        """
55        )
56
57        cu.foo(torch.tensor(0))
58        with self.assertRaisesRegex(torch.jit.Error, "3"):
59            cu.foo(torch.tensor(1))
60
61        def foo(cond):
62            a = 3
63            if bool(cond):
64                raise ArbitraryError(a, "hi")  # noqa: F821
65                if 1 == 2:
66                    raise ArbitraryError  # noqa: F821
67            return a
68
69        with self.assertRaisesRegex(RuntimeError, "undefined value ArbitraryError"):
70            torch.jit.script(foo)
71
72        def exception_as_value():
73            a = Exception()
74            print(a)
75
76        with self.assertRaisesRegex(RuntimeError, "cannot be used as a value"):
77            torch.jit.script(exception_as_value)
78
79        @torch.jit.script
80        def foo_no_decl_always_throws():
81            raise RuntimeError("Hi")
82
83        # function that has no declared type but always throws set to None
84        output_type = next(foo_no_decl_always_throws.graph.outputs()).type()
85        self.assertTrue(str(output_type) == "NoneType")
86
87        @torch.jit.script
88        def foo_decl_always_throws():
89            # type: () -> Tensor
90            raise Exception("Hi")  # noqa: TRY002
91
92        output_type = next(foo_decl_always_throws.graph.outputs()).type()
93        self.assertTrue(str(output_type) == "Tensor")
94
95        def foo():
96            raise 3 + 4
97
98        with self.assertRaisesRegex(RuntimeError, "must derive from BaseException"):
99            torch.jit.script(foo)
100
101        # a escapes scope
102        @torch.jit.script
103        def foo():
104            if 1 == 1:
105                a = 1
106            else:
107                if 1 == 1:
108                    raise Exception("Hi")  # noqa: TRY002
109                else:
110                    raise Exception("Hi")  # noqa: TRY002
111            return a
112
113        self.assertEqual(foo(), 1)
114
115        @torch.jit.script
116        def tuple_fn():
117            raise RuntimeError("hello", "goodbye")
118
119        with self.assertRaisesRegex(torch.jit.Error, "hello, goodbye"):
120            tuple_fn()
121
122        @torch.jit.script
123        def no_message():
124            raise RuntimeError
125
126        with self.assertRaisesRegex(torch.jit.Error, "RuntimeError"):
127            no_message()
128
129    def test_assertions(self):
130        cu = torch.jit.CompilationUnit(
131            """
132            def foo(cond):
133                assert bool(cond), "hi"
134                return 0
135        """
136        )
137
138        cu.foo(torch.tensor(1))
139        with self.assertRaisesRegex(torch.jit.Error, "AssertionError: hi"):
140            cu.foo(torch.tensor(0))
141
142        @torch.jit.script
143        def foo(cond):
144            assert bool(cond), "hi"
145
146        foo(torch.tensor(1))
147        # we don't currently validate the name of the exception
148        with self.assertRaisesRegex(torch.jit.Error, "AssertionError: hi"):
149            foo(torch.tensor(0))
150
151    def test_python_op_exception(self):
152        @torch.jit.ignore
153        def python_op(x):
154            raise Exception("bad!")  # noqa: TRY002
155
156        @torch.jit.script
157        def fn(x):
158            return python_op(x)
159
160        with self.assertRaisesRegex(
161            RuntimeError, "operation failed in the TorchScript interpreter"
162        ):
163            fn(torch.tensor(4))
164
165    def test_dict_expansion_raises_error(self):
166        def fn(self):
167            d = {"foo": 1, "bar": 2, "baz": 3}
168            return {**d}
169
170        with self.assertRaisesRegex(
171            torch.jit.frontend.NotSupportedError, "Dict expansion "
172        ):
173            torch.jit.script(fn)
174
175    def test_custom_python_exception(self):
176        class MyValueError(ValueError):
177            pass
178
179        @torch.jit.script
180        def fn():
181            raise MyValueError("test custom exception")
182
183        with self.assertRaisesRegex(
184            torch.jit.Error, "jit.test_exception.MyValueError: test custom exception"
185        ):
186            fn()
187
188    def test_custom_python_exception_defined_elsewhere(self):
189        from jit.myexception import MyKeyError
190
191        @torch.jit.script
192        def fn():
193            raise MyKeyError("This is a user defined key error")
194
195        with self.assertRaisesRegex(
196            torch.jit.Error,
197            "jit.myexception.MyKeyError: This is a user defined key error",
198        ):
199            fn()
200