1# mypy: allow-untyped-defs 2# Copyright (c) Meta Platforms, Inc. and affiliates 3import logging 4from dataclasses import dataclass 5from typing import List, Tuple, Union 6 7import torch 8from torch import fx 9 10 11logger = logging.getLogger(__name__) 12 13 14def flatten_args_detach(args): 15 """ 16 Flatten the args into a list form and detach the tensors from computational graph. 17 """ 18 flat_detached_args = [] 19 20 def extract_tensor_args(a): 21 nonlocal flat_detached_args 22 if isinstance(a, torch.Tensor): 23 val = a.detach().requires_grad_(a.requires_grad) 24 flat_detached_args.append(val) 25 return val 26 else: 27 flat_detached_args.append(a) 28 return a 29 30 new_args = fx.node.map_aggregate( 31 args, 32 extract_tensor_args, 33 ) 34 35 return new_args, flat_detached_args 36 37 38def flatten_args(args): 39 """ 40 Flatten the args into a list form. 41 """ 42 flat_args = [] 43 44 def extract_tensor_args(a): 45 nonlocal flat_args 46 flat_args.append(a) 47 return a 48 49 fx.node.map_aggregate( 50 args, 51 extract_tensor_args, 52 ) 53 54 return flat_args 55 56 57class PipeliningShapeError(RuntimeError): 58 """Shape mismatch between configured and runtime values.""" 59 60 61def validate_tensor_metadata(desc, expected, given): 62 if not expected.shape == given.shape: 63 raise PipeliningShapeError( 64 f"{desc} has a shape mismatch: expected {expected.shape} actual {given.shape}" 65 ) 66 if not expected.dtype == given.dtype: 67 raise PipeliningShapeError( 68 f"{desc} has a dtype mismatch: expected {expected.dtype} actual {given.dtype}" 69 ) 70 if not expected.stride() == given.stride(): 71 raise PipeliningShapeError( 72 f"{desc} has a stride mismatch: expected {expected.stride()} actual {given.stride()}" 73 ) 74 75 76def validate_tensors_metadata( 77 desc, 78 expected_tensors: Union[List[torch.Tensor], Tuple[torch.Tensor, ...]], 79 actual_tensors: Union[List[torch.Tensor], Tuple[torch.Tensor, ...]], 80): 81 if len(expected_tensors) != len(actual_tensors): 82 raise PipeliningShapeError( 83 f"{desc}: Number of values ({len(actual_tensors)}) does not match expected number ({len(expected_tensors)})" 84 ) 85 for i in range(len(expected_tensors)): 86 validate_tensor_metadata( 87 f"{desc}: value {i}", expected_tensors[i], actual_tensors[i] 88 ) 89 90 91@dataclass 92class PipeInfo: 93 """ 94 Captures information for a pipeline (`Pipe` object). 95 """ 96 97 graph: fx.Graph 98 num_stages: int 99 has_loss_and_backward: bool 100