# Owner(s): ["module: dynamo"] import torch import torch._dynamo.config import torch._dynamo.test_case import torch._functorch.config import torch.nn import torch.utils.checkpoint class ExceptionTests(torch._dynamo.test_case.TestCase): def test_exception(self): def fn(x): x = torch.cos(x) try: x = torch.sin(x) raise NotImplementedError except Exception: x = torch.sigmoid(x) return x x = torch.randn(4) ref = fn(x) opt_fn = torch.compile(fn, backend="eager", fullgraph=True) res = opt_fn(x) self.assertEqual(ref, res) def test_exception2(self): def fn(x): x = torch.cos(x) try: x = torch.sin(x) raise NotImplementedError except (NotImplementedError, AttributeError) as e: x = torch.sigmoid(x) return x x = torch.randn(4) ref = fn(x) opt_fn = torch.compile(fn, backend="eager", fullgraph=True) res = opt_fn(x) self.assertEqual(ref, res) def test_exception3(self): def fn(x): x = torch.cos(x) try: x = torch.sin(x) raise NotImplementedError("Not implemented") except AssertionError: x = torch.sigmoid(x) except NotImplementedError: x = torch.cos(x) finally: x = torch.cos(x) return x x = torch.randn(4) ref = fn(x) opt_fn = torch.compile(fn, backend="eager", fullgraph=True) res = opt_fn(x) self.assertEqual(ref, res) def test_exception4(self): def fn(x): for i in range(10): if i == 5: return x try: x = torch.sin(x) raise NotImplementedError except Exception: x = torch.sigmoid(x) return x x = torch.randn(4) ref = fn(x) opt_fn = torch.compile(fn, backend="eager", fullgraph=True) res = opt_fn(x) self.assertEqual(ref, res) def test_exception_with_another_exception(self): def fn(x): x = torch.cos(x) try: x = torch.sin(x) raise NotImplementedError("Not implemented") except NotImplementedError as e: x = torch.sigmoid(x) try: x = torch.cos(x) raise AssertionError except AssertionError: x = torch.cos(x) x = torch.randn(4) ref = fn(x) opt_fn = torch.compile(fn, backend="eager", fullgraph=True) res = opt_fn(x) self.assertEqual(ref, res) def test_exception_else(self): def gn(x): return torch.cos(x) def fn(x): x = torch.cos(x) try: x = torch.sin(x) x = gn(x) except Exception: x = torch.sigmoid(x) else: x = torch.cos(x) return x x = torch.randn(4) ref = fn(x) opt_fn = torch.compile(fn, backend="eager", fullgraph=True) res = opt_fn(x) self.assertEqual(ref, res) # TODO(anijain2305) - does not work with fullgraph=True def test_exception_with_another_exception2(self): def gn(x): try: x = torch.cos(x) raise NotImplementedError("Not implemented") except NotImplementedError as e: x = torch.sigmoid(x) raise def fn(x): try: x = torch.cos(x) gn(x) except Exception: pass return x x = torch.randn(4) ref = fn(x) # Cant use fullgraph=True because RERAISE is not supported opt_fn = torch.compile(fn, backend="eager") res = opt_fn(x) # TODO(anijain2305) - does not work with fullgraph=True def test_exception_with_ctx_manager(self): def fn(x): x = torch.cos(x) try: with torch.no_grad(): x = torch.sin(x) raise NotImplementedError("Not implemented") except NotImplementedError as e: x = torch.sigmoid(x) return x x = torch.randn(4) ref = fn(x) # Cant use fullgraph=True because WITH_EXCEPT_START is not supported opt_fn = torch.compile(fn, backend="eager") res = opt_fn(x) self.assertEqual(ref, res) def test_exception_raised_from_child(self): def gn(): raise NotImplementedError("foo") def fn(x): x = torch.cos(x) try: x = torch.sin(x) gn() x = torch.sin(x) except Exception: x = torch.sigmoid(x) return x x = torch.randn(4) ref = fn(x) opt_fn = torch.compile(fn, backend="eager", fullgraph=True) res = opt_fn(x) self.assertEqual(ref, res) def test_dynamo_undo_kw_names(self): def g(x, k=None): if k: raise TypeError("error") return x.sin() def fn(x): d = {"a": x} try: g(x, k=True) except Exception: y = 0 for _, b in d.items(): # noqa: PERF102 y += b.sum() return y x = torch.randn(2, 3) expected = fn(x) opt_fn = torch.compile(fn, backend="eager", fullgraph=True) got = opt_fn(x) self.assertEqual(expected, got) def test_nn_module_getattr(self): class A: def __init__(self) -> None: self._b = 20 def __getattr__(self, name): fixed_name = "_" + name if fixed_name in self.__dict__: return self.__dict__[fixed_name] raise AttributeError(f"{name} absent") class B(A): def __init__(self) -> None: self.a = 10 def __getattr__(self, name): try: return super().__getattr__(name) except AttributeError: return 30 obj = B() def fn(x): return x * obj.a * obj.b * obj.c x = torch.ones(4) ref = fn(x) print(ref) opt_fn = torch.compile(fn, backend="eager", fullgraph=True) res = opt_fn(x) self.assertEqual(ref, res) @torch._dynamo.config.patch(inline_inbuilt_nn_modules=True) def test_custom_getattr_on_module_exception(self): class Foo(torch.nn.Module): def __init__(self, a=3): super().__init__() self.register_parameter("a", torch.nn.Parameter(torch.ones(4) * 2)) def __getattr__(self, name): try: return super().__getattr__(name) # defer to nn.Module's logic except AttributeError: if name == "a_copy": return self.a raise def forward(self, x): return x * self.a * self.a_copy mod = Foo() opt_mod = torch.compile(mod, backend="eager", fullgraph=True) x = torch.ones(4) self.assertEqual(mod(x), opt_mod(x)) def test_attribute_error_from_getattr(self): class Mock: def __init__(self): self.a = 5 def __getattr__(self, name): if name != "a": raise AttributeError("missing") return self.__dict__["a"] mock = Mock() def fn(x): if hasattr(mock, "b"): return torch.cos(x) return torch.sin(x) opt_fn = torch.compile(fn, backend="eager", fullgraph=True) x = torch.randn(4) ref = fn(x) res = opt_fn(x) self.assertEqual(ref, res) def test_stop_iteration(self): def zip_longest(*iterables, fillvalue=None): # Get the iterators for each iterable iterators = [iter(it) for it in iterables] result = [] while True: for it in iterators: try: value = next(it) except StopIteration: result.append(fillvalue) return result result.append(value) def fn(x, y): torch.cos(torch.randn(4)) return tuple(zip_longest(x, y)) x = [1, 2, 3, 4] y = [10, 11, 12] opt_fn = torch.compile(fn, backend="eager", fullgraph=True) ref = fn(x, y) res = opt_fn(x, y) self.assertEqual(ref, res) def test_nn_reraise(self): class M(torch.nn.Module): def forward(self, x): raise ValueError("woof") return x + 2 m = M() m.register_forward_pre_hook(lambda m, go: None) torch._dynamo.utils.clear_compilation_metrics() opt_call = torch.compile(lambda x: m(x), backend="eager") self.assertRaises(ValueError, lambda: opt_call(torch.randn(3))) metrics = torch._dynamo.utils.get_compilation_metrics() self.assertEqual(metrics[0].fail_reason, "Observed exception") def test_key_error(self): def fn(x, d): try: a = d["b"] except KeyError: a = 2 return x * a opt_fn = torch.compile(fn, backend="eager", fullgraph=True) x = torch.randn(4) d = {"a": 1} ref = fn(x, d) res = opt_fn(x, d) self.assertEqual(ref, res) def test_atrribute_error(self): class Mock: def __init__(self): self.a = 1 mock = Mock() def fn(x): try: c = 2 mock.b except AttributeError: c = 3 return torch.sin(x) * c opt_fn = torch.compile(fn, backend="eager") x = torch.randn(4) ref = fn(x) res = opt_fn(x) self.assertEqual(ref, res) def test_raise_from_None(self): # Inspired from os.environ class MyMapping: def __init__(self, d): self._d = d def __getitem__(self, key): try: value = self._d[key] except KeyError: raise KeyError(key) from None return value d = MyMapping({"a": 10, "b": 20}) def mapping_get(obj, key, value=None): try: return obj.__getitem__(key) except KeyError: return value def fn(x, d, key): x = torch.sin(x + 1) return x, mapping_get(d, key) opt_fn = torch.compile(fn, backend="eager", fullgraph=True) x = torch.rand(2, 3) ref = fn(x, d, "m") res = opt_fn(x, d, "m") self.assertEqual(ref[0], res[0]) self.assertEqual(ref[1], res[1]) if __name__ == "__main__": from torch._dynamo.test_case import run_tests run_tests()