1# mypy: allow-untyped-defs 2import dataclasses 3import io 4import logging 5import os 6import re 7import shutil 8import subprocess 9import sys 10import tempfile 11import traceback 12from typing import Optional 13from unittest.mock import patch 14 15import torch 16import torch._dynamo 17import torch._dynamo.test_case 18from torch._dynamo.trace_rules import _as_posix_path 19from torch.utils._traceback import report_compile_source_on_error 20 21 22@dataclasses.dataclass 23class MinifierTestResult: 24 minifier_code: str 25 repro_code: str 26 27 def _get_module(self, t): 28 match = re.search(r"class Repro\(torch\.nn\.Module\):\s+([ ].*\n| *\n)+", t) 29 assert match is not None, "failed to find module" 30 r = match.group(0) 31 r = re.sub(r"\s+$", "\n", r, flags=re.MULTILINE) 32 r = re.sub(r"\n{3,}", "\n\n", r) 33 return r.strip() 34 35 def minifier_module(self): 36 return self._get_module(self.minifier_code) 37 38 def repro_module(self): 39 return self._get_module(self.repro_code) 40 41 42class MinifierTestBase(torch._dynamo.test_case.TestCase): 43 DEBUG_DIR = tempfile.mkdtemp() 44 45 @classmethod 46 def setUpClass(cls): 47 super().setUpClass() 48 cls._exit_stack.enter_context( # type: ignore[attr-defined] 49 torch._dynamo.config.patch(debug_dir_root=cls.DEBUG_DIR) 50 ) 51 # These configurations make new process startup slower. Disable them 52 # for the minification tests to speed them up. 53 cls._exit_stack.enter_context( # type: ignore[attr-defined] 54 torch._inductor.config.patch( 55 { 56 # https://github.com/pytorch/pytorch/issues/100376 57 "pattern_matcher": False, 58 # multiprocess compilation takes a long time to warmup 59 "compile_threads": 1, 60 # https://github.com/pytorch/pytorch/issues/100378 61 "cpp.vec_isa_ok": False, 62 } 63 ) 64 ) 65 66 @classmethod 67 def tearDownClass(cls): 68 if os.getenv("PYTORCH_KEEP_TMPDIR", "0") != "1": 69 shutil.rmtree(cls.DEBUG_DIR) 70 else: 71 print(f"test_minifier_common tmpdir kept at: {cls.DEBUG_DIR}") 72 cls._exit_stack.close() # type: ignore[attr-defined] 73 74 def _gen_codegen_fn_patch_code(self, device, bug_type): 75 assert bug_type in ("compile_error", "runtime_error", "accuracy") 76 return f"""\ 77{torch._dynamo.config.codegen_config()} 78{torch._inductor.config.codegen_config()} 79torch._inductor.config.{"cpp" if device == "cpu" else "triton"}.inject_relu_bug_TESTING_ONLY = {bug_type!r} 80""" 81 82 def _maybe_subprocess_run(self, args, *, isolate, cwd=None): 83 if not isolate: 84 assert len(args) >= 2, args 85 assert args[0] == "python3", args 86 if args[1] == "-c": 87 assert len(args) == 3, args 88 code = args[2] 89 args = ["-c"] 90 else: 91 assert len(args) >= 2, args 92 with open(args[1]) as f: 93 code = f.read() 94 args = args[1:] 95 96 # WARNING: This is not a perfect simulation of running 97 # the program out of tree. We only interpose on things we KNOW we 98 # need to handle for tests. If you need more stuff, you will 99 # need to augment this appropriately. 100 101 # NB: Can't use save_config because that will omit some fields, 102 # but we must save and reset ALL fields 103 dynamo_config = torch._dynamo.config.shallow_copy_dict() 104 inductor_config = torch._inductor.config.shallow_copy_dict() 105 try: 106 stderr = io.StringIO() 107 log_handler = logging.StreamHandler(stderr) 108 log = logging.getLogger("torch._dynamo") 109 log.addHandler(log_handler) 110 try: 111 prev_cwd = _as_posix_path(os.getcwd()) 112 if cwd is not None: 113 cwd = _as_posix_path(cwd) 114 os.chdir(cwd) 115 with patch("sys.argv", args), report_compile_source_on_error(): 116 exec(code, {"__name__": "__main__", "__compile_source__": code}) 117 rc = 0 118 except Exception: 119 rc = 1 120 traceback.print_exc(file=stderr) 121 finally: 122 log.removeHandler(log_handler) 123 if cwd is not None: 124 os.chdir(prev_cwd) # type: ignore[possibly-undefined] 125 # Make sure we don't leave buggy compiled frames lying 126 # around 127 torch._dynamo.reset() 128 finally: 129 torch._dynamo.config.load_config(dynamo_config) 130 torch._inductor.config.load_config(inductor_config) 131 132 # TODO: return a more appropriate data structure here 133 return subprocess.CompletedProcess( 134 args, 135 rc, 136 b"", 137 stderr.getvalue().encode("utf-8"), 138 ) 139 else: 140 if cwd is not None: 141 cwd = _as_posix_path(cwd) 142 return subprocess.run(args, capture_output=True, cwd=cwd, check=False) 143 144 # Run `code` in a separate python process. 145 # Returns the completed process state and the directory containing the 146 # minifier launcher script, if `code` outputted it. 147 def _run_test_code(self, code, *, isolate): 148 proc = self._maybe_subprocess_run( 149 ["python3", "-c", code], isolate=isolate, cwd=self.DEBUG_DIR 150 ) 151 152 print("test stdout:", proc.stdout.decode("utf-8")) 153 print("test stderr:", proc.stderr.decode("utf-8")) 154 repro_dir_match = re.search( 155 r"(\S+)minifier_launcher.py", proc.stderr.decode("utf-8") 156 ) 157 if repro_dir_match is not None: 158 return proc, repro_dir_match.group(1) 159 return proc, None 160 161 # Runs the minifier launcher script in `repro_dir` 162 def _run_minifier_launcher(self, repro_dir, isolate, *, minifier_args=()): 163 self.assertIsNotNone(repro_dir) 164 launch_file = _as_posix_path(os.path.join(repro_dir, "minifier_launcher.py")) 165 with open(launch_file) as f: 166 launch_code = f.read() 167 self.assertTrue(os.path.exists(launch_file)) 168 169 args = ["python3", launch_file, "minify", *minifier_args] 170 if not isolate: 171 args.append("--no-isolate") 172 launch_proc = self._maybe_subprocess_run(args, isolate=isolate, cwd=repro_dir) 173 print("minifier stdout:", launch_proc.stdout.decode("utf-8")) 174 stderr = launch_proc.stderr.decode("utf-8") 175 print("minifier stderr:", stderr) 176 self.assertNotIn("Input graph did not fail the tester", stderr) 177 178 return launch_proc, launch_code 179 180 # Runs the repro script in `repro_dir` 181 def _run_repro(self, repro_dir, *, isolate=True): 182 self.assertIsNotNone(repro_dir) 183 repro_file = _as_posix_path(os.path.join(repro_dir, "repro.py")) 184 with open(repro_file) as f: 185 repro_code = f.read() 186 self.assertTrue(os.path.exists(repro_file)) 187 188 repro_proc = self._maybe_subprocess_run( 189 ["python3", repro_file], isolate=isolate, cwd=repro_dir 190 ) 191 print("repro stdout:", repro_proc.stdout.decode("utf-8")) 192 print("repro stderr:", repro_proc.stderr.decode("utf-8")) 193 return repro_proc, repro_code 194 195 # Template for testing code. 196 # `run_code` is the code to run for the test case. 197 # `patch_code` is the code to be patched in every generated file; usually 198 # just use this to turn on bugs via the config 199 def _gen_test_code(self, run_code, repro_after, repro_level): 200 return f"""\ 201import torch 202import torch._dynamo 203{_as_posix_path(torch._dynamo.config.codegen_config())} 204{_as_posix_path(torch._inductor.config.codegen_config())} 205torch._dynamo.config.repro_after = "{repro_after}" 206torch._dynamo.config.repro_level = {repro_level} 207torch._dynamo.config.debug_dir_root = "{_as_posix_path(self.DEBUG_DIR)}" 208{run_code} 209""" 210 211 # Runs a full minifier test. 212 # Minifier tests generally consist of 3 stages: 213 # 1. Run the problematic code 214 # 2. Run the generated minifier launcher script 215 # 3. Run the generated repro script 216 # 217 # If possible, you should run the test with isolate=False; use 218 # isolate=True only if the bug you're testing would otherwise 219 # crash the process 220 def _run_full_test( 221 self, run_code, repro_after, expected_error, *, isolate, minifier_args=() 222 ) -> Optional[MinifierTestResult]: 223 if isolate: 224 repro_level = 3 225 elif expected_error is None or expected_error == "AccuracyError": 226 repro_level = 4 227 else: 228 repro_level = 2 229 test_code = self._gen_test_code(run_code, repro_after, repro_level) 230 print("running test", file=sys.stderr) 231 test_proc, repro_dir = self._run_test_code(test_code, isolate=isolate) 232 if expected_error is None: 233 # Just check that there was no error 234 self.assertEqual(test_proc.returncode, 0) 235 self.assertIsNone(repro_dir) 236 return None 237 # NB: Intentionally do not test return code; we only care about 238 # actually generating the repro, we don't have to crash 239 self.assertIn(expected_error, test_proc.stderr.decode("utf-8")) 240 self.assertIsNotNone(repro_dir) 241 print("running minifier", file=sys.stderr) 242 minifier_proc, minifier_code = self._run_minifier_launcher( 243 repro_dir, isolate=isolate, minifier_args=minifier_args 244 ) 245 print("running repro", file=sys.stderr) 246 repro_proc, repro_code = self._run_repro(repro_dir, isolate=isolate) 247 self.assertIn(expected_error, repro_proc.stderr.decode("utf-8")) 248 self.assertNotEqual(repro_proc.returncode, 0) 249 return MinifierTestResult(minifier_code=minifier_code, repro_code=repro_code) 250