1# Owner(s): ["module: dynamo"] 2import dataclasses 3import importlib 4import inspect 5import math 6import types 7import unittest 8import warnings 9from typing import Any, Dict, Set 10 11import torch 12import torch._dynamo.config as config 13import torch._dynamo.test_case 14import torch._functorch.deprecated as deprecated_func 15from torch._dynamo.trace_rules import ( 16 LEGACY_MOD_INLINELIST, 17 load_object, 18 manual_torch_name_rule_map, 19 MOD_INLINELIST, 20 torch_c_binding_in_graph_functions, 21 torch_non_c_binding_in_graph_functions, 22) 23from torch._dynamo.utils import hashable, is_safe_constant, istype 24from torch._dynamo.variables import TorchInGraphFunctionVariable, UserFunctionVariable 25from torch.testing._internal.common_utils import skipIfWindows 26 27 28try: 29 from .utils import create_dummy_module_and_function 30except ImportError: 31 from utils import create_dummy_module_and_function 32 33 34ignored_c_binding_in_graph_function_names = { 35 # Ignored because they have manual rules defined at `trace_rules.manual_torch_name_rule_map`. 36 "torch._nested_tensor_from_mask", 37 "torch._nested_from_padded", 38 "torch.sparse_compressed_tensor", 39 "torch.sparse_bsc_tensor", 40 "torch.sparse_bsr_tensor", 41 "torch.sparse_coo_tensor", 42 "torch.sparse_csc_tensor", 43 "torch.sparse_csr_tensor", 44 "torch.cuda._get_device_properties", 45 # Ignored and go through rules defined at `trace_rules.check`. 46 "torch._functionalize_are_all_mutations_under_no_grad_or_inference_mode", 47 "torch._cslt_sparse_mm_search", 48 "torch._C._abort", 49 "torch._C._mps_is_on_macos_or_newer", 50 "torch._C._swap_tensor_impl", 51 "torch._C._unsafe_reset_storage", 52 "torch._dynamo.eval_frame.reset_code", 53 "torch._C.autocast_decrement_nesting", 54 "torch._C.autocast_increment_nesting", 55 "torch._C.clear_autocast_cache", 56 "torch._C.set_anomaly_enabled", 57 "torch._C.set_autocast_cache_enabled", 58 "torch._C.set_autocast_cpu_dtype", 59 "torch._C.set_autocast_cpu_enabled", 60 "torch._C.set_autocast_enabled", 61 "torch._C.set_autocast_gpu_dtype", 62 "torch._C.set_autocast_ipu_dtype", 63 "torch._C.set_autocast_ipu_enabled", 64 "torch._C.set_autocast_xla_dtype", 65 "torch._C.set_autocast_xla_enabled", 66 "torch.resize_as_", 67 "torch.resize_as_sparse_", 68 "torch._C._data_address", 69 "torch._C._is_cow_tensor", 70 "torch._lazy_clone", 71 "torch._test_parallel_materialize", 72 "torch._C._storage_address", 73 "torch._C._pickle_save", 74 "torch._validate_sparse_compressed_tensor_args", 75 "torch._validate_sparse_csr_tensor_args", 76 "torch._validate_sparse_bsr_tensor_args", 77 "torch._validate_sparse_csc_tensor_args", 78 "torch._validate_sparse_coo_tensor_args", 79 "torch._validate_sparse_bsc_tensor_args", 80 "torch._validate_compressed_sparse_indices", 81} 82if torch._C._llvm_enabled(): 83 ignored_c_binding_in_graph_function_names |= { 84 "torch._C._te.set_llvm_aot_workflow", 85 "torch._C._te.set_llvm_target_cpu", 86 "torch._C._te.set_llvm_target_attrs", 87 "torch._C._te.set_llvm_target_triple", 88 } 89 90 91# Helper function to dump the torch name rule map generated based on 92# the heuristic defined in gen_allowed_objs_and_ids. 93def dump_allowed_torch_name_rule_map() -> None: 94 m = gen_allowed_objs_and_ids(record=True, c_binding_only=False).name_rule_map 95 for k, v in m.items(): 96 print(f'"{k}": {v.__name__},') 97 98 99@dataclasses.dataclass 100class AllowedObjects: 101 """ 102 Track the objects, object id - name pairs, and name - dynamo wrapping rule pairs 103 from the heuristic defined in `gen_allowed_objs_and_ids`. 104 """ 105 106 object_ids: Dict[int, str] 107 c_binding_in_graph_functions: Set[Any] 108 non_c_binding_in_graph_functions: Set[Any] 109 name_rule_map: Dict[str, Any] 110 111 112def gen_allowed_objs_and_ids(record=False, c_binding_only=True) -> AllowedObjects: 113 """ 114 Walk torch.* and get the ids of all the stuff in it 115 """ 116 117 warnings.filterwarnings("ignore", category=UserWarning, module="torch.distributed") 118 torch_object_ids = {} 119 c_binding_in_graph_functions = set() 120 non_c_binding_in_graph_functions = set() 121 torch_name_rule_map = {} 122 123 # In some platforms, these functions were loaded as classes instead of functions. 124 # To mitigate these weired cases, we need this special check. 125 def is_special_functions(obj): 126 return hashable(obj) and obj in { 127 torch._C._cuda_isCurrentStreamCapturing, 128 torch._C._graph_pool_handle, 129 } 130 131 # Add obj to c_binding_in_graph_functions set or non_c_binding_in_graph_functions set 132 # if it's a torch function or method. 133 # This is used to generate the in graph function list based on heuristic. 134 def heuristic_record_if_in_graph_function(obj, module, name): 135 try: 136 if hasattr(obj, "__wrapped__"): 137 obj = obj.__wrapped__ 138 except Exception: 139 pass 140 if isinstance( 141 obj, 142 ( 143 types.FunctionType, 144 types.BuiltinFunctionType, 145 types.MethodDescriptorType, 146 types.WrapperDescriptorType, 147 ), 148 ) or is_special_functions(obj): 149 torch_name_rule_map[ 150 f"{module.__name__}.{name}" 151 ] = TorchInGraphFunctionVariable 152 if c_binding_only: 153 if not hasattr(obj, "__code__"): 154 c_binding_in_graph_functions.add(obj) 155 else: 156 if hasattr(obj, "__code__"): 157 non_c_binding_in_graph_functions.add(obj) 158 else: 159 c_binding_in_graph_functions.add(obj) 160 161 def _is_allowed_module_prefix(obj): 162 allowed_modules = ("torch", "math") 163 # torch.nn.modules.rnn is disallowed because these modules internally 164 # flatten their parameters. This flattening process will call 165 # Tensor.set_ with a Storage, and Storages cannot be traced with 166 # AOTAutograd; so we need to graph-break. To ensure this, we inline 167 # these functions, rather than keep them opaque-ly in the graph. 168 disallowed_modules = [ 169 "torch.optim.", 170 "torch.nn.modules.rnn.", 171 "torch._dynamo.", 172 "torch._C._dynamo.", 173 "torch._inductor.", 174 "torch._C.inductor.", 175 "torch.fx.", 176 "torch._C._autograd", 177 "torch._C._cudart", 178 "torch._C._distributed_autograd", 179 "torch._C._distributed_c10d", 180 "torch._C._distributed_rpc", 181 "torch._C._functorch", 182 "torch._C._monitor", 183 "torch._C._nvtx", 184 "torch._C._lazy", 185 "torch._C._profiler", 186 "torch.__config__", 187 "torch._custom_op", 188 "torch._decomp", 189 "torch._dispatch", 190 "torch._export", 191 "torch._functorch.make_functional", 192 "torch._functorch.compile_utils", 193 "torch._functorch.partitioners", 194 "torch._functorch.aot_autograd", 195 "torch._functorch.compilers", 196 "torch._functorch.fx_minifier", 197 "torch.autograd.profiler_util", 198 "torch.autograd.profiler", 199 "torch._jit_internal", 200 "torch._library", 201 "torch._lobpcg", 202 "torch._logging", 203 "torch._meta_registrations", 204 "torch._namedtensor_internals", 205 "torch._numpy", 206 "torch._sources", 207 "torch._subclasses", 208 "torch._tensor", 209 "torch._tensor_str", 210 "torch._utils", 211 "torch._utils_internal", 212 "torch._vmap_internals", 213 "torch.compiler", 214 "torch.distributed", 215 "torch.export", 216 "torch.hub", 217 "torch.jit", 218 "torch.library", 219 "torch.masked.maskedtensor", 220 "torch.nn.init", 221 "torch.nn.modules.module", 222 "torch.nn.parallel", 223 "torch.nn.utils", 224 "torch.multiprocessing", 225 "torch.onnx", 226 "torch.overrides", 227 "torch.package", 228 "torch.profiler", 229 "torch.serialization", 230 "torch.storage", 231 "torch.utils", 232 "torch.distributed.", 233 ] 234 235 allowed_modules_dot = tuple([x + "." for x in allowed_modules]) 236 module = inspect.getmodule(obj) 237 if module is None: 238 return False 239 240 mod_name = module.__name__ 241 242 if any(mod_name.startswith(m) for m in disallowed_modules): 243 return False 244 245 return mod_name in allowed_modules or mod_name.startswith(allowed_modules_dot) 246 247 def _find_torch_objects(module): 248 if any( 249 module.__name__.startswith(mod_name) 250 for mod_name in config.allowed_functions_module_string_ignorelist 251 ): 252 return 253 torch_object_ids[id(module)] = module.__name__ 254 for name, obj in list(module.__dict__.items()): 255 if id(obj) not in torch_object_ids: 256 # Dynamo allows all builtins into the graph and does not attempt 257 # to introspect into them. We don't want to allow instances of 258 # HigherOrderOperator into the graph all the time (Dynamo needs 259 # to introspect the body functions of these HigherOrderOperator 260 # first, decide they are safe, and then allow them into the graph). 261 # So we exclude HigherOrderOperator from being a builtin. 262 import torch._ops 263 264 if isinstance(obj, torch._ops.HigherOrderOperator): 265 continue 266 267 # We want to trace through `grad` and `vmap` 268 if obj in ( 269 torch.func.grad, 270 deprecated_func.grad, 271 torch.func.vmap, 272 deprecated_func.vmap, 273 torch.nn.functional.triplet_margin_with_distance_loss, 274 torch.cond, 275 ): 276 continue 277 278 if isinstance(obj, types.ModuleType): 279 if obj.__name__.startswith("torch.") and _is_allowed_module_prefix( 280 obj 281 ): 282 torch_object_ids[id(obj)] = f"{module.__name__}.{name}" 283 _find_torch_objects(obj) 284 elif _is_allowed_module_prefix(obj): 285 if record: 286 heuristic_record_if_in_graph_function(obj, module, name) 287 torch_object_ids[id(obj)] = f"{module.__name__}.{name}" 288 elif inspect.getmodule(obj) is None and not is_safe_constant(obj): 289 if record: 290 heuristic_record_if_in_graph_function(obj, module, name) 291 torch_object_ids[id(obj)] = f"{module.__name__}.{name}" 292 293 _find_torch_objects(torch) 294 _find_torch_objects(math) 295 296 return AllowedObjects( 297 torch_object_ids, 298 c_binding_in_graph_functions, 299 non_c_binding_in_graph_functions, 300 torch_name_rule_map, 301 ) 302 303 304class TraceRuleTests(torch._dynamo.test_case.TestCase): 305 def _check_set_equality(self, generated, used, rule_map, ignored_set): 306 x = generated - used 307 y = used - generated 308 msg1 = ( 309 f"New torch objects: {x} " 310 f"were not added to `trace_rules.{rule_map}` or `test_trace_rules.{ignored_set}`. " 311 "Refer the instruction in `torch/_dynamo/trace_rules.py` for more details." 312 ) 313 msg2 = ( 314 f"Existing torch objects: {y} were removed. " 315 f"Please remove them from `trace_rules.{rule_map}` or `test_trace_rules.{ignored_set}`. " 316 "Refer the instruction in `torch/_dynamo/trace_rules.py` for more details." 317 ) 318 self.assertTrue(len(x) == 0, msg1) 319 self.assertTrue(len(y) == 0, msg2) 320 321 # We are using python function and module string names for these inlinelist, 322 # this unit test is to make sure the functions/modules can be correctly imported 323 # or loaded in case there is typo in the strings. 324 def test_skipfiles_inlinelist(self): 325 for m in LEGACY_MOD_INLINELIST.union(MOD_INLINELIST): 326 self.assertTrue( 327 isinstance(importlib.import_module(m), types.ModuleType), 328 f"{m} from trace_rules.MOD_INLINELIST/LEGACY_MOD_INLINELIST is not a python module, please check and correct it.", 329 ) 330 331 @unittest.skip( 332 "This test keeps getting broken and our disable infra is not handling well. see #120627" 333 ) 334 def test_torch_name_rule_map_updated(self): 335 # Generate the allowed objects based on heuristic defined in `allowed_functions.py`, 336 objs = gen_allowed_objs_and_ids(record=True, c_binding_only=True) 337 # Test C binding in graph functions are updated in torch_name_rule_map. 338 generated = objs.c_binding_in_graph_functions 339 used = set() 340 for x in ( 341 set(torch_c_binding_in_graph_functions.keys()) 342 | ignored_c_binding_in_graph_function_names 343 ): 344 obj = load_object(x) 345 if obj is not None: 346 used.add(obj) 347 self._check_set_equality( 348 generated, 349 used, 350 "torch_c_binding_in_graph_functions", 351 "ignored_c_binding_in_graph_function_names", 352 ) 353 # For non C binding in graph functions, we only test if they can be loaded successfully. 354 for f in torch_non_c_binding_in_graph_functions: 355 self.assertTrue( 356 isinstance( 357 load_object(f), 358 ( 359 types.FunctionType, 360 types.BuiltinFunctionType, 361 types.MethodDescriptorType, 362 types.WrapperDescriptorType, 363 ), 364 ) 365 ) 366 367 def test_force_inline_torch_function(self): 368 # `torch._dynamo.utils.istype` is skipped by default 369 def fn(x): 370 if istype(x, torch.Tensor): 371 return x + 1 372 else: 373 return x - 1 374 375 _manual_torch_name_rule_map = manual_torch_name_rule_map.copy() 376 # Force inline `torch._dynamo.utils.istype` by setting trace rule. 377 _manual_torch_name_rule_map["torch._dynamo.utils.istype"] = UserFunctionVariable 378 379 _torch_name_rule_map = [ 380 _manual_torch_name_rule_map, 381 torch_c_binding_in_graph_functions, 382 torch_non_c_binding_in_graph_functions, 383 ] 384 385 self.assertTrue( 386 "torch._dynamo" not in torch._dynamo.trace_rules.LEGACY_MOD_INLINELIST 387 ) 388 self.assertTrue("torch._dynamo" not in torch._dynamo.trace_rules.MOD_INLINELIST) 389 390 with unittest.mock.patch( 391 "torch._dynamo.trace_rules.torch_name_rule_map", 392 _torch_name_rule_map, 393 ), unittest.mock.patch( 394 "torch._dynamo.trace_rules.get_torch_obj_rule_map", 395 torch._dynamo.trace_rules.get_torch_obj_rule_map.__wrapped__, # bypass functools.lru_cache 396 ): 397 x = torch.rand(3) 398 opt_fn = torch.compile(backend="eager", fullgraph=True)(fn) 399 ref = fn(x) 400 res = opt_fn(x) 401 self.assertEqual(ref, res) 402 403 def test_force_inline_custom_function(self): 404 mod, func = create_dummy_module_and_function() 405 406 def fn(x): 407 return func(x) 408 409 _manual_torch_name_rule_map = manual_torch_name_rule_map.copy() 410 # Force inline `mod.func` by setting trace rule. 411 _manual_torch_name_rule_map[ 412 f"{mod.__name__}.{func.__name__}" 413 ] = UserFunctionVariable 414 415 _torch_name_rule_map = [ 416 _manual_torch_name_rule_map, 417 torch_c_binding_in_graph_functions, 418 torch_non_c_binding_in_graph_functions, 419 ] 420 421 with unittest.mock.patch( 422 "torch._dynamo.trace_rules.torch_name_rule_map", 423 _torch_name_rule_map, 424 ), unittest.mock.patch( 425 "torch._dynamo.trace_rules.get_torch_obj_rule_map", 426 torch._dynamo.trace_rules.get_torch_obj_rule_map.__wrapped__, 427 ): 428 # First adding the module to SKIP_DIRS so that it will be skipped by default. 429 torch._dynamo.trace_rules.add(mod.__name__) 430 x = torch.rand(3) 431 opt_fn = torch.compile(backend="eager", fullgraph=True)(fn) 432 ref = fn(x) 433 res = opt_fn(x) 434 self.assertEqual(ref, res) 435 436 437class TestModuleSurviveSkipFiles(torch._dynamo.test_case.TestCase): 438 @unittest.skipIf( 439 not torch.distributed.is_available(), 440 "need to import MLP module from distributed", 441 ) 442 @skipIfWindows( 443 msg="AssertionError: False is not true : MLP did not survive skip files" 444 ) 445 def test_module_survive_skip_files(self): 446 from torch.testing._internal.common_fsdp import MLP 447 448 model = MLP(3) 449 inp = torch.randn((2, 3)) 450 frame_count_before = torch._dynamo.convert_frame.FRAME_COUNTER 451 model.compile(backend="eager") 452 model(inp) 453 frame_count_after = torch._dynamo.convert_frame.FRAME_COUNTER 454 self.assertTrue( 455 frame_count_after > frame_count_before, "MLP did not survive skip files" 456 ) 457 458 459if __name__ == "__main__": 460 from torch._dynamo.test_case import run_tests 461 462 run_tests() 463