1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: dynamo"] 2*da0073e9SAndroid Build Coastguard Workerimport sys 3*da0073e9SAndroid Build Coastguard Workerimport unittest 4*da0073e9SAndroid Build Coastguard Workerfrom unittest.mock import MagicMock, patch 5*da0073e9SAndroid Build Coastguard Worker 6*da0073e9SAndroid Build Coastguard Workerimport torch 7*da0073e9SAndroid Build Coastguard Workerimport torch._dynamo 8*da0073e9SAndroid Build Coastguard Workerimport torch._dynamo.backends 9*da0073e9SAndroid Build Coastguard Workerimport torch._dynamo.test_case 10*da0073e9SAndroid Build Coastguard Workerfrom torch._dynamo.backends.debugging import ExplainWithBackend 11*da0073e9SAndroid Build Coastguard Workerfrom torch._dynamo.backends.onnxrt import has_onnxruntime 12*da0073e9SAndroid Build Coastguard Workerfrom torch._dynamo.backends.tvm import has_tvm 13*da0073e9SAndroid Build Coastguard Workerfrom torch._dynamo.testing import same 14*da0073e9SAndroid Build Coastguard Workerfrom torch.fx._lazy_graph_module import _force_skip_lazy_graph_module 15*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.inductor_utils import HAS_CUDA 16*da0073e9SAndroid Build Coastguard Worker 17*da0073e9SAndroid Build Coastguard Worker 18*da0073e9SAndroid Build Coastguard Workerrequires_cuda = unittest.skipUnless(HAS_CUDA, "requires cuda") 19*da0073e9SAndroid Build Coastguard Worker 20*da0073e9SAndroid Build Coastguard Worker 21*da0073e9SAndroid Build Coastguard Workerclass Seq(torch.nn.Module): 22*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 23*da0073e9SAndroid Build Coastguard Worker super().__init__() 24*da0073e9SAndroid Build Coastguard Worker self.layers = torch.nn.Sequential( 25*da0073e9SAndroid Build Coastguard Worker torch.nn.Linear(10, 10), 26*da0073e9SAndroid Build Coastguard Worker torch.nn.ReLU(), 27*da0073e9SAndroid Build Coastguard Worker torch.nn.Linear(10, 10), 28*da0073e9SAndroid Build Coastguard Worker torch.nn.Sigmoid(), 29*da0073e9SAndroid Build Coastguard Worker ) 30*da0073e9SAndroid Build Coastguard Worker 31*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 32*da0073e9SAndroid Build Coastguard Worker return self.layers(x) 33*da0073e9SAndroid Build Coastguard Worker 34*da0073e9SAndroid Build Coastguard Worker 35*da0073e9SAndroid Build Coastguard Workerclass Conv_Bn_Relu(torch.nn.Module): 36*da0073e9SAndroid Build Coastguard Worker def __init__(self, in_channels, out_channels, **kwargs): 37*da0073e9SAndroid Build Coastguard Worker super().__init__() 38*da0073e9SAndroid Build Coastguard Worker self.conv = torch.nn.Conv2d(in_channels, out_channels, bias=False, **kwargs) 39*da0073e9SAndroid Build Coastguard Worker self.bn = torch.nn.BatchNorm2d(out_channels, eps=0.001) 40*da0073e9SAndroid Build Coastguard Worker self.relu = torch.nn.ReLU() 41*da0073e9SAndroid Build Coastguard Worker 42*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 43*da0073e9SAndroid Build Coastguard Worker return self.relu(self.bn(self.conv(x))) 44*da0073e9SAndroid Build Coastguard Worker 45*da0073e9SAndroid Build Coastguard Worker 46*da0073e9SAndroid Build Coastguard Workerclass TestOptimizations(torch._dynamo.test_case.TestCase): 47*da0073e9SAndroid Build Coastguard Worker def test_example_inputs(self): 48*da0073e9SAndroid Build Coastguard Worker def fn(a, bc, d): 49*da0073e9SAndroid Build Coastguard Worker b, c = bc 50*da0073e9SAndroid Build Coastguard Worker return a / d - b / c 51*da0073e9SAndroid Build Coastguard Worker 52*da0073e9SAndroid Build Coastguard Worker def compiler_fn(graph, example_inputs): 53*da0073e9SAndroid Build Coastguard Worker nonlocal r1 54*da0073e9SAndroid Build Coastguard Worker r1 = graph(*example_inputs)[0] 55*da0073e9SAndroid Build Coastguard Worker return graph.forward 56*da0073e9SAndroid Build Coastguard Worker 57*da0073e9SAndroid Build Coastguard Worker a = torch.empty(2).fill_(1) 58*da0073e9SAndroid Build Coastguard Worker b = torch.empty(2).fill_(2) 59*da0073e9SAndroid Build Coastguard Worker c = torch.empty(2).fill_(3) 60*da0073e9SAndroid Build Coastguard Worker d = 4 61*da0073e9SAndroid Build Coastguard Worker r1 = None 62*da0073e9SAndroid Build Coastguard Worker r2 = fn(a, (b, c), d) 63*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize_assert(compiler_fn)(fn) 64*da0073e9SAndroid Build Coastguard Worker r3 = opt_fn(a, (b, c), d) 65*da0073e9SAndroid Build Coastguard Worker 66*da0073e9SAndroid Build Coastguard Worker self.assertIsNotNone(r1) 67*da0073e9SAndroid Build Coastguard Worker self.assertEqual(r1.size(), r2.size()) 68*da0073e9SAndroid Build Coastguard Worker self.assertEqual(r1.stride(), r2.stride()) 69*da0073e9SAndroid Build Coastguard Worker self.assertEqual(r1.dtype, r2.dtype) 70*da0073e9SAndroid Build Coastguard Worker 71*da0073e9SAndroid Build Coastguard Worker self.assertEqual(r1.size(), r3.size()) 72*da0073e9SAndroid Build Coastguard Worker self.assertEqual(r1.stride(), r3.stride()) 73*da0073e9SAndroid Build Coastguard Worker self.assertEqual(r1.dtype, r3.dtype) 74*da0073e9SAndroid Build Coastguard Worker 75*da0073e9SAndroid Build Coastguard Worker def test_example_inputs_runtime_use(self): 76*da0073e9SAndroid Build Coastguard Worker def fn(a, bc, d): 77*da0073e9SAndroid Build Coastguard Worker b, c = bc 78*da0073e9SAndroid Build Coastguard Worker return a / d - b / c 79*da0073e9SAndroid Build Coastguard Worker 80*da0073e9SAndroid Build Coastguard Worker def compiler_fn(graph, example_inputs): 81*da0073e9SAndroid Build Coastguard Worker def fwd(*args): 82*da0073e9SAndroid Build Coastguard Worker nonlocal r1 83*da0073e9SAndroid Build Coastguard Worker r = graph.forward(*args) 84*da0073e9SAndroid Build Coastguard Worker r1 = r[0] 85*da0073e9SAndroid Build Coastguard Worker return r 86*da0073e9SAndroid Build Coastguard Worker 87*da0073e9SAndroid Build Coastguard Worker return fwd 88*da0073e9SAndroid Build Coastguard Worker 89*da0073e9SAndroid Build Coastguard Worker a = torch.empty(2).fill_(1) 90*da0073e9SAndroid Build Coastguard Worker b = torch.empty(2).fill_(2) 91*da0073e9SAndroid Build Coastguard Worker c = torch.empty(2).fill_(3) 92*da0073e9SAndroid Build Coastguard Worker d = 4 93*da0073e9SAndroid Build Coastguard Worker r1 = None 94*da0073e9SAndroid Build Coastguard Worker r2 = fn(a, (b, c), d) 95*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize_assert(compiler_fn)(fn) 96*da0073e9SAndroid Build Coastguard Worker r3 = opt_fn(a, (b, c), d) 97*da0073e9SAndroid Build Coastguard Worker 98*da0073e9SAndroid Build Coastguard Worker self.assertIsNotNone(r1) 99*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(r1, r2)) 100*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(r1, r3)) 101*da0073e9SAndroid Build Coastguard Worker 102*da0073e9SAndroid Build Coastguard Worker def _check_backend_works(self, backend, options=None): 103*da0073e9SAndroid Build Coastguard Worker model = Seq().eval() 104*da0073e9SAndroid Build Coastguard Worker input = torch.randn(2, 10) 105*da0073e9SAndroid Build Coastguard Worker r1 = model(input) 106*da0073e9SAndroid Build Coastguard Worker r2 = torch.compile(model, backend=backend, options=options)(input) 107*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(r1, r2.float(), tol=0.01)) 108*da0073e9SAndroid Build Coastguard Worker 109*da0073e9SAndroid Build Coastguard Worker def test_eager(self): 110*da0073e9SAndroid Build Coastguard Worker self._check_backend_works("eager") 111*da0073e9SAndroid Build Coastguard Worker 112*da0073e9SAndroid Build Coastguard Worker def test_eager_noexcept(self): 113*da0073e9SAndroid Build Coastguard Worker self._check_backend_works("eager_noexcept") 114*da0073e9SAndroid Build Coastguard Worker 115*da0073e9SAndroid Build Coastguard Worker @_force_skip_lazy_graph_module() 116*da0073e9SAndroid Build Coastguard Worker def test_torchscript(self): 117*da0073e9SAndroid Build Coastguard Worker self._check_backend_works("ts") 118*da0073e9SAndroid Build Coastguard Worker 119*da0073e9SAndroid Build Coastguard Worker def test_aot_eager(self): 120*da0073e9SAndroid Build Coastguard Worker self._check_backend_works("aot_eager") 121*da0073e9SAndroid Build Coastguard Worker 122*da0073e9SAndroid Build Coastguard Worker def test_aot_eager_decomp_partition(self): 123*da0073e9SAndroid Build Coastguard Worker self._check_backend_works("aot_eager_decomp_partition") 124*da0073e9SAndroid Build Coastguard Worker 125*da0073e9SAndroid Build Coastguard Worker @_force_skip_lazy_graph_module() 126*da0073e9SAndroid Build Coastguard Worker def test_aot_ts(self): 127*da0073e9SAndroid Build Coastguard Worker self._check_backend_works("aot_ts") 128*da0073e9SAndroid Build Coastguard Worker 129*da0073e9SAndroid Build Coastguard Worker @requires_cuda 130*da0073e9SAndroid Build Coastguard Worker def test_aot_cudagraphs(self): 131*da0073e9SAndroid Build Coastguard Worker self._check_backend_works("cudagraphs") 132*da0073e9SAndroid Build Coastguard Worker 133*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not has_onnxruntime(), "requires onnxruntime") 134*da0073e9SAndroid Build Coastguard Worker def test_onnxrt(self): 135*da0073e9SAndroid Build Coastguard Worker self._check_backend_works("onnxrt") 136*da0073e9SAndroid Build Coastguard Worker 137*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not has_tvm(), "requires tvm") 138*da0073e9SAndroid Build Coastguard Worker def test_tvm(self): 139*da0073e9SAndroid Build Coastguard Worker self._check_backend_works("tvm") 140*da0073e9SAndroid Build Coastguard Worker self._check_backend_works("tvm", options={"scheduler": None}) 141*da0073e9SAndroid Build Coastguard Worker self._check_backend_works("tvm", options={"opt_level": 0}) 142*da0073e9SAndroid Build Coastguard Worker 143*da0073e9SAndroid Build Coastguard Worker def test_list_backends(self): 144*da0073e9SAndroid Build Coastguard Worker self.assertIn("inductor", torch._dynamo.list_backends()) 145*da0073e9SAndroid Build Coastguard Worker self.assertIn("inductor", torch._dynamo.list_backends(exclude_tags=None)) 146*da0073e9SAndroid Build Coastguard Worker self.assertNotIn("eager", torch._dynamo.list_backends()) 147*da0073e9SAndroid Build Coastguard Worker self.assertNotIn("eager", torch._dynamo.list_backends(exclude_tags=["debug"])) 148*da0073e9SAndroid Build Coastguard Worker self.assertIn("eager", torch._dynamo.list_backends(exclude_tags=[])) 149*da0073e9SAndroid Build Coastguard Worker 150*da0073e9SAndroid Build Coastguard Worker 151*da0073e9SAndroid Build Coastguard Workerclass NormalizeIRTests(torch._dynamo.test_case.TestCase): 152*da0073e9SAndroid Build Coastguard Worker def test_inplace_normalize(self): 153*da0073e9SAndroid Build Coastguard Worker def fn(a, b): 154*da0073e9SAndroid Build Coastguard Worker x = torch.cos(a) 155*da0073e9SAndroid Build Coastguard Worker x += b 156*da0073e9SAndroid Build Coastguard Worker return torch.sin(x) 157*da0073e9SAndroid Build Coastguard Worker 158*da0073e9SAndroid Build Coastguard Worker a = torch.randn(10) 159*da0073e9SAndroid Build Coastguard Worker b = torch.randn(10).to(torch.float64) 160*da0073e9SAndroid Build Coastguard Worker 161*da0073e9SAndroid Build Coastguard Worker ref = fn(a, b) 162*da0073e9SAndroid Build Coastguard Worker 163*da0073e9SAndroid Build Coastguard Worker optimized_fn = torch._dynamo.optimize("aot_eager")(fn) 164*da0073e9SAndroid Build Coastguard Worker res = optimized_fn(a, b) 165*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(ref, res)) 166*da0073e9SAndroid Build Coastguard Worker 167*da0073e9SAndroid Build Coastguard Worker 168*da0073e9SAndroid Build Coastguard Workerclass MPSNotSupportedTest(torch._dynamo.test_case.TestCase): 169*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not torch.backends.mps.is_available(), "requires mps") 170*da0073e9SAndroid Build Coastguard Worker def test_mps_not_supported(self): 171*da0073e9SAndroid Build Coastguard Worker model = Seq().to("mps") 172*da0073e9SAndroid Build Coastguard Worker example_input = torch.randn(1, 10).to("mps") 173*da0073e9SAndroid Build Coastguard Worker self.assertRaises( 174*da0073e9SAndroid Build Coastguard Worker RuntimeError, 175*da0073e9SAndroid Build Coastguard Worker lambda: torch.compile(model, backend="inductor")(example_input), 176*da0073e9SAndroid Build Coastguard Worker ) 177*da0073e9SAndroid Build Coastguard Worker 178*da0073e9SAndroid Build Coastguard Worker 179*da0073e9SAndroid Build Coastguard Workerclass TestExplainWithBackend(torch._dynamo.test_case.TestCase): 180*da0073e9SAndroid Build Coastguard Worker def test_explain_with_backend(self): 181*da0073e9SAndroid Build Coastguard Worker def fn3(x): 182*da0073e9SAndroid Build Coastguard Worker x = torch.sin(x) 183*da0073e9SAndroid Build Coastguard Worker torch._dynamo.graph_break() 184*da0073e9SAndroid Build Coastguard Worker x = torch.sin(x) 185*da0073e9SAndroid Build Coastguard Worker return x 186*da0073e9SAndroid Build Coastguard Worker 187*da0073e9SAndroid Build Coastguard Worker def fn2(x): 188*da0073e9SAndroid Build Coastguard Worker x = torch.cos(x) 189*da0073e9SAndroid Build Coastguard Worker x = fn3(x) 190*da0073e9SAndroid Build Coastguard Worker x = torch.cos(x) 191*da0073e9SAndroid Build Coastguard Worker return x 192*da0073e9SAndroid Build Coastguard Worker 193*da0073e9SAndroid Build Coastguard Worker def fn1(x): 194*da0073e9SAndroid Build Coastguard Worker x = torch.tan(x) 195*da0073e9SAndroid Build Coastguard Worker x = fn2(x) 196*da0073e9SAndroid Build Coastguard Worker x = torch.tan(x) 197*da0073e9SAndroid Build Coastguard Worker return x 198*da0073e9SAndroid Build Coastguard Worker 199*da0073e9SAndroid Build Coastguard Worker def fn(x): 200*da0073e9SAndroid Build Coastguard Worker x = torch.sigmoid(x) 201*da0073e9SAndroid Build Coastguard Worker x = fn1(x) 202*da0073e9SAndroid Build Coastguard Worker x = torch.sigmoid(x) 203*da0073e9SAndroid Build Coastguard Worker return x 204*da0073e9SAndroid Build Coastguard Worker 205*da0073e9SAndroid Build Coastguard Worker # Wrap TorchInductor with explain backend 206*da0073e9SAndroid Build Coastguard Worker eb = ExplainWithBackend("inductor") 207*da0073e9SAndroid Build Coastguard Worker optimized_fn = torch.compile(fn, backend=eb) 208*da0073e9SAndroid Build Coastguard Worker input_tensor = torch.randn(5) 209*da0073e9SAndroid Build Coastguard Worker result = optimized_fn(input_tensor) 210*da0073e9SAndroid Build Coastguard Worker 211*da0073e9SAndroid Build Coastguard Worker # Check that fn still produces the same output when wrapped by ExplainWithBackend 212*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.allclose(result, fn(input_tensor))) 213*da0073e9SAndroid Build Coastguard Worker 214*da0073e9SAndroid Build Coastguard Worker # Verify ExplainOutput object contents, output might change but make sure these fields are present 215*da0073e9SAndroid Build Coastguard Worker explain_output = eb.output() 216*da0073e9SAndroid Build Coastguard Worker explain_str = str(explain_output) 217*da0073e9SAndroid Build Coastguard Worker self.assertIn("Graph Count", explain_str) 218*da0073e9SAndroid Build Coastguard Worker self.assertIn("Graph Break Count", explain_str) 219*da0073e9SAndroid Build Coastguard Worker self.assertIn("Op Count", explain_str) 220*da0073e9SAndroid Build Coastguard Worker self.assertIn("Break Reasons", explain_str) 221*da0073e9SAndroid Build Coastguard Worker 222*da0073e9SAndroid Build Coastguard Worker # Verify that for the given functions above, we report the correct number of graphs, graph breaks, and ops 223*da0073e9SAndroid Build Coastguard Worker self.assertEqual(8, explain_output.graph_count) 224*da0073e9SAndroid Build Coastguard Worker self.assertEqual(7, explain_output.graph_break_count) 225*da0073e9SAndroid Build Coastguard Worker self.assertEqual(8, explain_output.op_count) 226*da0073e9SAndroid Build Coastguard Worker 227*da0073e9SAndroid Build Coastguard Worker 228*da0073e9SAndroid Build Coastguard Workerclass TestCustomBackendAPI(torch._dynamo.test_case.TestCase): 229*da0073e9SAndroid Build Coastguard Worker """Test APIs documented by https://pytorch.org/docs/main/torch.compiler_custom_backends.html""" 230*da0073e9SAndroid Build Coastguard Worker 231*da0073e9SAndroid Build Coastguard Worker def test_register_backend_api(self): 232*da0073e9SAndroid Build Coastguard Worker from torch._dynamo import register_backend 233*da0073e9SAndroid Build Coastguard Worker 234*da0073e9SAndroid Build Coastguard Worker backend_run = False 235*da0073e9SAndroid Build Coastguard Worker 236*da0073e9SAndroid Build Coastguard Worker @register_backend 237*da0073e9SAndroid Build Coastguard Worker def my_custom_backend(gm, example_inputs): 238*da0073e9SAndroid Build Coastguard Worker nonlocal backend_run 239*da0073e9SAndroid Build Coastguard Worker backend_run = True 240*da0073e9SAndroid Build Coastguard Worker return gm.forward 241*da0073e9SAndroid Build Coastguard Worker 242*da0073e9SAndroid Build Coastguard Worker def f(x): 243*da0073e9SAndroid Build Coastguard Worker return torch.relu(x) 244*da0073e9SAndroid Build Coastguard Worker 245*da0073e9SAndroid Build Coastguard Worker opt_f = torch.compile(f, backend="my_custom_backend") 246*da0073e9SAndroid Build Coastguard Worker opt_f(torch.randn(3, 3)) 247*da0073e9SAndroid Build Coastguard Worker self.assertTrue(backend_run) 248*da0073e9SAndroid Build Coastguard Worker 249*da0073e9SAndroid Build Coastguard Worker def test_aot_autograd_api(self): 250*da0073e9SAndroid Build Coastguard Worker from functorch.compile import make_boxed_func 251*da0073e9SAndroid Build Coastguard Worker from torch._dynamo.backends.common import aot_autograd 252*da0073e9SAndroid Build Coastguard Worker 253*da0073e9SAndroid Build Coastguard Worker backend_run = False 254*da0073e9SAndroid Build Coastguard Worker 255*da0073e9SAndroid Build Coastguard Worker def my_compiler(gm, example_inputs): 256*da0073e9SAndroid Build Coastguard Worker nonlocal backend_run 257*da0073e9SAndroid Build Coastguard Worker backend_run = True 258*da0073e9SAndroid Build Coastguard Worker return make_boxed_func(gm.forward) 259*da0073e9SAndroid Build Coastguard Worker 260*da0073e9SAndroid Build Coastguard Worker my_backend = aot_autograd(fw_compiler=my_compiler) 261*da0073e9SAndroid Build Coastguard Worker 262*da0073e9SAndroid Build Coastguard Worker def f(x): 263*da0073e9SAndroid Build Coastguard Worker return torch.relu(x) 264*da0073e9SAndroid Build Coastguard Worker 265*da0073e9SAndroid Build Coastguard Worker opt_f = torch.compile(f, backend=my_backend) 266*da0073e9SAndroid Build Coastguard Worker opt_f(torch.randn(3, 3)) 267*da0073e9SAndroid Build Coastguard Worker self.assertTrue(backend_run) 268*da0073e9SAndroid Build Coastguard Worker 269*da0073e9SAndroid Build Coastguard Worker def test_lookup_backend(self): 270*da0073e9SAndroid Build Coastguard Worker from torch._dynamo import list_backends, lookup_backend 271*da0073e9SAndroid Build Coastguard Worker 272*da0073e9SAndroid Build Coastguard Worker backends = list_backends() 273*da0073e9SAndroid Build Coastguard Worker backend_run = False 274*da0073e9SAndroid Build Coastguard Worker 275*da0073e9SAndroid Build Coastguard Worker def my_compiler(gm, example_inputs): 276*da0073e9SAndroid Build Coastguard Worker nonlocal backend_run 277*da0073e9SAndroid Build Coastguard Worker backend_run = True 278*da0073e9SAndroid Build Coastguard Worker try: 279*da0073e9SAndroid Build Coastguard Worker trt_compiled = lookup_backend("tensorrt")(gm, example_inputs) 280*da0073e9SAndroid Build Coastguard Worker if trt_compiled is not None: 281*da0073e9SAndroid Build Coastguard Worker return trt_compiled 282*da0073e9SAndroid Build Coastguard Worker except Exception: 283*da0073e9SAndroid Build Coastguard Worker pass 284*da0073e9SAndroid Build Coastguard Worker # first backend failed, try something else... 285*da0073e9SAndroid Build Coastguard Worker try: 286*da0073e9SAndroid Build Coastguard Worker inductor_compiled = lookup_backend("inductor")(gm, example_inputs) 287*da0073e9SAndroid Build Coastguard Worker if inductor_compiled is not None: 288*da0073e9SAndroid Build Coastguard Worker return inductor_compiled 289*da0073e9SAndroid Build Coastguard Worker except Exception: 290*da0073e9SAndroid Build Coastguard Worker pass 291*da0073e9SAndroid Build Coastguard Worker return gm.forward 292*da0073e9SAndroid Build Coastguard Worker 293*da0073e9SAndroid Build Coastguard Worker def f(x): 294*da0073e9SAndroid Build Coastguard Worker return torch.relu(x) 295*da0073e9SAndroid Build Coastguard Worker 296*da0073e9SAndroid Build Coastguard Worker opt_f = torch.compile(f, backend=my_compiler) 297*da0073e9SAndroid Build Coastguard Worker opt_f(torch.randn(3, 3)) 298*da0073e9SAndroid Build Coastguard Worker self.assertTrue(backend_run) 299*da0073e9SAndroid Build Coastguard Worker 300*da0073e9SAndroid Build Coastguard Worker def test_lookup_custom_backend(self): 301*da0073e9SAndroid Build Coastguard Worker from torch._dynamo import list_backends 302*da0073e9SAndroid Build Coastguard Worker 303*da0073e9SAndroid Build Coastguard Worker backends_group = "torch_dynamo_backends" 304*da0073e9SAndroid Build Coastguard Worker name = "mycustombackend" 305*da0073e9SAndroid Build Coastguard Worker 306*da0073e9SAndroid Build Coastguard Worker mock_3_9 = MagicMock() 307*da0073e9SAndroid Build Coastguard Worker mock_3_9.load.return_value = lambda: "mocked 3.9" 308*da0073e9SAndroid Build Coastguard Worker mock_3_9.name = name 309*da0073e9SAndroid Build Coastguard Worker 310*da0073e9SAndroid Build Coastguard Worker mock_3_10 = MagicMock() 311*da0073e9SAndroid Build Coastguard Worker mock_3_10.load.return_value = lambda: "mocked 3.10" 312*da0073e9SAndroid Build Coastguard Worker 313*da0073e9SAndroid Build Coastguard Worker def mock_eps(group=None): 314*da0073e9SAndroid Build Coastguard Worker if sys.version_info < (3, 10): 315*da0073e9SAndroid Build Coastguard Worker return {backends_group: [mock_3_9]} 316*da0073e9SAndroid Build Coastguard Worker else: 317*da0073e9SAndroid Build Coastguard Worker assert group == backends_group, group 318*da0073e9SAndroid Build Coastguard Worker mock_group = MagicMock() 319*da0073e9SAndroid Build Coastguard Worker mock_group.names = [name] 320*da0073e9SAndroid Build Coastguard Worker mock_group[name] = mock_3_10 321*da0073e9SAndroid Build Coastguard Worker # mock_group[name].load.return_value = lambda: "mocked 3.10" 322*da0073e9SAndroid Build Coastguard Worker return mock_group 323*da0073e9SAndroid Build Coastguard Worker 324*da0073e9SAndroid Build Coastguard Worker with patch("importlib.metadata.entry_points", mock_eps): 325*da0073e9SAndroid Build Coastguard Worker from torch._dynamo.backends import registry 326*da0073e9SAndroid Build Coastguard Worker 327*da0073e9SAndroid Build Coastguard Worker registry._lazy_import.cache_clear() 328*da0073e9SAndroid Build Coastguard Worker registry._discover_entrypoint_backends.cache_clear() 329*da0073e9SAndroid Build Coastguard Worker 330*da0073e9SAndroid Build Coastguard Worker backends = list_backends() 331*da0073e9SAndroid Build Coastguard Worker assert name in backends, (name, backends) 332*da0073e9SAndroid Build Coastguard Worker 333*da0073e9SAndroid Build Coastguard Worker def test_backend_recompilation(self): 334*da0073e9SAndroid Build Coastguard Worker def fn(x): 335*da0073e9SAndroid Build Coastguard Worker return x + x 336*da0073e9SAndroid Build Coastguard Worker 337*da0073e9SAndroid Build Coastguard Worker input = torch.tensor(2.0) 338*da0073e9SAndroid Build Coastguard Worker 339*da0073e9SAndroid Build Coastguard Worker opt_fn = torch.compile( 340*da0073e9SAndroid Build Coastguard Worker fn, backend="inductor", options={"_raise_error_for_testing": False} 341*da0073e9SAndroid Build Coastguard Worker ) 342*da0073e9SAndroid Build Coastguard Worker opt_fn(input) 343*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(torch._dynamo.exc.BackendCompilerFailed): 344*da0073e9SAndroid Build Coastguard Worker opt_fn = torch.compile( 345*da0073e9SAndroid Build Coastguard Worker fn, backend="inductor", options={"_raise_error_for_testing": True} 346*da0073e9SAndroid Build Coastguard Worker ) 347*da0073e9SAndroid Build Coastguard Worker opt_fn(input) 348*da0073e9SAndroid Build Coastguard Worker 349*da0073e9SAndroid Build Coastguard Worker 350*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 351*da0073e9SAndroid Build Coastguard Worker from torch._dynamo.test_case import run_tests 352*da0073e9SAndroid Build Coastguard Worker 353*da0073e9SAndroid Build Coastguard Worker run_tests() 354