xref: /aosp_15_r20/external/pytorch/test/dynamo/test_trace_rules.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: dynamo"]
2import dataclasses
3import importlib
4import inspect
5import math
6import types
7import unittest
8import warnings
9from typing import Any, Dict, Set
10
11import torch
12import torch._dynamo.config as config
13import torch._dynamo.test_case
14import torch._functorch.deprecated as deprecated_func
15from torch._dynamo.trace_rules import (
16    LEGACY_MOD_INLINELIST,
17    load_object,
18    manual_torch_name_rule_map,
19    MOD_INLINELIST,
20    torch_c_binding_in_graph_functions,
21    torch_non_c_binding_in_graph_functions,
22)
23from torch._dynamo.utils import hashable, is_safe_constant, istype
24from torch._dynamo.variables import TorchInGraphFunctionVariable, UserFunctionVariable
25from torch.testing._internal.common_utils import skipIfWindows
26
27
28try:
29    from .utils import create_dummy_module_and_function
30except ImportError:
31    from utils import create_dummy_module_and_function
32
33
34ignored_c_binding_in_graph_function_names = {
35    # Ignored because they have manual rules defined at `trace_rules.manual_torch_name_rule_map`.
36    "torch._nested_tensor_from_mask",
37    "torch._nested_from_padded",
38    "torch.sparse_compressed_tensor",
39    "torch.sparse_bsc_tensor",
40    "torch.sparse_bsr_tensor",
41    "torch.sparse_coo_tensor",
42    "torch.sparse_csc_tensor",
43    "torch.sparse_csr_tensor",
44    "torch.cuda._get_device_properties",
45    # Ignored and go through rules defined at `trace_rules.check`.
46    "torch._functionalize_are_all_mutations_under_no_grad_or_inference_mode",
47    "torch._cslt_sparse_mm_search",
48    "torch._C._abort",
49    "torch._C._mps_is_on_macos_or_newer",
50    "torch._C._swap_tensor_impl",
51    "torch._C._unsafe_reset_storage",
52    "torch._dynamo.eval_frame.reset_code",
53    "torch._C.autocast_decrement_nesting",
54    "torch._C.autocast_increment_nesting",
55    "torch._C.clear_autocast_cache",
56    "torch._C.set_anomaly_enabled",
57    "torch._C.set_autocast_cache_enabled",
58    "torch._C.set_autocast_cpu_dtype",
59    "torch._C.set_autocast_cpu_enabled",
60    "torch._C.set_autocast_enabled",
61    "torch._C.set_autocast_gpu_dtype",
62    "torch._C.set_autocast_ipu_dtype",
63    "torch._C.set_autocast_ipu_enabled",
64    "torch._C.set_autocast_xla_dtype",
65    "torch._C.set_autocast_xla_enabled",
66    "torch.resize_as_",
67    "torch.resize_as_sparse_",
68    "torch._C._data_address",
69    "torch._C._is_cow_tensor",
70    "torch._lazy_clone",
71    "torch._test_parallel_materialize",
72    "torch._C._storage_address",
73    "torch._C._pickle_save",
74    "torch._validate_sparse_compressed_tensor_args",
75    "torch._validate_sparse_csr_tensor_args",
76    "torch._validate_sparse_bsr_tensor_args",
77    "torch._validate_sparse_csc_tensor_args",
78    "torch._validate_sparse_coo_tensor_args",
79    "torch._validate_sparse_bsc_tensor_args",
80    "torch._validate_compressed_sparse_indices",
81}
82if torch._C._llvm_enabled():
83    ignored_c_binding_in_graph_function_names |= {
84        "torch._C._te.set_llvm_aot_workflow",
85        "torch._C._te.set_llvm_target_cpu",
86        "torch._C._te.set_llvm_target_attrs",
87        "torch._C._te.set_llvm_target_triple",
88    }
89
90
91# Helper function to dump the torch name rule map generated based on
92# the heuristic defined in gen_allowed_objs_and_ids.
93def dump_allowed_torch_name_rule_map() -> None:
94    m = gen_allowed_objs_and_ids(record=True, c_binding_only=False).name_rule_map
95    for k, v in m.items():
96        print(f'"{k}": {v.__name__},')
97
98
99@dataclasses.dataclass
100class AllowedObjects:
101    """
102    Track the objects, object id - name pairs, and name - dynamo wrapping rule pairs
103    from the heuristic defined in `gen_allowed_objs_and_ids`.
104    """
105
106    object_ids: Dict[int, str]
107    c_binding_in_graph_functions: Set[Any]
108    non_c_binding_in_graph_functions: Set[Any]
109    name_rule_map: Dict[str, Any]
110
111
112def gen_allowed_objs_and_ids(record=False, c_binding_only=True) -> AllowedObjects:
113    """
114    Walk torch.* and get the ids of all the stuff in it
115    """
116
117    warnings.filterwarnings("ignore", category=UserWarning, module="torch.distributed")
118    torch_object_ids = {}
119    c_binding_in_graph_functions = set()
120    non_c_binding_in_graph_functions = set()
121    torch_name_rule_map = {}
122
123    # In some platforms, these functions were loaded as classes instead of functions.
124    # To mitigate these weired cases, we need this special check.
125    def is_special_functions(obj):
126        return hashable(obj) and obj in {
127            torch._C._cuda_isCurrentStreamCapturing,
128            torch._C._graph_pool_handle,
129        }
130
131    # Add obj to c_binding_in_graph_functions set or non_c_binding_in_graph_functions set
132    # if it's a torch function or method.
133    # This is used to generate the in graph function list based on heuristic.
134    def heuristic_record_if_in_graph_function(obj, module, name):
135        try:
136            if hasattr(obj, "__wrapped__"):
137                obj = obj.__wrapped__
138        except Exception:
139            pass
140        if isinstance(
141            obj,
142            (
143                types.FunctionType,
144                types.BuiltinFunctionType,
145                types.MethodDescriptorType,
146                types.WrapperDescriptorType,
147            ),
148        ) or is_special_functions(obj):
149            torch_name_rule_map[
150                f"{module.__name__}.{name}"
151            ] = TorchInGraphFunctionVariable
152            if c_binding_only:
153                if not hasattr(obj, "__code__"):
154                    c_binding_in_graph_functions.add(obj)
155            else:
156                if hasattr(obj, "__code__"):
157                    non_c_binding_in_graph_functions.add(obj)
158                else:
159                    c_binding_in_graph_functions.add(obj)
160
161    def _is_allowed_module_prefix(obj):
162        allowed_modules = ("torch", "math")
163        # torch.nn.modules.rnn is disallowed because these modules internally
164        # flatten their parameters.  This flattening process will call
165        # Tensor.set_ with a Storage, and Storages cannot be traced with
166        # AOTAutograd; so we need to graph-break. To ensure this, we inline
167        # these functions, rather than keep them opaque-ly in the graph.
168        disallowed_modules = [
169            "torch.optim.",
170            "torch.nn.modules.rnn.",
171            "torch._dynamo.",
172            "torch._C._dynamo.",
173            "torch._inductor.",
174            "torch._C.inductor.",
175            "torch.fx.",
176            "torch._C._autograd",
177            "torch._C._cudart",
178            "torch._C._distributed_autograd",
179            "torch._C._distributed_c10d",
180            "torch._C._distributed_rpc",
181            "torch._C._functorch",
182            "torch._C._monitor",
183            "torch._C._nvtx",
184            "torch._C._lazy",
185            "torch._C._profiler",
186            "torch.__config__",
187            "torch._custom_op",
188            "torch._decomp",
189            "torch._dispatch",
190            "torch._export",
191            "torch._functorch.make_functional",
192            "torch._functorch.compile_utils",
193            "torch._functorch.partitioners",
194            "torch._functorch.aot_autograd",
195            "torch._functorch.compilers",
196            "torch._functorch.fx_minifier",
197            "torch.autograd.profiler_util",
198            "torch.autograd.profiler",
199            "torch._jit_internal",
200            "torch._library",
201            "torch._lobpcg",
202            "torch._logging",
203            "torch._meta_registrations",
204            "torch._namedtensor_internals",
205            "torch._numpy",
206            "torch._sources",
207            "torch._subclasses",
208            "torch._tensor",
209            "torch._tensor_str",
210            "torch._utils",
211            "torch._utils_internal",
212            "torch._vmap_internals",
213            "torch.compiler",
214            "torch.distributed",
215            "torch.export",
216            "torch.hub",
217            "torch.jit",
218            "torch.library",
219            "torch.masked.maskedtensor",
220            "torch.nn.init",
221            "torch.nn.modules.module",
222            "torch.nn.parallel",
223            "torch.nn.utils",
224            "torch.multiprocessing",
225            "torch.onnx",
226            "torch.overrides",
227            "torch.package",
228            "torch.profiler",
229            "torch.serialization",
230            "torch.storage",
231            "torch.utils",
232            "torch.distributed.",
233        ]
234
235        allowed_modules_dot = tuple([x + "." for x in allowed_modules])
236        module = inspect.getmodule(obj)
237        if module is None:
238            return False
239
240        mod_name = module.__name__
241
242        if any(mod_name.startswith(m) for m in disallowed_modules):
243            return False
244
245        return mod_name in allowed_modules or mod_name.startswith(allowed_modules_dot)
246
247    def _find_torch_objects(module):
248        if any(
249            module.__name__.startswith(mod_name)
250            for mod_name in config.allowed_functions_module_string_ignorelist
251        ):
252            return
253        torch_object_ids[id(module)] = module.__name__
254        for name, obj in list(module.__dict__.items()):
255            if id(obj) not in torch_object_ids:
256                # Dynamo allows all builtins into the graph and does not attempt
257                # to introspect into them. We don't want to allow instances of
258                # HigherOrderOperator into the graph all the time (Dynamo needs
259                # to introspect the body functions of these HigherOrderOperator
260                # first, decide they are safe, and then allow them into the graph).
261                # So we exclude HigherOrderOperator from being a builtin.
262                import torch._ops
263
264                if isinstance(obj, torch._ops.HigherOrderOperator):
265                    continue
266
267                # We want to trace through `grad` and `vmap`
268                if obj in (
269                    torch.func.grad,
270                    deprecated_func.grad,
271                    torch.func.vmap,
272                    deprecated_func.vmap,
273                    torch.nn.functional.triplet_margin_with_distance_loss,
274                    torch.cond,
275                ):
276                    continue
277
278                if isinstance(obj, types.ModuleType):
279                    if obj.__name__.startswith("torch.") and _is_allowed_module_prefix(
280                        obj
281                    ):
282                        torch_object_ids[id(obj)] = f"{module.__name__}.{name}"
283                        _find_torch_objects(obj)
284                elif _is_allowed_module_prefix(obj):
285                    if record:
286                        heuristic_record_if_in_graph_function(obj, module, name)
287                    torch_object_ids[id(obj)] = f"{module.__name__}.{name}"
288                elif inspect.getmodule(obj) is None and not is_safe_constant(obj):
289                    if record:
290                        heuristic_record_if_in_graph_function(obj, module, name)
291                    torch_object_ids[id(obj)] = f"{module.__name__}.{name}"
292
293    _find_torch_objects(torch)
294    _find_torch_objects(math)
295
296    return AllowedObjects(
297        torch_object_ids,
298        c_binding_in_graph_functions,
299        non_c_binding_in_graph_functions,
300        torch_name_rule_map,
301    )
302
303
304class TraceRuleTests(torch._dynamo.test_case.TestCase):
305    def _check_set_equality(self, generated, used, rule_map, ignored_set):
306        x = generated - used
307        y = used - generated
308        msg1 = (
309            f"New torch objects: {x} "
310            f"were not added to `trace_rules.{rule_map}` or `test_trace_rules.{ignored_set}`. "
311            "Refer the instruction in `torch/_dynamo/trace_rules.py` for more details."
312        )
313        msg2 = (
314            f"Existing torch objects: {y} were removed. "
315            f"Please remove them from `trace_rules.{rule_map}` or `test_trace_rules.{ignored_set}`. "
316            "Refer the instruction in `torch/_dynamo/trace_rules.py` for more details."
317        )
318        self.assertTrue(len(x) == 0, msg1)
319        self.assertTrue(len(y) == 0, msg2)
320
321    # We are using python function and module string names for these inlinelist,
322    # this unit test is to make sure the functions/modules can be correctly imported
323    # or loaded in case there is typo in the strings.
324    def test_skipfiles_inlinelist(self):
325        for m in LEGACY_MOD_INLINELIST.union(MOD_INLINELIST):
326            self.assertTrue(
327                isinstance(importlib.import_module(m), types.ModuleType),
328                f"{m} from trace_rules.MOD_INLINELIST/LEGACY_MOD_INLINELIST is not a python module, please check and correct it.",
329            )
330
331    @unittest.skip(
332        "This test keeps getting broken and our disable infra is not handling well. see #120627"
333    )
334    def test_torch_name_rule_map_updated(self):
335        # Generate the allowed objects based on heuristic defined in `allowed_functions.py`,
336        objs = gen_allowed_objs_and_ids(record=True, c_binding_only=True)
337        # Test C binding in graph functions are updated in torch_name_rule_map.
338        generated = objs.c_binding_in_graph_functions
339        used = set()
340        for x in (
341            set(torch_c_binding_in_graph_functions.keys())
342            | ignored_c_binding_in_graph_function_names
343        ):
344            obj = load_object(x)
345            if obj is not None:
346                used.add(obj)
347        self._check_set_equality(
348            generated,
349            used,
350            "torch_c_binding_in_graph_functions",
351            "ignored_c_binding_in_graph_function_names",
352        )
353        # For non C binding in graph functions, we only test if they can be loaded successfully.
354        for f in torch_non_c_binding_in_graph_functions:
355            self.assertTrue(
356                isinstance(
357                    load_object(f),
358                    (
359                        types.FunctionType,
360                        types.BuiltinFunctionType,
361                        types.MethodDescriptorType,
362                        types.WrapperDescriptorType,
363                    ),
364                )
365            )
366
367    def test_force_inline_torch_function(self):
368        # `torch._dynamo.utils.istype` is skipped by default
369        def fn(x):
370            if istype(x, torch.Tensor):
371                return x + 1
372            else:
373                return x - 1
374
375        _manual_torch_name_rule_map = manual_torch_name_rule_map.copy()
376        # Force inline `torch._dynamo.utils.istype` by setting trace rule.
377        _manual_torch_name_rule_map["torch._dynamo.utils.istype"] = UserFunctionVariable
378
379        _torch_name_rule_map = [
380            _manual_torch_name_rule_map,
381            torch_c_binding_in_graph_functions,
382            torch_non_c_binding_in_graph_functions,
383        ]
384
385        self.assertTrue(
386            "torch._dynamo" not in torch._dynamo.trace_rules.LEGACY_MOD_INLINELIST
387        )
388        self.assertTrue("torch._dynamo" not in torch._dynamo.trace_rules.MOD_INLINELIST)
389
390        with unittest.mock.patch(
391            "torch._dynamo.trace_rules.torch_name_rule_map",
392            _torch_name_rule_map,
393        ), unittest.mock.patch(
394            "torch._dynamo.trace_rules.get_torch_obj_rule_map",
395            torch._dynamo.trace_rules.get_torch_obj_rule_map.__wrapped__,  # bypass functools.lru_cache
396        ):
397            x = torch.rand(3)
398            opt_fn = torch.compile(backend="eager", fullgraph=True)(fn)
399            ref = fn(x)
400            res = opt_fn(x)
401            self.assertEqual(ref, res)
402
403    def test_force_inline_custom_function(self):
404        mod, func = create_dummy_module_and_function()
405
406        def fn(x):
407            return func(x)
408
409        _manual_torch_name_rule_map = manual_torch_name_rule_map.copy()
410        # Force inline `mod.func` by setting trace rule.
411        _manual_torch_name_rule_map[
412            f"{mod.__name__}.{func.__name__}"
413        ] = UserFunctionVariable
414
415        _torch_name_rule_map = [
416            _manual_torch_name_rule_map,
417            torch_c_binding_in_graph_functions,
418            torch_non_c_binding_in_graph_functions,
419        ]
420
421        with unittest.mock.patch(
422            "torch._dynamo.trace_rules.torch_name_rule_map",
423            _torch_name_rule_map,
424        ), unittest.mock.patch(
425            "torch._dynamo.trace_rules.get_torch_obj_rule_map",
426            torch._dynamo.trace_rules.get_torch_obj_rule_map.__wrapped__,
427        ):
428            # First adding the module to SKIP_DIRS so that it will be skipped by default.
429            torch._dynamo.trace_rules.add(mod.__name__)
430            x = torch.rand(3)
431            opt_fn = torch.compile(backend="eager", fullgraph=True)(fn)
432            ref = fn(x)
433            res = opt_fn(x)
434            self.assertEqual(ref, res)
435
436
437class TestModuleSurviveSkipFiles(torch._dynamo.test_case.TestCase):
438    @unittest.skipIf(
439        not torch.distributed.is_available(),
440        "need to import MLP module from distributed",
441    )
442    @skipIfWindows(
443        msg="AssertionError: False is not true : MLP did not survive skip files"
444    )
445    def test_module_survive_skip_files(self):
446        from torch.testing._internal.common_fsdp import MLP
447
448        model = MLP(3)
449        inp = torch.randn((2, 3))
450        frame_count_before = torch._dynamo.convert_frame.FRAME_COUNTER
451        model.compile(backend="eager")
452        model(inp)
453        frame_count_after = torch._dynamo.convert_frame.FRAME_COUNTER
454        self.assertTrue(
455            frame_count_after > frame_count_before, "MLP did not survive skip files"
456        )
457
458
459if __name__ == "__main__":
460    from torch._dynamo.test_case import run_tests
461
462    run_tests()
463