1# Owner(s): ["module: inductor"] 2import torch 3from torch import _dynamo as dynamo, _inductor as inductor 4from torch._inductor.test_case import run_tests, TestCase 5from torch._inductor.utils import gen_gm_and_inputs 6from torch.fx import symbolic_trace 7from torch.fx.experimental.proxy_tensor import make_fx 8from torch.testing._internal.inductor_utils import HAS_CPU 9 10 11class MyModule(torch.nn.Module): 12 def __init__(self) -> None: 13 super().__init__() 14 self.a = torch.nn.Linear(10, 10) 15 self.b = torch.nn.Linear(10, 10) 16 self.relu = torch.nn.ReLU() 17 18 def forward(self, x): 19 x = self.relu(self.a(x)) 20 x = torch.sigmoid(self.b(x)) 21 return x 22 23 24class MyModule2(MyModule): 25 def forward(self, x): # takes a dict of list 26 a, b = x["key"] 27 return {"result": super().forward(a) + b} 28 29 30class MyModule3(MyModule): 31 def forward(self, x): 32 return (super().forward(x),) 33 34 35class TestStandaloneInductor(TestCase): 36 """ 37 These test check that you can call TorchInductor directly without 38 going through TorchDynamo. 39 """ 40 41 def test_inductor_via_fx(self): 42 mod = MyModule3().eval() 43 inp = torch.randn(10) 44 correct = mod(inp) 45 mod_opt = inductor.compile(symbolic_trace(mod), [inp]) 46 actual = mod_opt(inp) 47 self.assertEqual(actual, correct) 48 49 def test_inductor_via_fx_tensor_return(self): 50 mod = MyModule().eval() 51 inp = torch.randn(10) 52 correct = mod(inp) 53 mod_opt = inductor.compile(symbolic_trace(mod), [inp]) 54 actual = mod_opt(inp) 55 self.assertEqual(actual, correct) 56 57 def test_inductor_via_fx_dict_input(self): 58 mod = MyModule2().eval() 59 inp = {"key": [torch.randn(10), torch.randn(10)]} 60 correct = mod(inp) 61 mod_opt = inductor.compile(symbolic_trace(mod), [inp]) 62 actual = mod_opt(inp) 63 self.assertEqual(actual, correct) 64 65 def test_inductor_via_make_fx(self): 66 mod = MyModule().eval() 67 inp = torch.randn(10) 68 correct = mod(inp) 69 mod_opt = inductor.compile(make_fx(mod)(inp), [inp]) 70 actual = mod_opt(inp) 71 self.assertEqual(actual, correct) 72 73 def test_inductor_via_bare_module(self): 74 mod = MyModule3().eval() 75 inp = torch.randn(10) 76 correct = mod(inp) 77 # no FX graph at all (mod must return list/tuple in this case) 78 mod_opt = inductor.compile(mod, [inp]) 79 actual = mod_opt(inp) 80 self.assertEqual(actual, correct) 81 82 def test_inductor_via_export1(self): 83 mod = MyModule3().eval() 84 inp = torch.randn(10) 85 correct = mod(inp) 86 gm, guards = dynamo.export(mod, inp, aten_graph=True) 87 mod_opt = inductor.compile(gm, [inp]) 88 actual = mod_opt(inp) 89 self.assertEqual(actual, correct) 90 91 def test_inductor_via_export2(self): 92 mod = MyModule2().eval() 93 inp = {"key": [torch.randn(10), torch.randn(10)]} 94 correct = mod(inp) 95 gm, guards = dynamo.export(mod, inp) 96 mod_opt = inductor.compile(gm, [inp]) 97 actual = mod_opt(inp) 98 self.assertEqual(actual, correct) 99 100 def test_inductor_via_op_with_multiple_outputs(self): 101 x1 = torch.randn((2, 512, 128)) 102 x2 = [128] 103 x3 = torch.randn(128) 104 x4 = torch.randn((128,)) 105 x5 = 1e-6 106 mod, inp = gen_gm_and_inputs( 107 torch.ops.aten.native_layer_norm.default, (x1, x2, x3, x4, x5), {} 108 ) 109 mod_opt = inductor.compile(mod, inp) 110 self.assertEqual(mod(*inp), mod_opt(*inp)) 111 112 113if __name__ == "__main__": 114 if HAS_CPU: 115 run_tests() 116