1# Owner(s): ["oncall: jit"] 2import torch 3from torch import nn 4from torch.testing._internal.common_utils import TestCase 5 6 7r""" 8Test TorchScript exception handling. 9""" 10 11 12class TestException(TestCase): 13 def test_pyop_exception_message(self): 14 class Foo(torch.jit.ScriptModule): 15 def __init__(self) -> None: 16 super().__init__() 17 self.conv = nn.Conv2d(1, 10, kernel_size=5) 18 19 @torch.jit.script_method 20 def forward(self, x): 21 return self.conv(x) 22 23 foo = Foo() 24 # testing that the correct error message propagates 25 with self.assertRaisesRegex( 26 RuntimeError, r"Expected 3D \(unbatched\) or 4D \(batched\) input to conv2d" 27 ): 28 foo(torch.ones([123])) # wrong size 29 30 def test_builtin_error_messsage(self): 31 with self.assertRaisesRegex(RuntimeError, "Arguments for call are not valid"): 32 33 @torch.jit.script 34 def close_match(x): 35 return x.masked_fill(True) 36 37 with self.assertRaisesRegex( 38 RuntimeError, 39 "This op may not exist or may not be currently " "supported in TorchScript", 40 ): 41 42 @torch.jit.script 43 def unknown_op(x): 44 torch.set_anomaly_enabled(True) 45 return x 46 47 def test_exceptions(self): 48 cu = torch.jit.CompilationUnit( 49 """ 50 def foo(cond): 51 if bool(cond): 52 raise ValueError(3) 53 return 1 54 """ 55 ) 56 57 cu.foo(torch.tensor(0)) 58 with self.assertRaisesRegex(torch.jit.Error, "3"): 59 cu.foo(torch.tensor(1)) 60 61 def foo(cond): 62 a = 3 63 if bool(cond): 64 raise ArbitraryError(a, "hi") # noqa: F821 65 if 1 == 2: 66 raise ArbitraryError # noqa: F821 67 return a 68 69 with self.assertRaisesRegex(RuntimeError, "undefined value ArbitraryError"): 70 torch.jit.script(foo) 71 72 def exception_as_value(): 73 a = Exception() 74 print(a) 75 76 with self.assertRaisesRegex(RuntimeError, "cannot be used as a value"): 77 torch.jit.script(exception_as_value) 78 79 @torch.jit.script 80 def foo_no_decl_always_throws(): 81 raise RuntimeError("Hi") 82 83 # function that has no declared type but always throws set to None 84 output_type = next(foo_no_decl_always_throws.graph.outputs()).type() 85 self.assertTrue(str(output_type) == "NoneType") 86 87 @torch.jit.script 88 def foo_decl_always_throws(): 89 # type: () -> Tensor 90 raise Exception("Hi") # noqa: TRY002 91 92 output_type = next(foo_decl_always_throws.graph.outputs()).type() 93 self.assertTrue(str(output_type) == "Tensor") 94 95 def foo(): 96 raise 3 + 4 97 98 with self.assertRaisesRegex(RuntimeError, "must derive from BaseException"): 99 torch.jit.script(foo) 100 101 # a escapes scope 102 @torch.jit.script 103 def foo(): 104 if 1 == 1: 105 a = 1 106 else: 107 if 1 == 1: 108 raise Exception("Hi") # noqa: TRY002 109 else: 110 raise Exception("Hi") # noqa: TRY002 111 return a 112 113 self.assertEqual(foo(), 1) 114 115 @torch.jit.script 116 def tuple_fn(): 117 raise RuntimeError("hello", "goodbye") 118 119 with self.assertRaisesRegex(torch.jit.Error, "hello, goodbye"): 120 tuple_fn() 121 122 @torch.jit.script 123 def no_message(): 124 raise RuntimeError 125 126 with self.assertRaisesRegex(torch.jit.Error, "RuntimeError"): 127 no_message() 128 129 def test_assertions(self): 130 cu = torch.jit.CompilationUnit( 131 """ 132 def foo(cond): 133 assert bool(cond), "hi" 134 return 0 135 """ 136 ) 137 138 cu.foo(torch.tensor(1)) 139 with self.assertRaisesRegex(torch.jit.Error, "AssertionError: hi"): 140 cu.foo(torch.tensor(0)) 141 142 @torch.jit.script 143 def foo(cond): 144 assert bool(cond), "hi" 145 146 foo(torch.tensor(1)) 147 # we don't currently validate the name of the exception 148 with self.assertRaisesRegex(torch.jit.Error, "AssertionError: hi"): 149 foo(torch.tensor(0)) 150 151 def test_python_op_exception(self): 152 @torch.jit.ignore 153 def python_op(x): 154 raise Exception("bad!") # noqa: TRY002 155 156 @torch.jit.script 157 def fn(x): 158 return python_op(x) 159 160 with self.assertRaisesRegex( 161 RuntimeError, "operation failed in the TorchScript interpreter" 162 ): 163 fn(torch.tensor(4)) 164 165 def test_dict_expansion_raises_error(self): 166 def fn(self): 167 d = {"foo": 1, "bar": 2, "baz": 3} 168 return {**d} 169 170 with self.assertRaisesRegex( 171 torch.jit.frontend.NotSupportedError, "Dict expansion " 172 ): 173 torch.jit.script(fn) 174 175 def test_custom_python_exception(self): 176 class MyValueError(ValueError): 177 pass 178 179 @torch.jit.script 180 def fn(): 181 raise MyValueError("test custom exception") 182 183 with self.assertRaisesRegex( 184 torch.jit.Error, "jit.test_exception.MyValueError: test custom exception" 185 ): 186 fn() 187 188 def test_custom_python_exception_defined_elsewhere(self): 189 from jit.myexception import MyKeyError 190 191 @torch.jit.script 192 def fn(): 193 raise MyKeyError("This is a user defined key error") 194 195 with self.assertRaisesRegex( 196 torch.jit.Error, 197 "jit.myexception.MyKeyError: This is a user defined key error", 198 ): 199 fn() 200