xref: /aosp_15_r20/external/pytorch/torch/testing/_internal/distributed/common_state_dict.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2
3# Owner(s): ["oncall: distributed"]
4
5import copy
6from itertools import chain
7from typing import Any, Dict
8
9import torch
10import torch.nn as nn
11from torch.distributed._sharded_tensor import ShardedTensor
12from torch.distributed._state_dict_utils import _gather_state_dict
13from torch.distributed._tensor import DTensor
14from torch.distributed.checkpoint.state_dict import (
15    _PG,
16    _STATE,
17    set_state_dict,
18    StateDictOptions,
19)
20
21
22class VerifyStateDictMixin:
23    def _compare_tensor(self, orig_tensor, dist_tensor, offload_to_cpu=False):
24        if isinstance(dist_tensor, (DTensor, ShardedTensor)):
25            dist_tensor = _gather_state_dict({"mykey": dist_tensor}).pop("mykey")
26
27        if offload_to_cpu:
28            orig_tensor = orig_tensor.cpu()
29            dist_tensor = dist_tensor.cpu()
30        self.assertTrue(isinstance(dist_tensor, torch.Tensor))
31        self.assertTrue(torch.allclose(orig_tensor, dist_tensor))
32
33    def _verify_msd(
34        self,
35        msd: Dict[str, Any],
36        dist_msd: Dict[str, Any],
37        options: StateDictOptions = StateDictOptions(),
38        offload_to_cpu=False,
39    ) -> None:
40        if not options.ignore_frozen_params:
41            self.assertEqual(len(msd), len(dist_msd))
42        for fqn, param in msd.items():
43            dist_param = dist_msd.get(fqn, None)
44            if not options.ignore_frozen_params:
45                self.assertIsNotNone(dist_param, f"{fqn=}")
46                try:
47                    self._compare_tensor(param, dist_param, offload_to_cpu)
48                except AssertionError as e:
49                    raise AssertionError(
50                        f"{fqn} has mismatched value {param} {dist_param}"
51                    ) from e
52            elif dist_param is None:
53                self.assertFalse(param.requires_grad, f"{fqn=}")
54
55    def _verify_osd(
56        self,
57        model: nn.Module,
58        optim: torch.optim.Optimizer,
59        osd: Dict[str, Any],
60        dist_osd: Dict[str, Any],
61    ) -> None:
62        params = list(chain.from_iterable(g["params"] for g in optim.param_groups))
63        param_pid_mapping = dict(zip(params, range(len(params))))
64        fqn_pid_mapping = {}
65        for fqn, param in model.named_parameters():
66            pid = param_pid_mapping[param]
67            fqn_pid_mapping[fqn] = pid
68            fqn_pid_mapping[pid] = fqn
69        # Check optimizer_state_dict state
70
71        self.assertEqual(len(osd[_STATE]), len(dist_osd[_STATE]))
72        for pid, states in osd[_STATE].items():
73            fqn = fqn_pid_mapping[pid]
74            dist_states = dist_osd[_STATE].get(fqn, None)
75            self.assertIsNotNone(dist_states, fqn)
76            self.assertEqual(len(states), len(dist_states))
77            for key, state in states.items():
78                dist_state = states.get(key, None)
79                self.assertIsNotNone(dist_state)
80                self._compare_tensor(state, dist_state)
81
82        # Check optimizer_state_dict param_group
83        old_dist_osd_pg = dist_osd[_PG]
84        if len(osd[_PG]) != len(dist_osd[_PG]):
85            self.assertTrue(len(dist_osd[_PG]) > len(osd[_PG]))
86            new_pg = copy.deepcopy(dist_osd[_PG][0])
87            new_pg["params"] = []
88            for dist_group in dist_osd[_PG]:
89                new_pg["params"].extend(dist_group["params"])
90            dist_osd[_PG] = [new_pg]
91
92        self.assertEqual(len(osd[_PG]), len(dist_osd[_PG]))
93        for group, dist_group in zip(osd[_PG], dist_osd[_PG]):
94            self.assertEqual(len(group), len(dist_group))
95            for key, value in group.items():
96                # Below doesn't work because param_groups can have None
97                # values.
98                # dist_value = dist_group.get(key, None)
99                # self.assertIsNotNone(dist_value, (dist_group, group))
100                dist_value = dist_group[key]
101                if key == "params":
102                    fqns = [fqn_pid_mapping[pid] for pid in value]
103                    self.assertEqual(sorted(fqns), sorted(dist_value))
104                else:
105                    self.assertEqual(value, dist_value)
106        dist_osd[_PG] = old_dist_osd_pg
107
108    def _verify_osd_by_load(
109        self,
110        model: nn.Module,
111        optim: torch.optim.Optimizer,
112        new_optim: torch.optim.Optimizer,
113        dist_osd: Dict[str, Any],
114    ) -> None:
115        new_dist_osd = _gather_state_dict(dist_osd)
116        set_state_dict(
117            model,
118            optimizers=new_optim,
119            model_state_dict={},
120            optim_state_dict=new_dist_osd,
121        )
122        self.assertEqual(optim.state_dict(), new_optim.state_dict())
123