1# Owner(s): ["module: dynamo"] 2 3import inspect 4import io 5import os 6import tempfile 7from unittest.mock import patch 8 9import torch 10from torch._dynamo.test_case import run_tests, TestCase 11from torch._dynamo.testing import CompileCounter 12 13 14class ToyModel(torch.nn.Module): 15 def __init__(self) -> None: 16 super().__init__() 17 self.linear = torch.nn.Linear(10, 10) 18 self.relu = torch.nn.ReLU() 19 20 def forward(self, x): 21 return self.relu(self.linear(x)) 22 23 24class InPlaceCompilationTests(TestCase): 25 def test_compilation(self): 26 torch._dynamo.reset() 27 model = ToyModel() 28 cnt = CompileCounter() 29 model.compile(backend=cnt) 30 x = torch.randn(10, 10) 31 model(x) 32 self.assertEqual(cnt.frame_count, 1) 33 34 def test_overwrite_call_impl(self): 35 torch._dynamo.reset() 36 model = ToyModel() 37 self.assertTrue(model._compiled_call_impl is None) 38 model.compile() 39 self.assertTrue(model._compiled_call_impl is not None) 40 41 def test_save(self): 42 torch._dynamo.reset() 43 model = ToyModel() 44 model.compile() 45 model(torch.randn(1, 10)) 46 47 with tempfile.TemporaryDirectory() as tmpdirname: 48 torch.save(model, os.path.join(tmpdirname, "model.pt")) 49 loaded_model = torch.load(os.path.join(tmpdirname, "model.pt")) 50 loaded_model(torch.randn(1, 10)) 51 52 def test_state_dict_save(self): 53 torch._dynamo.reset() 54 model = ToyModel() 55 model.compile() 56 model(torch.randn(1, 10)) 57 with tempfile.TemporaryDirectory() as tmpdirname: 58 torch.save(model.state_dict(), os.path.join(tmpdirname, "model.pt")) 59 loaded_model = ToyModel() 60 loaded_model.load_state_dict( 61 torch.load(os.path.join(tmpdirname, "model.pt")) 62 ) 63 loaded_model(torch.randn(1, 10)) 64 65 def test_jit_save(self): 66 torch._dynamo.reset() 67 model = ToyModel() 68 model.compile() 69 model(torch.randn(1, 10)) 70 scripted_model = torch.jit.script(model) 71 with tempfile.TemporaryDirectory() as tmpdirname: 72 torch.jit.save(scripted_model, os.path.join(tmpdirname, "model.pt")) 73 loaded_model = torch.jit.load(os.path.join(tmpdirname, "model.pt")) 74 loaded_model(torch.randn(1, 10)) 75 76 def test_compilation_callback(self): 77 torch._dynamo.reset() 78 79 @torch._dynamo.on_compile_start 80 def start_callback(): 81 print("Compilation started.") 82 83 @torch._dynamo.on_compile_end 84 def end_callback(): 85 print("Compilation ended.") 86 87 mod = ToyModel() 88 x = torch.randn(10, 10) 89 90 with patch("sys.stdout", new_callable=io.StringIO) as mock_stdout: 91 opt_mod = torch.compile(backend="eager", fullgraph=True)(mod) 92 opt_mod(x) 93 printed_output = mock_stdout.getvalue().strip() 94 95 self.assertEqual(printed_output, "Compilation started.\nCompilation ended.") 96 97 def test_compile_eager_options(self): 98 @torch.compile(backend="eager", options={"foo": 2}) 99 def f(x): 100 return x + x 101 102 f(torch.randn(3)) 103 104 @torch.compile(backend="aot_eager", options={"foo": 2}) 105 def g(x): 106 return x + x 107 108 g(torch.randn(3)) 109 110 def test_compilation_callback_with_graph_break(self): 111 torch._dynamo.reset() 112 counter = 0 113 114 @torch._dynamo.on_compile_start 115 def start_callback(): 116 nonlocal counter 117 counter += 1 118 print(f"Counter = {counter}") 119 120 @torch._dynamo.on_compile_end 121 def end_callback(): 122 nonlocal counter 123 counter += 1 124 print(f"Counter = {counter}") 125 126 @torch.compile(backend="eager") 127 def fn(x): 128 x = x + 1 129 torch._dynamo.graph_break() 130 return torch.sin(x) 131 132 x = torch.randn(10, 10) 133 134 with patch("sys.stdout", new_callable=io.StringIO) as mock_stdout: 135 fn(x) 136 printed_output = mock_stdout.getvalue().strip() 137 138 self.assertEqual( 139 printed_output, "Counter = 1\nCounter = 2\nCounter = 3\nCounter = 4" 140 ) 141 142 143# The private variants of the below functions are extensively tested 144# So as long as the signatures match we're good 145class PublicTorchCompilerTests(TestCase): 146 def check_signature(self, public_fn_name, private_fn_name, private_namespace): 147 public_fn = getattr(torch.compiler, public_fn_name) 148 private_fn = getattr(private_namespace, private_fn_name) 149 150 public_sig = inspect.signature(public_fn) 151 private_sig = inspect.signature(private_fn) 152 153 self.assertEqual( 154 public_sig, 155 private_sig, 156 f"Signatures do not match for function {public_fn_name}() \n Public: {public_sig} \n Private: {private_sig}", 157 ) 158 159 def test_dynamo_signatures(self): 160 function_names = [ 161 "reset", 162 "allow_in_graph", 163 "list_backends", 164 "assume_constant_result", 165 "disable", 166 ] 167 168 for fn_name in function_names: 169 self.check_signature(fn_name, fn_name, torch._dynamo) 170 171 172if __name__ == "__main__": 173 run_tests() 174