1# mypy: allow-untyped-defs 2 3# Owner(s): ["oncall: distributed"] 4 5import copy 6from itertools import chain 7from typing import Any, Dict 8 9import torch 10import torch.nn as nn 11from torch.distributed._sharded_tensor import ShardedTensor 12from torch.distributed._state_dict_utils import _gather_state_dict 13from torch.distributed._tensor import DTensor 14from torch.distributed.checkpoint.state_dict import ( 15 _PG, 16 _STATE, 17 set_state_dict, 18 StateDictOptions, 19) 20 21 22class VerifyStateDictMixin: 23 def _compare_tensor(self, orig_tensor, dist_tensor, offload_to_cpu=False): 24 if isinstance(dist_tensor, (DTensor, ShardedTensor)): 25 dist_tensor = _gather_state_dict({"mykey": dist_tensor}).pop("mykey") 26 27 if offload_to_cpu: 28 orig_tensor = orig_tensor.cpu() 29 dist_tensor = dist_tensor.cpu() 30 self.assertTrue(isinstance(dist_tensor, torch.Tensor)) 31 self.assertTrue(torch.allclose(orig_tensor, dist_tensor)) 32 33 def _verify_msd( 34 self, 35 msd: Dict[str, Any], 36 dist_msd: Dict[str, Any], 37 options: StateDictOptions = StateDictOptions(), 38 offload_to_cpu=False, 39 ) -> None: 40 if not options.ignore_frozen_params: 41 self.assertEqual(len(msd), len(dist_msd)) 42 for fqn, param in msd.items(): 43 dist_param = dist_msd.get(fqn, None) 44 if not options.ignore_frozen_params: 45 self.assertIsNotNone(dist_param, f"{fqn=}") 46 try: 47 self._compare_tensor(param, dist_param, offload_to_cpu) 48 except AssertionError as e: 49 raise AssertionError( 50 f"{fqn} has mismatched value {param} {dist_param}" 51 ) from e 52 elif dist_param is None: 53 self.assertFalse(param.requires_grad, f"{fqn=}") 54 55 def _verify_osd( 56 self, 57 model: nn.Module, 58 optim: torch.optim.Optimizer, 59 osd: Dict[str, Any], 60 dist_osd: Dict[str, Any], 61 ) -> None: 62 params = list(chain.from_iterable(g["params"] for g in optim.param_groups)) 63 param_pid_mapping = dict(zip(params, range(len(params)))) 64 fqn_pid_mapping = {} 65 for fqn, param in model.named_parameters(): 66 pid = param_pid_mapping[param] 67 fqn_pid_mapping[fqn] = pid 68 fqn_pid_mapping[pid] = fqn 69 # Check optimizer_state_dict state 70 71 self.assertEqual(len(osd[_STATE]), len(dist_osd[_STATE])) 72 for pid, states in osd[_STATE].items(): 73 fqn = fqn_pid_mapping[pid] 74 dist_states = dist_osd[_STATE].get(fqn, None) 75 self.assertIsNotNone(dist_states, fqn) 76 self.assertEqual(len(states), len(dist_states)) 77 for key, state in states.items(): 78 dist_state = states.get(key, None) 79 self.assertIsNotNone(dist_state) 80 self._compare_tensor(state, dist_state) 81 82 # Check optimizer_state_dict param_group 83 old_dist_osd_pg = dist_osd[_PG] 84 if len(osd[_PG]) != len(dist_osd[_PG]): 85 self.assertTrue(len(dist_osd[_PG]) > len(osd[_PG])) 86 new_pg = copy.deepcopy(dist_osd[_PG][0]) 87 new_pg["params"] = [] 88 for dist_group in dist_osd[_PG]: 89 new_pg["params"].extend(dist_group["params"]) 90 dist_osd[_PG] = [new_pg] 91 92 self.assertEqual(len(osd[_PG]), len(dist_osd[_PG])) 93 for group, dist_group in zip(osd[_PG], dist_osd[_PG]): 94 self.assertEqual(len(group), len(dist_group)) 95 for key, value in group.items(): 96 # Below doesn't work because param_groups can have None 97 # values. 98 # dist_value = dist_group.get(key, None) 99 # self.assertIsNotNone(dist_value, (dist_group, group)) 100 dist_value = dist_group[key] 101 if key == "params": 102 fqns = [fqn_pid_mapping[pid] for pid in value] 103 self.assertEqual(sorted(fqns), sorted(dist_value)) 104 else: 105 self.assertEqual(value, dist_value) 106 dist_osd[_PG] = old_dist_osd_pg 107 108 def _verify_osd_by_load( 109 self, 110 model: nn.Module, 111 optim: torch.optim.Optimizer, 112 new_optim: torch.optim.Optimizer, 113 dist_osd: Dict[str, Any], 114 ) -> None: 115 new_dist_osd = _gather_state_dict(dist_osd) 116 set_state_dict( 117 model, 118 optimizers=new_optim, 119 model_state_dict={}, 120 optim_state_dict=new_dist_osd, 121 ) 122 self.assertEqual(optim.state_dict(), new_optim.state_dict()) 123