xref: /aosp_15_r20/external/pytorch/torch/distributed/pipelining/_utils.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2# Copyright (c) Meta Platforms, Inc. and affiliates
3import logging
4from dataclasses import dataclass
5from typing import List, Tuple, Union
6
7import torch
8from torch import fx
9
10
11logger = logging.getLogger(__name__)
12
13
14def flatten_args_detach(args):
15    """
16    Flatten the args into a list form and detach the tensors from computational graph.
17    """
18    flat_detached_args = []
19
20    def extract_tensor_args(a):
21        nonlocal flat_detached_args
22        if isinstance(a, torch.Tensor):
23            val = a.detach().requires_grad_(a.requires_grad)
24            flat_detached_args.append(val)
25            return val
26        else:
27            flat_detached_args.append(a)
28            return a
29
30    new_args = fx.node.map_aggregate(
31        args,
32        extract_tensor_args,
33    )
34
35    return new_args, flat_detached_args
36
37
38def flatten_args(args):
39    """
40    Flatten the args into a list form.
41    """
42    flat_args = []
43
44    def extract_tensor_args(a):
45        nonlocal flat_args
46        flat_args.append(a)
47        return a
48
49    fx.node.map_aggregate(
50        args,
51        extract_tensor_args,
52    )
53
54    return flat_args
55
56
57class PipeliningShapeError(RuntimeError):
58    """Shape mismatch between configured and runtime values."""
59
60
61def validate_tensor_metadata(desc, expected, given):
62    if not expected.shape == given.shape:
63        raise PipeliningShapeError(
64            f"{desc} has a shape mismatch: expected {expected.shape} actual {given.shape}"
65        )
66    if not expected.dtype == given.dtype:
67        raise PipeliningShapeError(
68            f"{desc} has a dtype mismatch: expected {expected.dtype} actual {given.dtype}"
69        )
70    if not expected.stride() == given.stride():
71        raise PipeliningShapeError(
72            f"{desc} has a stride mismatch: expected {expected.stride()} actual {given.stride()}"
73        )
74
75
76def validate_tensors_metadata(
77    desc,
78    expected_tensors: Union[List[torch.Tensor], Tuple[torch.Tensor, ...]],
79    actual_tensors: Union[List[torch.Tensor], Tuple[torch.Tensor, ...]],
80):
81    if len(expected_tensors) != len(actual_tensors):
82        raise PipeliningShapeError(
83            f"{desc}: Number of values ({len(actual_tensors)}) does not match expected number ({len(expected_tensors)})"
84        )
85    for i in range(len(expected_tensors)):
86        validate_tensor_metadata(
87            f"{desc}: value {i}", expected_tensors[i], actual_tensors[i]
88        )
89
90
91@dataclass
92class PipeInfo:
93    """
94    Captures information for a pipeline (`Pipe` object).
95    """
96
97    graph: fx.Graph
98    num_stages: int
99    has_loss_and_backward: bool
100