1# Owner(s): ["oncall: distributed"] 2 3# Copyright (c) Meta Platforms, Inc. and affiliates. 4# All rights reserved. 5# 6# This source code is licensed under the BSD-style license found in the 7# LICENSE file in the root directory of this source tree. 8 9 10import unittest 11from copy import deepcopy 12 13import torch 14import torch.nn as nn 15from torch.distributed.optim import ( 16 _apply_optimizer_in_backward, 17 _get_in_backward_optimizers, 18) 19 20 21# TODO (rohan-varma): Add FSDP & DDP tests once supported 22 23 24def _validate_params(params_list, fn): 25 ref_params = params_list[0] 26 for param_list in params_list[1:]: 27 for p1, p2 in zip(ref_params, param_list): 28 fn(p1, p2) 29 30 31class ApplyOverlappedOptimizerTest(unittest.TestCase): 32 def _run_training_loop_and_validate(self, inp, models, optimizers): 33 for i in range(6): 34 for model in models: 35 model(inp).sum().backward() 36 for opt in optimizers: 37 opt.step() 38 39 with self.subTest(i): 40 _validate_params( 41 [model.parameters() for model in models], 42 torch.testing.assert_allclose, 43 ) 44 45 for opt in optimizers: 46 opt.zero_grad(set_to_none=True) 47 48 def _test_apply_optimizer_in_backward(self, share_params) -> None: 49 weight_optimizer_kwargs = {"lr": 1.0} 50 bias_optimizer_kwargs = {"lr": 0.5} 51 model = nn.Sequential(nn.Linear(10, 10), nn.Linear(10, 10)) 52 if share_params: 53 model[0].weight = model[1].weight 54 55 # Use different optimizers for weights & biases. 56 weights = [m.weight for m in model] 57 biases = [m.bias for m in model] 58 optim_weight = torch.optim.SGD(weights, **weight_optimizer_kwargs) 59 optim_bias = torch.optim.SGD(biases, **bias_optimizer_kwargs) 60 model_with_opt_in_bwd = deepcopy(model) 61 62 # Apply different optimizer in backwards for weights and biases. 63 _apply_optimizer_in_backward( 64 torch.optim.SGD, 65 [m.weight for m in model_with_opt_in_bwd], 66 optimizer_kwargs=weight_optimizer_kwargs, 67 ) 68 69 _apply_optimizer_in_backward( 70 torch.optim.SGD, 71 [m.bias for m in model_with_opt_in_bwd], 72 optimizer_kwargs=bias_optimizer_kwargs, 73 ) 74 75 _validate_params( 76 [ 77 model.parameters(), 78 model_with_opt_in_bwd.parameters(), 79 ], 80 torch.testing.assert_allclose, 81 ) 82 83 self._run_training_loop_and_validate( 84 torch.randn(4, 10), 85 [model, model_with_opt_in_bwd], 86 [optim_weight, optim_bias], 87 ) 88 89 def test_apply_optimizer_in_backward(self) -> None: 90 self._test_apply_optimizer_in_backward(share_params=False) 91 92 def test_apply_optimizer_in_backward_shared_params(self) -> None: 93 self._test_apply_optimizer_in_backward(share_params=True) 94 95 def test_no_register_hook(self): 96 model_with_hook = nn.Sequential(nn.Linear(10, 10), nn.Linear(10, 10)) 97 initial_model = deepcopy(model_with_hook) 98 model_no_hook = deepcopy(model_with_hook) 99 _apply_optimizer_in_backward( 100 torch.optim.SGD, 101 model_with_hook.parameters(), 102 optimizer_kwargs={"lr": 0.03}, 103 ) 104 _apply_optimizer_in_backward( 105 torch.optim.SGD, 106 model_no_hook.parameters(), 107 optimizer_kwargs={"lr": 0.03}, 108 register_hook=False, 109 ) 110 inp = torch.randn(4, 10) 111 model_with_hook(inp).sum().backward() 112 model_no_hook(inp).sum().backward() 113 114 for p1, p2 in zip(model_with_hook.parameters(), initial_model.parameters()): 115 with self.assertRaises(AssertionError): 116 torch.testing.assert_allclose(p1, p2) 117 118 for p1, p2 in zip(model_no_hook.parameters(), initial_model.parameters()): 119 torch.testing.assert_allclose(p1, p2) 120 121 def test_multiple_optim_for_params(self) -> None: 122 model = nn.Sequential(nn.Linear(10, 10), nn.Linear(10, 10)) 123 opt_0_kwargs = {"lr": 0.03} 124 opt_1_kwargs = {"lr": 0.01} 125 opt_0 = torch.optim.SGD(model.parameters(), **opt_0_kwargs) 126 opt_1 = torch.optim.SGD(model.parameters(), **opt_1_kwargs) 127 model_with_opt_in_bwd = deepcopy(model) 128 _apply_optimizer_in_backward( 129 torch.optim.SGD, 130 model_with_opt_in_bwd.parameters(), 131 optimizer_kwargs=opt_0_kwargs, 132 ) 133 _apply_optimizer_in_backward( 134 torch.optim.SGD, 135 model_with_opt_in_bwd.parameters(), 136 optimizer_kwargs=opt_1_kwargs, 137 ) 138 self._run_training_loop_and_validate( 139 torch.randn(4, 10), 140 [model, model_with_opt_in_bwd], 141 [opt_0, opt_1], 142 ) 143 144 def test_get_optimizers_in_backward(self): 145 # Create a simple test model 146 class TestModel(torch.nn.Module): 147 def __init__(self) -> None: 148 super().__init__() 149 self.linear1 = torch.nn.Linear(10, 5) 150 self.linear2 = torch.nn.Linear(5, 2) 151 152 model = TestModel() 153 154 # Apply optimizers in backward 155 _apply_optimizer_in_backward(torch.optim.SGD, model.parameters(), {"lr": 0.01}) 156 in_backward_optims = _get_in_backward_optimizers(model) 157 self.assertEqual(len(list(model.parameters())), len(in_backward_optims)) 158 result = set(in_backward_optims) 159 expected = { 160 optim for p in model.parameters() for optim in p._in_backward_optimizers 161 } 162 self.assertEqual(result, expected) 163