1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: dynamo"] 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Workerimport torch 4*da0073e9SAndroid Build Coastguard Workerimport torch._dynamo.config 5*da0073e9SAndroid Build Coastguard Workerimport torch._dynamo.test_case 6*da0073e9SAndroid Build Coastguard Workerimport torch._functorch.config 7*da0073e9SAndroid Build Coastguard Workerimport torch.nn 8*da0073e9SAndroid Build Coastguard Workerimport torch.utils.checkpoint 9*da0073e9SAndroid Build Coastguard Worker 10*da0073e9SAndroid Build Coastguard Worker 11*da0073e9SAndroid Build Coastguard Workerclass ExceptionTests(torch._dynamo.test_case.TestCase): 12*da0073e9SAndroid Build Coastguard Worker def test_exception(self): 13*da0073e9SAndroid Build Coastguard Worker def fn(x): 14*da0073e9SAndroid Build Coastguard Worker x = torch.cos(x) 15*da0073e9SAndroid Build Coastguard Worker try: 16*da0073e9SAndroid Build Coastguard Worker x = torch.sin(x) 17*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError 18*da0073e9SAndroid Build Coastguard Worker except Exception: 19*da0073e9SAndroid Build Coastguard Worker x = torch.sigmoid(x) 20*da0073e9SAndroid Build Coastguard Worker 21*da0073e9SAndroid Build Coastguard Worker return x 22*da0073e9SAndroid Build Coastguard Worker 23*da0073e9SAndroid Build Coastguard Worker x = torch.randn(4) 24*da0073e9SAndroid Build Coastguard Worker ref = fn(x) 25*da0073e9SAndroid Build Coastguard Worker opt_fn = torch.compile(fn, backend="eager", fullgraph=True) 26*da0073e9SAndroid Build Coastguard Worker res = opt_fn(x) 27*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ref, res) 28*da0073e9SAndroid Build Coastguard Worker 29*da0073e9SAndroid Build Coastguard Worker def test_exception2(self): 30*da0073e9SAndroid Build Coastguard Worker def fn(x): 31*da0073e9SAndroid Build Coastguard Worker x = torch.cos(x) 32*da0073e9SAndroid Build Coastguard Worker try: 33*da0073e9SAndroid Build Coastguard Worker x = torch.sin(x) 34*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError 35*da0073e9SAndroid Build Coastguard Worker except (NotImplementedError, AttributeError) as e: 36*da0073e9SAndroid Build Coastguard Worker x = torch.sigmoid(x) 37*da0073e9SAndroid Build Coastguard Worker 38*da0073e9SAndroid Build Coastguard Worker return x 39*da0073e9SAndroid Build Coastguard Worker 40*da0073e9SAndroid Build Coastguard Worker x = torch.randn(4) 41*da0073e9SAndroid Build Coastguard Worker ref = fn(x) 42*da0073e9SAndroid Build Coastguard Worker opt_fn = torch.compile(fn, backend="eager", fullgraph=True) 43*da0073e9SAndroid Build Coastguard Worker res = opt_fn(x) 44*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ref, res) 45*da0073e9SAndroid Build Coastguard Worker 46*da0073e9SAndroid Build Coastguard Worker def test_exception3(self): 47*da0073e9SAndroid Build Coastguard Worker def fn(x): 48*da0073e9SAndroid Build Coastguard Worker x = torch.cos(x) 49*da0073e9SAndroid Build Coastguard Worker try: 50*da0073e9SAndroid Build Coastguard Worker x = torch.sin(x) 51*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError("Not implemented") 52*da0073e9SAndroid Build Coastguard Worker except AssertionError: 53*da0073e9SAndroid Build Coastguard Worker x = torch.sigmoid(x) 54*da0073e9SAndroid Build Coastguard Worker except NotImplementedError: 55*da0073e9SAndroid Build Coastguard Worker x = torch.cos(x) 56*da0073e9SAndroid Build Coastguard Worker finally: 57*da0073e9SAndroid Build Coastguard Worker x = torch.cos(x) 58*da0073e9SAndroid Build Coastguard Worker 59*da0073e9SAndroid Build Coastguard Worker return x 60*da0073e9SAndroid Build Coastguard Worker 61*da0073e9SAndroid Build Coastguard Worker x = torch.randn(4) 62*da0073e9SAndroid Build Coastguard Worker ref = fn(x) 63*da0073e9SAndroid Build Coastguard Worker opt_fn = torch.compile(fn, backend="eager", fullgraph=True) 64*da0073e9SAndroid Build Coastguard Worker res = opt_fn(x) 65*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ref, res) 66*da0073e9SAndroid Build Coastguard Worker 67*da0073e9SAndroid Build Coastguard Worker def test_exception4(self): 68*da0073e9SAndroid Build Coastguard Worker def fn(x): 69*da0073e9SAndroid Build Coastguard Worker for i in range(10): 70*da0073e9SAndroid Build Coastguard Worker if i == 5: 71*da0073e9SAndroid Build Coastguard Worker return x 72*da0073e9SAndroid Build Coastguard Worker try: 73*da0073e9SAndroid Build Coastguard Worker x = torch.sin(x) 74*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError 75*da0073e9SAndroid Build Coastguard Worker except Exception: 76*da0073e9SAndroid Build Coastguard Worker x = torch.sigmoid(x) 77*da0073e9SAndroid Build Coastguard Worker 78*da0073e9SAndroid Build Coastguard Worker return x 79*da0073e9SAndroid Build Coastguard Worker 80*da0073e9SAndroid Build Coastguard Worker x = torch.randn(4) 81*da0073e9SAndroid Build Coastguard Worker ref = fn(x) 82*da0073e9SAndroid Build Coastguard Worker opt_fn = torch.compile(fn, backend="eager", fullgraph=True) 83*da0073e9SAndroid Build Coastguard Worker res = opt_fn(x) 84*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ref, res) 85*da0073e9SAndroid Build Coastguard Worker 86*da0073e9SAndroid Build Coastguard Worker def test_exception_with_another_exception(self): 87*da0073e9SAndroid Build Coastguard Worker def fn(x): 88*da0073e9SAndroid Build Coastguard Worker x = torch.cos(x) 89*da0073e9SAndroid Build Coastguard Worker try: 90*da0073e9SAndroid Build Coastguard Worker x = torch.sin(x) 91*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError("Not implemented") 92*da0073e9SAndroid Build Coastguard Worker except NotImplementedError as e: 93*da0073e9SAndroid Build Coastguard Worker x = torch.sigmoid(x) 94*da0073e9SAndroid Build Coastguard Worker try: 95*da0073e9SAndroid Build Coastguard Worker x = torch.cos(x) 96*da0073e9SAndroid Build Coastguard Worker raise AssertionError 97*da0073e9SAndroid Build Coastguard Worker except AssertionError: 98*da0073e9SAndroid Build Coastguard Worker x = torch.cos(x) 99*da0073e9SAndroid Build Coastguard Worker 100*da0073e9SAndroid Build Coastguard Worker x = torch.randn(4) 101*da0073e9SAndroid Build Coastguard Worker ref = fn(x) 102*da0073e9SAndroid Build Coastguard Worker opt_fn = torch.compile(fn, backend="eager", fullgraph=True) 103*da0073e9SAndroid Build Coastguard Worker res = opt_fn(x) 104*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ref, res) 105*da0073e9SAndroid Build Coastguard Worker 106*da0073e9SAndroid Build Coastguard Worker def test_exception_else(self): 107*da0073e9SAndroid Build Coastguard Worker def gn(x): 108*da0073e9SAndroid Build Coastguard Worker return torch.cos(x) 109*da0073e9SAndroid Build Coastguard Worker 110*da0073e9SAndroid Build Coastguard Worker def fn(x): 111*da0073e9SAndroid Build Coastguard Worker x = torch.cos(x) 112*da0073e9SAndroid Build Coastguard Worker try: 113*da0073e9SAndroid Build Coastguard Worker x = torch.sin(x) 114*da0073e9SAndroid Build Coastguard Worker x = gn(x) 115*da0073e9SAndroid Build Coastguard Worker except Exception: 116*da0073e9SAndroid Build Coastguard Worker x = torch.sigmoid(x) 117*da0073e9SAndroid Build Coastguard Worker else: 118*da0073e9SAndroid Build Coastguard Worker x = torch.cos(x) 119*da0073e9SAndroid Build Coastguard Worker 120*da0073e9SAndroid Build Coastguard Worker return x 121*da0073e9SAndroid Build Coastguard Worker 122*da0073e9SAndroid Build Coastguard Worker x = torch.randn(4) 123*da0073e9SAndroid Build Coastguard Worker ref = fn(x) 124*da0073e9SAndroid Build Coastguard Worker opt_fn = torch.compile(fn, backend="eager", fullgraph=True) 125*da0073e9SAndroid Build Coastguard Worker res = opt_fn(x) 126*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ref, res) 127*da0073e9SAndroid Build Coastguard Worker 128*da0073e9SAndroid Build Coastguard Worker # TODO(anijain2305) - does not work with fullgraph=True 129*da0073e9SAndroid Build Coastguard Worker def test_exception_with_another_exception2(self): 130*da0073e9SAndroid Build Coastguard Worker def gn(x): 131*da0073e9SAndroid Build Coastguard Worker try: 132*da0073e9SAndroid Build Coastguard Worker x = torch.cos(x) 133*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError("Not implemented") 134*da0073e9SAndroid Build Coastguard Worker except NotImplementedError as e: 135*da0073e9SAndroid Build Coastguard Worker x = torch.sigmoid(x) 136*da0073e9SAndroid Build Coastguard Worker raise 137*da0073e9SAndroid Build Coastguard Worker 138*da0073e9SAndroid Build Coastguard Worker def fn(x): 139*da0073e9SAndroid Build Coastguard Worker try: 140*da0073e9SAndroid Build Coastguard Worker x = torch.cos(x) 141*da0073e9SAndroid Build Coastguard Worker gn(x) 142*da0073e9SAndroid Build Coastguard Worker except Exception: 143*da0073e9SAndroid Build Coastguard Worker pass 144*da0073e9SAndroid Build Coastguard Worker return x 145*da0073e9SAndroid Build Coastguard Worker 146*da0073e9SAndroid Build Coastguard Worker x = torch.randn(4) 147*da0073e9SAndroid Build Coastguard Worker ref = fn(x) 148*da0073e9SAndroid Build Coastguard Worker # Cant use fullgraph=True because RERAISE is not supported 149*da0073e9SAndroid Build Coastguard Worker opt_fn = torch.compile(fn, backend="eager") 150*da0073e9SAndroid Build Coastguard Worker res = opt_fn(x) 151*da0073e9SAndroid Build Coastguard Worker 152*da0073e9SAndroid Build Coastguard Worker # TODO(anijain2305) - does not work with fullgraph=True 153*da0073e9SAndroid Build Coastguard Worker def test_exception_with_ctx_manager(self): 154*da0073e9SAndroid Build Coastguard Worker def fn(x): 155*da0073e9SAndroid Build Coastguard Worker x = torch.cos(x) 156*da0073e9SAndroid Build Coastguard Worker try: 157*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 158*da0073e9SAndroid Build Coastguard Worker x = torch.sin(x) 159*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError("Not implemented") 160*da0073e9SAndroid Build Coastguard Worker except NotImplementedError as e: 161*da0073e9SAndroid Build Coastguard Worker x = torch.sigmoid(x) 162*da0073e9SAndroid Build Coastguard Worker return x 163*da0073e9SAndroid Build Coastguard Worker 164*da0073e9SAndroid Build Coastguard Worker x = torch.randn(4) 165*da0073e9SAndroid Build Coastguard Worker ref = fn(x) 166*da0073e9SAndroid Build Coastguard Worker # Cant use fullgraph=True because WITH_EXCEPT_START is not supported 167*da0073e9SAndroid Build Coastguard Worker opt_fn = torch.compile(fn, backend="eager") 168*da0073e9SAndroid Build Coastguard Worker res = opt_fn(x) 169*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ref, res) 170*da0073e9SAndroid Build Coastguard Worker 171*da0073e9SAndroid Build Coastguard Worker def test_exception_raised_from_child(self): 172*da0073e9SAndroid Build Coastguard Worker def gn(): 173*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError("foo") 174*da0073e9SAndroid Build Coastguard Worker 175*da0073e9SAndroid Build Coastguard Worker def fn(x): 176*da0073e9SAndroid Build Coastguard Worker x = torch.cos(x) 177*da0073e9SAndroid Build Coastguard Worker try: 178*da0073e9SAndroid Build Coastguard Worker x = torch.sin(x) 179*da0073e9SAndroid Build Coastguard Worker gn() 180*da0073e9SAndroid Build Coastguard Worker x = torch.sin(x) 181*da0073e9SAndroid Build Coastguard Worker except Exception: 182*da0073e9SAndroid Build Coastguard Worker x = torch.sigmoid(x) 183*da0073e9SAndroid Build Coastguard Worker 184*da0073e9SAndroid Build Coastguard Worker return x 185*da0073e9SAndroid Build Coastguard Worker 186*da0073e9SAndroid Build Coastguard Worker x = torch.randn(4) 187*da0073e9SAndroid Build Coastguard Worker ref = fn(x) 188*da0073e9SAndroid Build Coastguard Worker opt_fn = torch.compile(fn, backend="eager", fullgraph=True) 189*da0073e9SAndroid Build Coastguard Worker res = opt_fn(x) 190*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ref, res) 191*da0073e9SAndroid Build Coastguard Worker 192*da0073e9SAndroid Build Coastguard Worker def test_dynamo_undo_kw_names(self): 193*da0073e9SAndroid Build Coastguard Worker def g(x, k=None): 194*da0073e9SAndroid Build Coastguard Worker if k: 195*da0073e9SAndroid Build Coastguard Worker raise TypeError("error") 196*da0073e9SAndroid Build Coastguard Worker return x.sin() 197*da0073e9SAndroid Build Coastguard Worker 198*da0073e9SAndroid Build Coastguard Worker def fn(x): 199*da0073e9SAndroid Build Coastguard Worker d = {"a": x} 200*da0073e9SAndroid Build Coastguard Worker try: 201*da0073e9SAndroid Build Coastguard Worker g(x, k=True) 202*da0073e9SAndroid Build Coastguard Worker except Exception: 203*da0073e9SAndroid Build Coastguard Worker y = 0 204*da0073e9SAndroid Build Coastguard Worker for _, b in d.items(): # noqa: PERF102 205*da0073e9SAndroid Build Coastguard Worker y += b.sum() 206*da0073e9SAndroid Build Coastguard Worker return y 207*da0073e9SAndroid Build Coastguard Worker 208*da0073e9SAndroid Build Coastguard Worker x = torch.randn(2, 3) 209*da0073e9SAndroid Build Coastguard Worker expected = fn(x) 210*da0073e9SAndroid Build Coastguard Worker opt_fn = torch.compile(fn, backend="eager", fullgraph=True) 211*da0073e9SAndroid Build Coastguard Worker got = opt_fn(x) 212*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected, got) 213*da0073e9SAndroid Build Coastguard Worker 214*da0073e9SAndroid Build Coastguard Worker def test_nn_module_getattr(self): 215*da0073e9SAndroid Build Coastguard Worker class A: 216*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 217*da0073e9SAndroid Build Coastguard Worker self._b = 20 218*da0073e9SAndroid Build Coastguard Worker 219*da0073e9SAndroid Build Coastguard Worker def __getattr__(self, name): 220*da0073e9SAndroid Build Coastguard Worker fixed_name = "_" + name 221*da0073e9SAndroid Build Coastguard Worker if fixed_name in self.__dict__: 222*da0073e9SAndroid Build Coastguard Worker return self.__dict__[fixed_name] 223*da0073e9SAndroid Build Coastguard Worker raise AttributeError(f"{name} absent") 224*da0073e9SAndroid Build Coastguard Worker 225*da0073e9SAndroid Build Coastguard Worker class B(A): 226*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 227*da0073e9SAndroid Build Coastguard Worker self.a = 10 228*da0073e9SAndroid Build Coastguard Worker 229*da0073e9SAndroid Build Coastguard Worker def __getattr__(self, name): 230*da0073e9SAndroid Build Coastguard Worker try: 231*da0073e9SAndroid Build Coastguard Worker return super().__getattr__(name) 232*da0073e9SAndroid Build Coastguard Worker except AttributeError: 233*da0073e9SAndroid Build Coastguard Worker return 30 234*da0073e9SAndroid Build Coastguard Worker 235*da0073e9SAndroid Build Coastguard Worker obj = B() 236*da0073e9SAndroid Build Coastguard Worker 237*da0073e9SAndroid Build Coastguard Worker def fn(x): 238*da0073e9SAndroid Build Coastguard Worker return x * obj.a * obj.b * obj.c 239*da0073e9SAndroid Build Coastguard Worker 240*da0073e9SAndroid Build Coastguard Worker x = torch.ones(4) 241*da0073e9SAndroid Build Coastguard Worker ref = fn(x) 242*da0073e9SAndroid Build Coastguard Worker print(ref) 243*da0073e9SAndroid Build Coastguard Worker opt_fn = torch.compile(fn, backend="eager", fullgraph=True) 244*da0073e9SAndroid Build Coastguard Worker res = opt_fn(x) 245*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ref, res) 246*da0073e9SAndroid Build Coastguard Worker 247*da0073e9SAndroid Build Coastguard Worker @torch._dynamo.config.patch(inline_inbuilt_nn_modules=True) 248*da0073e9SAndroid Build Coastguard Worker def test_custom_getattr_on_module_exception(self): 249*da0073e9SAndroid Build Coastguard Worker class Foo(torch.nn.Module): 250*da0073e9SAndroid Build Coastguard Worker def __init__(self, a=3): 251*da0073e9SAndroid Build Coastguard Worker super().__init__() 252*da0073e9SAndroid Build Coastguard Worker self.register_parameter("a", torch.nn.Parameter(torch.ones(4) * 2)) 253*da0073e9SAndroid Build Coastguard Worker 254*da0073e9SAndroid Build Coastguard Worker def __getattr__(self, name): 255*da0073e9SAndroid Build Coastguard Worker try: 256*da0073e9SAndroid Build Coastguard Worker return super().__getattr__(name) # defer to nn.Module's logic 257*da0073e9SAndroid Build Coastguard Worker except AttributeError: 258*da0073e9SAndroid Build Coastguard Worker if name == "a_copy": 259*da0073e9SAndroid Build Coastguard Worker return self.a 260*da0073e9SAndroid Build Coastguard Worker raise 261*da0073e9SAndroid Build Coastguard Worker 262*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 263*da0073e9SAndroid Build Coastguard Worker return x * self.a * self.a_copy 264*da0073e9SAndroid Build Coastguard Worker 265*da0073e9SAndroid Build Coastguard Worker mod = Foo() 266*da0073e9SAndroid Build Coastguard Worker opt_mod = torch.compile(mod, backend="eager", fullgraph=True) 267*da0073e9SAndroid Build Coastguard Worker 268*da0073e9SAndroid Build Coastguard Worker x = torch.ones(4) 269*da0073e9SAndroid Build Coastguard Worker self.assertEqual(mod(x), opt_mod(x)) 270*da0073e9SAndroid Build Coastguard Worker 271*da0073e9SAndroid Build Coastguard Worker def test_attribute_error_from_getattr(self): 272*da0073e9SAndroid Build Coastguard Worker class Mock: 273*da0073e9SAndroid Build Coastguard Worker def __init__(self): 274*da0073e9SAndroid Build Coastguard Worker self.a = 5 275*da0073e9SAndroid Build Coastguard Worker 276*da0073e9SAndroid Build Coastguard Worker def __getattr__(self, name): 277*da0073e9SAndroid Build Coastguard Worker if name != "a": 278*da0073e9SAndroid Build Coastguard Worker raise AttributeError("missing") 279*da0073e9SAndroid Build Coastguard Worker return self.__dict__["a"] 280*da0073e9SAndroid Build Coastguard Worker 281*da0073e9SAndroid Build Coastguard Worker mock = Mock() 282*da0073e9SAndroid Build Coastguard Worker 283*da0073e9SAndroid Build Coastguard Worker def fn(x): 284*da0073e9SAndroid Build Coastguard Worker if hasattr(mock, "b"): 285*da0073e9SAndroid Build Coastguard Worker return torch.cos(x) 286*da0073e9SAndroid Build Coastguard Worker return torch.sin(x) 287*da0073e9SAndroid Build Coastguard Worker 288*da0073e9SAndroid Build Coastguard Worker opt_fn = torch.compile(fn, backend="eager", fullgraph=True) 289*da0073e9SAndroid Build Coastguard Worker x = torch.randn(4) 290*da0073e9SAndroid Build Coastguard Worker ref = fn(x) 291*da0073e9SAndroid Build Coastguard Worker res = opt_fn(x) 292*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ref, res) 293*da0073e9SAndroid Build Coastguard Worker 294*da0073e9SAndroid Build Coastguard Worker def test_stop_iteration(self): 295*da0073e9SAndroid Build Coastguard Worker def zip_longest(*iterables, fillvalue=None): 296*da0073e9SAndroid Build Coastguard Worker # Get the iterators for each iterable 297*da0073e9SAndroid Build Coastguard Worker iterators = [iter(it) for it in iterables] 298*da0073e9SAndroid Build Coastguard Worker 299*da0073e9SAndroid Build Coastguard Worker result = [] 300*da0073e9SAndroid Build Coastguard Worker while True: 301*da0073e9SAndroid Build Coastguard Worker for it in iterators: 302*da0073e9SAndroid Build Coastguard Worker try: 303*da0073e9SAndroid Build Coastguard Worker value = next(it) 304*da0073e9SAndroid Build Coastguard Worker except StopIteration: 305*da0073e9SAndroid Build Coastguard Worker result.append(fillvalue) 306*da0073e9SAndroid Build Coastguard Worker return result 307*da0073e9SAndroid Build Coastguard Worker result.append(value) 308*da0073e9SAndroid Build Coastguard Worker 309*da0073e9SAndroid Build Coastguard Worker def fn(x, y): 310*da0073e9SAndroid Build Coastguard Worker torch.cos(torch.randn(4)) 311*da0073e9SAndroid Build Coastguard Worker return tuple(zip_longest(x, y)) 312*da0073e9SAndroid Build Coastguard Worker 313*da0073e9SAndroid Build Coastguard Worker x = [1, 2, 3, 4] 314*da0073e9SAndroid Build Coastguard Worker y = [10, 11, 12] 315*da0073e9SAndroid Build Coastguard Worker 316*da0073e9SAndroid Build Coastguard Worker opt_fn = torch.compile(fn, backend="eager", fullgraph=True) 317*da0073e9SAndroid Build Coastguard Worker ref = fn(x, y) 318*da0073e9SAndroid Build Coastguard Worker res = opt_fn(x, y) 319*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ref, res) 320*da0073e9SAndroid Build Coastguard Worker 321*da0073e9SAndroid Build Coastguard Worker def test_nn_reraise(self): 322*da0073e9SAndroid Build Coastguard Worker class M(torch.nn.Module): 323*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 324*da0073e9SAndroid Build Coastguard Worker raise ValueError("woof") 325*da0073e9SAndroid Build Coastguard Worker return x + 2 326*da0073e9SAndroid Build Coastguard Worker 327*da0073e9SAndroid Build Coastguard Worker m = M() 328*da0073e9SAndroid Build Coastguard Worker m.register_forward_pre_hook(lambda m, go: None) 329*da0073e9SAndroid Build Coastguard Worker 330*da0073e9SAndroid Build Coastguard Worker torch._dynamo.utils.clear_compilation_metrics() 331*da0073e9SAndroid Build Coastguard Worker opt_call = torch.compile(lambda x: m(x), backend="eager") 332*da0073e9SAndroid Build Coastguard Worker self.assertRaises(ValueError, lambda: opt_call(torch.randn(3))) 333*da0073e9SAndroid Build Coastguard Worker metrics = torch._dynamo.utils.get_compilation_metrics() 334*da0073e9SAndroid Build Coastguard Worker self.assertEqual(metrics[0].fail_reason, "Observed exception") 335*da0073e9SAndroid Build Coastguard Worker 336*da0073e9SAndroid Build Coastguard Worker def test_key_error(self): 337*da0073e9SAndroid Build Coastguard Worker def fn(x, d): 338*da0073e9SAndroid Build Coastguard Worker try: 339*da0073e9SAndroid Build Coastguard Worker a = d["b"] 340*da0073e9SAndroid Build Coastguard Worker except KeyError: 341*da0073e9SAndroid Build Coastguard Worker a = 2 342*da0073e9SAndroid Build Coastguard Worker return x * a 343*da0073e9SAndroid Build Coastguard Worker 344*da0073e9SAndroid Build Coastguard Worker opt_fn = torch.compile(fn, backend="eager", fullgraph=True) 345*da0073e9SAndroid Build Coastguard Worker x = torch.randn(4) 346*da0073e9SAndroid Build Coastguard Worker d = {"a": 1} 347*da0073e9SAndroid Build Coastguard Worker ref = fn(x, d) 348*da0073e9SAndroid Build Coastguard Worker res = opt_fn(x, d) 349*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ref, res) 350*da0073e9SAndroid Build Coastguard Worker 351*da0073e9SAndroid Build Coastguard Worker def test_atrribute_error(self): 352*da0073e9SAndroid Build Coastguard Worker class Mock: 353*da0073e9SAndroid Build Coastguard Worker def __init__(self): 354*da0073e9SAndroid Build Coastguard Worker self.a = 1 355*da0073e9SAndroid Build Coastguard Worker 356*da0073e9SAndroid Build Coastguard Worker mock = Mock() 357*da0073e9SAndroid Build Coastguard Worker 358*da0073e9SAndroid Build Coastguard Worker def fn(x): 359*da0073e9SAndroid Build Coastguard Worker try: 360*da0073e9SAndroid Build Coastguard Worker c = 2 361*da0073e9SAndroid Build Coastguard Worker mock.b 362*da0073e9SAndroid Build Coastguard Worker except AttributeError: 363*da0073e9SAndroid Build Coastguard Worker c = 3 364*da0073e9SAndroid Build Coastguard Worker return torch.sin(x) * c 365*da0073e9SAndroid Build Coastguard Worker 366*da0073e9SAndroid Build Coastguard Worker opt_fn = torch.compile(fn, backend="eager") 367*da0073e9SAndroid Build Coastguard Worker x = torch.randn(4) 368*da0073e9SAndroid Build Coastguard Worker ref = fn(x) 369*da0073e9SAndroid Build Coastguard Worker res = opt_fn(x) 370*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ref, res) 371*da0073e9SAndroid Build Coastguard Worker 372*da0073e9SAndroid Build Coastguard Worker def test_raise_from_None(self): 373*da0073e9SAndroid Build Coastguard Worker # Inspired from os.environ 374*da0073e9SAndroid Build Coastguard Worker class MyMapping: 375*da0073e9SAndroid Build Coastguard Worker def __init__(self, d): 376*da0073e9SAndroid Build Coastguard Worker self._d = d 377*da0073e9SAndroid Build Coastguard Worker 378*da0073e9SAndroid Build Coastguard Worker def __getitem__(self, key): 379*da0073e9SAndroid Build Coastguard Worker try: 380*da0073e9SAndroid Build Coastguard Worker value = self._d[key] 381*da0073e9SAndroid Build Coastguard Worker except KeyError: 382*da0073e9SAndroid Build Coastguard Worker raise KeyError(key) from None 383*da0073e9SAndroid Build Coastguard Worker return value 384*da0073e9SAndroid Build Coastguard Worker 385*da0073e9SAndroid Build Coastguard Worker d = MyMapping({"a": 10, "b": 20}) 386*da0073e9SAndroid Build Coastguard Worker 387*da0073e9SAndroid Build Coastguard Worker def mapping_get(obj, key, value=None): 388*da0073e9SAndroid Build Coastguard Worker try: 389*da0073e9SAndroid Build Coastguard Worker return obj.__getitem__(key) 390*da0073e9SAndroid Build Coastguard Worker except KeyError: 391*da0073e9SAndroid Build Coastguard Worker return value 392*da0073e9SAndroid Build Coastguard Worker 393*da0073e9SAndroid Build Coastguard Worker def fn(x, d, key): 394*da0073e9SAndroid Build Coastguard Worker x = torch.sin(x + 1) 395*da0073e9SAndroid Build Coastguard Worker return x, mapping_get(d, key) 396*da0073e9SAndroid Build Coastguard Worker 397*da0073e9SAndroid Build Coastguard Worker opt_fn = torch.compile(fn, backend="eager", fullgraph=True) 398*da0073e9SAndroid Build Coastguard Worker 399*da0073e9SAndroid Build Coastguard Worker x = torch.rand(2, 3) 400*da0073e9SAndroid Build Coastguard Worker ref = fn(x, d, "m") 401*da0073e9SAndroid Build Coastguard Worker res = opt_fn(x, d, "m") 402*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ref[0], res[0]) 403*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ref[1], res[1]) 404*da0073e9SAndroid Build Coastguard Worker 405*da0073e9SAndroid Build Coastguard Worker 406*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 407*da0073e9SAndroid Build Coastguard Worker from torch._dynamo.test_case import run_tests 408*da0073e9SAndroid Build Coastguard Worker 409*da0073e9SAndroid Build Coastguard Worker run_tests() 410