xref: /aosp_15_r20/external/pytorch/torch/distributed/fsdp/_debug_utils.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import logging
3import time
4from collections import defaultdict
5from contextlib import contextmanager
6from enum import Enum
7from typing import Dict, Iterator, List, Set, Tuple
8
9import torch
10import torch.distributed as dist
11import torch.distributed.fsdp._flat_param as flat_param_file
12from torch.distributed.fsdp._common_utils import (
13    _apply_to_modules,
14    _get_module_fsdp_state,
15    clean_tensor_name,
16)
17
18
19logger = logging.getLogger(__name__)
20
21
22class SimpleProfiler:
23    class Type(str, Enum):
24        ALL = "all"
25        ALLGATHER = "all_gather"
26        ALLGATHER_OBJ = "all_gather_object"
27        RESHARDING = "resharding"
28        H2D = "H2D"
29        D2H = "D2H"
30
31    results: Dict[str, float] = defaultdict(float)
32    profiling: Set[str] = set()
33
34    @classmethod
35    def reset(cls) -> None:
36        cls.results.clear()
37        cls.profiling.clear()
38
39    @classmethod
40    @contextmanager
41    def profile(cls, profile_type: str) -> Iterator[None]:
42        assert profile_type not in cls.profiling, (
43            f"{profile_type} is already being profiled. "
44            "SimpleProfiler does not support profiling multiple instances at "
45            "the same time. "
46        )
47
48        cls.profiling.add(profile_type)
49        begin = time.monotonic()
50        try:
51            yield
52        finally:
53            end = time.monotonic()
54            cls.results[profile_type] += end - begin
55            cls.profiling.remove(profile_type)
56
57    @classmethod
58    def dump_and_reset(cls, msg: str) -> None:
59        # This cannot be combined with DETAIL distributed log
60        # as the profiling will be very incorrect.
61        if dist.get_rank() == 0 and dist.get_debug_level() == dist.DebugLevel.INFO:
62            logger.info("%s %s", msg, cls.results)
63        cls.reset()
64
65
66def _get_sharded_module_tree_with_module_name_to_fqns(
67    model: torch.nn.Module,
68) -> Tuple[str, Dict[str, List[str]]]:
69    """
70    It is used for composable fully_shard() code path, it returns
71      1. sharded module tree info: each line reprents a submodule name that contats the
72    submodule's FQN and its submodule class name, if the submodule is sharded by `fully_shard`,
73    the submodule name will add a postfix with ' FULLY SHARDED'. Each increased tree
74    level adds 4 spaces before the printed name. A printed sharded module tree info for a toy model
75    is like this:
76        [CompositeModel] FULLY SHARDED
77            l1[Linear]
78            u1[UnitModule] FULLY SHARDED
79                u1.l1[Linear]
80                u1.seq[Sequential]
81                    u1.seq.0[ReLU]
82                    u1.seq.1[Linear]
83                    u1.seq.2[ReLU]
84                u1.l2[Linear]
85            u2[UnitModule] FULLY SHARDED
86                u2.l1[Linear]
87                u2.seq[Sequential]
88                    u2.seq.0[ReLU]
89                    u2.seq.1[Linear]
90                    u2.seq.2[ReLU]
91                u2.l2[Linear]
92            l2[Linear]
93      2. a dict mapping from the concated module FQN and class name to a list of its managed
94    original parameters' FQNs. An example of the dict for the above toy sharded model is like this:
95            {'[CompositeModel]': ['l1.weight', 'l1.bias', 'l2.weight', 'l2.bias'],
96             'u1[UnitModule]': ['u1.l1.weight', 'u1.l1.bias', 'u1.seq.1.weight', 'u1.seq.1.bias', 'u1.l2.weight', 'u1.l2.bias'],
97             'u2[UnitModule]': ['u2.l1.weight', 'u2.l1.bias', 'u2.seq.1.weight', 'u2.seq.1.bias', 'u2.l2.weight', 'u2.l2.bias']
98            }
99    All FQNs are prefixed starting from ``model``.
100
101    Args:
102        model (torch.nn.Module): Root module (which may or may not be passed to
103                                 composable `fully_shard()`).
104    """
105
106    def module_fn(
107        module, prefix, tree_level, sharded_tree_info, sharded_module_name_to_fqns
108    ):
109        num_spaces = tree_level * 4
110        trimed_prefix = (
111            prefix[:-1] if (len(prefix) > 0 and prefix[-1] == ".") else prefix
112        )
113        prefixed_module_name = trimed_prefix + "[" + module.__class__.__name__ + "]"
114        printed_prefixed_module_name = " " * num_spaces + prefixed_module_name
115
116        state = _get_module_fsdp_state(module)
117        if state is None:
118            sharded_tree_info[0] += printed_prefixed_module_name + "\n"
119            return
120
121        handle = state._fully_sharded_module_to_handle.get(module, None)
122
123        if handle:
124            sharded_tree_info[0] += (
125                printed_prefixed_module_name + " FULLY SHARDED" + "\n"
126            )
127        else:
128            sharded_tree_info[0] += printed_prefixed_module_name + "\n"
129
130        if handle:
131            param = handle.flat_param
132            assert isinstance(param, flat_param_file.FlatParameter)
133            global_fqns = [
134                clean_tensor_name(prefix + name) for name in param._fqns
135            ]  # prefixed from the top level `model` (i.e. including `prefix`)
136
137            if prefixed_module_name in sharded_module_name_to_fqns:
138                sharded_module_name_to_fqns[prefixed_module_name].extend(global_fqns)
139            else:
140                sharded_module_name_to_fqns[prefixed_module_name] = global_fqns
141
142    def return_fn(sharded_tree_info, sharded_module_name_to_fqns):
143        return sharded_tree_info[0], sharded_module_name_to_fqns
144
145    # Use List to mutate its value in place while running the recursive functions
146    sharded_tree_info: List[str] = [
147        "",
148    ]
149    sharded_module_name_to_fqns: Dict[str, List[str]] = {}
150    return _apply_to_modules(
151        model,
152        module_fn,
153        return_fn,
154        [key for key, _ in model.named_parameters()],
155        sharded_tree_info,
156        sharded_module_name_to_fqns,
157    )
158