1# Owner(s): ["oncall: distributed"] 2 3import unittest 4from typing import List, Optional, Tuple 5 6import torch 7import torch.distributed 8import torch.nn as nn 9import torch.nn.functional as F 10from torch import Tensor 11from torch.optim import Adam, AdamW, SGD 12from torch.testing._internal.common_utils import run_tests, TestCase 13 14 15class MyModule(torch.nn.Module): 16 def __init__(self) -> None: 17 super().__init__() 18 torch.manual_seed(0) 19 self.lin1 = nn.Linear(3, 3, bias=False) 20 self.lin2 = nn.Linear(3, 3, bias=False) 21 22 def forward(self, t1): 23 return self.lin2(F.relu(self.lin1(t1))) 24 25 26# dummy class to showcase custom optimizer registration with functional wrapper 27class MyDummyFnOptimizer: 28 def __init__( 29 self, 30 params: List[Tensor], 31 lr: float = 1e-3, 32 betas: Tuple[float, float] = (0.9, 0.999), 33 eps: float = 1e-6, 34 weight_decay: float = 0.0, 35 _allow_empty_param_list: bool = False, 36 ): 37 if not 0.0 <= lr: 38 raise ValueError(f"Invalid learning rate: {lr}") 39 if not 0.0 <= eps: 40 raise ValueError(f"Invalid epsilon value: {eps}") 41 if not 0.0 <= betas[0] < 1.0: 42 raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}") 43 if not 0.0 <= betas[1] < 1.0: 44 raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}") 45 if not 0.0 < weight_decay: 46 raise ValueError(f"Invalid weight_decay value: {weight_decay}") 47 48 self.defaults = { 49 "lr": lr, 50 "eps": eps, 51 "beta1": betas[0], 52 "beta2": betas[1], 53 "weight_decay": weight_decay, 54 } 55 56 if len(params) == 0 and not _allow_empty_param_list: 57 raise ValueError("optimizer got an empty parameter list") 58 59 def step_param(self, param: Tensor, grad: Optional[Tensor]): 60 # call the custom optimizer step_param implementation 61 with torch.no_grad(): 62 raise RuntimeError( 63 "MyDummyFnOptimizer does not support step_param() as of now" 64 ) 65 66 def step(self, gradients: List[Optional[Tensor]]): 67 # call the custom optimizer step implementation 68 with torch.no_grad(): 69 raise RuntimeError("MyDummyFnOptimizer does not support step() as of now") 70 71 72if torch.distributed.is_available(): 73 from torch.distributed.optim.utils import ( 74 functional_optim_map, 75 register_functional_optim, 76 ) 77 78 79@unittest.skipIf( 80 not torch.distributed.is_available(), "These are testing distributed functions" 81) 82class TestFunctionalOptimParity(TestCase): 83 def _validate_parameters(self, params_1, params_2): 84 for p1, p2 in zip(params_1, params_2): 85 self.assertEqual(p1, p2) 86 87 # Dynamo fails at compiling this for python 3.8/3.11 88 # Since it passes while compiling the actual code under test 89 # we disable dynamo here. 90 @torch._disable_dynamo(recursive=False) 91 def _test_functional_optim_parity(self, optim_cls, *args, **kwargs): 92 module_optim = MyModule() 93 module_functional = MyModule() 94 optim_params = module_optim.parameters() 95 functional_params = module_functional.parameters() 96 optim = optim_cls(optim_params, *args, **kwargs) 97 functional_optim_cls = functional_optim_map.get(optim_cls, None) 98 if not functional_optim_cls: 99 raise ValueError(f"Functional optimizer not implemented for {optim_cls}") 100 optim_functional = functional_optim_cls( 101 [], *args, **kwargs, _allow_empty_param_list=True 102 ) 103 if not hasattr(optim_functional, "step_param"): 104 raise ValueError( 105 f"Functional optimizer class {optim_functional} must implement step_param method." 106 ) 107 108 # Initial weights should match 109 self._validate_parameters( 110 module_optim.parameters(), module_functional.parameters() 111 ) 112 # Save old parameters to verify optimizer modifies them. 113 old_module_optim_params = [ 114 param.clone().detach() for param in module_optim.parameters() 115 ] 116 old_module_functional_params = [ 117 param.clone().detach() for param in module_functional.parameters() 118 ] 119 120 t1 = torch.randn(3, 3) 121 for _ in range(10): 122 module_optim.zero_grad() 123 module_functional.zero_grad() 124 # Forward + Backward 125 optim_out = module_optim(t1).sum() 126 functional_out = module_functional(t1).sum() 127 optim_out.backward() 128 functional_out.backward() 129 # Optimizer step 130 optim.step() 131 # Functional optimizer step_param 132 for param in module_functional.parameters(): 133 grad = param.grad 134 optim_functional.step_param(param, grad) 135 136 # Validate parameters are equal 137 for optim_param, functional_param in zip( 138 module_optim.parameters(), module_functional.parameters() 139 ): 140 self.assertEqual(optim_param, functional_param) 141 # Validate parameters are modified. 142 for i, (optim_param, functional_param) in enumerate( 143 zip(module_optim.parameters(), module_functional.parameters()) 144 ): 145 self.assertNotEqual(old_module_optim_params[i], optim_param) 146 self.assertNotEqual(old_module_functional_params[i], functional_param) 147 148 def _test_functional_optim_registration(self): 149 fn_map_key = "MyDummyFnOptimizer" 150 fn_optim = MyDummyFnOptimizer 151 register_functional_optim(fn_map_key, fn_optim) 152 functional_optim_cls = functional_optim_map.get(fn_map_key, None) 153 if not functional_optim_cls: 154 raise ValueError(f"Functional optimizer not registered for {fn_map_key}") 155 156 def test_functional_optim_registration(self): 157 self._test_functional_optim_registration() 158 159 def test_functional_optim_parity_sgd(self): 160 self._test_functional_optim_parity(SGD, 1e-2, momentum=0.9, weight_decay=0.01) 161 162 def test_functional_optim_parity_adam(self): 163 self._test_functional_optim_parity(Adam, 1e-2, betas=(0.9, 0.999), eps=1e-6) 164 165 def test_functional_optim_parity_adam_w(self): 166 self._test_functional_optim_parity(AdamW, 1e-2, betas=(0.9, 0.999), eps=1e-6) 167 168 169if __name__ == "__main__": 170 run_tests() 171