xref: /aosp_15_r20/external/pytorch/test/dynamo/test_exceptions.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: dynamo"]
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerimport torch
4*da0073e9SAndroid Build Coastguard Workerimport torch._dynamo.config
5*da0073e9SAndroid Build Coastguard Workerimport torch._dynamo.test_case
6*da0073e9SAndroid Build Coastguard Workerimport torch._functorch.config
7*da0073e9SAndroid Build Coastguard Workerimport torch.nn
8*da0073e9SAndroid Build Coastguard Workerimport torch.utils.checkpoint
9*da0073e9SAndroid Build Coastguard Worker
10*da0073e9SAndroid Build Coastguard Worker
11*da0073e9SAndroid Build Coastguard Workerclass ExceptionTests(torch._dynamo.test_case.TestCase):
12*da0073e9SAndroid Build Coastguard Worker    def test_exception(self):
13*da0073e9SAndroid Build Coastguard Worker        def fn(x):
14*da0073e9SAndroid Build Coastguard Worker            x = torch.cos(x)
15*da0073e9SAndroid Build Coastguard Worker            try:
16*da0073e9SAndroid Build Coastguard Worker                x = torch.sin(x)
17*da0073e9SAndroid Build Coastguard Worker                raise NotImplementedError
18*da0073e9SAndroid Build Coastguard Worker            except Exception:
19*da0073e9SAndroid Build Coastguard Worker                x = torch.sigmoid(x)
20*da0073e9SAndroid Build Coastguard Worker
21*da0073e9SAndroid Build Coastguard Worker            return x
22*da0073e9SAndroid Build Coastguard Worker
23*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(4)
24*da0073e9SAndroid Build Coastguard Worker        ref = fn(x)
25*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
26*da0073e9SAndroid Build Coastguard Worker        res = opt_fn(x)
27*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(ref, res)
28*da0073e9SAndroid Build Coastguard Worker
29*da0073e9SAndroid Build Coastguard Worker    def test_exception2(self):
30*da0073e9SAndroid Build Coastguard Worker        def fn(x):
31*da0073e9SAndroid Build Coastguard Worker            x = torch.cos(x)
32*da0073e9SAndroid Build Coastguard Worker            try:
33*da0073e9SAndroid Build Coastguard Worker                x = torch.sin(x)
34*da0073e9SAndroid Build Coastguard Worker                raise NotImplementedError
35*da0073e9SAndroid Build Coastguard Worker            except (NotImplementedError, AttributeError) as e:
36*da0073e9SAndroid Build Coastguard Worker                x = torch.sigmoid(x)
37*da0073e9SAndroid Build Coastguard Worker
38*da0073e9SAndroid Build Coastguard Worker            return x
39*da0073e9SAndroid Build Coastguard Worker
40*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(4)
41*da0073e9SAndroid Build Coastguard Worker        ref = fn(x)
42*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
43*da0073e9SAndroid Build Coastguard Worker        res = opt_fn(x)
44*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(ref, res)
45*da0073e9SAndroid Build Coastguard Worker
46*da0073e9SAndroid Build Coastguard Worker    def test_exception3(self):
47*da0073e9SAndroid Build Coastguard Worker        def fn(x):
48*da0073e9SAndroid Build Coastguard Worker            x = torch.cos(x)
49*da0073e9SAndroid Build Coastguard Worker            try:
50*da0073e9SAndroid Build Coastguard Worker                x = torch.sin(x)
51*da0073e9SAndroid Build Coastguard Worker                raise NotImplementedError("Not implemented")
52*da0073e9SAndroid Build Coastguard Worker            except AssertionError:
53*da0073e9SAndroid Build Coastguard Worker                x = torch.sigmoid(x)
54*da0073e9SAndroid Build Coastguard Worker            except NotImplementedError:
55*da0073e9SAndroid Build Coastguard Worker                x = torch.cos(x)
56*da0073e9SAndroid Build Coastguard Worker            finally:
57*da0073e9SAndroid Build Coastguard Worker                x = torch.cos(x)
58*da0073e9SAndroid Build Coastguard Worker
59*da0073e9SAndroid Build Coastguard Worker            return x
60*da0073e9SAndroid Build Coastguard Worker
61*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(4)
62*da0073e9SAndroid Build Coastguard Worker        ref = fn(x)
63*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
64*da0073e9SAndroid Build Coastguard Worker        res = opt_fn(x)
65*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(ref, res)
66*da0073e9SAndroid Build Coastguard Worker
67*da0073e9SAndroid Build Coastguard Worker    def test_exception4(self):
68*da0073e9SAndroid Build Coastguard Worker        def fn(x):
69*da0073e9SAndroid Build Coastguard Worker            for i in range(10):
70*da0073e9SAndroid Build Coastguard Worker                if i == 5:
71*da0073e9SAndroid Build Coastguard Worker                    return x
72*da0073e9SAndroid Build Coastguard Worker                try:
73*da0073e9SAndroid Build Coastguard Worker                    x = torch.sin(x)
74*da0073e9SAndroid Build Coastguard Worker                    raise NotImplementedError
75*da0073e9SAndroid Build Coastguard Worker                except Exception:
76*da0073e9SAndroid Build Coastguard Worker                    x = torch.sigmoid(x)
77*da0073e9SAndroid Build Coastguard Worker
78*da0073e9SAndroid Build Coastguard Worker            return x
79*da0073e9SAndroid Build Coastguard Worker
80*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(4)
81*da0073e9SAndroid Build Coastguard Worker        ref = fn(x)
82*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
83*da0073e9SAndroid Build Coastguard Worker        res = opt_fn(x)
84*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(ref, res)
85*da0073e9SAndroid Build Coastguard Worker
86*da0073e9SAndroid Build Coastguard Worker    def test_exception_with_another_exception(self):
87*da0073e9SAndroid Build Coastguard Worker        def fn(x):
88*da0073e9SAndroid Build Coastguard Worker            x = torch.cos(x)
89*da0073e9SAndroid Build Coastguard Worker            try:
90*da0073e9SAndroid Build Coastguard Worker                x = torch.sin(x)
91*da0073e9SAndroid Build Coastguard Worker                raise NotImplementedError("Not implemented")
92*da0073e9SAndroid Build Coastguard Worker            except NotImplementedError as e:
93*da0073e9SAndroid Build Coastguard Worker                x = torch.sigmoid(x)
94*da0073e9SAndroid Build Coastguard Worker                try:
95*da0073e9SAndroid Build Coastguard Worker                    x = torch.cos(x)
96*da0073e9SAndroid Build Coastguard Worker                    raise AssertionError
97*da0073e9SAndroid Build Coastguard Worker                except AssertionError:
98*da0073e9SAndroid Build Coastguard Worker                    x = torch.cos(x)
99*da0073e9SAndroid Build Coastguard Worker
100*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(4)
101*da0073e9SAndroid Build Coastguard Worker        ref = fn(x)
102*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
103*da0073e9SAndroid Build Coastguard Worker        res = opt_fn(x)
104*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(ref, res)
105*da0073e9SAndroid Build Coastguard Worker
106*da0073e9SAndroid Build Coastguard Worker    def test_exception_else(self):
107*da0073e9SAndroid Build Coastguard Worker        def gn(x):
108*da0073e9SAndroid Build Coastguard Worker            return torch.cos(x)
109*da0073e9SAndroid Build Coastguard Worker
110*da0073e9SAndroid Build Coastguard Worker        def fn(x):
111*da0073e9SAndroid Build Coastguard Worker            x = torch.cos(x)
112*da0073e9SAndroid Build Coastguard Worker            try:
113*da0073e9SAndroid Build Coastguard Worker                x = torch.sin(x)
114*da0073e9SAndroid Build Coastguard Worker                x = gn(x)
115*da0073e9SAndroid Build Coastguard Worker            except Exception:
116*da0073e9SAndroid Build Coastguard Worker                x = torch.sigmoid(x)
117*da0073e9SAndroid Build Coastguard Worker            else:
118*da0073e9SAndroid Build Coastguard Worker                x = torch.cos(x)
119*da0073e9SAndroid Build Coastguard Worker
120*da0073e9SAndroid Build Coastguard Worker            return x
121*da0073e9SAndroid Build Coastguard Worker
122*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(4)
123*da0073e9SAndroid Build Coastguard Worker        ref = fn(x)
124*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
125*da0073e9SAndroid Build Coastguard Worker        res = opt_fn(x)
126*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(ref, res)
127*da0073e9SAndroid Build Coastguard Worker
128*da0073e9SAndroid Build Coastguard Worker    # TODO(anijain2305) - does not work with fullgraph=True
129*da0073e9SAndroid Build Coastguard Worker    def test_exception_with_another_exception2(self):
130*da0073e9SAndroid Build Coastguard Worker        def gn(x):
131*da0073e9SAndroid Build Coastguard Worker            try:
132*da0073e9SAndroid Build Coastguard Worker                x = torch.cos(x)
133*da0073e9SAndroid Build Coastguard Worker                raise NotImplementedError("Not implemented")
134*da0073e9SAndroid Build Coastguard Worker            except NotImplementedError as e:
135*da0073e9SAndroid Build Coastguard Worker                x = torch.sigmoid(x)
136*da0073e9SAndroid Build Coastguard Worker                raise
137*da0073e9SAndroid Build Coastguard Worker
138*da0073e9SAndroid Build Coastguard Worker        def fn(x):
139*da0073e9SAndroid Build Coastguard Worker            try:
140*da0073e9SAndroid Build Coastguard Worker                x = torch.cos(x)
141*da0073e9SAndroid Build Coastguard Worker                gn(x)
142*da0073e9SAndroid Build Coastguard Worker            except Exception:
143*da0073e9SAndroid Build Coastguard Worker                pass
144*da0073e9SAndroid Build Coastguard Worker            return x
145*da0073e9SAndroid Build Coastguard Worker
146*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(4)
147*da0073e9SAndroid Build Coastguard Worker        ref = fn(x)
148*da0073e9SAndroid Build Coastguard Worker        # Cant use fullgraph=True because RERAISE is not supported
149*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch.compile(fn, backend="eager")
150*da0073e9SAndroid Build Coastguard Worker        res = opt_fn(x)
151*da0073e9SAndroid Build Coastguard Worker
152*da0073e9SAndroid Build Coastguard Worker    # TODO(anijain2305) - does not work with fullgraph=True
153*da0073e9SAndroid Build Coastguard Worker    def test_exception_with_ctx_manager(self):
154*da0073e9SAndroid Build Coastguard Worker        def fn(x):
155*da0073e9SAndroid Build Coastguard Worker            x = torch.cos(x)
156*da0073e9SAndroid Build Coastguard Worker            try:
157*da0073e9SAndroid Build Coastguard Worker                with torch.no_grad():
158*da0073e9SAndroid Build Coastguard Worker                    x = torch.sin(x)
159*da0073e9SAndroid Build Coastguard Worker                    raise NotImplementedError("Not implemented")
160*da0073e9SAndroid Build Coastguard Worker            except NotImplementedError as e:
161*da0073e9SAndroid Build Coastguard Worker                x = torch.sigmoid(x)
162*da0073e9SAndroid Build Coastguard Worker            return x
163*da0073e9SAndroid Build Coastguard Worker
164*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(4)
165*da0073e9SAndroid Build Coastguard Worker        ref = fn(x)
166*da0073e9SAndroid Build Coastguard Worker        # Cant use fullgraph=True because WITH_EXCEPT_START is not supported
167*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch.compile(fn, backend="eager")
168*da0073e9SAndroid Build Coastguard Worker        res = opt_fn(x)
169*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(ref, res)
170*da0073e9SAndroid Build Coastguard Worker
171*da0073e9SAndroid Build Coastguard Worker    def test_exception_raised_from_child(self):
172*da0073e9SAndroid Build Coastguard Worker        def gn():
173*da0073e9SAndroid Build Coastguard Worker            raise NotImplementedError("foo")
174*da0073e9SAndroid Build Coastguard Worker
175*da0073e9SAndroid Build Coastguard Worker        def fn(x):
176*da0073e9SAndroid Build Coastguard Worker            x = torch.cos(x)
177*da0073e9SAndroid Build Coastguard Worker            try:
178*da0073e9SAndroid Build Coastguard Worker                x = torch.sin(x)
179*da0073e9SAndroid Build Coastguard Worker                gn()
180*da0073e9SAndroid Build Coastguard Worker                x = torch.sin(x)
181*da0073e9SAndroid Build Coastguard Worker            except Exception:
182*da0073e9SAndroid Build Coastguard Worker                x = torch.sigmoid(x)
183*da0073e9SAndroid Build Coastguard Worker
184*da0073e9SAndroid Build Coastguard Worker            return x
185*da0073e9SAndroid Build Coastguard Worker
186*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(4)
187*da0073e9SAndroid Build Coastguard Worker        ref = fn(x)
188*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
189*da0073e9SAndroid Build Coastguard Worker        res = opt_fn(x)
190*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(ref, res)
191*da0073e9SAndroid Build Coastguard Worker
192*da0073e9SAndroid Build Coastguard Worker    def test_dynamo_undo_kw_names(self):
193*da0073e9SAndroid Build Coastguard Worker        def g(x, k=None):
194*da0073e9SAndroid Build Coastguard Worker            if k:
195*da0073e9SAndroid Build Coastguard Worker                raise TypeError("error")
196*da0073e9SAndroid Build Coastguard Worker            return x.sin()
197*da0073e9SAndroid Build Coastguard Worker
198*da0073e9SAndroid Build Coastguard Worker        def fn(x):
199*da0073e9SAndroid Build Coastguard Worker            d = {"a": x}
200*da0073e9SAndroid Build Coastguard Worker            try:
201*da0073e9SAndroid Build Coastguard Worker                g(x, k=True)
202*da0073e9SAndroid Build Coastguard Worker            except Exception:
203*da0073e9SAndroid Build Coastguard Worker                y = 0
204*da0073e9SAndroid Build Coastguard Worker                for _, b in d.items():  # noqa: PERF102
205*da0073e9SAndroid Build Coastguard Worker                    y += b.sum()
206*da0073e9SAndroid Build Coastguard Worker            return y
207*da0073e9SAndroid Build Coastguard Worker
208*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(2, 3)
209*da0073e9SAndroid Build Coastguard Worker        expected = fn(x)
210*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
211*da0073e9SAndroid Build Coastguard Worker        got = opt_fn(x)
212*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expected, got)
213*da0073e9SAndroid Build Coastguard Worker
214*da0073e9SAndroid Build Coastguard Worker    def test_nn_module_getattr(self):
215*da0073e9SAndroid Build Coastguard Worker        class A:
216*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
217*da0073e9SAndroid Build Coastguard Worker                self._b = 20
218*da0073e9SAndroid Build Coastguard Worker
219*da0073e9SAndroid Build Coastguard Worker            def __getattr__(self, name):
220*da0073e9SAndroid Build Coastguard Worker                fixed_name = "_" + name
221*da0073e9SAndroid Build Coastguard Worker                if fixed_name in self.__dict__:
222*da0073e9SAndroid Build Coastguard Worker                    return self.__dict__[fixed_name]
223*da0073e9SAndroid Build Coastguard Worker                raise AttributeError(f"{name} absent")
224*da0073e9SAndroid Build Coastguard Worker
225*da0073e9SAndroid Build Coastguard Worker        class B(A):
226*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
227*da0073e9SAndroid Build Coastguard Worker                self.a = 10
228*da0073e9SAndroid Build Coastguard Worker
229*da0073e9SAndroid Build Coastguard Worker            def __getattr__(self, name):
230*da0073e9SAndroid Build Coastguard Worker                try:
231*da0073e9SAndroid Build Coastguard Worker                    return super().__getattr__(name)
232*da0073e9SAndroid Build Coastguard Worker                except AttributeError:
233*da0073e9SAndroid Build Coastguard Worker                    return 30
234*da0073e9SAndroid Build Coastguard Worker
235*da0073e9SAndroid Build Coastguard Worker        obj = B()
236*da0073e9SAndroid Build Coastguard Worker
237*da0073e9SAndroid Build Coastguard Worker        def fn(x):
238*da0073e9SAndroid Build Coastguard Worker            return x * obj.a * obj.b * obj.c
239*da0073e9SAndroid Build Coastguard Worker
240*da0073e9SAndroid Build Coastguard Worker        x = torch.ones(4)
241*da0073e9SAndroid Build Coastguard Worker        ref = fn(x)
242*da0073e9SAndroid Build Coastguard Worker        print(ref)
243*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
244*da0073e9SAndroid Build Coastguard Worker        res = opt_fn(x)
245*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(ref, res)
246*da0073e9SAndroid Build Coastguard Worker
247*da0073e9SAndroid Build Coastguard Worker    @torch._dynamo.config.patch(inline_inbuilt_nn_modules=True)
248*da0073e9SAndroid Build Coastguard Worker    def test_custom_getattr_on_module_exception(self):
249*da0073e9SAndroid Build Coastguard Worker        class Foo(torch.nn.Module):
250*da0073e9SAndroid Build Coastguard Worker            def __init__(self, a=3):
251*da0073e9SAndroid Build Coastguard Worker                super().__init__()
252*da0073e9SAndroid Build Coastguard Worker                self.register_parameter("a", torch.nn.Parameter(torch.ones(4) * 2))
253*da0073e9SAndroid Build Coastguard Worker
254*da0073e9SAndroid Build Coastguard Worker            def __getattr__(self, name):
255*da0073e9SAndroid Build Coastguard Worker                try:
256*da0073e9SAndroid Build Coastguard Worker                    return super().__getattr__(name)  # defer to nn.Module's logic
257*da0073e9SAndroid Build Coastguard Worker                except AttributeError:
258*da0073e9SAndroid Build Coastguard Worker                    if name == "a_copy":
259*da0073e9SAndroid Build Coastguard Worker                        return self.a
260*da0073e9SAndroid Build Coastguard Worker                    raise
261*da0073e9SAndroid Build Coastguard Worker
262*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
263*da0073e9SAndroid Build Coastguard Worker                return x * self.a * self.a_copy
264*da0073e9SAndroid Build Coastguard Worker
265*da0073e9SAndroid Build Coastguard Worker        mod = Foo()
266*da0073e9SAndroid Build Coastguard Worker        opt_mod = torch.compile(mod, backend="eager", fullgraph=True)
267*da0073e9SAndroid Build Coastguard Worker
268*da0073e9SAndroid Build Coastguard Worker        x = torch.ones(4)
269*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(mod(x), opt_mod(x))
270*da0073e9SAndroid Build Coastguard Worker
271*da0073e9SAndroid Build Coastguard Worker    def test_attribute_error_from_getattr(self):
272*da0073e9SAndroid Build Coastguard Worker        class Mock:
273*da0073e9SAndroid Build Coastguard Worker            def __init__(self):
274*da0073e9SAndroid Build Coastguard Worker                self.a = 5
275*da0073e9SAndroid Build Coastguard Worker
276*da0073e9SAndroid Build Coastguard Worker            def __getattr__(self, name):
277*da0073e9SAndroid Build Coastguard Worker                if name != "a":
278*da0073e9SAndroid Build Coastguard Worker                    raise AttributeError("missing")
279*da0073e9SAndroid Build Coastguard Worker                return self.__dict__["a"]
280*da0073e9SAndroid Build Coastguard Worker
281*da0073e9SAndroid Build Coastguard Worker        mock = Mock()
282*da0073e9SAndroid Build Coastguard Worker
283*da0073e9SAndroid Build Coastguard Worker        def fn(x):
284*da0073e9SAndroid Build Coastguard Worker            if hasattr(mock, "b"):
285*da0073e9SAndroid Build Coastguard Worker                return torch.cos(x)
286*da0073e9SAndroid Build Coastguard Worker            return torch.sin(x)
287*da0073e9SAndroid Build Coastguard Worker
288*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
289*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(4)
290*da0073e9SAndroid Build Coastguard Worker        ref = fn(x)
291*da0073e9SAndroid Build Coastguard Worker        res = opt_fn(x)
292*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(ref, res)
293*da0073e9SAndroid Build Coastguard Worker
294*da0073e9SAndroid Build Coastguard Worker    def test_stop_iteration(self):
295*da0073e9SAndroid Build Coastguard Worker        def zip_longest(*iterables, fillvalue=None):
296*da0073e9SAndroid Build Coastguard Worker            # Get the iterators for each iterable
297*da0073e9SAndroid Build Coastguard Worker            iterators = [iter(it) for it in iterables]
298*da0073e9SAndroid Build Coastguard Worker
299*da0073e9SAndroid Build Coastguard Worker            result = []
300*da0073e9SAndroid Build Coastguard Worker            while True:
301*da0073e9SAndroid Build Coastguard Worker                for it in iterators:
302*da0073e9SAndroid Build Coastguard Worker                    try:
303*da0073e9SAndroid Build Coastguard Worker                        value = next(it)
304*da0073e9SAndroid Build Coastguard Worker                    except StopIteration:
305*da0073e9SAndroid Build Coastguard Worker                        result.append(fillvalue)
306*da0073e9SAndroid Build Coastguard Worker                        return result
307*da0073e9SAndroid Build Coastguard Worker                    result.append(value)
308*da0073e9SAndroid Build Coastguard Worker
309*da0073e9SAndroid Build Coastguard Worker        def fn(x, y):
310*da0073e9SAndroid Build Coastguard Worker            torch.cos(torch.randn(4))
311*da0073e9SAndroid Build Coastguard Worker            return tuple(zip_longest(x, y))
312*da0073e9SAndroid Build Coastguard Worker
313*da0073e9SAndroid Build Coastguard Worker        x = [1, 2, 3, 4]
314*da0073e9SAndroid Build Coastguard Worker        y = [10, 11, 12]
315*da0073e9SAndroid Build Coastguard Worker
316*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
317*da0073e9SAndroid Build Coastguard Worker        ref = fn(x, y)
318*da0073e9SAndroid Build Coastguard Worker        res = opt_fn(x, y)
319*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(ref, res)
320*da0073e9SAndroid Build Coastguard Worker
321*da0073e9SAndroid Build Coastguard Worker    def test_nn_reraise(self):
322*da0073e9SAndroid Build Coastguard Worker        class M(torch.nn.Module):
323*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
324*da0073e9SAndroid Build Coastguard Worker                raise ValueError("woof")
325*da0073e9SAndroid Build Coastguard Worker                return x + 2
326*da0073e9SAndroid Build Coastguard Worker
327*da0073e9SAndroid Build Coastguard Worker        m = M()
328*da0073e9SAndroid Build Coastguard Worker        m.register_forward_pre_hook(lambda m, go: None)
329*da0073e9SAndroid Build Coastguard Worker
330*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.utils.clear_compilation_metrics()
331*da0073e9SAndroid Build Coastguard Worker        opt_call = torch.compile(lambda x: m(x), backend="eager")
332*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(ValueError, lambda: opt_call(torch.randn(3)))
333*da0073e9SAndroid Build Coastguard Worker        metrics = torch._dynamo.utils.get_compilation_metrics()
334*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(metrics[0].fail_reason, "Observed exception")
335*da0073e9SAndroid Build Coastguard Worker
336*da0073e9SAndroid Build Coastguard Worker    def test_key_error(self):
337*da0073e9SAndroid Build Coastguard Worker        def fn(x, d):
338*da0073e9SAndroid Build Coastguard Worker            try:
339*da0073e9SAndroid Build Coastguard Worker                a = d["b"]
340*da0073e9SAndroid Build Coastguard Worker            except KeyError:
341*da0073e9SAndroid Build Coastguard Worker                a = 2
342*da0073e9SAndroid Build Coastguard Worker            return x * a
343*da0073e9SAndroid Build Coastguard Worker
344*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
345*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(4)
346*da0073e9SAndroid Build Coastguard Worker        d = {"a": 1}
347*da0073e9SAndroid Build Coastguard Worker        ref = fn(x, d)
348*da0073e9SAndroid Build Coastguard Worker        res = opt_fn(x, d)
349*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(ref, res)
350*da0073e9SAndroid Build Coastguard Worker
351*da0073e9SAndroid Build Coastguard Worker    def test_atrribute_error(self):
352*da0073e9SAndroid Build Coastguard Worker        class Mock:
353*da0073e9SAndroid Build Coastguard Worker            def __init__(self):
354*da0073e9SAndroid Build Coastguard Worker                self.a = 1
355*da0073e9SAndroid Build Coastguard Worker
356*da0073e9SAndroid Build Coastguard Worker        mock = Mock()
357*da0073e9SAndroid Build Coastguard Worker
358*da0073e9SAndroid Build Coastguard Worker        def fn(x):
359*da0073e9SAndroid Build Coastguard Worker            try:
360*da0073e9SAndroid Build Coastguard Worker                c = 2
361*da0073e9SAndroid Build Coastguard Worker                mock.b
362*da0073e9SAndroid Build Coastguard Worker            except AttributeError:
363*da0073e9SAndroid Build Coastguard Worker                c = 3
364*da0073e9SAndroid Build Coastguard Worker            return torch.sin(x) * c
365*da0073e9SAndroid Build Coastguard Worker
366*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch.compile(fn, backend="eager")
367*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(4)
368*da0073e9SAndroid Build Coastguard Worker        ref = fn(x)
369*da0073e9SAndroid Build Coastguard Worker        res = opt_fn(x)
370*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(ref, res)
371*da0073e9SAndroid Build Coastguard Worker
372*da0073e9SAndroid Build Coastguard Worker    def test_raise_from_None(self):
373*da0073e9SAndroid Build Coastguard Worker        # Inspired from os.environ
374*da0073e9SAndroid Build Coastguard Worker        class MyMapping:
375*da0073e9SAndroid Build Coastguard Worker            def __init__(self, d):
376*da0073e9SAndroid Build Coastguard Worker                self._d = d
377*da0073e9SAndroid Build Coastguard Worker
378*da0073e9SAndroid Build Coastguard Worker            def __getitem__(self, key):
379*da0073e9SAndroid Build Coastguard Worker                try:
380*da0073e9SAndroid Build Coastguard Worker                    value = self._d[key]
381*da0073e9SAndroid Build Coastguard Worker                except KeyError:
382*da0073e9SAndroid Build Coastguard Worker                    raise KeyError(key) from None
383*da0073e9SAndroid Build Coastguard Worker                return value
384*da0073e9SAndroid Build Coastguard Worker
385*da0073e9SAndroid Build Coastguard Worker        d = MyMapping({"a": 10, "b": 20})
386*da0073e9SAndroid Build Coastguard Worker
387*da0073e9SAndroid Build Coastguard Worker        def mapping_get(obj, key, value=None):
388*da0073e9SAndroid Build Coastguard Worker            try:
389*da0073e9SAndroid Build Coastguard Worker                return obj.__getitem__(key)
390*da0073e9SAndroid Build Coastguard Worker            except KeyError:
391*da0073e9SAndroid Build Coastguard Worker                return value
392*da0073e9SAndroid Build Coastguard Worker
393*da0073e9SAndroid Build Coastguard Worker        def fn(x, d, key):
394*da0073e9SAndroid Build Coastguard Worker            x = torch.sin(x + 1)
395*da0073e9SAndroid Build Coastguard Worker            return x, mapping_get(d, key)
396*da0073e9SAndroid Build Coastguard Worker
397*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
398*da0073e9SAndroid Build Coastguard Worker
399*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(2, 3)
400*da0073e9SAndroid Build Coastguard Worker        ref = fn(x, d, "m")
401*da0073e9SAndroid Build Coastguard Worker        res = opt_fn(x, d, "m")
402*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(ref[0], res[0])
403*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(ref[1], res[1])
404*da0073e9SAndroid Build Coastguard Worker
405*da0073e9SAndroid Build Coastguard Worker
406*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
407*da0073e9SAndroid Build Coastguard Worker    from torch._dynamo.test_case import run_tests
408*da0073e9SAndroid Build Coastguard Worker
409*da0073e9SAndroid Build Coastguard Worker    run_tests()
410