1# mypy: allow-untyped-defs 2# Copyright (c) Facebook, Inc. and its affiliates. 3# 4# This source code is licensed under the BSD license found in the 5# LICENSE file in the root directory of this source tree. 6 7import contextlib 8import copy 9from abc import ABC, abstractmethod 10from typing import ( 11 Any, 12 Callable, 13 cast, 14 Dict, 15 Generator, 16 Iterable, 17 Optional, 18 Sequence, 19 Set, 20 Tuple, 21 Type, 22 Union, 23) 24 25import torch.nn as nn 26 27 28__all__ = [ 29 "always_wrap_policy", 30 "lambda_auto_wrap_policy", 31 "transformer_auto_wrap_policy", 32 "size_based_auto_wrap_policy", 33 "enable_wrap", 34 "wrap", 35 "CustomPolicy", 36 "ModuleWrapPolicy", 37] 38 39 40# NOTE: We intentionally keep this function simple and isolate the complexity 41# to `fn` to enable using this function generically. We may move this to a 42# non-FSDP-specific folder and/or make it public in the future. 43def _post_order_apply( 44 root_module: nn.Module, 45 fn: Callable[[nn.Module], Optional[nn.Module]], 46): 47 """ 48 This applies ``fn`` to every module in the module tree of ``root_module`` 49 following a post-order traversal. If ``fn`` returns an :class:`nn.Module`, 50 then this replaces the original module with the newly returned one in the 51 tree. Otherwise, ``fn`` should return ``None``, in which case the module is 52 not changed. 53 """ 54 # Track visited modules to avoid visiting shared modules multiple times 55 visited_modules: Set[nn.Module] = {root_module} 56 57 def _post_order_apply_inner( 58 module: nn.Module, 59 module_name: str, 60 parent_module: Optional[nn.Module], 61 ): 62 for child_module_name, child_module in module.named_children(): 63 if child_module not in visited_modules: 64 visited_modules.add(child_module) 65 _post_order_apply_inner(child_module, child_module_name, module) 66 optional_module = fn(module) 67 if optional_module is not None: 68 assert isinstance(parent_module, nn.Module), ( 69 "Non-root modules should have their parent module set but got " 70 f"{parent_module} for {module}" 71 ) 72 assert module_name, ( 73 "Non-root modules should have their module name set but got " 74 f"an empty module name for {module}" 75 ) 76 assert isinstance( 77 optional_module, nn.Module 78 ), f"fn should return None or an nn.Module but got {optional_module}" 79 setattr(parent_module, module_name, optional_module) 80 81 _post_order_apply_inner(root_module, "", None) 82 83 84def _construct_wrap_fn( 85 root_module: nn.Module, 86 target_module_to_kwargs: Dict[nn.Module, Dict[str, Any]], 87 fsdp_fn: Callable, 88) -> Callable[[nn.Module], Optional[nn.Module]]: 89 """ 90 This constructs the "wrap" function to pass to :func:`_post_order_apply` 91 based on ``target_module_to_kwargs``, which should be constructed from the 92 wrapping policy. 93 """ 94 95 def fn(module: nn.Module) -> Optional[nn.Module]: 96 # Explicitly avoid wrapping the root module since for FSDP, it is 97 # handled by the caller 98 if module in target_module_to_kwargs and module is not root_module: 99 kwargs = target_module_to_kwargs[module] 100 return fsdp_fn(module, **kwargs) 101 return None 102 103 return fn 104 105 106def _run_mixed_precision_override_policy( 107 root_module: nn.Module, 108 module_classes: Iterable[Type[nn.Module]], 109 ignored_modules: Set[nn.Module], 110 root_kwargs: Dict[str, Any], 111 target_module_to_kwargs: Dict[nn.Module, Dict[str, Any]], 112): 113 module_classes_tuple = tuple(set(module_classes)) 114 for module in root_module.modules(): 115 if module in ignored_modules: 116 continue 117 elif isinstance(module, module_classes_tuple): 118 # This policy overrides any existing policy 119 if module not in target_module_to_kwargs: 120 # Only inherit from the root kwargs if not already specified 121 target_module_to_kwargs[module] = root_kwargs 122 target_module_to_kwargs[module]["mixed_precision"] = None 123 return target_module_to_kwargs 124 125 126def always_wrap_policy(*args, **kwargs) -> bool: 127 """ 128 A simple recursive wrap policy that always returns ``True``. This means 129 that every submodule is wrapped by the wrapper class in 130 :func:`_recursive_wrap`. 131 """ 132 return True 133 134 135class _Policy(ABC): 136 """ 137 This defines an abstract base class that represents a policy for applying 138 a module-level API. 139 """ 140 141 @abstractmethod 142 def _run_policy( 143 self, 144 root_module: nn.Module, 145 ignored_modules: Set[nn.Module], 146 root_kwargs: Dict[str, Any], 147 ) -> Dict[nn.Module, Dict[str, Any]]: 148 """ 149 This should return a dict ``target_module_to_kwargs`` that maps from 150 each target module to wrap to its kwargs. 151 """ 152 ... 153 154 155def _module_wrap_policy( 156 module: nn.Module, 157 recurse: bool, 158 nonwrapped_numel: int, 159 module_classes: Set[Type[nn.Module]], 160) -> bool: 161 """ 162 This auto wrap policy wraps every module that is an instance of any type in 163 ``module_classes`` as its own FSDP instance. The root module given by 164 ``module`` is always wrapped as an FSDP instance regardless. Since the 165 wrapping proceeds bottom up, each FSDP instance manages the parameters in 166 its subtree excluding any already managed by a child FSDP instance. 167 168 Args: 169 module (nn.Module): Current module being considered. 170 recurse (bool): If ``False``, then this function must decide whether 171 ``module`` should be wrapped as an FSDP instance or not. If 172 ``True``, then the function is still recursing down the module 173 tree as a part of the DFS. 174 nonwrapped_numel (int): Parameter numel not yet wrapped. 175 module_classes (Set[Type[nn.Module]]): Set of module classes that are 176 wrapped as FSDP instances. 177 178 Returns: 179 ``True`` if ``recurse=True``, and whether ``module`` should be wrapped 180 if ``recurse=False``. 181 """ 182 if recurse: 183 return True # always recurse 184 return isinstance(module, tuple(module_classes)) 185 186 187class ModuleWrapPolicy(_Policy): 188 """ 189 This policy applies to every module of the specified module classes, 190 passing in the kwargs given to the root. 191 """ 192 193 def __init__(self, module_classes: Iterable[Type[nn.Module]]): 194 module_classes_set = set(module_classes) 195 self._module_classes = module_classes_set 196 self._module_classes_str = str(module_classes_set) 197 198 def _run_policy( 199 self, 200 root_module: nn.Module, 201 ignored_modules: Set[nn.Module], 202 root_kwargs: Dict[str, Any], 203 ) -> Dict[nn.Module, Dict[str, Any]]: 204 module_classes = tuple(self._module_classes) 205 target_module_to_kwargs: Dict[nn.Module, Dict[str, Any]] = {} 206 for module in root_module.modules(): 207 if module in ignored_modules: 208 continue 209 elif isinstance(module, module_classes): 210 # Shallow copy to avoid coupling changes across modules 211 target_module_to_kwargs[module] = copy.copy(root_kwargs) 212 return target_module_to_kwargs 213 214 def __call__(self, module, recurse, *args, **kwargs): 215 # nonwrapped_numel is not used. 216 return _module_wrap_policy( 217 module, recurse, nonwrapped_numel=-1, module_classes=self._module_classes 218 ) 219 220 def __repr__(self) -> str: 221 return super().__repr__() + f"({self._module_classes_str})" 222 223 224class CustomPolicy(_Policy): 225 """ 226 This policy takes in a lambda function that maps a given ``nn.Module`` to 227 either ``False``, ``True``, or a kwarg dictionary. 228 - If the function returns ``False`` or an empty dictionary, then the module 229 does not have the API applied. 230 - If the function returns ``True``, then the module has the API applied 231 with the root's kwargs. 232 - If the function returns a non-empty dictionary, then the module has the 233 API applied, and the dictionary overrides the root's kwargs. 234 235 Example:: 236 237 >>> # xdoctest: +SKIP("undefined variables") 238 >>> model = init_transformer_model(...) 239 >>> def lambda_fn(module: nn.Module): 240 >>> if module is model.lm_head: 241 >>> return {"sharding_strategy": ShardingStrategy.SHARD_GRAD_OP} 242 >>> elif isinstance(module, TransformerBlock): 243 >>> return True 244 >>> return False 245 >>> policy = CustomPolicy(lambda_fn) 246 >>> fsdp_model = FSDP(model, auto_wrap_policy=policy) 247 """ 248 249 def __init__(self, lambda_fn: Callable[[nn.Module], Union[bool, Dict[str, Any]]]): 250 self._lambda_fn = lambda_fn 251 252 def _run_policy( 253 self, 254 root_module: nn.Module, 255 ignored_modules: Set[nn.Module], 256 root_kwargs: Dict[str, Any], 257 ) -> Dict[nn.Module, Dict[str, Any]]: 258 target_module_to_kwargs: Dict[nn.Module, Dict[str, Any]] = {} 259 for module in root_module.modules(): 260 if module in ignored_modules: 261 continue 262 res = self._lambda_fn(module) 263 if not isinstance(res, (dict, bool)): 264 raise ValueError( 265 "The lambda_fn passed to CustomPolicy should return " 266 f"False/True or a kwarg dict, but it returned {res}" 267 ) 268 if not res: 269 continue 270 kwargs = copy.copy(root_kwargs) 271 if isinstance(res, dict): 272 # Override the root kwargs with the ones specified by the 273 # lambda function 274 kwargs.update(res) 275 target_module_to_kwargs[module] = kwargs 276 return target_module_to_kwargs 277 278 279def lambda_auto_wrap_policy( 280 module: nn.Module, recurse: bool, nonwrapped_numel: int, lambda_fn: Callable 281) -> bool: 282 """ 283 A convenient auto wrap policy to wrap submodules based on an arbitrary user 284 function. If `lambda_fn(submodule) == True``, the submodule will be wrapped as 285 a `wrapper_cls` unit. 286 287 Return if a module should be wrapped during auto wrapping. 288 289 The first three parameters are required by :func:`_recursive_wrap`. 290 291 Args: 292 module (nn.Module): Current module being considered. 293 recurse (bool): If ``False``, then this function must decide whether 294 ``module`` should be wrapped as an FSDP instance or not. If 295 ``True``, then the function is still recursing down the module 296 tree as a part of the DFS. 297 nonwrapped_numel (int): Parameter numel not yet wrapped. 298 299 lambda_fn (Callable[[nn.Module], bool]): If this returns ``True``, then 300 this module will be wrapped. 301 """ 302 if recurse: 303 return True # always recurse 304 return lambda_fn(module) 305 306 307def transformer_auto_wrap_policy( 308 module: nn.Module, 309 recurse: bool, 310 nonwrapped_numel: int, 311 transformer_layer_cls: Set[Type[nn.Module]], 312) -> bool: 313 """ 314 See :func:`_module_wrap_policy`, where ``transformer_layer_cls`` is the 315 same as ``module_classes``. Note that shared parameters must be wrapped in 316 the same FSDP instance, so this auto wrap policy can help wrap shared 317 embeddings into the same FSDP instance for transformer models. 318 """ 319 return _module_wrap_policy(module, recurse, nonwrapped_numel, transformer_layer_cls) 320 321 322def _wrap_module_cls_individually( 323 module: nn.Module, module_classes: Sequence[type], recurse: bool, *args, **kwargs 324): 325 if recurse: 326 # always recurse 327 return True 328 else: 329 # if not recursing, decide whether we should wrap based on whether the type of module 330 # is in `module_classes`. 331 return isinstance(module, tuple(module_classes)) 332 333 334def _or_policy( 335 module: nn.Module, 336 recurse: bool, 337 nonwrapped_numel: int, 338 policies, 339) -> bool: 340 """ 341 A policy that wraps ``module`` if any policy in the passed in iterable of 342 ``policies`` returns ``True``. 343 """ 344 return any( 345 policy(module=module, recurse=recurse, nonwrapped_numel=nonwrapped_numel) 346 for policy in policies 347 ) 348 349 350def size_based_auto_wrap_policy( 351 module: nn.Module, 352 recurse: bool, 353 nonwrapped_numel: int, 354 # Additional custom arguments 355 min_num_params: int = int(1e8), 356 force_leaf_modules: Optional[Set[Type[nn.Module]]] = None, 357 exclude_wrap_modules: Optional[Set[Type[nn.Module]]] = None, 358) -> bool: 359 """ 360 A size-based auto wrap policy. 361 362 Args: 363 module (nn.Module): Current module being considered. 364 recurse (bool): If ``False``, then this function must decide whether 365 ``module`` should be wrapped as an FSDP instance or not. If 366 ``True``, then the function is still recursing down the module 367 tree as a part of the DFS. 368 nonwrapped_numel (int): Parameter numel not yet wrapped. 369 370 min_num_params (int): Customizable policy input that controls the size 371 threshold over which a module is ready to be wrapped. This is in 372 units of numel. 373 force_leaf_modules (Set[Type[nn.Module]]): Set of module types to keep 374 as leaves, i.e. their children will never be wrapped. 375 exclude_wrap_modules (Set[Type[nn.Module]]): Set of module types to be 376 excluded in wrapping. 377 378 Returns: 379 Whether ``module`` should be wrapped. 380 """ 381 force_leaf_modules = ( 382 size_based_auto_wrap_policy.FORCE_LEAF_MODULES # type: ignore[attr-defined] 383 if force_leaf_modules is None 384 else force_leaf_modules 385 ) 386 exclude_wrap_modules = ( 387 size_based_auto_wrap_policy.EXCLUDE_WRAP_MODULES # type: ignore[attr-defined] 388 if exclude_wrap_modules is None 389 else exclude_wrap_modules 390 ) 391 392 # Keep the argument `min_num_params` for BC for now, but it represents the 393 # minimum non-wrapped *numel* before triggering a wrapping 394 min_nonwrapped_numel = min_num_params 395 is_large = nonwrapped_numel >= min_nonwrapped_numel 396 if recurse: 397 # We should recurse if the module is big enough but not in force_leaf_modules list. 398 return is_large and not isinstance(module, tuple(force_leaf_modules)) 399 else: 400 # If we are not recursing, determine if we should wrap. 401 return is_large and not isinstance(module, tuple(exclude_wrap_modules)) 402 403 404# Set those defaults to the size_based_auto_wrap_policy function. Make them easy to be imported. 405size_based_auto_wrap_policy.EXCLUDE_WRAP_MODULES = {nn.ModuleList, nn.ModuleDict} # type: ignore[attr-defined] 406size_based_auto_wrap_policy.FORCE_LEAF_MODULES = {nn.MultiheadAttention} # type: ignore[attr-defined] 407 408 409@contextlib.contextmanager 410def enable_wrap( 411 *, wrapper_cls: Any, **wrapper_kwargs: Any 412) -> Generator[None, None, None]: 413 """ 414 Context manager to wrap modules using a wrapper. 415 416 Useful for when you'd like to apply the same configuration arguments to all 417 child modules that you wrap. A particularly important use case is wrapping 418 large layers so that they get sharded (in-place) during initialization, to 419 avoid running out of system memory. Large layers can indicate that they 420 should be sharded via the ``wrap`` annotation and this context manager can 421 provide the exact configuration for these nested instances. 422 423 Usage:: 424 425 with enable_wrap(wrapper_cls, **params): 426 # Wraps layer in FSDP by default if within context 427 self.l1 = wrap(torch.nn.Linear(5, 5)) 428 429 Args: 430 wrapper_cls: 431 Class that `wrap` annotation will `wrap` modules with, such as 432 `FullyShardedDataParallel`. 433 **wrapper_kwargs: 434 Configuration settings that will be passed to all ``wrap`` 435 instances inside the context 436 """ 437 kwargs = { 438 "wrapper_cls": wrapper_cls, 439 **wrapper_kwargs, 440 } 441 with _ConfigAutoWrap(**kwargs): 442 yield 443 444 445def wrap(module: nn.Module, **wrap_overrides: Any) -> nn.Module: 446 """ 447 Annotate that a module should be wrapped. Annotated modules will only be 448 wrapped if inside of an :func:`enable_wrap` context manager. This allows 449 a module to be initialized both with and without a wrapper without code 450 change. 451 452 The class that this function wraps the passed in ``nn.Module`` with is the 453 passed in ``wrapper_cls`` argument into ``enable_wrap``. Both 454 ``enable_wrap`` and ``wrap`` can take in kwargs specifying how to construct 455 the ``wrapper_cls`` instance. In the case of duplicate kwargs in 456 ``enable_wrap`` and ``wrap``, the argument passed into ``wrap`` will be 457 respected. 458 459 Usage:: 460 461 with enable_wrap(wrapper_cls=FSDP, **fsdp_config): 462 # Wraps layer in FSDP by default if within context 463 self.l1 = wrap(torch.nn.Linear(5, 5)) 464 465 Args: 466 module (nn.Module): module to wrap (if in :func:`enable_wrap` context) 467 **wrap_overrides: configuration overrides that will take priority over 468 the values provided by the :func:`enable_wrap` context 469 """ 470 if _ConfigAutoWrap.in_autowrap_context: 471 assert _ConfigAutoWrap.wrapper_cls is not None 472 473 wrap_overrides = {**_ConfigAutoWrap.kwargs, **wrap_overrides} 474 return _wrap( 475 module, 476 _ConfigAutoWrap.wrapper_cls, 477 **wrap_overrides, 478 ) 479 return module 480 481 482def _wrap(module: nn.Module, wrapper_cls: Callable, **kwargs) -> nn.Module: 483 assert wrapper_cls is not None 484 if hasattr(module, "_wrap_overrides"): 485 # If module has a _wrap_overrides attribute, we force overriding the 486 # FSDP config with these attributes for this module. Currently this 487 # is only used to disable mixed precision for BatchNorm when 488 # auto_wrapping. 489 overrides = {**kwargs, **module._wrap_overrides} # type: ignore[arg-type] 490 return wrapper_cls(module, **overrides) 491 492 return wrapper_cls(module, **kwargs) 493 494 495def _recursive_wrap( 496 module: nn.Module, 497 auto_wrap_policy: Callable, 498 wrapper_cls: Callable, 499 ignored_modules: Set[nn.Module], 500 ignored_params: Set[nn.Parameter], 501 only_wrap_children: bool = False, 502 **kwargs: Any, 503) -> Tuple[nn.Module, int]: 504 """ 505 Wraps submodules of ``module`` for which ``auto_wrap_policy`` returns 506 ``True`` with ``wrapper_cls``. 507 508 Args: 509 module (nn.Module): Module to recursively wrap. 510 auto_wrap_policy (Callable): A callable representing a policy that 511 determines which modules to recursively wrap with ``wrapper_cls``. 512 ignored_modules (Set[torch.nn.Module]): Modules to ignore when 513 wrapping. 514 ignored_params (Set[torch.nn.Parameter]): Parameters to ignore when 515 wrapping; these should be the parameters contained in the modules 516 in ``ignored_modules``. 517 Returns: 518 (nn.Module, int): 519 ``module`` after wrapping and the numel recursively wrapped. 520 """ 521 assert auto_wrap_policy is not None, "Must specify auto_wrap_policy." 522 assert wrapper_cls is not None, "Must specify wrapper_cls" 523 # Make sure no child is already wrapped. 524 for _, child in module.named_modules(): 525 if child in ignored_modules: 526 continue 527 try: 528 assert not isinstance(child, cast(type, wrapper_cls)) 529 except TypeError: 530 # wrapper_cls is a function as opposed to a class type, just bypass above check. 531 pass 532 533 # We count all params, assuming none of them are already wrapped. 534 nonwrapped_numel = sum( 535 p.numel() for p in module.parameters() if p not in ignored_params 536 ) 537 538 assert auto_wrap_policy is not None 539 if auto_wrap_policy(module=module, recurse=True, nonwrapped_numel=nonwrapped_numel): 540 total_wrapped_numel = 0 541 # Iterate through the children, recursively wrap if necessary 542 for name, child in module.named_children(): 543 if child in ignored_modules: 544 continue 545 wrapped_child, num_wrapped_params = _recursive_wrap( 546 module=child, 547 auto_wrap_policy=auto_wrap_policy, 548 wrapper_cls=wrapper_cls, 549 ignored_modules=ignored_modules, 550 ignored_params=ignored_params, 551 **kwargs, 552 ) 553 setattr(module, name, wrapped_child) 554 # Keep track of how many parameters have been wrapped 555 total_wrapped_numel += num_wrapped_params 556 # decide if we need to wrap the current module, 557 # since the left over parameters exceed the number of params to wrap 558 remainder = nonwrapped_numel - total_wrapped_numel 559 if not only_wrap_children and auto_wrap_policy( 560 module=module, recurse=False, nonwrapped_numel=remainder 561 ): 562 # Leaf node or final wrapping of the remainder both happen here. 563 return _wrap(module, wrapper_cls, **kwargs), nonwrapped_numel 564 else: 565 return module, total_wrapped_numel 566 return module, 0 567 568 569class _ConfigAutoWrap: 570 """ 571 Helper class to wrap modules based on default config args via a context manager. 572 See :func:`enable_wrap` for more information. 573 """ 574 575 in_autowrap_context: bool = False # Context flag 576 wrapper_cls: Optional[Callable] = None # The wrapper class 577 kwargs: Dict[str, Any] = {} # Wrapper's args 578 579 def __init__(self, **kwargs: Dict[str, Any]): 580 self.kwargs = kwargs 581 582 @staticmethod 583 def enable_autowrap_context(kwargs: Any) -> None: 584 if _ConfigAutoWrap.in_autowrap_context: 585 raise NotImplementedError( 586 "You are already within an autowrap context and we currently do not supported nested autowrap." 587 ) 588 _ConfigAutoWrap.in_autowrap_context = True 589 # Get and save the wrapper cls for the context. 590 assert ( 591 "wrapper_cls" in kwargs.keys() 592 ), "Expected to pass in wrapper_cls arg into _ConfigAutoWrap." 593 _ConfigAutoWrap.wrapper_cls = cast(Callable, kwargs["wrapper_cls"]) 594 del kwargs["wrapper_cls"] 595 # Save the rest. 596 _ConfigAutoWrap.kwargs = kwargs 597 598 @staticmethod 599 def disable_autowrap_context() -> None: 600 _ConfigAutoWrap.in_autowrap_context = False 601 _ConfigAutoWrap.wrapper_cls = None 602 _ConfigAutoWrap.kwargs = {} 603 604 def __enter__(self) -> None: 605 self.enable_autowrap_context(self.kwargs) 606 607 def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: 608 self.disable_autowrap_context() 609