1# mypy: allow-untyped-defs 2# Copyright (c) Meta Platforms, Inc. and affiliates 3import logging 4from typing import Any, Dict, List, Optional, Tuple 5 6import torch 7from torch.fx.node import map_aggregate 8from torch.utils._pytree import tree_flatten, tree_unflatten 9 10 11__all__ = [ 12 "TensorChunkSpec", 13 "split_args_kwargs_into_chunks", 14 "merge_chunks", 15] 16 17logger = logging.getLogger(__name__) 18 19""" 20_debug_mask_minibatches specifies to send masked versions of the mini-batch 21through instead of micro-batch slices--this can be used for more stable 22numerical testing (see [A Note About Correctness Testing]) 23""" 24_debug_mask_minibatches = False 25 26 27class _CustomReducer: 28 """ 29 Custom reducer class that can be used to specify a custom operation that 30 reduces losses of multiple microbatches into one value. 31 32 Example: 33 >>> # xdoctest: +SKIP 34 >>> sum_reducer = _CustomReducer( 35 >>> torch.tensor(0.0), 36 >>> lambda a, b: a + b 37 >>> ) 38 """ 39 40 def __init__(self, init_value, reduce_fn): 41 self.init_value = init_value 42 self.reduce_fn = reduce_fn 43 44 45class _LossReducer(_CustomReducer): 46 pass 47 48 49sum_reducer = _LossReducer(torch.tensor(0.0), lambda a, b: a + b) 50 51# Default chunking dimension is 0. This is used for the case where the user did 52# not specify a chunking dimension. 53DEFAULT_CHUNK_DIM = 0 54 55 56class TensorChunkSpec: 57 """ 58 Class used to specify chunking of inputs 59 """ 60 61 def __init__(self, split_dim): 62 self.split_dim = split_dim 63 64 split_dim: int 65 66 def __repr__(self): 67 return ( 68 f"{self.__class__.__module__}.{self.__class__.__name__}({self.split_dim})" 69 ) 70 71 def __str__(self): 72 return f"TensorChunkSpec({self.split_dim})" 73 74 @staticmethod 75 def from_tuple( 76 chunk_dims: Tuple[int, ...], 77 ): 78 """ 79 A helper for creating a tuple of `TensorChunkSpec` from a tuple of chunk 80 dimensions (int's). 81 Example: 82 >>> # xdoctest: +SKIP 83 >>> # There are three positional arguments to the model, and 84 >>> # we are chunking them along dimension 0, 0 and 1, respectively 85 >>> args_chunk_spec = TensorChunkSpec.from_tuple((0, 0, 1)) 86 """ 87 args_chunk_spec = map_aggregate( 88 chunk_dims, 89 lambda dim: TensorChunkSpec(dim), # type: ignore[arg-type,return-value] 90 ) 91 return args_chunk_spec 92 93 @staticmethod 94 def from_dict( 95 chunk_dims: Dict[str, int], 96 ): 97 """ 98 A helper for creating a dictionary of `TensorChunkSpec` from a 99 dictionary of chunk dimensions (int's). 100 Example: 101 >>> # xdoctest: +SKIP 102 >>> # Chunk dimension 0 for the "id" argument, 1 for the "mask" argument 103 >>> kwargs_chunk_spec = TensorChunkSpec.from_dict({"id": 0, "mask": 1}) 104 """ 105 kwargs_chunk_spec = map_aggregate( 106 chunk_dims, 107 lambda dim: TensorChunkSpec(dim), # type: ignore[arg-type,return-value] 108 ) 109 return kwargs_chunk_spec 110 111 112# Class used to specify replication of inputs 113class _Replicate: 114 pass 115 116 117def _shard_dict_of_args( 118 args_dict, 119 args_chunk_spec, 120 num_chunks, 121): 122 """ 123 Given a dictionary of args, and a dictionary of chunking specs, shard the 124 args according to the chunking specs. 125 126 Args: 127 args_dict: Dictionary of args 128 args_chunk_spec: Dictionary of chunking specs 129 num_chunks: Number of chunks to shard the args into 130 131 Returns: 132 args_split: List of sharded args 133 """ 134 # Stage 1+2: flatten and shard/replicate 135 136 # args_sharded_replicated : [num args, num flat values, num chunks] 137 args_sharded_replicated = {} 138 arg_specs = [] 139 140 real_num_chunks = num_chunks 141 first_tensor = True 142 143 assert len(args_dict) == len( 144 args_chunk_spec 145 ), f"args_dict.keys() = {list(args_dict.keys())} args_chunk_spec.keys() = {list(args_chunk_spec.keys())}" 146 147 for arg_key, arg in args_dict.items(): 148 flat, spec = tree_flatten(arg) 149 arg_specs.append(spec) 150 151 chunk_spec = args_chunk_spec[arg_key] 152 assert chunk_spec is not None # Should have been set by caller 153 chunk_spec_flat, _ = tree_flatten(chunk_spec) 154 if len(flat) != len(chunk_spec_flat): 155 raise ValueError( 156 f"Argument value {arg} did not have the same number of " 157 f"values as as chunk spec {chunk_spec}" 158 ) 159 160 sharded_arg_flat = [] 161 162 for v, chunk_v in zip(flat, chunk_spec_flat): 163 if chunk_v is _Replicate or not isinstance(v, torch.Tensor): 164 sharded_arg_flat.append([v] * real_num_chunks) 165 elif isinstance(chunk_v, TensorChunkSpec): 166 # TODO: check type of v. If it's a tensor, use chunk (or debug mask). 167 # If it's a collection type, split it as you would expect. Otherwise, 168 # Throw an error 169 assert isinstance(v, torch.Tensor), f"{v} is not a tensor" 170 171 v_split_dim_size = v.size(chunk_v.split_dim) 172 if v_split_dim_size < real_num_chunks: 173 if first_tensor: 174 # We can only adjust number of chunks when we hit this 175 # issue at the first tensor encountered 176 logger.warning( 177 f"Tensor size on chunking dimension is {v_split_dim_size}, " # noqa: G004 178 f"downsizing the number of chunks from {num_chunks} to {v_split_dim_size}." 179 ) 180 real_num_chunks = v_split_dim_size 181 else: 182 raise RuntimeError( 183 f"Arg {arg_key} on chunking dimension has a size of {v_split_dim_size}, " 184 f"smaller than the number of chunks {num_chunks}. " 185 "PiPPy cannot reduce the number of chunks because " 186 "other arguments have bigger chunk-dimension sizes. " 187 "Please adjust your num_chunks setting." 188 ) 189 190 chunk_tensors = torch.tensor_split( 191 v, real_num_chunks, chunk_v.split_dim 192 ) 193 194 if _debug_mask_minibatches: 195 expanded_chunks = [] 196 197 split_dim_idx = 0 198 for chunk_tensor in chunk_tensors: 199 new_val = torch.zeros_like(v) 200 upper_idx = split_dim_idx + chunk_tensor.size(chunk_v.split_dim) 201 202 slice_indices = [slice(None, None, None)] * new_val.ndim 203 slice_indices[chunk_v.split_dim] = slice( 204 split_dim_idx, upper_idx 205 ) 206 new_val[slice_indices] = chunk_tensor 207 208 expanded_chunks.append(new_val) 209 210 split_dim_idx += chunk_tensor.size(chunk_v.split_dim) 211 212 sharded_arg_flat.append(expanded_chunks) 213 else: 214 sharded_arg_flat.append(chunk_tensors) # type: ignore[arg-type] 215 216 first_tensor = False 217 else: 218 raise TypeError(f"Unrecognized chunk spec: {chunk_v}") 219 220 args_sharded_replicated[arg_key] = sharded_arg_flat 221 222 # chunks_flat : [num chunks, num args, num flat values] 223 chunks_flat = [] 224 for chunk_idx in range(real_num_chunks): 225 chunk_args = {} 226 for key, arg in args_sharded_replicated.items(): 227 arg_single_chunk = [] 228 for v_flat in arg: 229 arg_single_chunk.append(v_flat[chunk_idx]) 230 chunk_args[key] = arg_single_chunk 231 chunks_flat.append(chunk_args) 232 233 # args_split : [num chunks, num args] 234 args_split = [] 235 236 for chunk in chunks_flat: 237 per_chunk_args = {} 238 assert len(arg_specs) == len(chunk) 239 for (key, arg), arg_spec in zip(chunk.items(), arg_specs): 240 per_chunk_args[key] = tree_unflatten(arg, arg_spec) 241 args_split.append(per_chunk_args) 242 243 return args_split 244 245 246def split_args_kwargs_into_chunks( 247 args: Tuple[Any, ...], 248 kwargs: Optional[Dict[str, Any]], 249 chunks: int, 250 args_chunk_spec: Optional[Tuple[TensorChunkSpec, ...]] = None, 251 kwargs_chunk_spec: Optional[Dict[str, TensorChunkSpec]] = None, 252) -> Tuple[List[Tuple], List[Dict]]: 253 """ 254 Given a sequence of args and kwargs, split them into a number of chunks 255 according to their respective chunking specs. 256 257 Args: 258 args: Tuple of args 259 kwargs: Dict of kwargs 260 chunks: Number of chunks to split the args and kwargs into 261 args_chunk_spec: chunking specs for args, in same shape as args 262 kwargs_chunk_spec: chunking specs for kwargs, in same shape as kwargs 263 264 Returns: 265 args_split: List of sharded args 266 kwargs_split: List of sharded kwargs 267 """ 268 # Given `args` and `kwargs`, we want to yield a set of `chunks` args and kwargs such that 269 # the constituent Tensor values have been sharded/replicated according to the `args_chunk_spec` 270 # and `kwargs_chunk_spec` specifications. The steps are as follows: 271 # 272 # 1. Use pytree.tree_flatten to flatten each arg and its spec into nto a 1d array of values. 273 # To use a running example: suppose our inputs look like 274 # 275 # args = ([A, [B, C]], D) args_spec = ([None, [None, TensorChunkSpec]], None) 276 # (kwargs not shown but it's a similar process) 277 # 278 # Then for this step we would end up with 279 # 280 # args = ([A, B, C], D) args_spec = ([None, None, TensorChunkSpec], None) 281 # 282 # 2. Shard or replicate the arguments subject to the policy in the spec. Suppose chunks = 2 283 # 284 # args = ([[A, A], [B, B], [C_1, C_2]], [D, D]) 285 # 286 # 3. Rotate the nesting order such that chunks are the outer dimension 287 # 288 # args_chunks = [ 289 # ([A, B, C_1], D), 290 # ([A, B, C_2], D), 291 # ] 292 # 293 # 4. Unflatten each chunk according to the spec 294 # 295 # args_chunks = [ 296 # ([A, [B, C_1]], D), 297 # ([A, [B, C_2]], D), 298 # ] 299 300 # TODO: _debug_mask_minibatches 301 # Handle the case where kwargs is None 302 if kwargs is None: 303 kwargs = {} 304 305 # If user did not provide args_chunk_spec or kwargs_chunk_spec, we extend 306 # their format and use default chunking along dim 0 307 if args_chunk_spec is None: 308 args_chunk_spec = (TensorChunkSpec(DEFAULT_CHUNK_DIM),) * len(args) 309 310 if kwargs_chunk_spec is None: 311 kwargs_chunk_spec = dict.fromkeys(kwargs, TensorChunkSpec(DEFAULT_CHUNK_DIM)) 312 313 args_split_dict = _shard_dict_of_args( 314 dict(enumerate(args)), 315 dict(enumerate(args_chunk_spec)), 316 chunks, 317 ) 318 real_num_chunks = len(args_split_dict) 319 320 kwargs_split = _shard_dict_of_args( 321 kwargs, 322 kwargs_chunk_spec, 323 real_num_chunks, 324 ) 325 326 if len(kwargs_split) < real_num_chunks: 327 # In case kwargs are sharded into less chunks 328 # e.g. when `args` has no tensor, just values 329 real_num_chunks = len(kwargs_split) 330 # Re-shard args 331 args_split_dict = _shard_dict_of_args( 332 dict(enumerate(args)), 333 dict(enumerate(args_chunk_spec)), 334 real_num_chunks, 335 ) 336 337 if len(args_split_dict) != len(kwargs_split): 338 raise RuntimeError( 339 "args and kwargs are split into different number of chunks: " 340 f"{len(args_split_dict)}, {len(kwargs_split)}" 341 ) 342 343 args_split = [] 344 for chunk_args in args_split_dict: 345 args_split.append(tuple(chunk_args[i] for i in range(len(chunk_args)))) 346 347 return args_split, kwargs_split 348 349 350def merge_chunks( 351 chunks: List[Any], 352 chunk_spec, 353): 354 """ 355 Given a list of chunks, merge them into a single value according to 356 the chunk spec. 357 358 Args: 359 chunks: list of chunks 360 chunk_spec: Chunking spec for the chunks 361 362 Returns: 363 value: Merged value 364 """ 365 # This is essentially the inverse of `split_args_kwargs_into_chunks`, so the 366 # steps are similar to the steps in that function but in reverse. Given the 367 # input values: 368 # 369 # chunks = [ 370 # ([A, [B, C_1]], D), 371 # ([A, [B, C_2]], D), 372 # ] 373 # args_spec = ([None, [None, TensorChunkSpec]], None) 374 # 375 # 1. Flatten the chunks according to the chunk_spec 376 # 377 # chunks_flat = [ 378 # ([A, B, C_1], D), 379 # ([A, B, C_2], D), 380 # ] 381 # 382 # 2. Rotate the nesting order such that chunks are the inner dimension 383 # 384 # value_inner = ([A, B, [C_1, C_2]], D) 385 # 386 # 3. Concatenate sharded arguments 387 # 388 # value_combined = ([A, B, C], D) 389 # 390 # 4. Unflatten the combined args given the spec 391 # 392 # value = ([A, [B, C]], D) 393 394 # Preliminary: flatten the chunk spec 395 if chunk_spec is not None: 396 spec_flattened, flatten_spec = tree_flatten(chunk_spec) 397 else: 398 # If chunk_spec is not provided, we will merge chunks along the default dimension (0), for all output fields 399 # We obtain the output structure by flattening chunk 0 and generate the chunk_spec 400 chunk0_flat, flatten_spec = tree_flatten(chunks[0]) 401 spec_flattened = [TensorChunkSpec(DEFAULT_CHUNK_DIM)] * len(chunk0_flat) 402 403 # Stage 1: flatten chunks 404 # chunks_flattened : [num chunks, num args] 405 chunks_flattened = [] 406 407 for chunk in chunks: 408 chunk_flattened, _ = tree_flatten(chunk) 409 if len(chunk_flattened) != len(spec_flattened): 410 raise ValueError(f"Chunk {chunk} did not match chunk spec {chunk_spec}") 411 412 chunks_flattened.append(chunk_flattened) 413 414 # Stage 2 and 3: Rotate nesting order s.t. chunks are inner dimension and 415 # concatenate sharded operands 416 # args_flattened : [num args] 417 args_flattened = [] 418 for arg_idx, arg in enumerate(spec_flattened): 419 if isinstance(arg, TensorChunkSpec): 420 partial_values = [ 421 chunks_flattened[chunk_idx][arg_idx] 422 for chunk_idx in range(len(chunks_flattened)) 423 ] 424 425 if _debug_mask_minibatches: 426 # Infer size of individual chunks by running `tensor_split` again 427 overall_shape = partial_values[0].shape 428 for val in partial_values[1:]: 429 assert val.shape == overall_shape 430 meta_chunks = torch.tensor_split( 431 torch.empty(*overall_shape, device="meta"), 432 sections=len(partial_values), 433 dim=arg.split_dim, 434 ) 435 436 values_to_cat = [] 437 chunk_start_idx = 0 438 assert len(partial_values) == len(meta_chunks) 439 for partial_value, meta_chunk in zip(partial_values, meta_chunks): 440 chunk_end_idx = chunk_start_idx + meta_chunk.size(arg.split_dim) 441 442 slice_indices = [slice(None, None, None)] * partial_value.ndim 443 slice_indices[arg.split_dim] = slice(chunk_start_idx, chunk_end_idx) 444 sliced = partial_value[slice_indices] 445 values_to_cat.append(sliced) 446 447 chunk_start_idx = chunk_end_idx 448 449 else: 450 values_to_cat = partial_values 451 452 args_flattened.append(torch.cat(values_to_cat, dim=arg.split_dim)) 453 elif isinstance(arg, _CustomReducer): 454 reduced_val = arg.init_value 455 456 for chunk_idx in range(len(chunks_flattened)): 457 reduced_val = arg.reduce_fn( 458 reduced_val, chunks_flattened[chunk_idx][arg_idx] 459 ) 460 461 args_flattened.append(reduced_val) 462 else: 463 value = chunks_flattened[0][arg_idx] 464 for chunk_idx in range(1, len(chunks_flattened)): 465 assert chunks_flattened[chunk_idx][arg_idx] == value 466 args_flattened.append(value) 467 468 # Stage 4: Unflatten combined args 469 return tree_unflatten(args_flattened, flatten_spec) 470