1# mypy: allow-untyped-defs 2import logging 3import time 4from collections import defaultdict 5from contextlib import contextmanager 6from enum import Enum 7from typing import Dict, Iterator, List, Set, Tuple 8 9import torch 10import torch.distributed as dist 11import torch.distributed.fsdp._flat_param as flat_param_file 12from torch.distributed.fsdp._common_utils import ( 13 _apply_to_modules, 14 _get_module_fsdp_state, 15 clean_tensor_name, 16) 17 18 19logger = logging.getLogger(__name__) 20 21 22class SimpleProfiler: 23 class Type(str, Enum): 24 ALL = "all" 25 ALLGATHER = "all_gather" 26 ALLGATHER_OBJ = "all_gather_object" 27 RESHARDING = "resharding" 28 H2D = "H2D" 29 D2H = "D2H" 30 31 results: Dict[str, float] = defaultdict(float) 32 profiling: Set[str] = set() 33 34 @classmethod 35 def reset(cls) -> None: 36 cls.results.clear() 37 cls.profiling.clear() 38 39 @classmethod 40 @contextmanager 41 def profile(cls, profile_type: str) -> Iterator[None]: 42 assert profile_type not in cls.profiling, ( 43 f"{profile_type} is already being profiled. " 44 "SimpleProfiler does not support profiling multiple instances at " 45 "the same time. " 46 ) 47 48 cls.profiling.add(profile_type) 49 begin = time.monotonic() 50 try: 51 yield 52 finally: 53 end = time.monotonic() 54 cls.results[profile_type] += end - begin 55 cls.profiling.remove(profile_type) 56 57 @classmethod 58 def dump_and_reset(cls, msg: str) -> None: 59 # This cannot be combined with DETAIL distributed log 60 # as the profiling will be very incorrect. 61 if dist.get_rank() == 0 and dist.get_debug_level() == dist.DebugLevel.INFO: 62 logger.info("%s %s", msg, cls.results) 63 cls.reset() 64 65 66def _get_sharded_module_tree_with_module_name_to_fqns( 67 model: torch.nn.Module, 68) -> Tuple[str, Dict[str, List[str]]]: 69 """ 70 It is used for composable fully_shard() code path, it returns 71 1. sharded module tree info: each line reprents a submodule name that contats the 72 submodule's FQN and its submodule class name, if the submodule is sharded by `fully_shard`, 73 the submodule name will add a postfix with ' FULLY SHARDED'. Each increased tree 74 level adds 4 spaces before the printed name. A printed sharded module tree info for a toy model 75 is like this: 76 [CompositeModel] FULLY SHARDED 77 l1[Linear] 78 u1[UnitModule] FULLY SHARDED 79 u1.l1[Linear] 80 u1.seq[Sequential] 81 u1.seq.0[ReLU] 82 u1.seq.1[Linear] 83 u1.seq.2[ReLU] 84 u1.l2[Linear] 85 u2[UnitModule] FULLY SHARDED 86 u2.l1[Linear] 87 u2.seq[Sequential] 88 u2.seq.0[ReLU] 89 u2.seq.1[Linear] 90 u2.seq.2[ReLU] 91 u2.l2[Linear] 92 l2[Linear] 93 2. a dict mapping from the concated module FQN and class name to a list of its managed 94 original parameters' FQNs. An example of the dict for the above toy sharded model is like this: 95 {'[CompositeModel]': ['l1.weight', 'l1.bias', 'l2.weight', 'l2.bias'], 96 'u1[UnitModule]': ['u1.l1.weight', 'u1.l1.bias', 'u1.seq.1.weight', 'u1.seq.1.bias', 'u1.l2.weight', 'u1.l2.bias'], 97 'u2[UnitModule]': ['u2.l1.weight', 'u2.l1.bias', 'u2.seq.1.weight', 'u2.seq.1.bias', 'u2.l2.weight', 'u2.l2.bias'] 98 } 99 All FQNs are prefixed starting from ``model``. 100 101 Args: 102 model (torch.nn.Module): Root module (which may or may not be passed to 103 composable `fully_shard()`). 104 """ 105 106 def module_fn( 107 module, prefix, tree_level, sharded_tree_info, sharded_module_name_to_fqns 108 ): 109 num_spaces = tree_level * 4 110 trimed_prefix = ( 111 prefix[:-1] if (len(prefix) > 0 and prefix[-1] == ".") else prefix 112 ) 113 prefixed_module_name = trimed_prefix + "[" + module.__class__.__name__ + "]" 114 printed_prefixed_module_name = " " * num_spaces + prefixed_module_name 115 116 state = _get_module_fsdp_state(module) 117 if state is None: 118 sharded_tree_info[0] += printed_prefixed_module_name + "\n" 119 return 120 121 handle = state._fully_sharded_module_to_handle.get(module, None) 122 123 if handle: 124 sharded_tree_info[0] += ( 125 printed_prefixed_module_name + " FULLY SHARDED" + "\n" 126 ) 127 else: 128 sharded_tree_info[0] += printed_prefixed_module_name + "\n" 129 130 if handle: 131 param = handle.flat_param 132 assert isinstance(param, flat_param_file.FlatParameter) 133 global_fqns = [ 134 clean_tensor_name(prefix + name) for name in param._fqns 135 ] # prefixed from the top level `model` (i.e. including `prefix`) 136 137 if prefixed_module_name in sharded_module_name_to_fqns: 138 sharded_module_name_to_fqns[prefixed_module_name].extend(global_fqns) 139 else: 140 sharded_module_name_to_fqns[prefixed_module_name] = global_fqns 141 142 def return_fn(sharded_tree_info, sharded_module_name_to_fqns): 143 return sharded_tree_info[0], sharded_module_name_to_fqns 144 145 # Use List to mutate its value in place while running the recursive functions 146 sharded_tree_info: List[str] = [ 147 "", 148 ] 149 sharded_module_name_to_fqns: Dict[str, List[str]] = {} 150 return _apply_to_modules( 151 model, 152 module_fn, 153 return_fn, 154 [key for key, _ in model.named_parameters()], 155 sharded_tree_info, 156 sharded_module_name_to_fqns, 157 ) 158