xref: /aosp_15_r20/external/pytorch/test/dynamo/test_backends.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: dynamo"]
2*da0073e9SAndroid Build Coastguard Workerimport sys
3*da0073e9SAndroid Build Coastguard Workerimport unittest
4*da0073e9SAndroid Build Coastguard Workerfrom unittest.mock import MagicMock, patch
5*da0073e9SAndroid Build Coastguard Worker
6*da0073e9SAndroid Build Coastguard Workerimport torch
7*da0073e9SAndroid Build Coastguard Workerimport torch._dynamo
8*da0073e9SAndroid Build Coastguard Workerimport torch._dynamo.backends
9*da0073e9SAndroid Build Coastguard Workerimport torch._dynamo.test_case
10*da0073e9SAndroid Build Coastguard Workerfrom torch._dynamo.backends.debugging import ExplainWithBackend
11*da0073e9SAndroid Build Coastguard Workerfrom torch._dynamo.backends.onnxrt import has_onnxruntime
12*da0073e9SAndroid Build Coastguard Workerfrom torch._dynamo.backends.tvm import has_tvm
13*da0073e9SAndroid Build Coastguard Workerfrom torch._dynamo.testing import same
14*da0073e9SAndroid Build Coastguard Workerfrom torch.fx._lazy_graph_module import _force_skip_lazy_graph_module
15*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.inductor_utils import HAS_CUDA
16*da0073e9SAndroid Build Coastguard Worker
17*da0073e9SAndroid Build Coastguard Worker
18*da0073e9SAndroid Build Coastguard Workerrequires_cuda = unittest.skipUnless(HAS_CUDA, "requires cuda")
19*da0073e9SAndroid Build Coastguard Worker
20*da0073e9SAndroid Build Coastguard Worker
21*da0073e9SAndroid Build Coastguard Workerclass Seq(torch.nn.Module):
22*da0073e9SAndroid Build Coastguard Worker    def __init__(self) -> None:
23*da0073e9SAndroid Build Coastguard Worker        super().__init__()
24*da0073e9SAndroid Build Coastguard Worker        self.layers = torch.nn.Sequential(
25*da0073e9SAndroid Build Coastguard Worker            torch.nn.Linear(10, 10),
26*da0073e9SAndroid Build Coastguard Worker            torch.nn.ReLU(),
27*da0073e9SAndroid Build Coastguard Worker            torch.nn.Linear(10, 10),
28*da0073e9SAndroid Build Coastguard Worker            torch.nn.Sigmoid(),
29*da0073e9SAndroid Build Coastguard Worker        )
30*da0073e9SAndroid Build Coastguard Worker
31*da0073e9SAndroid Build Coastguard Worker    def forward(self, x):
32*da0073e9SAndroid Build Coastguard Worker        return self.layers(x)
33*da0073e9SAndroid Build Coastguard Worker
34*da0073e9SAndroid Build Coastguard Worker
35*da0073e9SAndroid Build Coastguard Workerclass Conv_Bn_Relu(torch.nn.Module):
36*da0073e9SAndroid Build Coastguard Worker    def __init__(self, in_channels, out_channels, **kwargs):
37*da0073e9SAndroid Build Coastguard Worker        super().__init__()
38*da0073e9SAndroid Build Coastguard Worker        self.conv = torch.nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
39*da0073e9SAndroid Build Coastguard Worker        self.bn = torch.nn.BatchNorm2d(out_channels, eps=0.001)
40*da0073e9SAndroid Build Coastguard Worker        self.relu = torch.nn.ReLU()
41*da0073e9SAndroid Build Coastguard Worker
42*da0073e9SAndroid Build Coastguard Worker    def forward(self, x):
43*da0073e9SAndroid Build Coastguard Worker        return self.relu(self.bn(self.conv(x)))
44*da0073e9SAndroid Build Coastguard Worker
45*da0073e9SAndroid Build Coastguard Worker
46*da0073e9SAndroid Build Coastguard Workerclass TestOptimizations(torch._dynamo.test_case.TestCase):
47*da0073e9SAndroid Build Coastguard Worker    def test_example_inputs(self):
48*da0073e9SAndroid Build Coastguard Worker        def fn(a, bc, d):
49*da0073e9SAndroid Build Coastguard Worker            b, c = bc
50*da0073e9SAndroid Build Coastguard Worker            return a / d - b / c
51*da0073e9SAndroid Build Coastguard Worker
52*da0073e9SAndroid Build Coastguard Worker        def compiler_fn(graph, example_inputs):
53*da0073e9SAndroid Build Coastguard Worker            nonlocal r1
54*da0073e9SAndroid Build Coastguard Worker            r1 = graph(*example_inputs)[0]
55*da0073e9SAndroid Build Coastguard Worker            return graph.forward
56*da0073e9SAndroid Build Coastguard Worker
57*da0073e9SAndroid Build Coastguard Worker        a = torch.empty(2).fill_(1)
58*da0073e9SAndroid Build Coastguard Worker        b = torch.empty(2).fill_(2)
59*da0073e9SAndroid Build Coastguard Worker        c = torch.empty(2).fill_(3)
60*da0073e9SAndroid Build Coastguard Worker        d = 4
61*da0073e9SAndroid Build Coastguard Worker        r1 = None
62*da0073e9SAndroid Build Coastguard Worker        r2 = fn(a, (b, c), d)
63*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize_assert(compiler_fn)(fn)
64*da0073e9SAndroid Build Coastguard Worker        r3 = opt_fn(a, (b, c), d)
65*da0073e9SAndroid Build Coastguard Worker
66*da0073e9SAndroid Build Coastguard Worker        self.assertIsNotNone(r1)
67*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(r1.size(), r2.size())
68*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(r1.stride(), r2.stride())
69*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(r1.dtype, r2.dtype)
70*da0073e9SAndroid Build Coastguard Worker
71*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(r1.size(), r3.size())
72*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(r1.stride(), r3.stride())
73*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(r1.dtype, r3.dtype)
74*da0073e9SAndroid Build Coastguard Worker
75*da0073e9SAndroid Build Coastguard Worker    def test_example_inputs_runtime_use(self):
76*da0073e9SAndroid Build Coastguard Worker        def fn(a, bc, d):
77*da0073e9SAndroid Build Coastguard Worker            b, c = bc
78*da0073e9SAndroid Build Coastguard Worker            return a / d - b / c
79*da0073e9SAndroid Build Coastguard Worker
80*da0073e9SAndroid Build Coastguard Worker        def compiler_fn(graph, example_inputs):
81*da0073e9SAndroid Build Coastguard Worker            def fwd(*args):
82*da0073e9SAndroid Build Coastguard Worker                nonlocal r1
83*da0073e9SAndroid Build Coastguard Worker                r = graph.forward(*args)
84*da0073e9SAndroid Build Coastguard Worker                r1 = r[0]
85*da0073e9SAndroid Build Coastguard Worker                return r
86*da0073e9SAndroid Build Coastguard Worker
87*da0073e9SAndroid Build Coastguard Worker            return fwd
88*da0073e9SAndroid Build Coastguard Worker
89*da0073e9SAndroid Build Coastguard Worker        a = torch.empty(2).fill_(1)
90*da0073e9SAndroid Build Coastguard Worker        b = torch.empty(2).fill_(2)
91*da0073e9SAndroid Build Coastguard Worker        c = torch.empty(2).fill_(3)
92*da0073e9SAndroid Build Coastguard Worker        d = 4
93*da0073e9SAndroid Build Coastguard Worker        r1 = None
94*da0073e9SAndroid Build Coastguard Worker        r2 = fn(a, (b, c), d)
95*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize_assert(compiler_fn)(fn)
96*da0073e9SAndroid Build Coastguard Worker        r3 = opt_fn(a, (b, c), d)
97*da0073e9SAndroid Build Coastguard Worker
98*da0073e9SAndroid Build Coastguard Worker        self.assertIsNotNone(r1)
99*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(r1, r2))
100*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(r1, r3))
101*da0073e9SAndroid Build Coastguard Worker
102*da0073e9SAndroid Build Coastguard Worker    def _check_backend_works(self, backend, options=None):
103*da0073e9SAndroid Build Coastguard Worker        model = Seq().eval()
104*da0073e9SAndroid Build Coastguard Worker        input = torch.randn(2, 10)
105*da0073e9SAndroid Build Coastguard Worker        r1 = model(input)
106*da0073e9SAndroid Build Coastguard Worker        r2 = torch.compile(model, backend=backend, options=options)(input)
107*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(r1, r2.float(), tol=0.01))
108*da0073e9SAndroid Build Coastguard Worker
109*da0073e9SAndroid Build Coastguard Worker    def test_eager(self):
110*da0073e9SAndroid Build Coastguard Worker        self._check_backend_works("eager")
111*da0073e9SAndroid Build Coastguard Worker
112*da0073e9SAndroid Build Coastguard Worker    def test_eager_noexcept(self):
113*da0073e9SAndroid Build Coastguard Worker        self._check_backend_works("eager_noexcept")
114*da0073e9SAndroid Build Coastguard Worker
115*da0073e9SAndroid Build Coastguard Worker    @_force_skip_lazy_graph_module()
116*da0073e9SAndroid Build Coastguard Worker    def test_torchscript(self):
117*da0073e9SAndroid Build Coastguard Worker        self._check_backend_works("ts")
118*da0073e9SAndroid Build Coastguard Worker
119*da0073e9SAndroid Build Coastguard Worker    def test_aot_eager(self):
120*da0073e9SAndroid Build Coastguard Worker        self._check_backend_works("aot_eager")
121*da0073e9SAndroid Build Coastguard Worker
122*da0073e9SAndroid Build Coastguard Worker    def test_aot_eager_decomp_partition(self):
123*da0073e9SAndroid Build Coastguard Worker        self._check_backend_works("aot_eager_decomp_partition")
124*da0073e9SAndroid Build Coastguard Worker
125*da0073e9SAndroid Build Coastguard Worker    @_force_skip_lazy_graph_module()
126*da0073e9SAndroid Build Coastguard Worker    def test_aot_ts(self):
127*da0073e9SAndroid Build Coastguard Worker        self._check_backend_works("aot_ts")
128*da0073e9SAndroid Build Coastguard Worker
129*da0073e9SAndroid Build Coastguard Worker    @requires_cuda
130*da0073e9SAndroid Build Coastguard Worker    def test_aot_cudagraphs(self):
131*da0073e9SAndroid Build Coastguard Worker        self._check_backend_works("cudagraphs")
132*da0073e9SAndroid Build Coastguard Worker
133*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not has_onnxruntime(), "requires onnxruntime")
134*da0073e9SAndroid Build Coastguard Worker    def test_onnxrt(self):
135*da0073e9SAndroid Build Coastguard Worker        self._check_backend_works("onnxrt")
136*da0073e9SAndroid Build Coastguard Worker
137*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not has_tvm(), "requires tvm")
138*da0073e9SAndroid Build Coastguard Worker    def test_tvm(self):
139*da0073e9SAndroid Build Coastguard Worker        self._check_backend_works("tvm")
140*da0073e9SAndroid Build Coastguard Worker        self._check_backend_works("tvm", options={"scheduler": None})
141*da0073e9SAndroid Build Coastguard Worker        self._check_backend_works("tvm", options={"opt_level": 0})
142*da0073e9SAndroid Build Coastguard Worker
143*da0073e9SAndroid Build Coastguard Worker    def test_list_backends(self):
144*da0073e9SAndroid Build Coastguard Worker        self.assertIn("inductor", torch._dynamo.list_backends())
145*da0073e9SAndroid Build Coastguard Worker        self.assertIn("inductor", torch._dynamo.list_backends(exclude_tags=None))
146*da0073e9SAndroid Build Coastguard Worker        self.assertNotIn("eager", torch._dynamo.list_backends())
147*da0073e9SAndroid Build Coastguard Worker        self.assertNotIn("eager", torch._dynamo.list_backends(exclude_tags=["debug"]))
148*da0073e9SAndroid Build Coastguard Worker        self.assertIn("eager", torch._dynamo.list_backends(exclude_tags=[]))
149*da0073e9SAndroid Build Coastguard Worker
150*da0073e9SAndroid Build Coastguard Worker
151*da0073e9SAndroid Build Coastguard Workerclass NormalizeIRTests(torch._dynamo.test_case.TestCase):
152*da0073e9SAndroid Build Coastguard Worker    def test_inplace_normalize(self):
153*da0073e9SAndroid Build Coastguard Worker        def fn(a, b):
154*da0073e9SAndroid Build Coastguard Worker            x = torch.cos(a)
155*da0073e9SAndroid Build Coastguard Worker            x += b
156*da0073e9SAndroid Build Coastguard Worker            return torch.sin(x)
157*da0073e9SAndroid Build Coastguard Worker
158*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(10)
159*da0073e9SAndroid Build Coastguard Worker        b = torch.randn(10).to(torch.float64)
160*da0073e9SAndroid Build Coastguard Worker
161*da0073e9SAndroid Build Coastguard Worker        ref = fn(a, b)
162*da0073e9SAndroid Build Coastguard Worker
163*da0073e9SAndroid Build Coastguard Worker        optimized_fn = torch._dynamo.optimize("aot_eager")(fn)
164*da0073e9SAndroid Build Coastguard Worker        res = optimized_fn(a, b)
165*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(ref, res))
166*da0073e9SAndroid Build Coastguard Worker
167*da0073e9SAndroid Build Coastguard Worker
168*da0073e9SAndroid Build Coastguard Workerclass MPSNotSupportedTest(torch._dynamo.test_case.TestCase):
169*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not torch.backends.mps.is_available(), "requires mps")
170*da0073e9SAndroid Build Coastguard Worker    def test_mps_not_supported(self):
171*da0073e9SAndroid Build Coastguard Worker        model = Seq().to("mps")
172*da0073e9SAndroid Build Coastguard Worker        example_input = torch.randn(1, 10).to("mps")
173*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(
174*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
175*da0073e9SAndroid Build Coastguard Worker            lambda: torch.compile(model, backend="inductor")(example_input),
176*da0073e9SAndroid Build Coastguard Worker        )
177*da0073e9SAndroid Build Coastguard Worker
178*da0073e9SAndroid Build Coastguard Worker
179*da0073e9SAndroid Build Coastguard Workerclass TestExplainWithBackend(torch._dynamo.test_case.TestCase):
180*da0073e9SAndroid Build Coastguard Worker    def test_explain_with_backend(self):
181*da0073e9SAndroid Build Coastguard Worker        def fn3(x):
182*da0073e9SAndroid Build Coastguard Worker            x = torch.sin(x)
183*da0073e9SAndroid Build Coastguard Worker            torch._dynamo.graph_break()
184*da0073e9SAndroid Build Coastguard Worker            x = torch.sin(x)
185*da0073e9SAndroid Build Coastguard Worker            return x
186*da0073e9SAndroid Build Coastguard Worker
187*da0073e9SAndroid Build Coastguard Worker        def fn2(x):
188*da0073e9SAndroid Build Coastguard Worker            x = torch.cos(x)
189*da0073e9SAndroid Build Coastguard Worker            x = fn3(x)
190*da0073e9SAndroid Build Coastguard Worker            x = torch.cos(x)
191*da0073e9SAndroid Build Coastguard Worker            return x
192*da0073e9SAndroid Build Coastguard Worker
193*da0073e9SAndroid Build Coastguard Worker        def fn1(x):
194*da0073e9SAndroid Build Coastguard Worker            x = torch.tan(x)
195*da0073e9SAndroid Build Coastguard Worker            x = fn2(x)
196*da0073e9SAndroid Build Coastguard Worker            x = torch.tan(x)
197*da0073e9SAndroid Build Coastguard Worker            return x
198*da0073e9SAndroid Build Coastguard Worker
199*da0073e9SAndroid Build Coastguard Worker        def fn(x):
200*da0073e9SAndroid Build Coastguard Worker            x = torch.sigmoid(x)
201*da0073e9SAndroid Build Coastguard Worker            x = fn1(x)
202*da0073e9SAndroid Build Coastguard Worker            x = torch.sigmoid(x)
203*da0073e9SAndroid Build Coastguard Worker            return x
204*da0073e9SAndroid Build Coastguard Worker
205*da0073e9SAndroid Build Coastguard Worker        # Wrap TorchInductor with explain backend
206*da0073e9SAndroid Build Coastguard Worker        eb = ExplainWithBackend("inductor")
207*da0073e9SAndroid Build Coastguard Worker        optimized_fn = torch.compile(fn, backend=eb)
208*da0073e9SAndroid Build Coastguard Worker        input_tensor = torch.randn(5)
209*da0073e9SAndroid Build Coastguard Worker        result = optimized_fn(input_tensor)
210*da0073e9SAndroid Build Coastguard Worker
211*da0073e9SAndroid Build Coastguard Worker        # Check that fn still produces the same output when wrapped by ExplainWithBackend
212*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch.allclose(result, fn(input_tensor)))
213*da0073e9SAndroid Build Coastguard Worker
214*da0073e9SAndroid Build Coastguard Worker        # Verify ExplainOutput object contents, output might change but make sure these fields are present
215*da0073e9SAndroid Build Coastguard Worker        explain_output = eb.output()
216*da0073e9SAndroid Build Coastguard Worker        explain_str = str(explain_output)
217*da0073e9SAndroid Build Coastguard Worker        self.assertIn("Graph Count", explain_str)
218*da0073e9SAndroid Build Coastguard Worker        self.assertIn("Graph Break Count", explain_str)
219*da0073e9SAndroid Build Coastguard Worker        self.assertIn("Op Count", explain_str)
220*da0073e9SAndroid Build Coastguard Worker        self.assertIn("Break Reasons", explain_str)
221*da0073e9SAndroid Build Coastguard Worker
222*da0073e9SAndroid Build Coastguard Worker        # Verify that for the given functions above, we report the correct number of graphs, graph breaks, and ops
223*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(8, explain_output.graph_count)
224*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(7, explain_output.graph_break_count)
225*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(8, explain_output.op_count)
226*da0073e9SAndroid Build Coastguard Worker
227*da0073e9SAndroid Build Coastguard Worker
228*da0073e9SAndroid Build Coastguard Workerclass TestCustomBackendAPI(torch._dynamo.test_case.TestCase):
229*da0073e9SAndroid Build Coastguard Worker    """Test APIs documented by https://pytorch.org/docs/main/torch.compiler_custom_backends.html"""
230*da0073e9SAndroid Build Coastguard Worker
231*da0073e9SAndroid Build Coastguard Worker    def test_register_backend_api(self):
232*da0073e9SAndroid Build Coastguard Worker        from torch._dynamo import register_backend
233*da0073e9SAndroid Build Coastguard Worker
234*da0073e9SAndroid Build Coastguard Worker        backend_run = False
235*da0073e9SAndroid Build Coastguard Worker
236*da0073e9SAndroid Build Coastguard Worker        @register_backend
237*da0073e9SAndroid Build Coastguard Worker        def my_custom_backend(gm, example_inputs):
238*da0073e9SAndroid Build Coastguard Worker            nonlocal backend_run
239*da0073e9SAndroid Build Coastguard Worker            backend_run = True
240*da0073e9SAndroid Build Coastguard Worker            return gm.forward
241*da0073e9SAndroid Build Coastguard Worker
242*da0073e9SAndroid Build Coastguard Worker        def f(x):
243*da0073e9SAndroid Build Coastguard Worker            return torch.relu(x)
244*da0073e9SAndroid Build Coastguard Worker
245*da0073e9SAndroid Build Coastguard Worker        opt_f = torch.compile(f, backend="my_custom_backend")
246*da0073e9SAndroid Build Coastguard Worker        opt_f(torch.randn(3, 3))
247*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(backend_run)
248*da0073e9SAndroid Build Coastguard Worker
249*da0073e9SAndroid Build Coastguard Worker    def test_aot_autograd_api(self):
250*da0073e9SAndroid Build Coastguard Worker        from functorch.compile import make_boxed_func
251*da0073e9SAndroid Build Coastguard Worker        from torch._dynamo.backends.common import aot_autograd
252*da0073e9SAndroid Build Coastguard Worker
253*da0073e9SAndroid Build Coastguard Worker        backend_run = False
254*da0073e9SAndroid Build Coastguard Worker
255*da0073e9SAndroid Build Coastguard Worker        def my_compiler(gm, example_inputs):
256*da0073e9SAndroid Build Coastguard Worker            nonlocal backend_run
257*da0073e9SAndroid Build Coastguard Worker            backend_run = True
258*da0073e9SAndroid Build Coastguard Worker            return make_boxed_func(gm.forward)
259*da0073e9SAndroid Build Coastguard Worker
260*da0073e9SAndroid Build Coastguard Worker        my_backend = aot_autograd(fw_compiler=my_compiler)
261*da0073e9SAndroid Build Coastguard Worker
262*da0073e9SAndroid Build Coastguard Worker        def f(x):
263*da0073e9SAndroid Build Coastguard Worker            return torch.relu(x)
264*da0073e9SAndroid Build Coastguard Worker
265*da0073e9SAndroid Build Coastguard Worker        opt_f = torch.compile(f, backend=my_backend)
266*da0073e9SAndroid Build Coastguard Worker        opt_f(torch.randn(3, 3))
267*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(backend_run)
268*da0073e9SAndroid Build Coastguard Worker
269*da0073e9SAndroid Build Coastguard Worker    def test_lookup_backend(self):
270*da0073e9SAndroid Build Coastguard Worker        from torch._dynamo import list_backends, lookup_backend
271*da0073e9SAndroid Build Coastguard Worker
272*da0073e9SAndroid Build Coastguard Worker        backends = list_backends()
273*da0073e9SAndroid Build Coastguard Worker        backend_run = False
274*da0073e9SAndroid Build Coastguard Worker
275*da0073e9SAndroid Build Coastguard Worker        def my_compiler(gm, example_inputs):
276*da0073e9SAndroid Build Coastguard Worker            nonlocal backend_run
277*da0073e9SAndroid Build Coastguard Worker            backend_run = True
278*da0073e9SAndroid Build Coastguard Worker            try:
279*da0073e9SAndroid Build Coastguard Worker                trt_compiled = lookup_backend("tensorrt")(gm, example_inputs)
280*da0073e9SAndroid Build Coastguard Worker                if trt_compiled is not None:
281*da0073e9SAndroid Build Coastguard Worker                    return trt_compiled
282*da0073e9SAndroid Build Coastguard Worker            except Exception:
283*da0073e9SAndroid Build Coastguard Worker                pass
284*da0073e9SAndroid Build Coastguard Worker            # first backend failed, try something else...
285*da0073e9SAndroid Build Coastguard Worker            try:
286*da0073e9SAndroid Build Coastguard Worker                inductor_compiled = lookup_backend("inductor")(gm, example_inputs)
287*da0073e9SAndroid Build Coastguard Worker                if inductor_compiled is not None:
288*da0073e9SAndroid Build Coastguard Worker                    return inductor_compiled
289*da0073e9SAndroid Build Coastguard Worker            except Exception:
290*da0073e9SAndroid Build Coastguard Worker                pass
291*da0073e9SAndroid Build Coastguard Worker            return gm.forward
292*da0073e9SAndroid Build Coastguard Worker
293*da0073e9SAndroid Build Coastguard Worker        def f(x):
294*da0073e9SAndroid Build Coastguard Worker            return torch.relu(x)
295*da0073e9SAndroid Build Coastguard Worker
296*da0073e9SAndroid Build Coastguard Worker        opt_f = torch.compile(f, backend=my_compiler)
297*da0073e9SAndroid Build Coastguard Worker        opt_f(torch.randn(3, 3))
298*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(backend_run)
299*da0073e9SAndroid Build Coastguard Worker
300*da0073e9SAndroid Build Coastguard Worker    def test_lookup_custom_backend(self):
301*da0073e9SAndroid Build Coastguard Worker        from torch._dynamo import list_backends
302*da0073e9SAndroid Build Coastguard Worker
303*da0073e9SAndroid Build Coastguard Worker        backends_group = "torch_dynamo_backends"
304*da0073e9SAndroid Build Coastguard Worker        name = "mycustombackend"
305*da0073e9SAndroid Build Coastguard Worker
306*da0073e9SAndroid Build Coastguard Worker        mock_3_9 = MagicMock()
307*da0073e9SAndroid Build Coastguard Worker        mock_3_9.load.return_value = lambda: "mocked 3.9"
308*da0073e9SAndroid Build Coastguard Worker        mock_3_9.name = name
309*da0073e9SAndroid Build Coastguard Worker
310*da0073e9SAndroid Build Coastguard Worker        mock_3_10 = MagicMock()
311*da0073e9SAndroid Build Coastguard Worker        mock_3_10.load.return_value = lambda: "mocked 3.10"
312*da0073e9SAndroid Build Coastguard Worker
313*da0073e9SAndroid Build Coastguard Worker        def mock_eps(group=None):
314*da0073e9SAndroid Build Coastguard Worker            if sys.version_info < (3, 10):
315*da0073e9SAndroid Build Coastguard Worker                return {backends_group: [mock_3_9]}
316*da0073e9SAndroid Build Coastguard Worker            else:
317*da0073e9SAndroid Build Coastguard Worker                assert group == backends_group, group
318*da0073e9SAndroid Build Coastguard Worker                mock_group = MagicMock()
319*da0073e9SAndroid Build Coastguard Worker                mock_group.names = [name]
320*da0073e9SAndroid Build Coastguard Worker                mock_group[name] = mock_3_10
321*da0073e9SAndroid Build Coastguard Worker                # mock_group[name].load.return_value = lambda: "mocked 3.10"
322*da0073e9SAndroid Build Coastguard Worker                return mock_group
323*da0073e9SAndroid Build Coastguard Worker
324*da0073e9SAndroid Build Coastguard Worker        with patch("importlib.metadata.entry_points", mock_eps):
325*da0073e9SAndroid Build Coastguard Worker            from torch._dynamo.backends import registry
326*da0073e9SAndroid Build Coastguard Worker
327*da0073e9SAndroid Build Coastguard Worker            registry._lazy_import.cache_clear()
328*da0073e9SAndroid Build Coastguard Worker            registry._discover_entrypoint_backends.cache_clear()
329*da0073e9SAndroid Build Coastguard Worker
330*da0073e9SAndroid Build Coastguard Worker            backends = list_backends()
331*da0073e9SAndroid Build Coastguard Worker            assert name in backends, (name, backends)
332*da0073e9SAndroid Build Coastguard Worker
333*da0073e9SAndroid Build Coastguard Worker    def test_backend_recompilation(self):
334*da0073e9SAndroid Build Coastguard Worker        def fn(x):
335*da0073e9SAndroid Build Coastguard Worker            return x + x
336*da0073e9SAndroid Build Coastguard Worker
337*da0073e9SAndroid Build Coastguard Worker        input = torch.tensor(2.0)
338*da0073e9SAndroid Build Coastguard Worker
339*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch.compile(
340*da0073e9SAndroid Build Coastguard Worker            fn, backend="inductor", options={"_raise_error_for_testing": False}
341*da0073e9SAndroid Build Coastguard Worker        )
342*da0073e9SAndroid Build Coastguard Worker        opt_fn(input)
343*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(torch._dynamo.exc.BackendCompilerFailed):
344*da0073e9SAndroid Build Coastguard Worker            opt_fn = torch.compile(
345*da0073e9SAndroid Build Coastguard Worker                fn, backend="inductor", options={"_raise_error_for_testing": True}
346*da0073e9SAndroid Build Coastguard Worker            )
347*da0073e9SAndroid Build Coastguard Worker            opt_fn(input)
348*da0073e9SAndroid Build Coastguard Worker
349*da0073e9SAndroid Build Coastguard Worker
350*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
351*da0073e9SAndroid Build Coastguard Worker    from torch._dynamo.test_case import run_tests
352*da0073e9SAndroid Build Coastguard Worker
353*da0073e9SAndroid Build Coastguard Worker    run_tests()
354