1# mypy: allow-untyped-defs 2# mypy: disable-error-code="method-assign" 3import atexit 4import copy 5import cProfile 6import functools 7import getpass 8import inspect 9import itertools 10import logging 11import os 12import re 13import subprocess 14import sys 15import tempfile 16import textwrap 17from collections import Counter 18from importlib import import_module 19from typing import Any, Callable, Dict, List, Optional, TypeVar 20 21import torch 22import torch._prims_common as utils 23import torch._subclasses.meta_utils 24from torch import Tensor 25from torch._dynamo.testing import rand_strided 26from torch._prims_common import is_float_dtype 27from torch.multiprocessing.reductions import StorageWeakRef 28from torch.utils._content_store import ContentStoreReader, ContentStoreWriter 29 30from . import config 31from .utils import clone_inputs, get_debug_dir 32 33 34log = logging.getLogger(__name__) 35 36T = TypeVar("T") 37 38 39inductor_config = import_module("torch._inductor.config") 40use_buck = inductor_config.is_fbcode() 41 42if use_buck: 43 import libfb.py.build_info 44 45 46extra_deps = [] 47extra_imports = "" 48if use_buck: 49 extra_deps = [ 50 "//caffe2/torch/fb/sparsenn:sparsenn_operators_gpu", 51 "//caffe2/torch/fb/sparsenn:sparsenn_operators", 52 "//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu", 53 "//deeplearning/fbgemm/fbgemm_gpu:sparse_ops", 54 ] 55 cur_target = libfb.py.build_info.BuildInfo.get_build_rule().replace("fbcode:", "//") # type: ignore[possibly-undefined] 56 extra_imports = "\n".join([f'torch.ops.load_library("{x}")' for x in extra_deps]) 57 58 59BUCK_CMD_PREFIX = ["buck2", "run", "@mode/dev-nosan"] 60 61 62class BuckTargetWriter: 63 def __init__(self, filename): 64 self.subdir, self.py_file = os.path.split(os.path.abspath(filename)) 65 self.target = self.py_file.replace(".py", "") 66 67 # Get main_module path from fbcode 68 self.path = f'{self.subdir.replace("/", ".")}.{self.target}' 69 self.path = self.path[self.path.find("fbcode.") :] 70 self.path = self.path[7:] 71 72 # Get cmd line path 73 tmp = self.subdir 74 tmp = tmp[tmp.find("fbcode/") :][7:] 75 self.cmd_line_path = f"//{tmp}:{self.target}" 76 77 def build(self): 78 extra_cpp_deps = "\n".join([f' "{x}",' for x in extra_deps]) 79 return textwrap.dedent( 80 f""" 81load("@fbcode_macros//build_defs:python_binary.bzl", "python_binary") 82 83python_binary( 84 name="{self.target}", 85 srcs = ["{self.py_file}"], 86 compile = False, 87 deps = [ 88 "//caffe2:torch", 89 "//caffe2/functorch:functorch", 90 "//triton:triton", 91 "{cur_target}", 92 ], 93 cpp_deps = [ 94{extra_cpp_deps} 95 ], 96 main_module = "{self.path}", 97 par_style = "xar", 98) 99""" 100 ) 101 102 def write(self, print_msg=True): 103 target_file = os.path.join(self.subdir, "TARGETS") 104 with open(target_file, "w") as fd: 105 fd.write(self.build()) 106 # log.warning("Wrote isolation TARGETS file at %s", target_file) 107 cmd_split = BUCK_CMD_PREFIX + [self.cmd_line_path] 108 if print_msg: 109 log.warning( 110 "Found an example that reproduces the error. Run this cmd to repro - %s", 111 " ".join(cmd_split), 112 ) 113 return cmd_split 114 115 116def minifier_dir(): 117 path = os.path.join(get_debug_dir(), "minifier") 118 if path is None: 119 path = f"{tempfile.gettempdir()}/minifier_{getpass.getuser()}" 120 if not os.path.exists(path): 121 os.makedirs(path, exist_ok=True) 122 return path 123 124 125MAX_CONSTANT_NUMEL_INLINE = 4 126 127 128class NNModuleToString: 129 safe_reprs = [ 130 torch.nn.Linear, 131 torch.nn.Conv1d, 132 torch.nn.Conv2d, 133 torch.nn.Conv3d, 134 torch.nn.BatchNorm1d, 135 torch.nn.BatchNorm2d, 136 torch.nn.BatchNorm3d, 137 torch.nn.LayerNorm, 138 torch.nn.Dropout, 139 torch.nn.Softmax, 140 torch.nn.ReLU, 141 torch.nn.GELU, 142 torch.nn.Identity, 143 torch.nn.MaxPool2d, 144 torch.nn.Embedding, 145 torch.nn.Tanh, 146 torch.nn.ConvTranspose1d, 147 torch.nn.GLU, 148 torch.nn.LSTM, 149 torch.nn.Flatten, 150 torch.nn.AdaptiveAvgPool2d, 151 ] 152 153 @staticmethod 154 def can_convert_to_string(gm): 155 cant_convert = set() 156 for _, module in gm.named_children(): 157 if type(module) not in NNModuleToString.safe_reprs: 158 cant_convert.add(module) 159 160 if len(cant_convert) > 0: 161 log.warning("We have not tested reprs of some modules - %s", cant_convert) 162 # TODO - Assuming that all modules can be safely repr'd. Check if that assumption is correct. 163 return True 164 165 @staticmethod 166 def convert(gm): 167 from torch.nn.modules.module import _addindent 168 169 tab = " " * 4 170 171 model_str = textwrap.dedent( 172 """ 173 from torch.nn import * 174 class Repro(torch.nn.Module): 175 def __init__(self) -> None: 176 super().__init__() 177 """ 178 ) 179 180 for module_name, module in gm.named_children(): 181 module_str = f"{module.__repr__()}" 182 # module should be a core torch.nn.Module, so all parameters 183 # should be on the same device. 184 example_param = next(module.parameters(), None) 185 if example_param is not None and example_param.is_cuda: 186 module_str = f"{module_str}.cuda()" 187 model_str += f"{tab*2}self.{module_name} = {module_str}\n" 188 189 for buffer_name, buffer in gm._buffers.items(): 190 if buffer is None: 191 continue 192 # Serialize full data for small buffers 193 if buffer.numel() <= MAX_CONSTANT_NUMEL_INLINE: 194 from torch._tensor_str import PRINT_OPTS 195 196 assert PRINT_OPTS.threshold >= MAX_CONSTANT_NUMEL_INLINE 197 tensor_str = repr(buffer) 198 elif torch.is_floating_point(buffer): 199 tensor_str = f"torch.randn({list(buffer.shape)}, dtype={buffer.dtype})" 200 else: 201 tensor_str = ( 202 f"torch.randint(1, size={list(buffer.shape)}, dtype={buffer.dtype})" 203 ) 204 if buffer.is_cuda: 205 tensor_str = f"{tensor_str}.cuda()" 206 model_str += f"{tab*2}self.register_buffer('{buffer_name}', {tensor_str})\n" 207 208 for param_name, param in gm._parameters.items(): 209 if param is None: 210 continue 211 maybe_device = "" 212 if param.is_cuda: 213 maybe_device = ', device="cuda"' 214 tensor_str = f"torch.nn.Parameter(torch.randn({list(param.shape)}, dtype={param.dtype}{maybe_device}))" 215 model_str += f"{tab*2}self.{param_name} = {tensor_str}\n" 216 217 # TODO - Keep this code for now. But, I don't think we will need this. 218 # attrs = dir(gm) 219 # for attr in attrs: 220 # if "_tensor_constant" in attr: 221 # val = getattr(gm, attr) 222 # model_str += f" {attr} = {val!r}\n" 223 224 model_str += f"{_addindent(gm.code, 4)}\n" 225 return model_str 226 227 228@functools.lru_cache(None) # subprocess is expensive 229def _cuda_system_info_comment(): 230 if not torch.cuda.is_available(): 231 return "# torch.cuda.is_available()==False, no GPU info collected\n" 232 233 model_str = "# CUDA Info: \n" 234 try: 235 cuda_version_out = subprocess.check_output(["nvcc", "--version"]) 236 cuda_version_lines = cuda_version_out.decode().split("\n") 237 comment = "".join([f"# {s} \n" for s in cuda_version_lines if s not in [""]]) 238 model_str += f"{comment}\n" 239 except (FileNotFoundError, subprocess.CalledProcessError): 240 model_str += "# nvcc not found\n" 241 242 gpu_names = Counter( 243 torch.cuda.get_device_name(i) for i in range(torch.cuda.device_count()) 244 ) 245 246 model_str += "# GPU Hardware Info: \n" 247 for name, count in gpu_names.items(): 248 model_str += f"# {name} : {count} \n" 249 model_str += "\n" 250 return model_str 251 252 253def generate_config_string(*, stable_output=False): 254 import torch._functorch.config 255 import torch._inductor.config 256 257 if stable_output: 258 return "# config omitted due to stable_output=True" 259 260 experimental_config = torch.fx.experimental._config.codegen_config() # type: ignore[attr-defined] 261 return f"""\ 262import torch._dynamo.config 263import torch._inductor.config 264import torch._functorch.config 265import torch.fx.experimental._config 266{torch._dynamo.config.codegen_config()} 267{torch._inductor.config.codegen_config()} 268{torch._functorch.config.codegen_config()} 269{experimental_config} 270""" 271 272 273def get_minifier_repro_path(): 274 return os.path.join(minifier_dir(), "minifier_launcher.py") 275 276 277def helper_for_dump_minify(contents): 278 minified_repro_path = get_minifier_repro_path() 279 log.warning("Writing minified repro to:\n%s", minified_repro_path) 280 281 if use_buck: 282 BuckTargetWriter(minified_repro_path).write() 283 try: 284 with open(minified_repro_path, "w") as fd: 285 fd.write(contents) 286 287 except OSError as e: 288 log.exception("") 289 raise NotImplementedError("Could not write to {minified_repro_path}") from e 290 291 292class AccuracyError(Exception): 293 pass 294 295 296def clone_inputs_retaining_gradness(example_inputs): 297 """ 298 This clone inputs is different from utils clone_input. In case of minifier, 299 all the tensors are leaf tensors while creating a new graph. So, we set the 300 requires_grad field w/o checking the leafness of the tensor. 301 """ 302 cloned_inputs = clone_inputs(example_inputs) 303 for idx in range(len(example_inputs)): 304 if isinstance(cloned_inputs[idx], torch.Tensor): 305 cloned_inputs[idx].requires_grad_(example_inputs[idx].requires_grad) 306 return cloned_inputs 307 308 309def run_fwd_maybe_bwd(gm, args, only_fwd=False, disable_clone=False): 310 """ 311 Runs a forward and possibly backward iteration for a given mod and args. 312 313 When disable_clone is True, we will use args as-is without cloning. 314 This is higher fidelity but we may destroy the args in the process. 315 """ 316 from .testing import collect_results, reduce_to_scalar_loss, requires_bwd_pass 317 318 gm = copy.deepcopy(gm) 319 if not disable_clone: 320 args = clone_inputs_retaining_gradness(args) 321 322 if hasattr(gm, "zero_grad"): 323 gm.zero_grad(True) 324 325 # TorchInductor returned callable expects lists. So, may need a boxed calling convention. 326 out = gm(args) if hasattr(gm, "_boxed_call") else gm(*args) 327 328 if only_fwd: 329 return out 330 if requires_bwd_pass(out): 331 loss = reduce_to_scalar_loss(out) 332 loss.backward() 333 return collect_results(gm, out, None, args) 334 335 336def same_two_models( 337 gm, 338 opt_gm, 339 example_inputs, 340 only_fwd=False, 341 *, 342 require_fp64=False, 343 ignore_non_fp=False, 344): 345 """ 346 Check two models have same accuracy. 347 348 require_fp64: if True, raise an error if we unable to calculate the fp64 reference 349 ignore_non_fp: if True, do not compare outputs which are not floating point. This 350 is mostly useful for the minifier (which wants to avoid quantizing floating point 351 error into integer/boolean error) 352 """ 353 from .utils import same 354 355 ref = run_fwd_maybe_bwd(gm, example_inputs, only_fwd) 356 357 fp64_ref = None 358 if config.same_two_models_use_fp64: 359 try: 360 fp64_model, fp64_examples = cast_to_fp64( 361 copy.deepcopy(gm), clone_inputs_retaining_gradness(example_inputs) 362 ) 363 fp64_ref = run_fwd_maybe_bwd(fp64_model, fp64_examples, only_fwd) 364 except Exception: 365 if require_fp64: 366 raise RuntimeError( # noqa: B904 367 "Could not generate fp64 outputs, workaround with torch._dynamo.config.same_two_models_use_fp64 = False" 368 ) 369 log.warning("Could not generate fp64 outputs") 370 371 try: 372 res = run_fwd_maybe_bwd(opt_gm, example_inputs, only_fwd) 373 except Exception as e: 374 # This means that the minified graph is bad/exposes a different problem. 375 # As we are checking accuracy here, lets log the exception and return True. 376 log.exception( 377 "While minifying the program in accuracy minification mode, " 378 "ran into a runtime exception which is likely an unrelated issue." 379 " Skipping this graph." 380 ) 381 return True 382 383 passing = same( 384 ref, 385 res, 386 fp64_ref, 387 tol=config.repro_tolerance, 388 equal_nan=True, 389 ignore_non_fp=ignore_non_fp, 390 ) 391 return passing 392 393 394def cast_dtype_args_to_fp64(model): 395 for node in model.graph.nodes: 396 if ( 397 node.op == "call_function" 398 and node.target == torch.ops.prims.convert_element_type.default 399 ): 400 assert len(node.args) == 2 401 if is_float_dtype(node.args[1]) and node.args[1] != torch.float64: 402 node.args = (node.args[0], torch.float64) 403 if node.op == "call_function": 404 dtype = node.kwargs.get("dtype") 405 if dtype is not None and is_float_dtype(dtype): 406 new_kwargs = dict(node.kwargs) 407 new_kwargs["dtype"] = torch.float64 408 node.kwargs = new_kwargs 409 410 model.graph.lint() 411 model.recompile() 412 return model 413 414 415def cast_to(dtype, model, inputs): 416 from torch.utils._pytree import tree_map 417 418 model = model.to(dtype) 419 if dtype == torch.float64: 420 # If casting to fp64 for accuracy comparison, we need to 421 # replace dtype arguments embedded in the graph with fp64 422 model = cast_dtype_args_to_fp64(model) 423 424 inputs = tree_map( 425 lambda x: x.to(dtype) 426 if isinstance(x, torch.Tensor) and x.is_floating_point() 427 else x, 428 inputs, 429 ) 430 return model, inputs 431 432 433def cast_to_fp64(model, inputs): 434 return cast_to(torch.float64, model, inputs) 435 436 437def backend_accuracy_fails( 438 gm, 439 example_inputs, 440 compiler_fn, 441 only_fwd=False, 442 *, 443 require_fp64=False, 444 ignore_non_fp=False, 445): 446 try: 447 compiled_gm = compiler_fn( 448 copy.deepcopy(gm), clone_inputs_retaining_gradness(example_inputs) 449 ) 450 return not same_two_models( 451 gm, 452 compiled_gm, 453 example_inputs, 454 only_fwd, 455 require_fp64=require_fp64, 456 ignore_non_fp=ignore_non_fp, 457 ) 458 except Exception as e: 459 # This means that the minified graph is bad/exposes a different problem. 460 # As we are checking accuracy here, lets log the exception and return False. 461 log.exception( 462 "While minifying the program in accuracy minification mode, " 463 "ran into a runtime exception which is likely an unrelated issue." 464 " Skipping this graph" 465 ) 466 return False 467 468 469# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 470# REPRO SUPPORT CODE 471# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 472 473 474# Helper functions for computing what the default values of tensor 475# values should be. These all coincide with factory functions, e.g., torch.empty 476 477 478def _stride_or_default( 479 stride: Optional["torch._prims_common.StrideType"], 480 *, 481 shape: "torch._prims_common.ShapeType", 482) -> "torch._prims_common.StrideType": 483 return stride if stride is not None else utils.make_contiguous_strides_for(shape) 484 485 486def _mk_defaulter(d: T) -> Callable[[Optional[T]], T]: 487 return lambda x: x if x is not None else d 488 489 490_dtype_or_default = _mk_defaulter(torch.float32) 491_device_or_default = _mk_defaulter(torch.device("cpu")) 492_storage_offset_or_default = _mk_defaulter(0) 493_requires_grad_or_default = _mk_defaulter(False) 494_is_leaf_or_default = _mk_defaulter(False) 495 496 497class NopInputReader: 498 def __init__(self) -> None: 499 self.total = 0 500 501 def storage(self, storage_hash, nbytes, *, device=None, dtype_hint=None): 502 self.total += 1 503 504 def tensor(self, *args, **kwargs): 505 pass 506 507 def symint(self, *args, **kwargs): 508 pass 509 510 511# TODO: Support bundling the entire repro into a zip file for ease of 512# transferring around 513class InputReader: 514 def __init__(self, save_dir=None, *, pbar=None): 515 # If None, we will generate random data instead. It's important 516 # to natively support this use case as it will allow people to 517 # share repros without including the real data, if the problem 518 # reproduces even on random data. 519 if save_dir is None: 520 log.warning("no save_dir specified, will generate random data") 521 self.store = ContentStoreReader(save_dir) if save_dir is not None else None 522 self.args = [] 523 self.pbar = pbar 524 525 def storage(self, storage_hash, nbytes, *, device=None, dtype_hint=None): 526 if self.pbar is not None: 527 self.pbar.update(1) 528 device = _device_or_default(device) 529 dtype_hint = _dtype_or_default(dtype_hint) 530 if self.store is not None and storage_hash is not None: 531 try: 532 storage = self.store.read_storage(storage_hash) 533 except FileNotFoundError: 534 pass 535 else: 536 if device != storage.device: 537 log.warning("device mismatch: %s != %s", device, storage.device) 538 # TODO: transfer it to the right device? But failing this 539 # way would be very mysterious! Would have been better 540 # not to store device in the serialized format... 541 return storage 542 log.warning("could not load %s, generating random data instead", storage_hash) 543 shape = (nbytes // dtype_hint.itemsize,) 544 stride = _stride_or_default(None, shape=shape) 545 return rand_strided(shape, stride, dtype_hint, device).untyped_storage() 546 547 def tensor( 548 self, 549 storage, 550 shape, 551 stride=None, 552 *, 553 storage_offset=None, 554 dtype=None, 555 requires_grad=None, 556 is_leaf=None, 557 **metadata, 558 ): 559 stride = _stride_or_default(stride, shape=shape) 560 storage_offset = _storage_offset_or_default(storage_offset) 561 dtype = _dtype_or_default(dtype) 562 is_leaf = _is_leaf_or_default(is_leaf) 563 requires_grad = _requires_grad_or_default(requires_grad) 564 t = torch.tensor( 565 [], dtype=dtype, device=storage.device, requires_grad=requires_grad 566 ) 567 with torch.no_grad(): 568 t.set_(storage, storage_offset, shape, stride) 569 if not is_leaf: 570 # Fake up some autograd history in a very naughty way 571 with torch.enable_grad(): 572 t = t.clone(memory_format=torch.preserve_format) 573 with torch.no_grad(): 574 t.set_(storage, storage_offset, shape, stride) 575 assert torch._subclasses.meta_utils.safe_is_leaf(t) == is_leaf 576 torch._utils.set_tensor_metadata(t, metadata) 577 self.args.append(t) 578 return t # for BC 579 580 def symint(self, val): 581 self.args.append(val) 582 return val # for BC 583 584 585# Here is our writer strategy: 586# 1. We will stream all of the inputs to disk 587# 2. You can now deterministically randomize the inputs, or reload 588# the inputs from disk 589# 3. You can YOLO run the script without the inputs, in which case 590# we'll fill the inputs with random data and pray. This is the 591# legacy behavior, but it's also useful if you want to find out 592# if we're so broken even random inputs trigger it 593# 4. We could offer an in process "check if the randomized thing 594# works too" but this is delicate so we don't do it 595 596 597class InputWriter: 598 def __init__(self, save_dir, *, stable_hash=False): 599 self._lines = [] 600 # TODO: consider ensuring tensor and storage counters line up? 601 self.storage_counter = itertools.count() 602 self.save_dir = save_dir 603 self.store = ( 604 ContentStoreWriter(save_dir, stable_hash=stable_hash) 605 if save_dir is not None 606 else None 607 ) 608 self.seen_storages = {} 609 610 def lines(self): 611 r = [ 612 "def load_args(reader):", 613 ] 614 r.extend(f" {l}" for l in self._lines) 615 # In case we need to change the internal format of load_args 616 # in an FC-breaking way 617 r.append("load_args._version = 0") 618 return r 619 620 # Storages are untyped, but we need to initialize them with data if 621 # we don't have the real data, so we give a hint saying what kind 622 # of initialization may be appropriate 623 # 624 # If we had a FakeTensor, device_hint tells us what device should be 625 def storage(self, untyped_storage, *, dtype_hint=None, device_hint=None) -> str: 626 ws = StorageWeakRef(untyped_storage) 627 v = self.seen_storages.get(ws) 628 if v is not None: 629 return v 630 v = f"buf{next(self.storage_counter)}" 631 maybe_dtype_hint = "" 632 if _dtype_or_default(None) != _dtype_or_default(dtype_hint): 633 maybe_dtype_hint = f", dtype_hint={dtype_hint!r}" 634 # TODO: being optional on device is kind of pointless as the default 635 # is CPU but most repros we care about are CUDA 636 maybe_device = "" 637 device = untyped_storage.device 638 if device.type == "meta": 639 assert device_hint is not None 640 device = device_hint 641 if _device_or_default(None) != device: 642 maybe_device = f", device={device!r}" 643 nbytes = untyped_storage.nbytes() 644 storage_hash = None 645 if self.store is not None and untyped_storage.device.type != "meta": 646 storage_hash = self.store.write_storage(untyped_storage) 647 self._lines.append( 648 f"{v} = reader.storage({storage_hash!r}, {nbytes!r}{maybe_device}{maybe_dtype_hint})" 649 ) 650 self.seen_storages[ws] = v 651 return v 652 653 def tensor(self, name, t) -> None: 654 from torch.fx.experimental.symbolic_shapes import statically_known_true 655 656 storage = self.storage( 657 t.untyped_storage(), dtype_hint=t.dtype, device_hint=t.device 658 ) 659 args = [] 660 # NB: this is positional, must come first 661 if _stride_or_default(None, shape=t.shape) != t.stride(): 662 args.append(str(tuple(t.stride()))) 663 if _dtype_or_default(None) != t.dtype: 664 args.append(f"dtype={t.dtype!r}") 665 if not statically_known_true( 666 _storage_offset_or_default(None) == t.storage_offset() 667 ): 668 args.append(f"storage_offset={t.storage_offset()!r}") 669 tensor_metadata = torch._utils.get_tensor_metadata(t) 670 if tensor_metadata: 671 args.extend(f"{k}={v!r}" for k, v in tensor_metadata.items()) 672 if _requires_grad_or_default(None) != t.requires_grad: 673 args.append(f"requires_grad={t.requires_grad!r}") 674 is_leaf = torch._subclasses.meta_utils.safe_is_leaf(t) 675 if _is_leaf_or_default(None) != is_leaf: 676 args.append(f"is_leaf={is_leaf!r}") 677 self._lines.append( 678 "reader.tensor(" 679 + ", ".join([storage, str(tuple(t.shape)), *args]) 680 + f") # {name}" 681 ) 682 683 # TODO: this doesn't actually symint atm 684 def symint(self, name, val) -> None: 685 if isinstance(val, torch.SymInt): 686 val = val.node.hint 687 self._lines.append(f"reader.symint({val!r}) # {name}") 688 689 690def aot_graph_input_parser( 691 func: Callable[[List[Tensor]], List[Tensor]], 692 device: str = "cuda", 693 sym_shapes: Optional[Dict[str, int]] = None, 694 default_sym_shape: Optional[int] = None, 695) -> Dict[str, Any]: 696 """ 697 Takes in a function which has been printed with print_readable() and constructs kwargs to run it. 698 699 Handles Tensor inputs, Symints, and a graph module which might have tensor constants. 700 701 Consider a function `forward` defined as follows: 702 703 def forward(self, primals_1: "f32[1001, 6]", primals_2: "f32[s0]", primals_3: "Sym(s0)",): 704 _tensor_constant0: "i64[4190]" = self._tensor_constant0 705 # Further implementation 706 707 kwargs = aot_graph_input_parser(forward) 708 forward(**kwargs) 709 """ 710 711 from torch.fx.graph import dtype_abbrs 712 713 dtype_map = {value: key for key, value in dtype_abbrs.items()} 714 dtype_pattern = "|".join(dtype_abbrs.values()) 715 716 # Extracting the source code from the function 717 source = inspect.getsource(func) 718 719 # Regular expressions 720 tensor_assignment_regex = rf"(_tensor_constant\d+): \"({dtype_pattern})\[\s*(.*?)\s*\]\" = self\.(_tensor_constant\d+)" 721 tensor_regex = rf"({dtype_pattern})\[\s*(.*?)\s*\]" 722 sym_shape_regex = r"Sym\((s\d+)\)" 723 724 class TensorContainer: 725 "Container for tensors as attributes" 726 727 # Dictionary for tensors from annotations 728 kwargs: Dict[str, Any] = {} 729 730 sym_shapes = sym_shapes or {} 731 732 def get_sym_int(symint): 733 torch._check( 734 symint in sym_shapes or default_sym_shape is not None, 735 lambda: f"{symint} not in symbolic_shapes and default sym shape not passed in", 736 ) 737 return sym_shapes.get(symint, default_sym_shape) 738 739 def gen_tensor(shape, dtype) -> Tensor: 740 # Resolve symbolic shapes to concrete values 741 resolved_shape = [] 742 dynamic_dims = [] 743 for i, dim in enumerate(shape): 744 dim = dim.strip() 745 if "s" in dim: 746 s = get_sym_int(dim) 747 resolved_shape.append(s) 748 dynamic_dims.append(i) 749 else: 750 if dim: 751 resolved_shape.append(int(dim)) 752 753 constructor = torch.randn if dtype.is_floating_point else torch.zeros 754 out = constructor(resolved_shape, dtype=dtype, device=device) # type: ignore[call-arg] 755 for d in dynamic_dims: 756 torch._dynamo.mark_dynamic(out, d) 757 return out 758 759 # Parse function annotations for tensor generation 760 annotations = func.__annotations__ 761 for param, annotation in annotations.items(): 762 # Skip 'return' annotation 763 if param == "return": 764 continue 765 766 match = re.search(tensor_regex, annotation) 767 if match: 768 data_type, shape_str = match.groups() 769 shape = tuple(shape_str.split(",")) 770 dtype = dtype_map[data_type] 771 kwargs[param] = gen_tensor(shape, dtype) 772 773 match = re.search(sym_shape_regex, annotation) 774 if match: 775 kwargs[param] = get_sym_int(match.group(1)) 776 777 if "self" in inspect.signature(func).parameters: 778 container = TensorContainer() 779 kwargs["self"] = container 780 for match in re.finditer(tensor_assignment_regex, source): 781 attr_name, data_type, shape_str, _ = match.groups() 782 shape = tuple(shape_str.split(",")) 783 dtype = dtype_map[data_type] 784 setattr(container, attr_name, gen_tensor(shape, dtype)) 785 786 return kwargs 787 788 789def profile_to_file(filename: str) -> Callable[[T], T]: 790 """ 791 Decorator to cProfile a given function and save the result to disk on process exit. 792 793 Args: 794 filename: filename to save profile to 795 """ 796 prof = cProfile.Profile() 797 filename = os.path.abspath(os.path.expanduser(filename)) 798 799 def decorator(fn): 800 @functools.wraps(fn) 801 def wrapper(*args, **kwargs): 802 prof.enable() 803 try: 804 return fn(*args, **kwargs) 805 finally: 806 prof.disable() 807 808 return wrapper 809 810 def save_it(): 811 prof.dump_stats(filename) 812 sys.stderr.write( 813 textwrap.dedent( 814 f"""\ 815 Wrote profile to {filename}, view with: 816 817 snakeviz {filename} 818 819 """ 820 ) 821 ) 822 823 atexit.register(save_it) 824 return decorator 825