1# mypy: allow-untyped-defs 2import contextlib 3import dis 4import functools 5import logging 6import os.path 7import random 8import re 9import sys 10import types 11import unittest 12from typing import List, Optional, Sequence, Union 13from unittest.mock import patch 14 15import torch 16from torch import fx 17from torch._dynamo.output_graph import OutputGraph 18 19from . import config, eval_frame, optimize_assert, reset 20from .bytecode_transformation import ( 21 create_instruction, 22 debug_checks, 23 is_generator, 24 transform_code_object, 25) 26from .guards import CheckFunctionManager, CompileId, GuardedCode 27from .utils import same 28 29 30np: Optional[types.ModuleType] = None 31try: 32 import numpy as np 33except ModuleNotFoundError: 34 np = None 35 36 37unsupported = eval_frame.unsupported 38three = 3 39 40log = logging.getLogger(__name__) 41 42 43def clone_me(x): 44 if x is None: 45 return None 46 return x.detach().clone().requires_grad_(x.requires_grad) 47 48 49def remove_optimized_module_prefix(name) -> str: 50 return re.sub(r"^_orig_mod[.]", "", name) 51 52 53def collect_results(model, prediction, loss, example_inputs): 54 results = [] 55 results.append(prediction) 56 results.append(loss) 57 # if isinstance(loss, torch.Tensor) and loss.item() > 1: 58 # log.warning( 59 # f"High loss value alert - {loss:.2f}. Can result in unstable gradients." 60 # ) 61 62 grads = {} 63 params = {} 64 for name, param in model.named_parameters(): 65 if isinstance(model, eval_frame.OptimizedModule): 66 name = remove_optimized_module_prefix(name) 67 param_copy = param 68 grad = param.grad 69 # Treat None and zero grad as same 70 if param.grad is None: 71 grad = torch.zeros_like(param) 72 grads[name + ".grad"] = grad 73 params[name] = param_copy 74 results.append(grads) 75 results.append(params) 76 buffers = {} 77 for name, buffer in model.named_buffers(): 78 if isinstance(model, eval_frame.OptimizedModule): 79 name = remove_optimized_module_prefix(name) 80 buffers[name] = buffer 81 results.append(buffers) 82 for example in example_inputs: 83 if isinstance(example, (tuple, list)): 84 for inp in example: 85 if isinstance(inp, torch.Tensor): 86 results.append(inp.grad) 87 else: 88 if isinstance(example, torch.Tensor): 89 results.append(example.grad) 90 return results 91 92 93def requires_bwd_pass(out): 94 if isinstance(out, torch.Tensor): 95 return out.requires_grad 96 elif isinstance(out, (list, tuple)): 97 return any(requires_bwd_pass(x) for x in out) 98 elif out is None: 99 return False 100 elif isinstance(out, int): 101 return False 102 raise NotImplementedError("Don't know how to reduce", type(out)) 103 104 105def reduce_to_scalar_loss(out): 106 """Reduce the output of a model to get scalar loss""" 107 if isinstance(out, torch.Tensor): 108 # Mean does not work on integer tensors 109 return out.sum() / out.numel() 110 elif isinstance(out, (list, tuple)): 111 return sum(reduce_to_scalar_loss(x) for x in out) / len(out) 112 elif type(out).__name__ in ( 113 "MaskedLMOutput", 114 "Seq2SeqLMOutput", 115 "CausalLMOutputWithCrossAttentions", 116 ): 117 return reduce_to_scalar_loss(out.logits) 118 elif type(out).__name__ == "SquashedNormal": 119 return out.mean.sum() 120 elif isinstance(out, dict): 121 return sum(reduce_to_scalar_loss(value) for value in out.values()) / len( 122 out.keys() 123 ) 124 raise NotImplementedError("Don't know how to reduce", type(out)) 125 126 127def debug_dir() -> str: 128 path = os.path.join(os.path.dirname(__file__), "../debug") 129 if not os.path.exists(path): 130 os.mkdir(path) 131 return path 132 133 134def debug_dump(name, code: types.CodeType, extra="") -> None: 135 with open(os.path.join(debug_dir(), name), "w") as fd: 136 fd.write( 137 f"{dis.Bytecode(code).info()}\n\n{dis.Bytecode(code).dis()}\n\n{extra}\n" 138 ) 139 140 141def debug_insert_nops( 142 frame, cache_size, hooks, _, *, skip: int = 0 143) -> Optional[GuardedCode]: 144 """used to debug jump updates""" 145 146 def insert_nops(instructions, code_options): 147 instructions.insert(0, create_instruction("NOP")) 148 instructions.insert(0, create_instruction("NOP")) 149 150 if is_generator(frame.f_code): 151 return None 152 153 debug_checks(frame.f_code) 154 code = transform_code_object(frame.f_code, insert_nops) 155 graph = OutputGraph( 156 code_options={}, 157 compiler_fn=None, 158 root_tx=None, 159 export=False, 160 export_constraints=None, 161 frame_state={"_id": 0}, 162 # TODO: shouldn't this be f_locals/f_globals from frame? 163 local_scope=locals(), 164 global_scope=globals(), 165 f_code=frame.f_code, 166 ) 167 168 return GuardedCode(code, CheckFunctionManager(graph).check_fn, CompileId(0, 0)) 169 170 171class CompileCounter: 172 def __init__(self): 173 self.frame_count = 0 174 self.op_count = 0 175 176 def __call__(self, gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]): 177 self.frame_count += 1 178 for node in gm.graph.nodes: 179 if "call" in node.op: 180 self.op_count += 1 181 return gm.forward 182 183 def clear(self): 184 self.frame_count = 0 185 self.op_count = 0 186 187 188class CompileCounterWithBackend: 189 def __init__(self, backend): 190 self.frame_count = 0 191 self.op_count = 0 192 self.backend = backend 193 self.graphs = [] 194 195 def __call__(self, gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]): 196 from .backends.registry import lookup_backend 197 198 self.frame_count += 1 199 for node in gm.graph.nodes: 200 if "call" in node.op: 201 self.op_count += 1 202 self.graphs.append(gm) 203 return lookup_backend(self.backend)(gm, example_inputs) 204 205 206# Equivalent to backend="eager", but also records graphs that 207# we can assert on 208class EagerAndRecordGraphs: 209 def __init__(self): 210 self.graphs = [] 211 212 def __call__(self, gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]): 213 self.graphs.append(gm) 214 return gm.forward 215 216 217def strip_comment(code) -> str: 218 code = str(code) 219 return re.sub(r"(?m)^ *#.*\n?", "", code) 220 221 222def remove_trailing_space(code) -> str: 223 return "\n".join([line.rstrip() for line in code.split("\n")]) 224 225 226def normalize_gm(gm_str) -> str: 227 # strip comments as comments have path to files which may differ from 228 # system to system. 229 return remove_trailing_space(strip_comment(gm_str)) 230 231 232def empty_line_normalizer(code: str) -> str: 233 """ 234 Normalize code: remove empty lines. 235 """ 236 normal_code = re.sub(r"[\r\n]+", "\n", code) 237 return normal_code 238 239 240def standard_test( 241 self, 242 fn, 243 nargs, 244 expected_ops=None, 245 expected_ops_dynamic=None, 246 expected_frame_count=1, 247): 248 if not config.assume_static_by_default and expected_ops_dynamic is not None: 249 expected_ops = expected_ops_dynamic 250 251 actual = CompileCounter() 252 253 args1 = [torch.randn(10, 10) for _ in range(nargs)] 254 args2 = [torch.randn(10, 10) for _ in range(nargs)] 255 correct1 = fn(*args1) 256 correct2 = fn(*args2) 257 reset() 258 opt_fn = optimize_assert(actual)(fn) 259 val1a = opt_fn(*args1) 260 val2a = opt_fn(*args2) 261 val1b = opt_fn(*args1) 262 val2b = opt_fn(*args2) 263 reset() 264 self.assertTrue(same(val1a, correct1)) 265 self.assertTrue(same(val1b, correct1)) 266 self.assertTrue(same(val2a, correct2)) 267 self.assertTrue(same(val2b, correct2)) 268 self.assertEqual(actual.frame_count, expected_frame_count) 269 if expected_ops is not None: 270 self.assertEqual(actual.op_count, expected_ops) 271 272 273def dummy_fx_compile(gm: fx.GraphModule, example_inputs): 274 return gm.forward 275 276 277def format_speedup(speedup, pvalue, is_correct=True, pvalue_threshold=0.1): 278 if not is_correct: 279 return "ERROR" 280 if pvalue > pvalue_threshold: 281 return f"{speedup:.3f}x SAME" 282 return f"{speedup:.3f}x p={pvalue:.2f}" 283 284 285def rand_strided( 286 size: Sequence[int], 287 stride: Sequence[int], 288 dtype: torch.dtype = torch.float32, 289 device: Union[str, torch.device] = "cpu", 290 extra_size: int = 0, 291): 292 needed_size = ( 293 sum((shape - 1) * stride for shape, stride in zip(size, stride)) 294 + 1 295 + extra_size 296 ) 297 if dtype.is_floating_point: 298 if dtype.itemsize == 1: 299 """ 300 normal distribution kernel is not implemented for fp8.. 301 Workaround that by creating a fp16 tensor and then cast. 302 """ 303 buffer = torch.randn(needed_size, dtype=torch.float16, device=device).to( 304 dtype=dtype 305 ) 306 else: 307 buffer = torch.randn(needed_size, dtype=dtype, device=device) 308 else: 309 buffer = torch.zeros(size=[needed_size], dtype=dtype, device=device) 310 return torch.as_strided(buffer, size, stride) 311 312 313def _make_fn_with_patches(fn, *patches): 314 @functools.wraps(fn) 315 def _fn(*args, **kwargs): 316 with contextlib.ExitStack() as stack: 317 for module, attr, val in patches: 318 stack.enter_context(patch.object(module, attr, val)) 319 320 return fn(*args, **kwargs) 321 322 return _fn 323 324 325def make_test_cls_with_patches( 326 cls, cls_prefix, fn_suffix, *patches, xfail_prop=None, decorator=lambda x: x 327): 328 DummyTestClass = type(f"{cls_prefix}{cls.__name__}", cls.__bases__, {}) 329 DummyTestClass.__qualname__ = DummyTestClass.__name__ 330 331 for name in dir(cls): 332 if name.startswith("test_"): 333 fn = getattr(cls, name) 334 if not callable(fn): 335 setattr(DummyTestClass, name, getattr(cls, name)) 336 continue 337 new_name = f"{name}{fn_suffix}" 338 new_fn = _make_fn_with_patches(fn, *patches) 339 new_fn.__name__ = new_name 340 if xfail_prop is not None and hasattr(fn, xfail_prop): 341 new_fn = unittest.expectedFailure(new_fn) 342 setattr(DummyTestClass, new_name, decorator(new_fn)) 343 # NB: Doesn't handle slots correctly, but whatever 344 elif not hasattr(DummyTestClass, name): 345 setattr(DummyTestClass, name, getattr(cls, name)) 346 347 return DummyTestClass 348 349 350# test Python 3.11+ specific features 351def skipIfNotPy311(fn): 352 if sys.version_info >= (3, 11): 353 return fn 354 return unittest.skip(fn) 355 356 357def skipIfNotPy312(fn): 358 if sys.version_info >= (3, 12): 359 return fn 360 return unittest.skip(fn) 361 362 363def xfailIfPy312(fn): 364 if sys.version_info >= (3, 12): 365 return unittest.expectedFailure(fn) 366 return fn 367 368 369def skipIfPy312(fn): 370 if sys.version_info >= (3, 12): 371 return unittest.skip(fn) 372 return fn 373 374 375def requiresPy310(fn): 376 if sys.version_info >= (3, 10): 377 return fn 378 else: 379 unittest.skip(fn) 380 381 382# Controls tests generated in test/inductor/test_torchinductor_dynamic_shapes.py 383# and test/dynamo/test_dynamic_shapes.py 384def expectedFailureDynamic(fn): 385 fn._expected_failure_dynamic = True 386 return fn 387 388 389# Controls tests generated in test/inductor/test_torchinductor_codegen_dynamic_shapes.py 390def expectedFailureCodegenDynamic(fn): 391 fn._expected_failure_codegen_dynamic = True 392 return fn 393 394 395# Controls test generated in test/inductor/test_cpp_wrapper.py 396def expectedFailureDynamicWrapper(fn): 397 fn._expected_failure_dynamic_wrapper = True 398 return fn 399 400 401def reset_rng_state(use_xla=False): 402 torch.manual_seed(1337) 403 random.seed(1337) 404 if np: 405 np.random.seed(1337) 406 if use_xla: 407 import torch_xla.core.xla_model as xm 408 409 xm.set_rng_state(1337, str(xm.xla_device())) 410