1# Copyright (c) Meta Platforms, Inc. and affiliates 2from fnmatch import fnmatch 3from typing import Dict, Union 4 5import torch 6import torch.distributed.tensor._random as random 7import torch.nn as nn 8from torch.distributed.tensor import DeviceMesh 9from torch.distributed.tensor._random import ( 10 is_rng_supported_mesh, 11 TensorParallelRNGTracker, 12) 13from torch.distributed.tensor.parallel._utils import _validate_tp_mesh_dim 14from torch.distributed.tensor.parallel.style import ParallelStyle 15 16 17__all__ = [ 18 "parallelize_module", 19] 20 21 22def parallelize_module( # type: ignore[return] 23 module: nn.Module, 24 device_mesh: DeviceMesh, 25 parallelize_plan: Union[ParallelStyle, Dict[str, ParallelStyle]], 26) -> nn.Module: 27 """ 28 Apply Tensor Parallelism in PyTorch by parallelizing modules or sub-modules based on a user-specified plan. 29 30 We parallelize module or sub_modules based on a parallelize_plan. The parallelize_plan contains 31 :class:`ParallelStyle`, which indicates how user wants the module or sub_module 32 to be parallelized. 33 34 User can also specify different parallel style per module fully qualified name (FQN). 35 36 Note that ``parallelize_module`` only accepts a 1-D :class:`DeviceMesh`, if you have a 2-D or N-D :class:`DeviceMesh`, 37 slice the DeviceMesh to a 1-D sub DeviceMesh first then pass to this API(i.e. ``device_mesh[\"tp\"]``) 38 39 Args: 40 module (:class:`nn.Module`): 41 Module to be parallelized. 42 device_mesh (:class:`DeviceMesh`): 43 Object which describes the mesh topology 44 of devices for the DTensor. 45 parallelize_plan (Union[:class:`ParallelStyle`, Dict[str, :class:`ParallelStyle`]]): 46 The plan used to parallelize the module. It can be either a 47 :class:`ParallelStyle` object which contains how 48 we prepare input/output for Tensor Parallelism or it can be a 49 dict of module FQN and its corresponding :class:`ParallelStyle` object. 50 Return: 51 A :class:`nn.Module` object parallelized. 52 53 Example:: 54 >>> # xdoctest: +SKIP("distributed") 55 >>> from torch.distributed.tensor.parallel import parallelize_module, ColwiseParallel 56 >>> from torch.distributed.device_mesh import init_device_mesh 57 >>> 58 >>> # Define the module. 59 >>> m = Model(...) 60 >>> tp_mesh = init_device_mesh("cuda", (8,)) 61 >>> m = parallelize_module(m, tp_mesh, {"w1": ColwiseParallel(), "w2": RowwiseParallel()}) 62 >>> 63 64 .. note:: For complex module architecture like Attention, MLP layers, we recommend composing 65 different ParallelStyles together (i.e. ``ColwiseParallel`` and ``RowwiseParallel``) and pass 66 as a parallelize_plan, to achieves the desired sharding computation. 67 """ 68 torch._C._log_api_usage_once("torch.distributed.tensor.parallel.parallelize_module") 69 70 _validate_tp_mesh_dim(device_mesh) 71 72 # instantiate a TP RNG state tracker if it's not there 73 if is_rng_supported_mesh(device_mesh) and not isinstance( 74 random._rng_tracker, TensorParallelRNGTracker 75 ): 76 random._rng_tracker = TensorParallelRNGTracker(device_mesh.device_type) 77 # TODO: we should allow user to pass in the default seed from a config 78 random._rng_tracker._manual_seed(device_mesh, base_seed=1234) 79 # By default we execute random ops in non-tensor-parallel region. If users want 80 # to execute in tensor-parallel region, they can manually set this field to True 81 # after parallelizing the model. 82 random._rng_tracker.distribute_region_enabled = False 83 84 if isinstance(parallelize_plan, ParallelStyle): 85 return parallelize_plan._apply(module, device_mesh) 86 elif isinstance(parallelize_plan, dict): 87 for module_path, parallelize_style in parallelize_plan.items(): 88 path_splits = module_path.split(".") 89 if len(path_splits) == 0: 90 raise ValueError( 91 "Expect module path to be non-empty, but got empty string!" 92 ) 93 while path_splits: 94 atom = path_splits.pop(0) 95 matched_children = filter( 96 # `t[0]` is child name 97 lambda t: fnmatch(t[0], atom), 98 module.named_children(), 99 ) 100 # apply the plan to all matched submodules 101 for _, submodule in matched_children: 102 if path_splits: 103 # we haven't reached the leaf, apply in dict style 104 leaf_path = ".".join( 105 path_splits 106 ) # rest of the path after `atom` 107 parallelize_module( 108 submodule, device_mesh, {leaf_path: parallelize_style} 109 ) 110 else: 111 # otherwise, directly apply style to this submodule 112 parallelize_module(submodule, device_mesh, parallelize_style) 113 return module 114 else: 115 raise TypeError( # pyre-ignore[7] 116 "Expect Union[ParallelStyle, Dict[str, ParallelStyle]] for" 117 f" parallelize_plan, {type(parallelize_plan)} found!" 118 ) 119